Add --primary_abi_only to command_sheet

- Support comparing two cts-on-gsi reports even if the bitness of the
  primary ABIs is different.
- Move _ParseSummary from command_sheet to result_utils.

Bug: 79905934
Test: sheet --src= --ref= --dest= --primary_abi_only
Change-Id: Ifd47459556042a62c05ae5a0d20be9507fa2d562
diff --git a/harnesses/host_controller/command_processor/command_sheet.py b/harnesses/host_controller/command_processor/command_sheet.py
index 808ecc1..7937e90 100644
--- a/harnesses/host_controller/command_processor/command_sheet.py
+++ b/harnesses/host_controller/command_processor/command_sheet.py
@@ -38,6 +38,25 @@
 from host_controller.utils.parser import result_utils
 from host_controller.utils.parser import xml_utils
 
+# Attributes shown on spreadsheet
+_RESULT_ATTR_KEYS = [
+    common._SUITE_NAME_ATTR_KEY, common._SUITE_PLAN_ATTR_KEY,
+    common._SUITE_VERSION_ATTR_KEY, common._SUITE_BUILD_NUM_ATTR_KEY,
+    common._START_DISPLAY_TIME_ATTR_KEY,
+    common._END_DISPLAY_TIME_ATTR_KEY
+]
+
+_BUILD_ATTR_KEYS = [
+    common._FINGERPRINT_ATTR_KEY,
+    common._SYSTEM_FINGERPRINT_ATTR_KEY,
+    common._VENDOR_FINGERPRINT_ATTR_KEY
+]
+
+_SUMMARY_ATTR_KEYS = [
+    common._PASSED_ATTR_KEY, common._FAILED_ATTR_KEY,
+    common._MODULES_TOTAL_ATTR_KEY, common._MODULES_DONE_ATTR_KEY
+]
+
 # Texts on spreadsheet
 _TABLE_HEADER = ("BITNESS", "TEST_MODULE", "TEST_CLASS", "TEST_CASE", "RESULT")
 
@@ -92,6 +111,12 @@
             help="Maximum number of results written to the spreadsheet. "
             "If there are too many results, only failing ones are written.")
         self.arg_parser.add_argument(
+            "--primary_abi_only",
+            action="store_true",
+            help="Whether to upload only the test results for primary ABI. If "
+            "ref is also specified, this command loads the primary ABI "
+            "results from ref and compares regardless of bitness.")
+        self.arg_parser.add_argument(
             "--client_secrets",
             default=None,
             help="The path to the client secrets file in JSON format for "
@@ -120,7 +145,8 @@
                 scopes=self._SCOPE)
         client = gspread.authorize(credentials)
 
-        # Load summary_attrs, src_dict, ref_dict, and exceed_max
+        # Load result_attrs, build_attrs, summary_attrs,
+        # src_dict, ref_dict, and exceed_max
         temp_dir = tempfile.mkdtemp()
         try:
             src_path = _GetResultAsXml(src_path, os.path.join(temp_dir, "src"))
@@ -128,18 +154,18 @@
                 return False
 
             with open(src_path, "r") as src_file:
-                summary_attrs = _ParseSummary(src_file)
-                result_cnt = _GetResultCount(summary_attrs)
-                show_pass = result_cnt >= 0 and result_cnt <= args.max
-
+                (result_attrs,
+                 build_attrs,
+                 summary_attrs) = result_utils.LoadTestSummary(src_file)
                 src_file.seek(0)
-                src_dict = _FilterTestResults(
-                    src_file, args.max + 1,
-                    lambda name, result: show_pass or result != "pass")
-
-            exceed_max = len(src_dict) > args.max
-            if src_dict and exceed_max:
-                del src_dict[max(src_dict)]
+                if args.primary_abi_only:
+                    abis = build_attrs.get(
+                        common._ABIS_ATTR_KEY, "").split(",")
+                    src_bitness = str(result_utils.GetAbiBitness(abis[0]))
+                    src_dict, exceed_max = _LoadSrcResults(src_file, args.max,
+                                                           src_bitness)
+                else:
+                    src_dict, exceed_max = _LoadSrcResults(src_file, args.max)
 
             if ref_path:
                 ref_path = _GetResultAsXml(
@@ -147,9 +173,18 @@
                 if not ref_path:
                     return False
                 with open(ref_path, "r") as ref_file:
-                    ref_dict = _FilterTestResults(
-                        ref_file, args.max,
-                        lambda name, result: src_dict.get(name, "") == "fail")
+                    if args.primary_abi_only:
+                        ref_build_attrs = xml_utils.GetAttributes(
+                            ref_file, common._BUILD_TAG,
+                            (common._ABIS_ATTR_KEY, ))
+                        ref_file.seek(0)
+                        abis = ref_build_attrs[
+                            common._ABIS_ATTR_KEY].split(",")
+                        ref_bitness = str(result_utils.GetAbiBitness(abis[0]))
+                        ref_dict = _LoadRefResults(ref_file, src_dict,
+                                                   ref_bitness, src_bitness)
+                    else:
+                        ref_dict = _LoadRefResults(ref_file, src_dict)
         finally:
             shutil.rmtree(temp_dir)
 
@@ -159,7 +194,12 @@
             writer = csv.writer(csv_file, lineterminator="\n")
 
             writer.writerows(row.split(",") for row in args.extra_rows)
-            writer.writerows(summary_attrs)
+
+            for keys, attrs in (
+                    (_RESULT_ATTR_KEYS, result_attrs),
+                    (_BUILD_ATTR_KEYS, build_attrs),
+                    (_SUMMARY_ATTR_KEYS, summary_attrs)):
+                writer.writerows((k, attrs.get(k, "")) for k in keys)
 
             src_list = sorted(src_dict.items())
             if ref_path:
@@ -250,78 +290,8 @@
     return src
 
 
-def _ParseSummary(result_xml):
-    """Gets test summary from an XML.
-
-    Args:
-        result_xml: The input file object in XML format.
-
-    Returns:
-        A list of (attribute_name, value).
-    """
-    result_attr_keys = [
-        common._SUITE_NAME_ATTR_KEY, common._SUITE_PLAN_ATTR_KEY,
-        common._SUITE_VERSION_ATTR_KEY, common._SUITE_BUILD_NUM_ATTR_KEY,
-        common._START_DISPLAY_TIME_ATTR_KEY,
-        common._END_DISPLAY_TIME_ATTR_KEY
-    ]
-    build_attr_keys = [
-        common._FINGERPRINT_ATTR_KEY,
-        common._SYSTEM_FINGERPRINT_ATTR_KEY,
-        common._VENDOR_FINGERPRINT_ATTR_KEY
-    ]
-    summary_attr_keys = [
-        common._PASSED_ATTR_KEY, common._FAILED_ATTR_KEY,
-        common._MODULES_TOTAL_ATTR_KEY, common._MODULES_DONE_ATTR_KEY
-    ]
-    result_xml.seek(0)
-    result_attrs = xml_utils.GetAttributes(
-        result_xml, common._RESULT_TAG, result_attr_keys)
-    result_xml.seek(0)
-    build_attrs = xml_utils.GetAttributes(
-        result_xml, common._BUILD_TAG, build_attr_keys)
-    result_xml.seek(0)
-    summary_attrs = xml_utils.GetAttributes(
-        result_xml, common._SUMMARY_TAG, summary_attr_keys)
-
-    attr_list = []
-
-    for attr_keys, attrs in (
-            (result_attr_keys, result_attrs),
-            (build_attr_keys, build_attrs),
-            (summary_attr_keys, summary_attrs)):
-        for attr_key in attr_keys:
-            attr_list.append((attr_key, attrs.get(attr_key, "")))
-
-    return attr_list
-
-
-def _GetResultCount(attr_list):
-    """Gets total number of results from a test summary.
-
-    Args:
-        attr_list: A list of (attribute_name, value).
-
-    Returns:
-        An integer, number of results.
-        -1 if fails to parse the number.
-    """
-    try:
-        pass_cnt = next(v for k, v in attr_list if
-                        k == common._PASSED_ATTR_KEY)
-        fail_cnt = next(v for k, v in attr_list if
-                        k == common._FAILED_ATTR_KEY)
-    except StopIteration:
-        return -1
-
-    try:
-        return int(pass_cnt) + int(fail_cnt)
-    except ValueError:
-        return -1
-
-
 def _FilterTestResults(xml_file, max_return, filter_func):
-    """Converts a TradeFed report from XML to dictionary.
+    """Loads test results from XML to dictionary with a filter.
 
     Args:
         xml_file: The input file object in XML format.
@@ -345,6 +315,79 @@
     return result_dict
 
 
+def _LoadSrcResults(src_xml, max_return, bitness=""):
+    """Loads test results from XML to dictionary.
+
+    If number of results exceeds max_return, only failures are returned.
+    If number of failures exceeds max_return, the results are truncated.
+
+    Args
+        src_xml: The file object in XML format.
+        max_return: Maximum number of returned results.
+        bitness: A string, the bitness of the returned results.
+
+    Returns:
+        A dict of {name: result} and a boolean which represents whether the
+        results are truncated.
+    """
+    def FilterBitness(name):
+        return not bitness or bitness == name[0]
+
+    results = _FilterTestResults(
+        src_xml, max_return + 1, lambda name, result: FilterBitness(name))
+
+    if len(results) > max_return:
+        src_xml.seek(0)
+        results = _FilterTestResults(
+            src_xml, max_return + 1,
+            lambda name, result: result == "fail" and FilterBitness(name))
+
+    exceed_max = len(results) > max_return
+    if results and exceed_max:
+        del results[max(results)]
+
+    return results, exceed_max
+
+
+def _LoadRefResults(ref_xml, base_results, ref_bitness="", base_bitness=""):
+    """Loads reference results from XML to dictionary.
+
+    A test result in ref_xml is returned if the test fails in base_results.
+
+    Args:
+        ref_xml: The file object in XML format.
+        base_results: A dict of {name: result} containing the test names to be
+                      loaded from ref_xml.
+        ref_bitness: A string, the bitness of the results to be loaded from
+                     ref_xml.
+        base_bitness: A string, the bitness of the returned results. If this
+                      argument is specified, the function ignores bitness when
+                      comparing test names.
+
+    Returns:
+        A dict of {name: result}, the test name in base_results and the result
+        in ref_xml.
+    """
+    ref_results = dict()
+    for module, testcase, test in result_utils.IterateTestResults(ref_xml):
+        if len(ref_results) >= len(base_results):
+            break
+        result = test.attrib.get(common._RESULT_ATTR_KEY, "")
+        name = result_utils.GetTestName(module, testcase, test)
+
+        if ref_bitness and name[0] != ref_bitness:
+            continue
+        if base_bitness:
+            name_in_base = (base_bitness, ) + name[1:]
+        else:
+            name_in_base = name
+
+        if base_results.get(name_in_base, "") == "fail":
+            ref_results[name_in_base] = result
+
+    return ref_results
+
+
 def _WriteResultsToCsv(result_list, writer):
     """Writes a list of test names and results to a CSV file.
 
diff --git a/harnesses/host_controller/command_processor/command_sheet_test.py b/harnesses/host_controller/command_processor/command_sheet_test.py
index 22b5c9f..9338e56 100644
--- a/harnesses/host_controller/command_processor/command_sheet_test.py
+++ b/harnesses/host_controller/command_processor/command_sheet_test.py
@@ -121,6 +121,17 @@
 too many to be displayed
 """
 
+_PRIMARY_ABI_RESULTS_1 = _CSV_HEAD + """\
+pass,1
+failed,3
+modules_total,2
+modules_done,2
+BITNESS,TEST_MODULE,TEST_CLASS,TEST_CASE,RESULT
+64,module2,testcase2,test1,pass
+64,module2,testcase2,test2,fail
+64,module2,testcase2,test3,fail
+"""
+
 _COMPARISON_1_2 = _CSV_HEAD + """\
 pass,1
 failed,3
@@ -132,6 +143,17 @@
 64,module2,testcase2,test3,fail,fail
 """
 
+_PRIMARY_ABI_COMPARISON_1_2 = _CSV_HEAD + """\
+pass,1
+failed,3
+modules_total,2
+modules_done,2
+BITNESS,TEST_MODULE,TEST_CLASS,TEST_CASE,RESULT,REFERENCE_RESULT
+64,module2,testcase2,test1,pass,
+64,module2,testcase2,test2,fail,pass
+64,module2,testcase2,test3,fail,fail
+"""
+
 _COMPARISON_2_1 = _CSV_HEAD + """\
 pass,3
 failed,1
@@ -234,6 +256,18 @@
 
         mock_client.import_csv.assert_called_with("123", _TRUNCATED_RESULTS_1)
 
+    def testPrimaryAbiOnly(self, mock_credentials, mock_gspread):
+        """Tests showing only results for primary ABI."""
+        """Tests showing only failing tests."""
+        mock_client = mock.Mock()
+        mock_gspread.authorize.return_value = mock_client
+
+        self._cmd.Run("--src %s --dest 123 --client_secret /abc "
+                      "--extra_rows %s --primary_abi_only" %
+                      (self._CreateXml(_XML_1), " ".join(_EXTRA_ROWS)))
+
+        mock_client.import_csv.assert_called_with("123", _PRIMARY_ABI_RESULTS_1)
+
     def testCompareLocal(self, mock_credentials, mock_gspread):
         """Tests comparing two local XML files."""
         mock_client = mock.Mock()
@@ -246,6 +280,19 @@
 
         mock_client.import_csv.assert_called_with("123", _COMPARISON_1_2)
 
+    def testComparePrimaryAbi(self, mock_credentials, mock_gspread):
+        """Tests comparing primary ABI only."""
+        mock_client = mock.Mock()
+        mock_gspread.authorize.return_value = mock_client
+
+        self._cmd.Run("--src %s --ref %s --dest 123 --client_secret /abc "
+                      "--extra_rows %s --primary_abi_only" %
+                      (self._CreateXml(_XML_1), self._CreateZip(_XML_2),
+                       " ".join(_EXTRA_ROWS)))
+
+        mock_client.import_csv.assert_called_with("123",
+                                                  _PRIMARY_ABI_COMPARISON_1_2)
+
     @mock.patch("host_controller.command_processor.command_sheet.gcs_utils")
     def testCompareGcs(self, mock_gcs_utils, mock_credentials, mock_gspread):
         """Tests comparing a local XML with a ZIP on GCS."""
diff --git a/harnesses/host_controller/common.py b/harnesses/host_controller/common.py
index ed26c8f..7ea3f8c 100644
--- a/harnesses/host_controller/common.py
+++ b/harnesses/host_controller/common.py
@@ -148,6 +148,8 @@
 _SUITE_NAME_ATTR_KEY = "suite_name"
 
 # The key value for retrieving build fingerprint values from the result xml.
+_ABIS_ATTR_KEY = "build_abis"
+
 _FINGERPRINT_ATTR_KEY = "build_fingerprint"
 
 _SYSTEM_FINGERPRINT_ATTR_KEY = "build_system_fingerprint"
diff --git a/harnesses/host_controller/utils/parser/result_utils.py b/harnesses/host_controller/utils/parser/result_utils.py
index 98ff932..9b9b51e 100644
--- a/harnesses/host_controller/utils/parser/result_utils.py
+++ b/harnesses/host_controller/utils/parser/result_utils.py
@@ -50,6 +50,30 @@
         return os.path.join(output_dir, xml_name)
 
 
+def LoadTestSummary(result_xml):
+    """Gets attributes of <Result>, <Build>, and <Summary>.
+
+    Args:
+        result_xml: A file object of the TradeFed report in XML format.
+
+    Returns:
+        3 dictionaries, the attributes of <Result>, <Build>, and <Summary>.
+    """
+    result_attrib = {}
+    build_attrib = {}
+    summary_attrib = {}
+    for event, elem in ElementTree.iterparse(result_xml, events=("start", )):
+        if all((result_attrib, build_attrib, summary_attrib)):
+            break
+        if elem.tag == common._RESULT_TAG:
+            result_attrib = dict(elem.attrib)
+        elif elem.tag == common._BUILD_TAG:
+            build_attrib = dict(elem.attrib)
+        elif elem.tag == common._SUMMARY_TAG:
+            summary_attrib = dict(elem.attrib)
+    return result_attrib, build_attrib, summary_attrib
+
+
 def IterateTestResults(result_xml):
     """Yields test records in test_result.xml.
 
@@ -74,6 +98,18 @@
             yield module_elem, testcase_elem, elem
 
 
+def GetAbiBitness(abi):
+    """Gets bitness of an ABI.
+
+    Args:
+        abi: A string, the ABI name.
+
+    Returns:
+        32 or 64, the ABI bitness.
+    """
+    return 64 if "arm64" in abi or "x86_64" in abi else 32
+
+
 def GetTestName(module, testcase, test):
     """Gets the bitness and the full test name.
 
@@ -86,7 +122,7 @@
         A tuple of (bitness, module_name, testcase_name, test_name).
     """
     abi = module.attrib.get(common._ABI_ATTR_KEY, "")
-    bitness = "64" if "arm64" in abi or "x86_64" in abi else "32"
+    bitness = str(GetAbiBitness(abi))
     module_name = module.attrib.get(common._NAME_ATTR_KEY, "")
     testcase_name = testcase.attrib.get(common._NAME_ATTR_KEY, "")
     test_name = test.attrib.get(common._NAME_ATTR_KEY, "")