blob: f52cd3b02906e5fe7fd1c0064367440ac66ccd3c [file] [log] [blame]
jamesrencd7a81a2010-04-21 20:39:08 +00001import operator, unittest
Alex Miller4f341702013-03-25 12:39:12 -07002import json
jamesrencd7a81a2010-04-21 20:39:08 +00003from django.test import client
4from autotest_lib.frontend.afe import frontend_test_utils, models as afe_models
5
6class ResourceTestCase(unittest.TestCase,
7 frontend_test_utils.FrontendTestMixin):
8 URI_PREFIX = None # subclasses may override this to use partial URIs
9
10 def setUp(self):
11 super(ResourceTestCase, self).setUp()
12 self._frontend_common_setup()
13 self._setup_debug_user()
14 self.client = client.Client()
15
16
17 def tearDown(self):
18 super(ResourceTestCase, self).tearDown()
19 self._frontend_common_teardown()
20
21
22 def _setup_debug_user(self):
23 user = afe_models.User.objects.create(login='debug_user')
24 acl = afe_models.AclGroup.objects.get(name='my_acl')
25 user.aclgroup_set.add(acl)
26
27
28 def _expected_status(self, method):
29 if method == 'post':
30 return 201
31 if method == 'delete':
32 return 204
33 return 200
34
35
36 def raw_request(self, method, uri, **kwargs):
37 method = method.lower()
38 if method == 'put':
39 # the put() implementation in Django's test client is poorly
40 # implemented and only supports url-encoded keyvals for the data.
41 # the post() implementation is correct, though, so use that, with a
42 # trick to override the method.
43 method = 'post'
44 kwargs['REQUEST_METHOD'] = 'PUT'
45
46 client_method = getattr(self.client, method)
47 return client_method(uri, **kwargs)
48
49
50 def request(self, method, uri, encode_body=True, **kwargs):
51 expected_status = self._expected_status(method)
52
53 if 'data' in kwargs:
54 kwargs.setdefault('content_type', 'application/json')
55 if kwargs['content_type'] == 'application/json':
Alex Miller4f341702013-03-25 12:39:12 -070056 kwargs['data'] = json.dumps(kwargs['data'])
jamesrencd7a81a2010-04-21 20:39:08 +000057
58 if uri.startswith('http://'):
59 full_uri = uri
60 else:
61 assert self.URI_PREFIX
62 full_uri = self.URI_PREFIX + '/' + uri
63
64 response = self.raw_request(method, full_uri, **kwargs)
65 self.assertEquals(
66 response.status_code, expected_status,
67 'Requesting %s\nExpected %s, got %s: %s (headers: %s)'
68 % (full_uri, expected_status, response.status_code,
69 response.content, response._headers))
70
71 if response['content-type'] != 'application/json':
72 return response.content
73
74 try:
Alex Miller4f341702013-03-25 12:39:12 -070075 return json.loads(response.content)
jamesrencd7a81a2010-04-21 20:39:08 +000076 except ValueError:
77 self.fail('Invalid reponse body: %s' % response.content)
78
79
80 def sorted_by(self, collection, attribute):
81 return sorted(collection, key=operator.itemgetter(attribute))
82
83
84 def _read_attribute(self, item, attribute_or_list):
85 if isinstance(attribute_or_list, basestring):
86 attribute_or_list = [attribute_or_list]
87 for attribute in attribute_or_list:
88 item = item[attribute]
89 return item
90
91
92 def check_collection(self, collection, attribute_or_list, expected_list,
93 length=None, check_number=None):
94 """Check the members of a collection of dicts.
95
96 @param collection: an iterable of dicts
97 @param attribute_or_list: an attribute or list of attributes to read.
98 the results will be sorted and compared with expected_list. if
99 a list of attributes is given, the attributes will be read
100 hierarchically, i.e. item[attribute1][attribute2]...
101 @param expected_list: list of expected values
102 @param check_number: if given, only check this number of entries
103 @param length: expected length of list, only necessary if check_number
104 is given
105 """
106 actual_list = sorted(self._read_attribute(item, attribute_or_list)
107 for item in collection['members'])
108 if length is None and check_number is None:
109 length = len(expected_list)
110 if length is not None:
111 self.assertEquals(len(actual_list), length,
112 'Expected %s, got %s: %s'
113 % (length, len(actual_list),
114 ', '.join(str(item) for item in actual_list)))
115 if check_number:
116 actual_list = actual_list[:check_number]
117 self.assertEquals(actual_list, expected_list)
118
119
120 def check_relationship(self, resource_uri, relationship_name,
121 other_entry_name, field, expected_values,
122 length=None, check_number=None):
123 """Check the members of a relationship collection.
124
125 @param resource_uri: URI of base resource
126 @param relationship_name: name of relationship attribute on base
127 resource
128 @param other_entry_name: name of other entry in relationship
129 @param field: name of field to grab on other entry
130 @param expected values: list of expected values for the given field
131 """
132 response = self.request('get', resource_uri)
133 relationship_uri = response[relationship_name]['href']
134 relationships = self.request('get', relationship_uri)
135 self.check_collection(relationships, [other_entry_name, field],
136 expected_values, length, check_number)