blob: 07da02da206c516b87decbbf4703d7894ad938d6 [file] [log] [blame]
Tor Norbye3a2425a2013-11-04 10:16:08 -08001import traceback, sys
2from unittest import TestResult
3import datetime
4
5from tcmessages import TeamcityServiceMessages
6
7PYTHON_VERSION_MAJOR = sys.version_info[0]
8
9def strclass(cls):
10 if not cls.__name__:
11 return cls.__module__
12 return "%s.%s" % (cls.__module__, cls.__name__)
13
14def smart_str(s):
15 encoding='utf-8'
16 errors='strict'
17 if PYTHON_VERSION_MAJOR < 3:
18 is_string = isinstance(s, basestring)
19 else:
20 is_string = isinstance(s, str)
21 if not is_string:
22 try:
23 return str(s)
24 except UnicodeEncodeError:
25 if isinstance(s, Exception):
26 # An Exception subclass containing non-ASCII data that doesn't
27 # know how to print itself properly. We shouldn't raise a
28 # further exception.
29 return ' '.join([smart_str(arg) for arg in s])
30 return unicode(s).encode(encoding, errors)
31 elif isinstance(s, unicode):
32 return s.encode(encoding, errors)
33 else:
34 return s
35
36class TeamcityTestResult(TestResult):
37 def __init__(self, stream=sys.stdout, *args, **kwargs):
38 TestResult.__init__(self)
39 for arg, value in kwargs.items():
40 setattr(self, arg, value)
41 self.output = stream
42 self.messages = TeamcityServiceMessages(self.output, prepend_linebreak=True)
43 self.messages.testMatrixEntered()
44 self.current_suite = None
45
46 def find_first(self, val):
47 quot = val[0]
48 count = 1
49 quote_ind = val[count:].find(quot)
50 while val[count+quote_ind-1] == "\\" and quote_ind != -1:
51 count = count + quote_ind + 1
52 quote_ind = val[count:].find(quot)
53
54 return val[0:quote_ind+count+1]
55
56 def find_second(self, val):
57 val_index = val.find("!=")
58 if val_index != -1:
59 count = 1
60 val = val[val_index+2:].strip()
61 quot = val[0]
62 quote_ind = val[count:].find(quot)
63 while val[count+quote_ind-1] == "\\" and quote_ind != -1:
64 count = count + quote_ind + 1
65 quote_ind = val[count:].find(quot)
66 return val[0:quote_ind+count+1]
67
68 else:
69 quot = val[-1]
70 count = 0
71 quote_ind = val[:len(val)-count-1].rfind(quot)
72 while val[quote_ind-1] == "\\":
73 quote_ind = val[:quote_ind-1].rfind(quot)
74 return val[quote_ind:]
75
76 def formatErr(self, err):
77 exctype, value, tb = err
78 return ''.join(traceback.format_exception(exctype, value, tb))
79
80 def getTestName(self, test):
81 if hasattr(test, '_testMethodName'):
82 if test._testMethodName == "runTest":
83 return str(test)
84 return test._testMethodName
85 else:
86 test_name = str(test)
87 whitespace_index = test_name.index(" ")
88 if whitespace_index != -1:
89 test_name = test_name[:whitespace_index]
90 return test_name
91
92 def getTestId(self, test):
93 return test.id
94
95 def addSuccess(self, test):
96 TestResult.addSuccess(self, test)
97
98 def addError(self, test, err):
99 TestResult.addError(self, test, err)
100
101 err = self._exc_info_to_string(err, test)
102
103 self.messages.testError(self.getTestName(test),
104 message='Error', details=err)
105
106 def find_error_value(self, err):
107 error_value = traceback.extract_tb(err)
108 error_value = error_value[-1][-1]
109 return error_value.split('assert')[-1].strip()
110
111 def addFailure(self, test, err):
112 TestResult.addFailure(self, test, err)
113
114 error_value = smart_str(err[1])
115 if not len(error_value):
116 # means it's test function and we have to extract value from traceback
117 error_value = self.find_error_value(err[2])
118
119 self_find_first = self.find_first(error_value)
120 self_find_second = self.find_second(error_value)
121 quotes = ["'", '"']
122 if (self_find_first[0] == self_find_first[-1] and self_find_first[0] in quotes and
123 self_find_second[0] == self_find_second[-1] and self_find_second[0] in quotes):
124 # let's unescape strings to show sexy multiline diff in PyCharm.
125 # By default all caret return chars are escaped by testing framework
126 first = self._unescape(self_find_first)
127 second = self._unescape(self_find_second)
128 else:
129 first = second = ""
130 err = self._exc_info_to_string(err, test)
131
132 self.messages.testFailed(self.getTestName(test),
133 message='Failure', details=err, expected=first, actual=second)
134
135 def addSkip(self, test, reason):
136 self.messages.testIgnored(self.getTestName(test), message=reason)
137
138 def __getSuite(self, test):
139 if hasattr(test, "suite"):
140 suite = strclass(test.suite)
141 suite_location = test.suite.location
142 location = test.suite.abs_location
143 if hasattr(test, "lineno"):
144 location = location + ":" + str(test.lineno)
145 else:
146 location = location + ":" + str(test.test.lineno)
147 else:
148 import inspect
149
150 try:
151 source_file = inspect.getsourcefile(test.__class__)
152 if source_file:
153 source_dir_splitted = source_file.split("/")[:-1]
154 source_dir = "/".join(source_dir_splitted) + "/"
155 else:
156 source_dir = ""
157 except TypeError:
158 source_dir = ""
159
160 suite = strclass(test.__class__)
161 suite_location = "python_uttestid://" + source_dir + suite
162 location = "python_uttestid://" + source_dir + str(test.id())
163
164 return (suite, location, suite_location)
165
166 def startTest(self, test):
167 suite, location, suite_location = self.__getSuite(test)
168 if suite != self.current_suite:
169 if self.current_suite:
170 self.messages.testSuiteFinished(self.current_suite)
171 self.current_suite = suite
172 self.messages.testSuiteStarted(self.current_suite, location=suite_location)
173 setattr(test, "startTime", datetime.datetime.now())
174 self.messages.testStarted(self.getTestName(test), location=location)
175
176 def stopTest(self, test):
177 start = getattr(test, "startTime", datetime.datetime.now())
178 d = datetime.datetime.now() - start
179 duration=d.microseconds / 1000 + d.seconds * 1000 + d.days * 86400000
180 self.messages.testFinished(self.getTestName(test), duration=int(duration))
181
182 def endLastSuite(self):
183 if self.current_suite:
184 self.messages.testSuiteFinished(self.current_suite)
185 self.current_suite = None
186
187 def _unescape(self, text):
188 # do not use text.decode('string_escape'), it leads to problems with different string encodings given
189 return text.replace("\\n", "\n")
190
191class TeamcityTestRunner(object):
192 def __init__(self, stream=sys.stdout):
193 self.stream = stream
194
195 def _makeResult(self, **kwargs):
196 return TeamcityTestResult(self.stream, **kwargs)
197
198 def run(self, test, **kwargs):
199 result = self._makeResult(**kwargs)
200 result.messages.testCount(test.countTestCases())
201 test(result)
202 result.endLastSuite()
203 return result