Java code for mainline modules to log to statsd.

Adds support for generating app code to allow java mainline modules to
call the StatsLog.writeRaw public API. Supports primitives, enums,
attribution chains, MODE_BYTES. Does not support key value pairs,
worksource methods (because worksource uses hidden apis in the
worksource object).

Test: manually tested atom logging from DocumentsUI
Test: existing autogenerated code is not modified
Test: cts will follow
Bug: 126134616

Change-Id: Ia321cf2d9952e3875ed0c7a28db1f4113711513f
diff --git a/tools/stats_log_api_gen/main.cpp b/tools/stats_log_api_gen/main.cpp
index a5b56a4..4e6d073 100644
--- a/tools/stats_log_api_gen/main.cpp
+++ b/tools/stats_log_api_gen/main.cpp
@@ -25,6 +25,11 @@
 const string DEFAULT_MODULE_NAME = "DEFAULT";
 const string DEFAULT_CPP_NAMESPACE = "android,util";
 const string DEFAULT_CPP_HEADER_IMPORT = "statslog.h";
+const string DEFAULT_JAVA_PACKAGE = "android.util";
+const string DEFAULT_JAVA_CLASS = "StatsLogInternal";
+
+const int JAVA_MODULE_REQUIRES_FLOAT = 0x01;
+const int JAVA_MODULE_REQUIRES_ATTRIBUTION = 0x02;
 
 using android::os::statsd::Atom;
 
@@ -921,11 +926,350 @@
     }
 }
 
+static void write_java_helpers_for_module(
+        FILE * out,
+        const AtomDecl &attributionDecl,
+        const int requiredHelpers) {
+    fprintf(out, "    private static void copyInt(byte[] buff, int pos, int val) {\n");
+    fprintf(out, "        buff[pos] = (byte) (val);\n");
+    fprintf(out, "        buff[pos + 1] = (byte) (val >> 8);\n");
+    fprintf(out, "        buff[pos + 2] = (byte) (val >> 16);\n");
+    fprintf(out, "        buff[pos + 3] = (byte) (val >> 24);\n");
+    fprintf(out, "        return;\n");
+    fprintf(out, "    }\n");
+    fprintf(out, "\n");
+
+    fprintf(out, "    private static void copyLong(byte[] buff, int pos, long val) {\n");
+    fprintf(out, "        buff[pos] = (byte) (val);\n");
+    fprintf(out, "        buff[pos + 1] = (byte) (val >> 8);\n");
+    fprintf(out, "        buff[pos + 2] = (byte) (val >> 16);\n");
+    fprintf(out, "        buff[pos + 3] = (byte) (val >> 24);\n");
+    fprintf(out, "        buff[pos + 4] = (byte) (val >> 32);\n");
+    fprintf(out, "        buff[pos + 5] = (byte) (val >> 40);\n");
+    fprintf(out, "        buff[pos + 6] = (byte) (val >> 48);\n");
+    fprintf(out, "        buff[pos + 7] = (byte) (val >> 56);\n");
+    fprintf(out, "        return;\n");
+    fprintf(out, "    }\n");
+    fprintf(out, "\n");
+
+    if (requiredHelpers & JAVA_MODULE_REQUIRES_FLOAT) {
+        fprintf(out, "    private static void copyFloat(byte[] buff, int pos, float val) {\n");
+        fprintf(out, "        copyInt(buff, pos, Float.floatToIntBits(val));\n");
+        fprintf(out, "        return;\n");
+        fprintf(out, "    }\n");
+        fprintf(out, "\n");
+    }
+
+    if (requiredHelpers & JAVA_MODULE_REQUIRES_ATTRIBUTION) {
+        fprintf(out, "    private static void writeAttributionChain(byte[] buff, int pos");
+        for (auto chainField : attributionDecl.fields) {
+            fprintf(out, ", %s[] %s",
+                java_type_name(chainField.javaType), chainField.name.c_str());
+        }
+        fprintf(out, ") {\n");
+
+        const char* uidName = attributionDecl.fields.front().name.c_str();
+        const char* tagName = attributionDecl.fields.back().name.c_str();
+
+        // Write the first list begin.
+        fprintf(out, "        buff[pos] = LIST_TYPE;\n");
+        fprintf(out, "        buff[pos + 1] = (byte) (%s.length);\n", tagName);
+        fprintf(out, "        pos += LIST_TYPE_OVERHEAD;\n");
+
+        // Iterate through the attribution chain and write the nodes.
+        fprintf(out, "        for (int i = 0; i < %s.length; i++) {\n", tagName);
+        // Write the list begin.
+        fprintf(out, "            buff[pos] = LIST_TYPE;\n");
+        fprintf(out, "            buff[pos + 1] = %lu;\n", attributionDecl.fields.size());
+        fprintf(out, "            pos += LIST_TYPE_OVERHEAD;\n");
+
+        // Write the uid.
+        fprintf(out, "            buff[pos] = INT_TYPE;\n");
+        fprintf(out, "            copyInt(buff, pos + 1, %s[i]);\n", uidName);
+        fprintf(out, "            pos += INT_TYPE_SIZE;\n");
+
+        // Write the tag.
+        fprintf(out, "            String %sStr = (%s[i] == null) ? \"\" : %s[i];\n",
+                tagName, tagName, tagName);
+        fprintf(out, "            byte[] %sByte = %sStr.getBytes(UTF_8);\n", tagName, tagName);
+        fprintf(out, "            buff[pos] = STRING_TYPE;\n");
+        fprintf(out, "            copyInt(buff, pos + 1, %sByte.length);\n", tagName);
+        fprintf(out, "            System.arraycopy("
+                "%sByte, 0, buff, pos + STRING_TYPE_OVERHEAD, %sByte.length);\n",
+                tagName, tagName);
+        fprintf(out, "            pos += STRING_TYPE_OVERHEAD + %sByte.length;\n", tagName);
+        fprintf(out, "        }\n");
+        fprintf(out, "    }\n");
+        fprintf(out, "\n");
+    }
+}
+
+
+static int write_java_non_chained_method_for_module(
+        FILE* out,
+        const map<vector<java_type_t>, set<string>>& signatures_to_modules,
+        const string& moduleName
+        ) {
+    for (auto signature_to_modules_it = signatures_to_modules.begin();
+            signature_to_modules_it != signatures_to_modules.end(); signature_to_modules_it++) {
+        // Skip if this signature is not needed for the module.
+        if (!signature_needed_for_module(signature_to_modules_it->second, moduleName)) {
+            continue;
+        }
+
+        // Print method signature.
+        vector<java_type_t> signature = signature_to_modules_it->first;
+        fprintf(out, "    public static void write_non_chained(int code");
+        int argIndex = 1;
+        for (vector<java_type_t>::const_iterator arg = signature.begin();
+                arg != signature.end(); arg++) {
+            if (*arg == JAVA_TYPE_ATTRIBUTION_CHAIN) {
+                // Non chained signatures should not have attribution chains.
+                return 1;
+            } else if (*arg == JAVA_TYPE_KEY_VALUE_PAIR) {
+                // Module logging does not yet support key value pair.
+                return 1;
+            } else {
+                fprintf(out, ", %s arg%d", java_type_name(*arg), argIndex);
+            }
+            argIndex++;
+        }
+        fprintf(out, ") {\n");
+
+        fprintf(out, "        write(code");
+        argIndex = 1;
+        for (vector<java_type_t>::const_iterator arg = signature.begin();
+                arg != signature.end(); arg++) {
+            // First two args are uid and tag of attribution chain.
+            if (argIndex == 1) {
+                fprintf(out, ", new int[] {arg%d}", argIndex);
+            } else if (argIndex == 2) {
+                fprintf(out, ", new java.lang.String[] {arg%d}", argIndex);
+            } else {
+                fprintf(out, ", arg%d", argIndex);
+            }
+            argIndex++;
+        }
+        fprintf(out, ");\n");
+        fprintf(out, "    }\n");
+        fprintf(out, "\n");
+    }
+    return 0;
+}
+
+static int write_java_method_for_module(
+        FILE* out,
+        const map<vector<java_type_t>, set<string>>& signatures_to_modules,
+        const AtomDecl &attributionDecl,
+        const string& moduleName,
+        int* requiredHelpers
+        ) {
+
+    for (auto signature_to_modules_it = signatures_to_modules.begin();
+            signature_to_modules_it != signatures_to_modules.end(); signature_to_modules_it++) {
+        // Skip if this signature is not needed for the module.
+        if (!signature_needed_for_module(signature_to_modules_it->second, moduleName)) {
+            continue;
+        }
+
+        // Print method signature.
+        vector<java_type_t> signature = signature_to_modules_it->first;
+        fprintf(out, "    public static void write(int code");
+        int argIndex = 1;
+        for (vector<java_type_t>::const_iterator arg = signature.begin();
+                arg != signature.end(); arg++) {
+            if (*arg == JAVA_TYPE_ATTRIBUTION_CHAIN) {
+                for (auto chainField : attributionDecl.fields) {
+                    fprintf(out, ", %s[] %s",
+                        java_type_name(chainField.javaType), chainField.name.c_str());
+                }
+            } else if (*arg == JAVA_TYPE_KEY_VALUE_PAIR) {
+                // Module logging does not yet support key value pair.
+                return 1;
+            } else {
+                fprintf(out, ", %s arg%d", java_type_name(*arg), argIndex);
+            }
+            argIndex++;
+        }
+        fprintf(out, ") {\n");
+
+        // Calculate the size of the buffer.
+        fprintf(out, "        // Initial overhead of the list, timestamp, and atom tag.\n");
+        fprintf(out, "        int needed = LIST_TYPE_OVERHEAD + LONG_TYPE_SIZE + INT_TYPE_SIZE;\n");
+        argIndex = 1;
+        for (vector<java_type_t>::const_iterator arg = signature.begin();
+                arg != signature.end(); arg++) {
+            switch (*arg) {
+            case JAVA_TYPE_BOOLEAN:
+            case JAVA_TYPE_INT:
+            case JAVA_TYPE_FLOAT:
+            case JAVA_TYPE_ENUM:
+                fprintf(out, "        needed += INT_TYPE_SIZE;\n");
+                break;
+            case JAVA_TYPE_LONG:
+                // Longs take 9 bytes, 1 for the type and 8 for the value.
+                fprintf(out, "        needed += LONG_TYPE_SIZE;\n");
+                break;
+            case JAVA_TYPE_STRING:
+                // Strings take 5 metadata bytes + length of byte encoded string.
+                fprintf(out, "        if (arg%d == null) {\n", argIndex);
+                fprintf(out, "            arg%d = \"\";\n", argIndex);
+                fprintf(out, "        }\n");
+                fprintf(out, "        byte[] arg%dBytes= arg%d.getBytes(UTF_8);\n",
+                        argIndex, argIndex);
+                fprintf(out, "        needed += STRING_TYPE_OVERHEAD + arg%dBytes.length;\n",
+                        argIndex);
+                break;
+            case JAVA_TYPE_BYTE_ARRAY:
+                // Byte arrays take 5 metadata bytes + length of byte array.
+                fprintf(out, "        if (arg%d == null) {\n", argIndex);
+                fprintf(out, "            arg%d = new byte[0];\n", argIndex);
+                fprintf(out, "        }\n");
+                fprintf(out, "        needed += STRING_TYPE_OVERHEAD + arg%d.length;\n", argIndex);
+                break;
+            case JAVA_TYPE_ATTRIBUTION_CHAIN:
+            {
+                const char* uidName = attributionDecl.fields.front().name.c_str();
+                const char* tagName = attributionDecl.fields.back().name.c_str();
+                // Null checks on the params.
+                fprintf(out, "        if (%s == null) {\n", uidName);
+                fprintf(out, "            %s = new %s[0];\n", uidName,
+                        java_type_name(attributionDecl.fields.front().javaType));
+                fprintf(out, "        }\n");
+                fprintf(out, "        if (%s == null) {\n", tagName);
+                fprintf(out, "            %s = new %s[0];\n", tagName,
+                        java_type_name(attributionDecl.fields.back().javaType));
+                fprintf(out, "        }\n");
+
+                // First check that the lengths of the uid and tag arrays are the same.
+                fprintf(out, "        if (%s.length != %s.length) {\n", uidName, tagName);
+                fprintf(out, "            return;\n");
+                fprintf(out, "        }\n");
+                fprintf(out, "        int attrSize = LIST_TYPE_OVERHEAD;\n");
+                fprintf(out, "        for (int i = 0; i < %s.length; i++) {\n", tagName);
+                fprintf(out, "            String str%d = (%s[i] == null) ? \"\" : %s[i];\n",
+                        argIndex, tagName, tagName);
+                fprintf(out, "            int str%dlen = str%d.getBytes(UTF_8).length;\n",
+                        argIndex, argIndex);
+                fprintf(out,
+                        "            attrSize += "
+                        "LIST_TYPE_OVERHEAD + INT_TYPE_SIZE + STRING_TYPE_OVERHEAD + str%dlen;\n",
+                        argIndex);
+                fprintf(out, "        }\n");
+                fprintf(out, "        needed += attrSize;\n");
+                break;
+            }
+            default:
+                // Unsupported types: OBJECT, DOUBLE, KEY_VALUE_PAIR.
+                return 1;
+            }
+            argIndex++;
+        }
+
+        // Now we have the size that is needed. Check for overflow and return if needed.
+        fprintf(out, "        if (needed > MAX_EVENT_PAYLOAD) {\n");
+        fprintf(out, "            return;\n");
+        fprintf(out, "        }\n");
+
+        // Create new buffer, and associated data types.
+        fprintf(out, "        byte[] buff = new byte[needed];\n");
+        fprintf(out, "        int pos = 0;\n");
+
+        // Initialize the buffer with list data type.
+        fprintf(out, "        buff[pos] = LIST_TYPE;\n");
+        fprintf(out, "        buff[pos + 1] = %lu;\n", signature.size() + 2);
+        fprintf(out, "        pos += LIST_TYPE_OVERHEAD;\n");
+
+        // Write timestamp.
+        fprintf(out, "        long elapsedRealtime = SystemClock.elapsedRealtimeNanos();\n");
+        fprintf(out, "        buff[pos] = LONG_TYPE;\n");
+        fprintf(out, "        copyLong(buff, pos + 1, elapsedRealtime);\n");
+        fprintf(out, "        pos += LONG_TYPE_SIZE;\n");
+
+        // Write atom code.
+        fprintf(out, "        buff[pos] = INT_TYPE;\n");
+        fprintf(out, "        copyInt(buff, pos + 1, code);\n");
+        fprintf(out, "        pos += INT_TYPE_SIZE;\n");
+
+        // Write the args.
+        argIndex = 1;
+        for (vector<java_type_t>::const_iterator arg = signature.begin();
+                arg != signature.end(); arg++) {
+            switch (*arg) {
+            case JAVA_TYPE_BOOLEAN:
+                fprintf(out, "        buff[pos] = INT_TYPE;\n");
+                fprintf(out, "        copyInt(buff, pos + 1, arg%d? 1 : 0);\n", argIndex);
+                fprintf(out, "        pos += INT_TYPE_SIZE;\n");
+                break;
+            case JAVA_TYPE_INT:
+            case JAVA_TYPE_ENUM:
+                fprintf(out, "        buff[pos] = INT_TYPE;\n");
+                fprintf(out, "        copyInt(buff, pos + 1, arg%d);\n", argIndex);
+                fprintf(out, "        pos += INT_TYPE_SIZE;\n");
+                break;
+            case JAVA_TYPE_FLOAT:
+                *requiredHelpers |= JAVA_MODULE_REQUIRES_FLOAT;
+                fprintf(out, "        buff[pos] = FLOAT_TYPE;\n");
+                fprintf(out, "        copyFloat(buff, pos + 1, arg%d);\n", argIndex);
+                fprintf(out, "        pos += FLOAT_TYPE_SIZE;\n");
+                break;
+            case JAVA_TYPE_LONG:
+                fprintf(out, "        buff[pos] = LONG_TYPE;\n");
+                fprintf(out, "        copyLong(buff, pos + 1, arg%d);\n", argIndex);
+                fprintf(out, "        pos += LONG_TYPE_SIZE;\n");
+                break;
+            case JAVA_TYPE_STRING:
+                fprintf(out, "        buff[pos] = STRING_TYPE;\n");
+                fprintf(out, "        copyInt(buff, pos + 1, arg%dBytes.length);\n", argIndex);
+                fprintf(out, "        System.arraycopy("
+                        "arg%dBytes, 0, buff, pos + STRING_TYPE_OVERHEAD, arg%dBytes.length);\n",
+                        argIndex, argIndex);
+                fprintf(out, "        pos += STRING_TYPE_OVERHEAD + arg%dBytes.length;\n",
+                        argIndex);
+                break;
+            case JAVA_TYPE_BYTE_ARRAY:
+                fprintf(out, "        buff[pos] = STRING_TYPE;\n");
+                fprintf(out, "        copyInt(buff, pos + 1, arg%d.length);\n", argIndex);
+                fprintf(out, "        System.arraycopy("
+                        "arg%d, 0, buff, pos + STRING_TYPE_OVERHEAD, arg%d.length);\n",
+                        argIndex, argIndex);
+                fprintf(out, "        pos += STRING_TYPE_OVERHEAD + arg%d.length;\n", argIndex);
+                break;
+            case JAVA_TYPE_ATTRIBUTION_CHAIN:
+            {
+                *requiredHelpers |= JAVA_MODULE_REQUIRES_ATTRIBUTION;
+                const char* uidName = attributionDecl.fields.front().name.c_str();
+                const char* tagName = attributionDecl.fields.back().name.c_str();
+
+                fprintf(out, "        writeAttributionChain(buff, pos, %s, %s);\n",
+                        uidName, tagName);
+                fprintf(out, "        pos += attrSize;\n");
+                break;
+            }
+            default:
+                // Unsupported types: OBJECT, DOUBLE, KEY_VALUE_PAIR.
+                return 1;
+            }
+            argIndex++;
+        }
+
+        fprintf(out, "        StatsLog.writeRaw(buff, pos);\n");
+        fprintf(out, "    }\n");
+        fprintf(out, "\n");
+    }
+    return 0;
+}
+
 static void write_java_work_source_method(FILE* out,
-        const map<vector<java_type_t>, set<string>>& signatures_to_modules) {
+        const map<vector<java_type_t>, set<string>>& signatures_to_modules,
+        const string& moduleName) {
     fprintf(out, "\n    // WorkSource methods.\n");
     for (auto signature_to_modules_it = signatures_to_modules.begin();
             signature_to_modules_it != signatures_to_modules.end(); signature_to_modules_it++) {
+        // Skip if this signature is not needed for the module.
+        if (!signature_needed_for_module(signature_to_modules_it->second, moduleName)) {
+            continue;
+        }
         vector<java_type_t> signature = signature_to_modules_it->first;
         // Determine if there is Attribution in this signature.
         int attributionArg = -1;
@@ -948,7 +1292,9 @@
         }
 
         // Method header (signature)
-        fprintf(out, "    /** @hide */\n");
+        if (moduleName == DEFAULT_MODULE_NAME) {
+            fprintf(out, "    /** @hide */\n");
+        }
         fprintf(out, "    public static void write(int code");
         int argIndex = 1;
         for (vector<java_type_t>::const_iterator arg = signature.begin();
@@ -973,7 +1319,7 @@
             }
         }
         fprintf(out, ");\n");
-        fprintf(out, "        }\n"); // close flor-loop
+        fprintf(out, "        }\n"); // close for-loop
 
         // write() component.
         fprintf(out, "        ArrayList<WorkSource.WorkChain> workChains = ws.getWorkChains();\n");
@@ -994,6 +1340,67 @@
     }
 }
 
+static void write_java_atom_codes(FILE* out, const Atoms& atoms, const string& moduleName) {
+    fprintf(out, "    // Constants for atom codes.\n");
+
+    std::map<int, set<AtomDecl>::const_iterator> atom_code_to_non_chained_decl_map;
+    build_non_chained_decl_map(atoms, &atom_code_to_non_chained_decl_map);
+
+    // Print constants for the atom codes.
+    for (set<AtomDecl>::const_iterator atom = atoms.decls.begin();
+            atom != atoms.decls.end(); atom++) {
+        // Skip if the atom is not needed for the module.
+        if (!atom_needed_for_module(*atom, moduleName)) {
+            continue;
+        }
+        string constant = make_constant_name(atom->name);
+        fprintf(out, "\n");
+        fprintf(out, "    /**\n");
+        fprintf(out, "     * %s %s<br>\n", atom->message.c_str(), atom->name.c_str());
+        write_java_usage(out, "write", constant, *atom);
+        auto non_chained_decl = atom_code_to_non_chained_decl_map.find(atom->code);
+        if (non_chained_decl != atom_code_to_non_chained_decl_map.end()) {
+            write_java_usage(out, "write_non_chained", constant, *non_chained_decl->second);
+        }
+        if (moduleName == DEFAULT_MODULE_NAME) {
+            fprintf(out, "     * @hide\n");
+        }
+        fprintf(out, "     */\n");
+        fprintf(out, "    public static final int %s = %d;\n", constant.c_str(), atom->code);
+    }
+    fprintf(out, "\n");
+}
+
+static void write_java_enum_values(FILE* out, const Atoms& atoms, const string& moduleName) {
+    fprintf(out, "    // Constants for enum values.\n\n");
+    for (set<AtomDecl>::const_iterator atom = atoms.decls.begin();
+        atom != atoms.decls.end(); atom++) {
+        // Skip if the atom is not needed for the module.
+        if (!atom_needed_for_module(*atom, moduleName)) {
+            continue;
+        }
+        for (vector<AtomField>::const_iterator field = atom->fields.begin();
+            field != atom->fields.end(); field++) {
+            if (field->javaType == JAVA_TYPE_ENUM) {
+                fprintf(out, "    // Values for %s.%s\n", atom->message.c_str(),
+                    field->name.c_str());
+                for (map<int, string>::const_iterator value = field->enumValues.begin();
+                    value != field->enumValues.end(); value++) {
+                    if (moduleName == DEFAULT_MODULE_NAME) {
+                        fprintf(out, "    /** @hide */\n");
+                    }
+                    fprintf(out, "    public static final int %s__%s__%s = %d;\n",
+                        make_constant_name(atom->message).c_str(),
+                        make_constant_name(field->name).c_str(),
+                        make_constant_name(value->second).c_str(),
+                        value->first);
+                }
+                fprintf(out, "\n");
+            }
+        }
+    }
+}
+
 static int
 write_stats_log_java(FILE* out, const Atoms& atoms, const AtomDecl &attributionDecl)
 {
@@ -1012,64 +1419,87 @@
     fprintf(out, " * @hide\n");
     fprintf(out, " */\n");
     fprintf(out, "public class StatsLogInternal {\n");
-    fprintf(out, "    // Constants for atom codes.\n");
+    write_java_atom_codes(out, atoms, DEFAULT_MODULE_NAME);
 
-    std::map<int, set<AtomDecl>::const_iterator> atom_code_to_non_chained_decl_map;
-    build_non_chained_decl_map(atoms, &atom_code_to_non_chained_decl_map);
-
-    // Print constants for the atom codes.
-    for (set<AtomDecl>::const_iterator atom = atoms.decls.begin();
-            atom != atoms.decls.end(); atom++) {
-        string constant = make_constant_name(atom->name);
-        fprintf(out, "\n");
-        fprintf(out, "    /**\n");
-        fprintf(out, "     * %s %s<br>\n", atom->message.c_str(), atom->name.c_str());
-        write_java_usage(out, "write", constant, *atom);
-        auto non_chained_decl = atom_code_to_non_chained_decl_map.find(atom->code);
-        if (non_chained_decl != atom_code_to_non_chained_decl_map.end()) {
-            write_java_usage(out, "write_non_chained", constant, *non_chained_decl->second);
-        }
-        fprintf(out, "     * @hide\n");
-        fprintf(out, "     */\n");
-        fprintf(out, "    public static final int %s = %d;\n", constant.c_str(), atom->code);
-    }
-    fprintf(out, "\n");
-
-    // Print constants for the enum values.
-    fprintf(out, "    // Constants for enum values.\n\n");
-    for (set<AtomDecl>::const_iterator atom = atoms.decls.begin();
-        atom != atoms.decls.end(); atom++) {
-        for (vector<AtomField>::const_iterator field = atom->fields.begin();
-            field != atom->fields.end(); field++) {
-            if (field->javaType == JAVA_TYPE_ENUM) {
-                fprintf(out, "    // Values for %s.%s\n", atom->message.c_str(),
-                    field->name.c_str());
-                for (map<int, string>::const_iterator value = field->enumValues.begin();
-                    value != field->enumValues.end(); value++) {
-                    fprintf(out, "    /** @hide */\n");
-                    fprintf(out, "    public static final int %s__%s__%s = %d;\n",
-                        make_constant_name(atom->message).c_str(),
-                        make_constant_name(field->name).c_str(),
-                        make_constant_name(value->second).c_str(),
-                        value->first);
-                }
-                fprintf(out, "\n");
-            }
-        }
-    }
+    write_java_enum_values(out, atoms, DEFAULT_MODULE_NAME);
 
     // Print write methods
     fprintf(out, "    // Write methods\n");
     write_java_method(out, "write", atoms.signatures_to_modules, attributionDecl);
     write_java_method(out, "write_non_chained", atoms.non_chained_signatures_to_modules,
             attributionDecl);
-    write_java_work_source_method(out, atoms.signatures_to_modules);
+    write_java_work_source_method(out, atoms.signatures_to_modules, DEFAULT_MODULE_NAME);
 
     fprintf(out, "}\n");
 
     return 0;
 }
 
+// TODO: Merge this with write_stats_log_java so that we can get rid of StatsLogInternal JNI.
+static int
+write_stats_log_java_for_module(FILE* out, const Atoms& atoms, const AtomDecl &attributionDecl,
+                     const string& moduleName, const string& javaClass, const string& javaPackage)
+{
+    // Print prelude
+    fprintf(out, "// This file is autogenerated\n");
+    fprintf(out, "\n");
+    fprintf(out, "package %s;\n", javaPackage.c_str());
+    fprintf(out, "\n");
+    fprintf(out, "import static java.nio.charset.StandardCharsets.UTF_8;\n");
+    fprintf(out, "\n");
+    fprintf(out, "import android.util.StatsLog;\n");
+    fprintf(out, "import android.os.SystemClock;\n");
+    fprintf(out, "\n");
+    fprintf(out, "import java.util.ArrayList;\n");
+    fprintf(out, "\n");
+    fprintf(out, "\n");
+    fprintf(out, "/**\n");
+    fprintf(out, " * Utility class for logging statistics events.\n");
+    fprintf(out, " */\n");
+    fprintf(out, "public class %s {\n", javaClass.c_str());
+
+    // TODO: ideally these match with the native values (and automatically change if they change).
+    fprintf(out, "    private static final int LOGGER_ENTRY_MAX_PAYLOAD = 4068;\n");
+    fprintf(out,
+            "    private static final int MAX_EVENT_PAYLOAD = LOGGER_ENTRY_MAX_PAYLOAD - 4;\n");
+    // Value types. Must match with EventLog.java and log.h.
+    fprintf(out, "    private static final byte INT_TYPE = 0;\n");
+    fprintf(out, "    private static final byte LONG_TYPE = 1;\n");
+    fprintf(out, "    private static final byte STRING_TYPE = 2;\n");
+    fprintf(out, "    private static final byte LIST_TYPE = 3;\n");
+    fprintf(out, "    private static final byte FLOAT_TYPE = 4;\n");
+
+    // Size of each value type.
+    // Booleans, ints, floats, and enums take 5 bytes, 1 for the type and 4 for the value.
+    fprintf(out, "    private static final int INT_TYPE_SIZE = 5;\n");
+    fprintf(out, "    private static final int FLOAT_TYPE_SIZE = 5;\n");
+    // Longs take 9 bytes, 1 for the type and 8 for the value.
+    fprintf(out, "    private static final int LONG_TYPE_SIZE = 9;\n");
+    // Strings take 5 metadata bytes: 1 byte is for the type, 4 are for the length.
+    fprintf(out, "    private static final int STRING_TYPE_OVERHEAD = 5;\n");
+    fprintf(out, "    private static final int LIST_TYPE_OVERHEAD = 2;\n");
+
+    write_java_atom_codes(out, atoms, moduleName);
+
+    write_java_enum_values(out, atoms, moduleName);
+
+    int errors = 0;
+    int requiredHelpers = 0;
+    // Print write methods
+    fprintf(out, "    // Write methods\n");
+    errors += write_java_method_for_module(out, atoms.signatures_to_modules, attributionDecl,
+            moduleName, &requiredHelpers);
+    errors += write_java_non_chained_method_for_module(out, atoms.non_chained_signatures_to_modules,
+            moduleName);
+
+    fprintf(out, "    // Helper methods for copying primitives\n");
+    write_java_helpers_for_module(out, attributionDecl, requiredHelpers);
+
+    fprintf(out, "}\n");
+
+    return errors;
+}
+
 static const char*
 jni_type_name(java_type_t type)
 {
@@ -1521,7 +1951,11 @@
     fprintf(stderr, "  --namespace COMMA,SEP,NAMESPACE   required for cpp/header with module\n");
     fprintf(stderr, "                                    comma separated namespace of the files\n");
     fprintf(stderr, "  --importHeader NAME  required for cpp/jni to say which header to import\n");
-}
+    fprintf(stderr, "  --javaPackage PACKAGE             the package for the java file.\n");
+    fprintf(stderr, "                                    required for java with module\n");
+    fprintf(stderr, "  --javaClass CLASS    the class name of the java class.\n");
+    fprintf(stderr, "                       Optional for Java with module.\n");
+    fprintf(stderr, "                       Default is \"StatsLogInternal\"\n");}
 
 /**
  * Do the argument parsing and execute the tasks.
@@ -1537,6 +1971,8 @@
     string moduleName = DEFAULT_MODULE_NAME;
     string cppNamespace = DEFAULT_CPP_NAMESPACE;
     string cppHeaderImport = DEFAULT_CPP_HEADER_IMPORT;
+    string javaPackage = DEFAULT_JAVA_PACKAGE;
+    string javaClass = DEFAULT_JAVA_CLASS;
 
     int index = 1;
     while (index < argc) {
@@ -1592,6 +2028,20 @@
                 return 1;
             }
             cppHeaderImport = argv[index];
+        } else if (0 == strcmp("--javaPackage", argv[index])) {
+            index++;
+            if (index >= argc) {
+                print_usage();
+                return 1;
+            }
+            javaPackage = argv[index];
+        } else if (0 == strcmp("--javaClass", argv[index])) {
+            index++;
+            if (index >= argc) {
+                print_usage();
+                return 1;
+            }
+            javaClass = argv[index];
         }
         index++;
     }
@@ -1661,8 +2111,18 @@
             fprintf(stderr, "Unable to open file for write: %s\n", javaFilename.c_str());
             return 1;
         }
-        errorCount = android::stats_log_api_gen::write_stats_log_java(
-            out, atoms, attributionDecl);
+        // If this is for a specific module, the java package must also be provided.
+        if (moduleName != DEFAULT_MODULE_NAME && javaPackage== DEFAULT_JAVA_PACKAGE) {
+            fprintf(stderr, "Must supply --javaPackage if supplying a specific module\n");
+            return 1;
+        }
+        if (moduleName == DEFAULT_MODULE_NAME) {
+            errorCount = android::stats_log_api_gen::write_stats_log_java(
+                    out, atoms, attributionDecl);
+        } else {
+            errorCount = android::stats_log_api_gen::write_stats_log_java_for_module(
+                    out, atoms, attributionDecl, moduleName, javaClass, javaPackage);
+        }
         fclose(out);
     }
 
@@ -1678,7 +2138,7 @@
         fclose(out);
     }
 
-    return 0;
+    return errorCount;
 }
 
 }