diff --git a/thrift/lib/py3/test/auto_migrate/sets.py b/thrift/lib/py3/test/auto_migrate/sets.py index cf2e10f3d89..cbf07a618b4 100644 --- a/thrift/lib/py3/test/auto_migrate/sets.py +++ b/thrift/lib/py3/test/auto_migrate/sets.py @@ -19,6 +19,8 @@ import unittest from typing import AbstractSet, Sequence, Tuple +import thrift.python.types as python_types + from testing.types import ( Color, ColorGroups, @@ -222,23 +224,27 @@ def test_hashability(self) -> None: for sub_set in z: hash(sub_set) + def test_set_op_return_type(self) -> None: + x = SetI32({1, 3, 4, 5}) + y = SetI32({1, 2, 4, 6}) + expected_type = python_types.Set if is_auto_migrated() else SetI32 + + self.assertIsInstance(x & y, expected_type) + self.assertIsInstance(x | y, expected_type) + self.assertIsInstance(x ^ y, expected_type) + self.assertIsInstance(x - y, expected_type) + + self.assertIsInstance(x.__rand__(y), expected_type) + self.assertIsInstance(x.__ror__(y), expected_type) + self.assertIsInstance(x.__rxor__(y), expected_type) + self.assertIsInstance(x.__rsub__(y), expected_type) + @brokenInAutoMigrate() def test_is_container(self) -> None: self.assertIsInstance(SetI32Lists(), Container) self.assertIsInstance(SetSetI32Lists(), Container) self.assertIsInstance(SetI32(), Container) - # in thrift-python, we return the frozenset directly, - # which is is probably not intended. - @brokenInAutoMigrate() - def test_set_op_return_type(self) -> None: - x = SetI32({1, 3, 4, 5}) - y = SetI32({1, 2, 4, 6}) - self.assertIs(type(x & y), SetI32) - self.assertIs(type(x | y), SetI32) - self.assertIs(type(x ^ y), SetI32) - self.assertIs(type(x - y), SetI32) - @brokenInAutoMigrate() def test_set_op_type_error_thrift_set(self) -> None: x = SetI32({1, 3, 4, 5}) diff --git a/thrift/lib/python/test/mutable_set_test.py b/thrift/lib/python/test/mutable_set_test.py index 59e0de38d47..6f709354957 100644 --- a/thrift/lib/python/test/mutable_set_test.py +++ b/thrift/lib/python/test/mutable_set_test.py @@ -278,6 +278,15 @@ def test_remove_i32_overflow(self) -> None: with self.assertRaises(OverflowError): mutable_set.remove(2**31) + def test_set_op_return_type(self) -> None: + mutable_set_1 = _create_MutableSet_i32(range(4)) + mutable_set_2 = _create_MutableSet_i32(range(2, 6)) + + self.assertIsInstance(mutable_set_1 & mutable_set_2, MutableSet) + self.assertIsInstance(mutable_set_1 | mutable_set_2, MutableSet) + self.assertIsInstance(mutable_set_1 ^ mutable_set_2, MutableSet) + self.assertIsInstance(mutable_set_1 - mutable_set_2, MutableSet) + def test_pop(self) -> None: mutable_set = _create_MutableSet_i32(range(3)) diff --git a/thrift/lib/python/types.pyx b/thrift/lib/python/types.pyx index 298e65632e8..962061444e2 100644 --- a/thrift/lib/python/types.pyx +++ b/thrift/lib/python/types.pyx @@ -2070,25 +2070,25 @@ cdef class Set(Container): return hash(self._fbthrift_elements) def __and__(Set self, other): - return self._fbthrift_elements & other + return Set(self._fbthrift_val_info, self._fbthrift_elements & other) def __rand__(Set self, other): return other & self._fbthrift_elements def __sub__(Set self, other): - return self._fbthrift_elements - other + return Set(self._fbthrift_val_info, self._fbthrift_elements - other) def __rsub__(Set self, other): return other - self._fbthrift_elements def __or__(Set self, other): - return self._fbthrift_elements | other + return Set(self._fbthrift_val_info, self._fbthrift_elements | other) def __ror__(Set self, other): return other | self._fbthrift_elements def __xor__(Set self, other): - return self._fbthrift_elements ^ other + return Set(self._fbthrift_val_info, self._fbthrift_elements ^ other) def __rxor__(Set self, other): return other ^ self._fbthrift_elements