diff --git a/python-stdlib/functools/functools.py b/python-stdlib/functools/functools.py index b3c368e8a..22db605d7 100644 --- a/python-stdlib/functools/functools.py +++ b/python-stdlib/functools/functools.py @@ -26,3 +26,131 @@ def reduce(function, iterable, initializer=None): for element in it: value = function(value, element) return value + + +################################################################################ +### total_ordering class decorator +################################################################################ +# Extracted from the CPython v3.7.17 implementation of Lib/functools.py +# +# Written by Nick Coghlan , +# Raymond Hettinger , +# and Łukasz Langa . +# Copyright (C) 2006-2013 Python Software Foundation. +# See C source code for _functools credits/copyright + +# The total ordering functions all invoke the root magic method directly +# rather than using the corresponding operator. This avoids possible +# infinite recursion that could occur when the operator dispatch logic +# detects a NotImplemented result and then calls a reflected method. + +def _gt_from_lt(self, other, NotImplemented=NotImplemented): + 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).' + op_result = self.__lt__(other) + if op_result is NotImplemented: + return op_result + return not op_result and self != other + +def _le_from_lt(self, other, NotImplemented=NotImplemented): + 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).' + op_result = self.__lt__(other) + return op_result or self == other + +def _ge_from_lt(self, other, NotImplemented=NotImplemented): + 'Return a >= b. Computed by @total_ordering from (not a < b).' + op_result = self.__lt__(other) + if op_result is NotImplemented: + return op_result + return not op_result + +def _ge_from_le(self, other, NotImplemented=NotImplemented): + 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).' + op_result = self.__le__(other) + if op_result is NotImplemented: + return op_result + return not op_result or self == other + +def _lt_from_le(self, other, NotImplemented=NotImplemented): + 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).' + op_result = self.__le__(other) + if op_result is NotImplemented: + return op_result + return op_result and self != other + +def _gt_from_le(self, other, NotImplemented=NotImplemented): + 'Return a > b. Computed by @total_ordering from (not a <= b).' + op_result = self.__le__(other) + if op_result is NotImplemented: + return op_result + return not op_result + +def _lt_from_gt(self, other, NotImplemented=NotImplemented): + 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).' + op_result = self.__gt__(other) + if op_result is NotImplemented: + return op_result + return not op_result and self != other + +def _ge_from_gt(self, other, NotImplemented=NotImplemented): + 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).' + op_result = self.__gt__(other) + return op_result or self == other + +def _le_from_gt(self, other, NotImplemented=NotImplemented): + 'Return a <= b. Computed by @total_ordering from (not a > b).' + op_result = self.__gt__(other) + if op_result is NotImplemented: + return op_result + return not op_result + +def _le_from_ge(self, other, NotImplemented=NotImplemented): + 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).' + op_result = self.__ge__(other) + if op_result is NotImplemented: + return op_result + return not op_result or self == other + +def _gt_from_ge(self, other, NotImplemented=NotImplemented): + 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).' + op_result = self.__ge__(other) + if op_result is NotImplemented: + return op_result + return op_result and self != other + +def _lt_from_ge(self, other, NotImplemented=NotImplemented): + 'Return a < b. Computed by @total_ordering from (not a >= b).' + op_result = self.__ge__(other) + if op_result is NotImplemented: + return op_result + return not op_result + +_convert = { + '__lt__': [('__gt__', _gt_from_lt), + ('__le__', _le_from_lt), + ('__ge__', _ge_from_lt)], + '__le__': [('__ge__', _ge_from_le), + ('__lt__', _lt_from_le), + ('__gt__', _gt_from_le)], + '__gt__': [('__lt__', _lt_from_gt), + ('__ge__', _ge_from_gt), + ('__le__', _le_from_gt)], + '__ge__': [('__le__', _le_from_ge), + ('__gt__', _gt_from_ge), + ('__lt__', _lt_from_ge)] +} + +def total_ordering(cls): + """Class decorator that fills in missing ordering methods""" + # Find user-defined comparisons (not those inherited from object). + roots = {op for op in _convert if getattr(cls, op, None) is not getattr(object, op, None)} + if not roots: + raise ValueError('must define at least one ordering operation: < > <= >=') + root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ + for opname, opfunc in _convert[root]: + if opname not in roots: +# function objects have no attributes in micropython +# opfunc.__name__ = opname + setattr(cls, opname, opfunc) + return cls + + diff --git a/python-stdlib/functools/manifest.py b/python-stdlib/functools/manifest.py index 634413c1e..e3ceca892 100644 --- a/python-stdlib/functools/manifest.py +++ b/python-stdlib/functools/manifest.py @@ -1,3 +1,3 @@ -metadata(version="0.0.7") +metadata(version="0.0.8") module("functools.py") diff --git a/python-stdlib/functools/test_total_ordering.py b/python-stdlib/functools/test_total_ordering.py new file mode 100644 index 000000000..f5f56e95a --- /dev/null +++ b/python-stdlib/functools/test_total_ordering.py @@ -0,0 +1,230 @@ +""" Test code for functools.total_ordering + +Copyright © 2001-2023 Python Software Foundation. All rights reserved. + +This code was extracted from CPython v3.7.17 Lib/test/test_functools.py +""" + +import unittest + +import functools + +class TestTotalOrdering(unittest.TestCase): + + def test_total_ordering_lt(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __lt__(self, other): + return self.value < other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(1) > A(2)) + + def test_total_ordering_le(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __le__(self, other): + return self.value <= other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(1) >= A(2)) + + def test_total_ordering_gt(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __gt__(self, other): + return self.value > other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(2) < A(1)) + + def test_total_ordering_ge(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __ge__(self, other): + return self.value >= other.value + def __eq__(self, other): + return self.value == other.value + self.assertTrue(A(1) < A(2)) + self.assertTrue(A(2) > A(1)) + self.assertTrue(A(1) <= A(2)) + self.assertTrue(A(2) >= A(1)) + self.assertTrue(A(2) <= A(2)) + self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(2) <= A(1)) + +# This test does not appear to work due to the lack of attributes on builtin types. +# This appears to lead to the comparison operators not being inherited by default. +# def test_total_ordering_no_overwrite(self): +# # new methods should not overwrite existing +# import sys +# @functools.total_ordering +# class A(int): +# pass +# self.assertTrue(A(1) < A(2)) +# self.assertTrue(A(2) > A(1)) +# self.assertTrue(A(1) <= A(2)) +# self.assertTrue(A(2) >= A(1)) +# self.assertTrue(A(2) <= A(2)) +# self.assertTrue(A(2) >= A(2)) + + def test_no_operations_defined(self): + with self.assertRaises(ValueError): + @functools.total_ordering + class A: + pass + + def test_type_error_when_not_implemented(self): + # bug 10042; ensure stack overflow does not occur + # when decorated types return NotImplemented + @functools.total_ordering + class ImplementsLessThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThan): + return self.value == other.value + return False + def __lt__(self, other): + if isinstance(other, ImplementsLessThan): + return self.value < other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value == other.value + return False + def __gt__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value > other.value + return NotImplemented + + @functools.total_ordering + class ImplementsLessThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value == other.value + return False + def __le__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value <= other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value == other.value + return False + def __ge__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value >= other.value + return NotImplemented + + @functools.total_ordering + class ComparatorNotImplemented: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ComparatorNotImplemented): + return self.value == other.value + return False + def __lt__(self, other): + return NotImplemented + + with self.subTest("LT < 1"), self.assertRaises(TypeError): + ImplementsLessThan(-1) < 1 + + with self.subTest("LT < LE"), self.assertRaises(TypeError): + ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) + + with self.subTest("LT < GT"), self.assertRaises(TypeError): + ImplementsLessThan(1) < ImplementsGreaterThan(1) + + with self.subTest("LE <= LT"), self.assertRaises(TypeError): + ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) + + with self.subTest("LE <= GE"), self.assertRaises(TypeError): + ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) + + with self.subTest("GT > GE"), self.assertRaises(TypeError): + ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) + + with self.subTest("GT > LT"), self.assertRaises(TypeError): + ImplementsGreaterThan(5) > ImplementsLessThan(5) + + with self.subTest("GE >= GT"), self.assertRaises(TypeError): + ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) + + with self.subTest("GE >= LE"), self.assertRaises(TypeError): + ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) + + with self.subTest("GE when equal"): + a = ComparatorNotImplemented(8) + b = ComparatorNotImplemented(8) + self.assertEqual(a, b) + with self.assertRaises(TypeError): + a >= b + + with self.subTest("LE when equal"): + a = ComparatorNotImplemented(9) + b = ComparatorNotImplemented(9) + self.assertEqual(a, b) + with self.assertRaises(TypeError): + a <= b + +# Leaving pickle support for a later date +# def test_pickle(self): +# for proto in range(pickle.HIGHEST_PROTOCOL + 1): +# for name in '__lt__', '__gt__', '__le__', '__ge__': +# with self.subTest(method=name, proto=proto): +# method = getattr(Orderable_LT, name) +# method_copy = pickle.loads(pickle.dumps(method, proto)) +# self.assertIs(method_copy, method) + +# @functools.total_ordering +# class Orderable_LT: +# def __init__(self, value): +# self.value = value +# def __lt__(self, other): +# return self.value < other.value +# def __eq__(self, other): +# return self.value == other.value + +if __name__ == "__main__": + unittest.main() +