blob: 6bc7baa565a391c5f285125d5a784e292b780db0 [file] [log] [blame]
jschung25850f02020-06-17 14:38:11 +09001# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Common functions for the stress tester."""
16
17import logging
18import os
19
20from absl import flags
21import stress_test_pb2
22from google.protobuf import text_format
23
24FLAGS = flags.FLAGS
25
26flags.DEFINE_string("resource_path", None,
27 "Optional override path where to grab resources from. By "
28 "default, resources are grabbed from "
29 "stress_test_common.RESOURCE_DIR, specifying this flag "
30 "will instead result in first looking in this path before "
31 "the module defined resource directory.")
32
33RESOURCE_DIR = "resources/"
34
35
36def MakeDirsIfNeeded(path):
37 """Helper function to create all the directories on a path."""
38 if not os.path.isdir(path):
39 os.makedirs(path)
40
41
42def GetResourceContents(resource_name):
43 """Gets a string containing the named resource."""
44 # Look in the resource override folder first (just go with the basename to
45 # find the file, rather than the full path).
46 if FLAGS.resource_path:
47 path = os.path.join(FLAGS.resource_path, os.path.basename(resource_name))
48 if os.path.exists(path):
49 return open(path, "rb").read()
50
51 # If the full path exists, grab that, otherwise fall back to the basename.
52 if os.path.exists(resource_name):
53 return open(resource_name, "rb").read()
54 return open(os.path.join(RESOURCE_DIR, os.path.basename(resource_name)),
55 "rb").read()
56
57
58def LoadDeviceConfig(device_type, serial_number):
59 """Assembles a DeviceConfig proto following all includes, or the default."""
60
61 config = stress_test_pb2.DeviceConfig()
62 text_format.Merge(GetResourceContents(
63 os.path.join(RESOURCE_DIR, "device_config.common.ascii_proto")), config)
64 def RecursiveIncludeToConfig(resource_prefix, print_error):
65 """Load configurations recursively."""
66 try:
67 new_config = stress_test_pb2.DeviceConfig()
68 text_format.Merge(GetResourceContents(
69 os.path.join(RESOURCE_DIR,
70 "device_config.%s.ascii_proto" % resource_prefix)),
71 new_config)
72 for include_name in new_config.include:
73 # If we've managed to import this level properly, then we should print
74 # out any errors if we hit them on the included files.
75 RecursiveIncludeToConfig(include_name, print_error=True)
76 config.MergeFrom(new_config)
77 except IOError as err:
78 if print_error:
79 logging.error(str(err))
80
81 RecursiveIncludeToConfig(device_type, print_error=True)
82 RecursiveIncludeToConfig(serial_number, print_error=False)
83
84 def TakeOnlyLatestFromRepeatedField(message, field, key):
85 """Take only the latest version."""
86 old_list = list(getattr(message, field))
87 message.ClearField(field)
88 new_list = []
89 for i in range(len(old_list) - 1, -1, -1):
90 element = old_list[i]
91 if not any([getattr(x, key) == getattr(element, key)
92 for x in old_list[i + 1:]]):
93 new_list.append(element)
94 getattr(message, field).extend(reversed(new_list))
95
96 # We actually need to do a bit of post-processing on the proto - we only want
97 # to take the latest version for each (that way people can override stuff if
98 # they want)
99 TakeOnlyLatestFromRepeatedField(config, "file_to_watch", "source")
100 TakeOnlyLatestFromRepeatedField(config, "file_to_move", "source")
101 TakeOnlyLatestFromRepeatedField(config, "event", "name")
102 TakeOnlyLatestFromRepeatedField(config, "daemon_process", "name")
103
104 return config