bpo-31690: Allow the inline flags "a", "L", and "u" to be used as group flags for RE. (#3885)

diff --git a/Lib/sre_compile.py b/Lib/sre_compile.py
index 144620c..e5216b7 100644
--- a/Lib/sre_compile.py
+++ b/Lib/sre_compile.py
@@ -62,6 +62,12 @@
 _ignorecase_fixes = {i: tuple(j for j in t if i != j)
                      for t in _equivalences for i in t}
 
+def _combine_flags(flags, add_flags, del_flags,
+                   TYPE_FLAGS=sre_parse.TYPE_FLAGS):
+    if add_flags & TYPE_FLAGS:
+        flags &= ~TYPE_FLAGS
+    return (flags | add_flags) & ~del_flags
+
 def _compile(code, pattern, flags):
     # internal: compile a (sub)pattern
     emit = code.append
@@ -87,15 +93,21 @@
                 emit(op)
                 emit(av)
             elif flags & SRE_FLAG_LOCALE:
-                emit(OP_LOC_IGNORE[op])
+                emit(OP_LOCALE_IGNORE[op])
                 emit(av)
             elif not iscased(av):
                 emit(op)
                 emit(av)
             else:
                 lo = tolower(av)
-                if fixes and lo in fixes:
-                    emit(IN_IGNORE)
+                if not fixes:  # ascii
+                    emit(OP_IGNORE[op])
+                    emit(lo)
+                elif lo not in fixes:
+                    emit(OP_UNICODE_IGNORE[op])
+                    emit(lo)
+                else:
+                    emit(IN_UNI_IGNORE)
                     skip = _len(code); emit(0)
                     if op is NOT_LITERAL:
                         emit(NEGATE)
@@ -104,17 +116,16 @@
                         emit(k)
                     emit(FAILURE)
                     code[skip] = _len(code) - skip
-                else:
-                    emit(OP_IGNORE[op])
-                    emit(lo)
         elif op is IN:
             charset, hascased = _optimize_charset(av, iscased, tolower, fixes)
             if flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE:
                 emit(IN_LOC_IGNORE)
-            elif hascased:
+            elif not hascased:
+                emit(IN)
+            elif not fixes:  # ascii
                 emit(IN_IGNORE)
             else:
-                emit(IN)
+                emit(IN_UNI_IGNORE)
             skip = _len(code); emit(0)
             _compile_charset(charset, flags, code)
             code[skip] = _len(code) - skip
@@ -153,8 +164,8 @@
             if group:
                 emit(MARK)
                 emit((group-1)*2)
-            # _compile_info(code, p, (flags | add_flags) & ~del_flags)
-            _compile(code, p, (flags | add_flags) & ~del_flags)
+            # _compile_info(code, p, _combine_flags(flags, add_flags, del_flags))
+            _compile(code, p, _combine_flags(flags, add_flags, del_flags))
             if group:
                 emit(MARK)
                 emit((group-1)*2+1)
@@ -210,10 +221,14 @@
                 av = CH_UNICODE[av]
             emit(av)
         elif op is GROUPREF:
-            if flags & SRE_FLAG_IGNORECASE:
-                emit(OP_IGNORE[op])
-            else:
+            if not flags & SRE_FLAG_IGNORECASE:
                 emit(op)
+            elif flags & SRE_FLAG_LOCALE:
+                emit(GROUPREF_LOC_IGNORE)
+            elif not fixes:  # ascii
+                emit(GROUPREF_IGNORE)
+            else:
+                emit(GROUPREF_UNI_IGNORE)
             emit(av-1)
         elif op is GROUPREF_EXISTS:
             emit(op)
@@ -240,7 +255,7 @@
             pass
         elif op is LITERAL:
             emit(av)
-        elif op is RANGE or op is RANGE_IGNORE:
+        elif op is RANGE or op is RANGE_UNI_IGNORE:
             emit(av[0])
             emit(av[1])
         elif op is CHARSET:
@@ -309,9 +324,9 @@
                     hascased = True
                     # There are only two ranges of cased non-BMP characters:
                     # 10400-1044F (Deseret) and 118A0-118DF (Warang Citi),
-                    # and for both ranges RANGE_IGNORE works.
+                    # and for both ranges RANGE_UNI_IGNORE works.
                     if op is RANGE:
-                        op = RANGE_IGNORE
+                        op = RANGE_UNI_IGNORE
                 tail.append((op, av))
             break
 
@@ -456,7 +471,7 @@
             prefixappend(av)
         elif op is SUBPATTERN:
             group, add_flags, del_flags, p = av
-            flags1 = (flags | add_flags) & ~del_flags
+            flags1 = _combine_flags(flags, add_flags, del_flags)
             if flags1 & SRE_FLAG_IGNORECASE and flags1 & SRE_FLAG_LOCALE:
                 break
             prefix1, prefix_skip1, got_all = _get_literal_prefix(p, flags1)
@@ -482,7 +497,7 @@
         if op is not SUBPATTERN:
             break
         group, add_flags, del_flags, pattern = av
-        flags = (flags | add_flags) & ~del_flags
+        flags = _combine_flags(flags, add_flags, del_flags)
         if flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE:
             return None
 
@@ -631,6 +646,7 @@
                 print_(op)
             elif op in (LITERAL, NOT_LITERAL,
                         LITERAL_IGNORE, NOT_LITERAL_IGNORE,
+                        LITERAL_UNI_IGNORE, NOT_LITERAL_UNI_IGNORE,
                         LITERAL_LOC_IGNORE, NOT_LITERAL_LOC_IGNORE):
                 arg = code[i]
                 i += 1
@@ -647,12 +663,12 @@
                 arg = str(CHCODES[arg])
                 assert arg[:9] == 'CATEGORY_'
                 print_(op, arg[9:])
-            elif op in (IN, IN_IGNORE, IN_LOC_IGNORE):
+            elif op in (IN, IN_IGNORE, IN_UNI_IGNORE, IN_LOC_IGNORE):
                 skip = code[i]
                 print_(op, skip, to=i+skip)
                 dis_(i+1, i+skip)
                 i += skip
-            elif op in (RANGE, RANGE_IGNORE):
+            elif op in (RANGE, RANGE_UNI_IGNORE):
                 lo, hi = code[i: i+2]
                 i += 2
                 print_(op, '%#02x %#02x (%r-%r)' % (lo, hi, chr(lo), chr(hi)))
@@ -671,7 +687,8 @@
                     print_2(_hex_code(code[i: i + 256//_CODEBITS]))
                     i += 256//_CODEBITS
                 level -= 1
-            elif op in (MARK, GROUPREF, GROUPREF_IGNORE):
+            elif op in (MARK, GROUPREF, GROUPREF_IGNORE, GROUPREF_UNI_IGNORE,
+                        GROUPREF_LOC_IGNORE):
                 arg = code[i]
                 i += 1
                 print_(op, arg)
diff --git a/Lib/sre_constants.py b/Lib/sre_constants.py
index 1daa7bd..13deb00 100644
--- a/Lib/sre_constants.py
+++ b/Lib/sre_constants.py
@@ -13,7 +13,7 @@
 
 # update when constants are added or removed
 
-MAGIC = 20170530
+MAGIC = 20171005
 
 from _sre import MAXREPEAT, MAXGROUPS
 
@@ -84,25 +84,37 @@
     CALL
     CATEGORY
     CHARSET BIGCHARSET
-    GROUPREF GROUPREF_EXISTS GROUPREF_IGNORE
-    IN IN_IGNORE
+    GROUPREF GROUPREF_EXISTS
+    IN
     INFO
     JUMP
-    LITERAL LITERAL_IGNORE
+    LITERAL
     MARK
     MAX_UNTIL
     MIN_UNTIL
-    NOT_LITERAL NOT_LITERAL_IGNORE
+    NOT_LITERAL
     NEGATE
     RANGE
     REPEAT
     REPEAT_ONE
     SUBPATTERN
     MIN_REPEAT_ONE
-    RANGE_IGNORE
+
+    GROUPREF_IGNORE
+    IN_IGNORE
+    LITERAL_IGNORE
+    NOT_LITERAL_IGNORE
+
+    GROUPREF_LOC_IGNORE
+    IN_LOC_IGNORE
     LITERAL_LOC_IGNORE
     NOT_LITERAL_LOC_IGNORE
-    IN_LOC_IGNORE
+
+    GROUPREF_UNI_IGNORE
+    IN_UNI_IGNORE
+    LITERAL_UNI_IGNORE
+    NOT_LITERAL_UNI_IGNORE
+    RANGE_UNI_IGNORE
 
     MIN_REPEAT MAX_REPEAT
 """)
@@ -113,7 +125,9 @@
     AT_BEGINNING AT_BEGINNING_LINE AT_BEGINNING_STRING
     AT_BOUNDARY AT_NON_BOUNDARY
     AT_END AT_END_LINE AT_END_STRING
+
     AT_LOC_BOUNDARY AT_LOC_NON_BOUNDARY
+
     AT_UNI_BOUNDARY AT_UNI_NON_BOUNDARY
 """)
 
@@ -123,7 +137,9 @@
     CATEGORY_SPACE CATEGORY_NOT_SPACE
     CATEGORY_WORD CATEGORY_NOT_WORD
     CATEGORY_LINEBREAK CATEGORY_NOT_LINEBREAK
+
     CATEGORY_LOC_WORD CATEGORY_LOC_NOT_WORD
+
     CATEGORY_UNI_DIGIT CATEGORY_UNI_NOT_DIGIT
     CATEGORY_UNI_SPACE CATEGORY_UNI_NOT_SPACE
     CATEGORY_UNI_WORD CATEGORY_UNI_NOT_WORD
@@ -133,18 +149,20 @@
 
 # replacement operations for "ignore case" mode
 OP_IGNORE = {
-    GROUPREF: GROUPREF_IGNORE,
-    IN: IN_IGNORE,
     LITERAL: LITERAL_IGNORE,
     NOT_LITERAL: NOT_LITERAL_IGNORE,
-    RANGE: RANGE_IGNORE,
 }
 
-OP_LOC_IGNORE = {
+OP_LOCALE_IGNORE = {
     LITERAL: LITERAL_LOC_IGNORE,
     NOT_LITERAL: NOT_LITERAL_LOC_IGNORE,
 }
 
+OP_UNICODE_IGNORE = {
+    LITERAL: LITERAL_UNI_IGNORE,
+    NOT_LITERAL: NOT_LITERAL_UNI_IGNORE,
+}
+
 AT_MULTILINE = {
     AT_BEGINNING: AT_BEGINNING_LINE,
     AT_END: AT_END_LINE
diff --git a/Lib/sre_parse.py b/Lib/sre_parse.py
index 5452520..8527412 100644
--- a/Lib/sre_parse.py
+++ b/Lib/sre_parse.py
@@ -65,8 +65,8 @@
     "u": SRE_FLAG_UNICODE,
 }
 
-GLOBAL_FLAGS = (SRE_FLAG_ASCII | SRE_FLAG_LOCALE | SRE_FLAG_UNICODE |
-                SRE_FLAG_DEBUG | SRE_FLAG_TEMPLATE)
+TYPE_FLAGS = SRE_FLAG_ASCII | SRE_FLAG_LOCALE | SRE_FLAG_UNICODE
+GLOBAL_FLAGS = SRE_FLAG_DEBUG | SRE_FLAG_TEMPLATE
 
 class Verbose(Exception):
     pass
@@ -822,7 +822,19 @@
     del_flags = 0
     if char != "-":
         while True:
-            add_flags |= FLAGS[char]
+            flag = FLAGS[char]
+            if source.istext:
+                if char == 'L':
+                    msg = "bad inline flags: cannot use 'L' flag with a str pattern"
+                    raise source.error(msg)
+            else:
+                if char == 'u':
+                    msg = "bad inline flags: cannot use 'u' flag with a bytes pattern"
+                    raise source.error(msg)
+            add_flags |= flag
+            if (flag & TYPE_FLAGS) and (add_flags & TYPE_FLAGS) != flag:
+                msg = "bad inline flags: flags 'a', 'u' and 'L' are incompatible"
+                raise source.error(msg)
             char = sourceget()
             if char is None:
                 raise source.error("missing -, : or )")
@@ -844,7 +856,11 @@
             msg = "unknown flag" if char.isalpha() else "missing flag"
             raise source.error(msg, len(char))
         while True:
-            del_flags |= FLAGS[char]
+            flag = FLAGS[char]
+            if flag & TYPE_FLAGS:
+                msg = "bad inline flags: cannot turn off flags 'a', 'u' and 'L'"
+                raise source.error(msg)
+            del_flags |= flag
             char = sourceget()
             if char is None:
                 raise source.error("missing :")
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index 9cb426a..fc015e4 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -1470,11 +1470,11 @@
             self.assertIsNone(pat.match(b'\xe0'))
         # Incompatibilities
         self.assertRaises(ValueError, re.compile, br'\w', re.UNICODE)
-        self.assertRaises(ValueError, re.compile, br'(?u)\w')
+        self.assertRaises(re.error, re.compile, br'(?u)\w')
         self.assertRaises(ValueError, re.compile, r'\w', re.UNICODE | re.ASCII)
         self.assertRaises(ValueError, re.compile, r'(?u)\w', re.ASCII)
         self.assertRaises(ValueError, re.compile, r'(?a)\w', re.UNICODE)
-        self.assertRaises(ValueError, re.compile, r'(?au)\w')
+        self.assertRaises(re.error, re.compile, r'(?au)\w')
 
     def test_locale_flag(self):
         import locale
@@ -1516,11 +1516,11 @@
             self.assertIsNone(pat.match(bletter))
         # Incompatibilities
         self.assertRaises(ValueError, re.compile, '', re.LOCALE)
-        self.assertRaises(ValueError, re.compile, '(?L)')
+        self.assertRaises(re.error, re.compile, '(?L)')
         self.assertRaises(ValueError, re.compile, b'', re.LOCALE | re.ASCII)
         self.assertRaises(ValueError, re.compile, b'(?L)', re.ASCII)
         self.assertRaises(ValueError, re.compile, b'(?a)', re.LOCALE)
-        self.assertRaises(ValueError, re.compile, b'(?aL)')
+        self.assertRaises(re.error, re.compile, b'(?aL)')
 
     def test_scoped_flags(self):
         self.assertTrue(re.match(r'(?i:a)b', 'Ab'))
@@ -1535,12 +1535,18 @@
         self.assertTrue(re.match(r'(?-x: a) b', ' ab', re.VERBOSE))
         self.assertIsNone(re.match(r'(?-x: a) b', 'ab', re.VERBOSE))
 
-        self.checkPatternError(r'(?a:\w)',
-                               'bad inline flags: cannot turn on global flag', 3)
+        self.assertTrue(re.match(r'\w(?a:\W)\w', '\xe0\xe0\xe0'))
+        self.assertTrue(re.match(r'(?a:\W(?u:\w)\W)', '\xe0\xe0\xe0'))
+        self.assertTrue(re.match(r'\W(?u:\w)\W', '\xe0\xe0\xe0', re.ASCII))
+
         self.checkPatternError(r'(?a)(?-a:\w)',
-                               'bad inline flags: cannot turn off global flag', 8)
+                "bad inline flags: cannot turn off flags 'a', 'u' and 'L'", 8)
         self.checkPatternError(r'(?i-i:a)',
-                               'bad inline flags: flag turned on and off', 5)
+                'bad inline flags: flag turned on and off', 5)
+        self.checkPatternError(r'(?au:a)',
+                "bad inline flags: flags 'a', 'u' and 'L' are incompatible", 4)
+        self.checkPatternError(br'(?aL:a)',
+                "bad inline flags: flags 'a', 'u' and 'L' are incompatible", 4)
 
         self.checkPatternError(r'(?-', 'missing flag', 3)
         self.checkPatternError(r'(?-+', 'missing flag', 3)