[3.10] bpo-44242: [Enum] remove missing bits test from Flag creation (GH-26586) (GH-26635)
Move the check for missing named flags in flag aliases from Flag creation
to a new *verify* decorator..
(cherry picked from commit eea8148b7dff5ffc7b84433859ac819b1d92a74d)
Co-authored-by: Ethan Furman <ethan@stoneleaf.us>
diff --git a/Lib/enum.py b/Lib/enum.py
index 01f4310..f74cc8c 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -6,10 +6,10 @@
__all__ = [
'EnumType', 'EnumMeta',
'Enum', 'IntEnum', 'StrEnum', 'Flag', 'IntFlag',
- 'auto', 'unique',
- 'property',
+ 'auto', 'unique', 'property', 'verify',
'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP',
'global_flag_repr', 'global_enum_repr', 'global_enum',
+ 'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE',
]
@@ -89,6 +89,9 @@ def _break_on_call_reduce(self, proto):
setattr(obj, '__module__', '<unknown>')
def _iter_bits_lsb(num):
+ # num must be an integer
+ if isinstance(num, Enum):
+ num = num.value
while num:
b = num & (~num + 1)
yield b
@@ -538,13 +541,6 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
else:
# multi-bit flags are considered aliases
multi_bit_total |= flag_value
- if enum_class._boundary_ is not KEEP:
- missed = list(_iter_bits_lsb(multi_bit_total & ~single_bit_total))
- if missed:
- raise TypeError(
- 'invalid Flag %r -- missing values: %s'
- % (cls, ', '.join((str(i) for i in missed)))
- )
enum_class._flag_mask_ = single_bit_total
#
# set correct __iter__
@@ -688,7 +684,10 @@ def __members__(cls):
return MappingProxyType(cls._member_map_)
def __repr__(cls):
- return "<enum %r>" % cls.__name__
+ if Flag is not None and issubclass(cls, Flag):
+ return "<flag %r>" % cls.__name__
+ else:
+ return "<enum %r>" % cls.__name__
def __reversed__(cls):
"""
@@ -1303,7 +1302,8 @@ def __invert__(self):
else:
# calculate flags not in this member
self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
- self._inverted_._inverted_ = self
+ if isinstance(self._inverted_, self.__class__):
+ self._inverted_._inverted_ = self
return self._inverted_
@@ -1561,6 +1561,91 @@ def convert_class(cls):
return enum_class
return convert_class
+@_simple_enum(StrEnum)
+class EnumCheck:
+ """
+ various conditions to check an enumeration for
+ """
+ CONTINUOUS = "no skipped integer values"
+ NAMED_FLAGS = "multi-flag aliases may not contain unnamed flags"
+ UNIQUE = "one name per value"
+CONTINUOUS, NAMED_FLAGS, UNIQUE = EnumCheck
+
+
+class verify:
+ """
+ Check an enumeration for various constraints. (see EnumCheck)
+ """
+ def __init__(self, *checks):
+ self.checks = checks
+ def __call__(self, enumeration):
+ checks = self.checks
+ cls_name = enumeration.__name__
+ if Flag is not None and issubclass(enumeration, Flag):
+ enum_type = 'flag'
+ elif issubclass(enumeration, Enum):
+ enum_type = 'enum'
+ else:
+ raise TypeError("the 'verify' decorator only works with Enum and Flag")
+ for check in checks:
+ if check is UNIQUE:
+ # check for duplicate names
+ duplicates = []
+ for name, member in enumeration.__members__.items():
+ if name != member.name:
+ duplicates.append((name, member.name))
+ if duplicates:
+ alias_details = ', '.join(
+ ["%s -> %s" % (alias, name) for (alias, name) in duplicates])
+ raise ValueError('aliases found in %r: %s' %
+ (enumeration, alias_details))
+ elif check is CONTINUOUS:
+ values = set(e.value for e in enumeration)
+ if len(values) < 2:
+ continue
+ low, high = min(values), max(values)
+ missing = []
+ if enum_type == 'flag':
+ # check for powers of two
+ for i in range(_high_bit(low)+1, _high_bit(high)):
+ if 2**i not in values:
+ missing.append(2**i)
+ elif enum_type == 'enum':
+ # check for powers of one
+ for i in range(low+1, high):
+ if i not in values:
+ missing.append(i)
+ else:
+ raise Exception('verify: unknown type %r' % enum_type)
+ if missing:
+ raise ValueError('invalid %s %r: missing values %s' % (
+ enum_type, cls_name, ', '.join((str(m) for m in missing)))
+ )
+ elif check is NAMED_FLAGS:
+ # examine each alias and check for unnamed flags
+ member_names = enumeration._member_names_
+ member_values = [m.value for m in enumeration]
+ missing = []
+ for name, alias in enumeration._member_map_.items():
+ if name in member_names:
+ # not an alias
+ continue
+ values = list(_iter_bits_lsb(alias.value))
+ missed = [v for v in values if v not in member_values]
+ if missed:
+ plural = ('', 's')[len(missed) > 1]
+ a = ('a ', '')[len(missed) > 1]
+ missing.append('%r is missing %snamed flag%s for value%s %s' % (
+ name, a, plural, plural,
+ ', '.join(str(v) for v in missed)
+ ))
+ if missing:
+ raise ValueError(
+ 'invalid Flag %r: %s'
+ % (cls_name, '; '.join(missing))
+ )
+ return enumeration
+
def _test_simple_enum(checked_enum, simple_enum):
"""
A function that can be used to test an enum created with :func:`_simple_enum`