Snapshot idea/138.1503 from git://git.jetbrains.org/idea/community.git

Change-Id: Ie01af1d8710ec0ff51d90301bda1a18b0b5c0faf
diff --git a/python/helpers/pycharm/tcunittest.py b/python/helpers/pycharm/tcunittest.py
index b6950c9..99b3059 100644
--- a/python/helpers/pycharm/tcunittest.py
+++ b/python/helpers/pycharm/tcunittest.py
@@ -6,14 +6,16 @@
 
 PYTHON_VERSION_MAJOR = sys.version_info[0]
 
+
 def strclass(cls):
   if not cls.__name__:
     return cls.__module__
   return "%s.%s" % (cls.__module__, cls.__name__)
 
+
 def smart_str(s):
-  encoding='utf-8'
-  errors='strict'
+  encoding = 'utf-8'
+  errors = 'strict'
   if PYTHON_VERSION_MAJOR < 3:
     is_string = isinstance(s, basestring)
   else:
@@ -33,6 +35,7 @@
   else:
     return s
 
+
 class TeamcityTestResult(TestResult):
   def __init__(self, stream=sys.stdout, *args, **kwargs):
     TestResult.__init__(self)
@@ -41,42 +44,47 @@
     self.output = stream
     self.messages = TeamcityServiceMessages(self.output, prepend_linebreak=True)
     self.messages.testMatrixEntered()
+    self.current_failed = False
     self.current_suite = None
+    self.subtest_suite = None
 
   def find_first(self, val):
     quot = val[0]
     count = 1
     quote_ind = val[count:].find(quot)
-    while quote_ind != -1 and val[count+quote_ind-1] == "\\":
+    while quote_ind != -1 and val[count + quote_ind - 1] == "\\":
       count = count + quote_ind + 1
       quote_ind = val[count:].find(quot)
 
-    return val[0:quote_ind+count+1]
+    return val[0:quote_ind + count + 1]
 
   def find_second(self, val):
     val_index = val.find("!=")
     if val_index != -1:
       count = 1
-      val = val[val_index+2:].strip()
+      val = val[val_index + 2:].strip()
       quot = val[0]
       quote_ind = val[count:].find(quot)
-      while quote_ind != -1 and val[count+quote_ind-1] == "\\":
+      while quote_ind != -1 and val[count + quote_ind - 1] == "\\":
         count = count + quote_ind + 1
         quote_ind = val[count:].find(quot)
-      return val[0:quote_ind+count+1]
+      return val[0:quote_ind + count + 1]
 
     else:
       quot = val[-1]
-      quote_ind = val[:len(val)-1].rfind(quot)
-      while quote_ind != -1 and val[quote_ind-1] == "\\":
-        quote_ind = val[:quote_ind-1].rfind(quot)
+      quote_ind = val[:len(val) - 1].rfind(quot)
+      while quote_ind != -1 and val[quote_ind - 1] == "\\":
+        quote_ind = val[:quote_ind - 1].rfind(quot)
       return val[quote_ind:]
 
   def formatErr(self, err):
     exctype, value, tb = err
     return ''.join(traceback.format_exception(exctype, value, tb))
 
-  def getTestName(self, test):
+  def getTestName(self, test, is_subtest=False):
+    if is_subtest:
+      test_name = self.getTestName(test.test_case)
+      return "{} {}".format(test_name, test._subDescription())
     if hasattr(test, '_testMethodName'):
       if test._testMethodName == "runTest":
         return str(test)
@@ -95,10 +103,13 @@
     TestResult.addSuccess(self, test)
 
   def addError(self, test, err):
+    self.init_suite(test)
+    self.current_failed = True
     TestResult.addError(self, test, err)
 
     err = self._exc_info_to_string(err, test)
 
+    self.messages.testStarted(self.getTestName(test))
     self.messages.testError(self.getTestName(test),
                             message='Error', details=err)
 
@@ -108,6 +119,8 @@
     return error_value.split('assert')[-1].strip()
 
   def addFailure(self, test, err):
+    self.init_suite(test)
+    self.current_failed = True
     TestResult.addFailure(self, test, err)
 
     error_value = smart_str(err[1])
@@ -119,7 +132,7 @@
     self_find_second = self.find_second(error_value)
     quotes = ["'", '"']
     if (self_find_first[0] == self_find_first[-1] and self_find_first[0] in quotes and
-        self_find_second[0] == self_find_second[-1] and self_find_second[0] in quotes):
+            self_find_second[0] == self_find_second[-1] and self_find_second[0] in quotes):
       # let's unescape strings to show sexy multiline diff in PyCharm.
       # By default all caret return chars are escaped by testing framework
       first = self._unescape(self_find_first)
@@ -128,10 +141,13 @@
       first = second = ""
     err = self._exc_info_to_string(err, test)
 
+    self.messages.testStarted(self.getTestName(test))
     self.messages.testFailed(self.getTestName(test),
                              message='Failure', details=err, expected=first, actual=second)
 
   def addSkip(self, test, reason):
+    self.init_suite(test)
+    self.current_failed = True
     self.messages.testIgnored(self.getTestName(test), message=reason)
 
   def __getSuite(self, test):
@@ -149,10 +165,10 @@
       try:
         source_file = inspect.getsourcefile(test.__class__)
         if source_file:
-            source_dir_splitted = source_file.split("/")[:-1]
-            source_dir = "/".join(source_dir_splitted) + "/"
+          source_dir_splitted = source_file.split("/")[:-1]
+          source_dir = "/".join(source_dir_splitted) + "/"
         else:
-            source_dir = ""
+          source_dir = ""
       except TypeError:
         source_dir = ""
 
@@ -163,20 +179,52 @@
     return (suite, location, suite_location)
 
   def startTest(self, test):
+    self.current_failed = False
+    setattr(test, "startTime", datetime.datetime.now())
+
+  def init_suite(self, test):
     suite, location, suite_location = self.__getSuite(test)
     if suite != self.current_suite:
       if self.current_suite:
         self.messages.testSuiteFinished(self.current_suite)
       self.current_suite = suite
       self.messages.testSuiteStarted(self.current_suite, location=suite_location)
-    setattr(test, "startTime", datetime.datetime.now())
-    self.messages.testStarted(self.getTestName(test), location=location)
+    return location
 
   def stopTest(self, test):
     start = getattr(test, "startTime", datetime.datetime.now())
     d = datetime.datetime.now() - start
-    duration=d.microseconds / 1000 + d.seconds * 1000 + d.days * 86400000
-    self.messages.testFinished(self.getTestName(test), duration=int(duration))
+    duration = d.microseconds / 1000 + d.seconds * 1000 + d.days * 86400000
+    if not self.subtest_suite:
+      if not self.current_failed:
+        location = self.init_suite(test)
+        self.messages.testStarted(self.getTestName(test), location=location)
+        self.messages.testFinished(self.getTestName(test), duration=int(duration))
+    else:
+      self.messages.testSuiteFinished(self.subtest_suite)
+      self.subtest_suite = None
+
+
+  def addSubTest(self, test, subtest, err):
+    suite_name = self.getTestName(test)  # + " (subTests)"
+    if not self.subtest_suite:
+      self.subtest_suite = suite_name
+      self.messages.testSuiteStarted(self.subtest_suite)
+    else:
+      if suite_name != self.subtest_suite:
+        self.messages.testSuiteFinished(self.subtest_suite)
+        self.subtest_suite = suite_name
+        self.messages.testSuiteStarted(self.subtest_suite)
+
+    name = self.getTestName(subtest, True)
+    if err is not None:
+      error = self._exc_info_to_string(err, test)
+      self.messages.testStarted(name)
+      self.messages.testFailed(name, message='Failure', details=error)
+    else:
+      self.messages.testStarted(name)
+      self.messages.testFinished(name)
+
 
   def endLastSuite(self):
     if self.current_suite:
@@ -187,6 +235,7 @@
     # do not use text.decode('string_escape'), it leads to problems with different string encodings given
     return text.replace("\\n", "\n")
 
+
 class TeamcityTestRunner(object):
   def __init__(self, stream=sys.stdout):
     self.stream = stream