Add Regulatory Compliance Tests to ChannelSweepTests

Test: act.py -c <config> -tc ChannelSweepTest:test_regulatory_compliance
Bug: 46417
Change-Id: Ia066f7b969c3ae7be77d82d157bd1860c1a03fd3
diff --git a/acts/tests/google/fuchsia/wlan/ChannelSweepTest.py b/acts/tests/google/fuchsia/wlan/ChannelSweepTest.py
index 33ad874..e153259 100644
--- a/acts/tests/google/fuchsia/wlan/ChannelSweepTest.py
+++ b/acts/tests/google/fuchsia/wlan/ChannelSweepTest.py
@@ -63,6 +63,7 @@
 IPERF_NO_THROUGHPUT_VALUE = 0
 MAX_2_4_CHANNEL = 14
 TIME_TO_SLEEP_BETWEEN_RETRIES = 1
+TIME_TO_WAIT_FOR_COUNTRY_CODE = 10
 WEP_HEX_STRING_LENGTH = 10
 
 
@@ -72,7 +73,7 @@
 
 
 class ChannelSweepTest(WifiBaseTest):
-    """Tests channel performance.
+    """Tests channel performance and regulatory compliance..
 
     Testbed Requirement:
     * One ACTS compatible device (dut)
@@ -113,6 +114,21 @@
         self.access_point.stop_all_aps()
 
     def setup_test(self):
+        # TODO(fxb/46417): Uncomment when wlanClearCountry is implemented up any
+        # country code changes.
+        # for fd in self.fuchsia_devices:
+        #     phy_ids_response = fd.wlan_lib.wlanPhyIdList()
+        #     if phy_ids_response.get('error'):
+        #         raise ConnectionError(
+        #             'Failed to retrieve phy ids from FuchsiaDevice (%s). '
+        #             'Error: %s' % (fd.ip, phy_ids_response['error']))
+        #     for id in phy_ids_response['result']:
+        #         clear_country_response = fd.wlan_lib.wlanClearCountry(id)
+        #         if clear_country_response.get('error'):
+        #             raise EnvironmentError(
+        #                 'Failed to reset country code on FuchsiaDevice (%s). '
+        #                 'Error: %s' % (fd.ip, clear_country_response['error'])
+        #                 )
         self.access_point.stop_all_aps()
         for ad in self.android_devices:
             ad.droid.wakeLockAcquireBright()
@@ -132,6 +148,56 @@
         self.dut.take_bug_report(test_name, begin_time)
         self.dut.get_log(test_name, begin_time)
 
+    def set_dut_country_code(self, country_code):
+        """Set the country code on the DUT. Then verify that the country
+        code was set successfully
+
+        Args:
+            country_code: string, the 2 character country code to set
+        """
+        self.log.info('Setting DUT country code to %s' % country_code)
+        country_code_response = self.dut.device.regulatory_region_lib.setRegion(
+            country_code)
+        if country_code_response.get('error'):
+            raise EnvironmentError(
+                'Failed to set country code (%s) on DUT. Error: %s' %
+                (country_code, country_code_response['error']))
+
+        self.log.info('Verifying DUT country code was correctly set to %s.' %
+                      country_code)
+        phy_ids_response = self.dut.device.wlan_lib.wlanPhyIdList()
+        if phy_ids_response.get('error'):
+            raise ConnectionError('Failed to get phy ids from DUT. Error: %s' %
+                                  (country_code, phy_ids_response['error']))
+
+        end_time = time.time() + TIME_TO_WAIT_FOR_COUNTRY_CODE
+        while time.time() < end_time:
+            for id in phy_ids_response['result']:
+                get_country_response = self.dut.device.wlan_lib.wlanGetCountry(
+                    id)
+                if get_country_response.get('error'):
+                    raise ConnectionError(
+                        'Failed to query PHY ID (%s) for country. Error: %s' %
+                        (id, get_country_response['error']))
+
+                set_code = ''.join([
+                    chr(ascii_char)
+                    for ascii_char in get_country_response['result']
+                ])
+                if set_code != country_code:
+                    self.log.debug(
+                        'PHY (id: %s) has incorrect country code set. '
+                        'Expected: %s, Got: %s' % (id, country_code, set_code))
+                    break
+            else:
+                self.log.info('All PHYs have expected country code (%s)' %
+                              country_code)
+                break
+            time.sleep(TIME_TO_SLEEP_BETWEEN_RETRIES)
+        else:
+            raise EnvironmentError('Failed to set DUT country code to %s.' %
+                                   country_code)
+
     def setup_ap(self, channel, channel_bandwidth, security_profile=None):
         """Start network on AP with basic configuration.
 
@@ -380,9 +446,10 @@
         if tx_std_dev > max_std_dev or rx_std_dev > max_std_dev:
             asserts.fail(
                 'With %smhz channel bandwidth, throughput standard '
-                'deviation (tx: %s mb/s, rx: %s mb/s) exceeds max standard deviation'
-                ' (%s mb/s).' % (self.throughput_data['channel_bandwidth'],
-                                 tx_std_dev, rx_std_dev, max_std_dev))
+                'deviation (tx: %s mb/s, rx: %s mb/s) exceeds max standard '
+                'deviation (%s mb/s).' %
+                (self.throughput_data['channel_bandwidth'], tx_std_dev,
+                 rx_std_dev, max_std_dev))
         else:
             asserts.explicit_pass(
                 'Throughput standard deviation (tx: %s mb/s, rx: %s mb/s) '
@@ -412,6 +479,10 @@
                     test (in mb/s).
                 base_test_name (optional): string, test name prefix to use with
                     generated subtests.
+                country_name (optional): string, country name from
+                    hostapd_constants to set on device.
+                country_code (optional): string, two-char country code to set on
+                    the DUT. Takes priority over country_name.
                 test_name (debug tests only): string, the test name for this
                     parent test case from the config file. In explicit tests,
                     this is not necessary.
@@ -435,6 +506,7 @@
                         min_tx_throughput=2,
                         min_rx_throughput=4,
                         max_std_dev=0.75,
+                        country_code='US',
                         base_test_name='test_us'))
         """
         test_channels = settings['test_channels']
@@ -447,6 +519,17 @@
         min_rx_throughput = settings.get('min_rx_throughput',
                                          DEFAULT_MIN_THROUGHPUT)
         max_std_dev = settings.get('max_std_dev', DEFAULT_MAX_STD_DEV)
+        country_code = settings.get('country_code')
+        country_name = settings.get('country_name')
+        country_label = None
+
+        if country_code:
+            country_label = country_code
+            self.set_dut_country_code(country_code)
+        elif country_name:
+            country_label = country_name
+            code = hostapd_constants.COUNTRY_CODE[country_name]['country_code']
+            self.set_dut_country_code(code)
 
         self.throughput_data = {
             'test': test_name,
@@ -455,8 +538,9 @@
         }
         test_list = []
         for channel in test_channels:
-            sub_test_name = '%s_channel_%s_%smhz' % (base_test_name, channel,
-                                                     test_channel_bandwidth)
+            sub_test_name = '%schannel_%s_%smhz_performance' % (
+                '%s_' % country_label if country_label else '', channel,
+                test_channel_bandwidth)
             test_list.append({
                 'test_name': sub_test_name,
                 'channel': int(channel),
@@ -540,12 +624,109 @@
                       'Minimum threshold (tx, rx): (%s mb/s, %s mb/s)' %
                       (tx_throughput, rx_throughput, min_tx_throughput,
                        min_rx_throughput))
-        base_message = 'Actual throughput (on channel: %s, channel bandwidth: %s, security: %s)' % (
-            channel, channel_bandwidth, security)
-        if tx_throughput < min_tx_throughput or rx_throughput < min_rx_throughput:
+        base_message = 'Actual throughput (on channel: %s, channel bandwidth: '
+        '%s, security: %s)' % (channel, channel_bandwidth, security)
+        if (tx_throughput < min_tx_throughput
+                or rx_throughput < min_rx_throughput):
             asserts.fail('%s below the minimum threshold.' % base_message)
         asserts.explicit_pass('%s above the minimum threshold.' % base_message)
 
+    def verify_regulatory_compliance(self, settings):
+        """Test function for regulatory compliance tests. Verify device complies
+        with provided regulatory requirements.
+
+        Args:
+            settings: dict, containing the following test settings
+                test_channels: dict, mapping channels to a set of the channel
+                    bandwidths to test (see example for using JSON). Defaults
+                    to hostapd_constants.ALL_CHANNELS.
+                country_code: string, two-char country code to set on device
+                    (prioritized over country_name)
+                country_name: string, country name from hostapd_constants to set
+                    on device.
+                base_test_name (optional): string, test name prefix to use with
+                    generatedsubtests.
+                test_name: string, the test name for this
+                    parent test case from the config file. In explicit tests,
+                    this is not necessary.
+        """
+        country_name = settings.get('country_name')
+        country_code = settings.get('country_code')
+        if not (country_code or country_name):
+            raise ValueError('No country code or name provided.')
+
+        test_channels = settings.get('test_channels',
+                                     hostapd_constants.ALL_CHANNELS)
+        allowed_channels = settings['allowed_channels']
+
+        base_test_name = settings.get('base_test_name', 'test_compliance')
+
+        if country_code:
+            code = country_code
+        else:
+            code = hostapd_constants.COUNTRY_CODE[country_name]['country_code']
+
+        self.set_dut_country_code(code)
+
+        test_list = []
+        for channel in test_channels:
+            for channel_bandwidth in test_channels[channel]:
+                sub_test_name = '%s_channel_%s_%smhz' % (
+                    base_test_name, channel, channel_bandwidth)
+                should_associate = (
+                    channel in allowed_channels
+                    and channel_bandwidth in allowed_channels[channel])
+                # Note: these int conversions because when these tests are
+                # imported via JSON, they may be strings since the channels
+                # will be keys. This makes the json/list test_channels param
+                # behave exactly like the in code dict/set test_channels.
+                test_list.append({
+                    'country_code': code,
+                    'channel': int(channel),
+                    'channel_bandwidth': int(channel_bandwidth),
+                    'should_associate': should_associate,
+                    'test_name': sub_test_name
+                })
+        self.run_generated_testcases(test_func=self.verify_channel_compliance,
+                                     settings=test_list,
+                                     name_func=get_test_name)
+
+    def verify_channel_compliance(self, settings):
+        """Verify device complies with provided regulatory requirements for a
+        specific channel and channel bandwidth. Run with generated test cases
+        in the verify_regulatory_compliance parent test.
+_
+        Args:
+            settings: see verify_regulatory_compliance`
+        """
+        channel = settings['channel']
+        channel_bandwidth = settings['channel_bandwidth']
+        code = settings['country_code']
+        should_associate = settings['should_associate']
+
+        ssid = self.setup_ap(channel, channel_bandwidth)
+
+        self.log.info(
+            'Attempting to associate with network (%s) on channel %s @ %smhz. '
+            'Expected behavior: %s' %
+            (ssid, channel, channel_bandwidth, 'Device should associate'
+             if should_associate else 'Device should NOT associate.'))
+
+        associated = wlan_utils.associate(client=self.dut, ssid=ssid)
+        if associated == should_associate:
+            asserts.explicit_pass(
+                'Device complied with %s regulatory requirement for channel %s '
+                ' with channel bandwidth %smhz. %s' %
+                (code, channel, channel_bandwidth,
+                 'Associated.' if associated else 'Refused to associate.'))
+        else:
+            asserts.fail(
+                'Device failed compliance with regulatory domain %s for '
+                'channel %s with channel bandwidth %smhz. Expected: %s, Got: %s'
+                % (code, channel, channel_bandwidth, 'Should associate'
+                   if should_associate else 'Should not associate',
+                   'Associated' if associated else 'Did not associate'))
+
     # Helper functions to allow explicit tests throughput and standard deviation
     # thresholds to be passed in via config.
     def _get_min_tx_throughput(self, test_name):
@@ -804,3 +985,36 @@
         self.run_generated_testcases(self.run_channel_performance_tests,
                                      settings=base_tests,
                                      name_func=get_test_name)
+
+    def test_regulatory_compliance(self):
+        """Run regulatory compliance test case from the ACTS config file.
+        Note: only one country_name OR country_code is required.
+
+        Example:
+        "channel_sweep_test_params": {
+            "regulatory_compliance_tests": [
+                {
+                    "test_name": "test_japan_compliance_1_13_36"
+                    "country_name": "JAPAN",
+                    "country_code": "JP",
+                    "test_channels": {
+                        "1": [20, 40], "13": [40], "36": [20, 40, 80]
+                    },
+                    "allowed_channels": {
+                        "1": [20, 40], "36": [20, 40, 80]
+                    },
+                    "base_test_name": "test_japan"
+                },
+                ...
+            ]
+        }
+        """
+        asserts.skip_if(
+            'regulatory_compliance_tests' not in self.user_params.get(
+                'channel_sweep_test_params', {}),
+            'No custom regulatory compliance tests provided in config.')
+        base_tests = self.user_params['channel_sweep_test_params'][
+            'regulatory_compliance_tests']
+        self.run_generated_testcases(self.verify_regulatory_compliance,
+                                     settings=base_tests,
+                                     name_func=get_test_name)