Skip to content

Commit

Permalink
Change pack_tensor to do unsafe casting when an explicit dtype is p…
Browse files Browse the repository at this point in the history
…rovided.

This fixes issues when running with NumPy 2.0.

Previously, we would cast using `same_kind`, which mean't only "safe" casts or casts within a kind, like float64 to float32, were allowed. In NumPy 2.0, the meaning of `same_kind` changed such that unsigned and signed integers are no longer deemed the same.

Considering it was still possible for data loss even when using same_kind casting, we propose changing to use `unsafe` when an explicit dtype is provided, requiring the end user to ensure the value being packed in compatible with the dtype.

PiperOrigin-RevId: 678408939
Change-Id: I1770714fb23335f5ee1a9ae5b7f6263394bb8297
  • Loading branch information
tomwardio authored and copybara-github committed Sep 24, 2024
1 parent b6ba877 commit c6b8435
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
6 changes: 1 addition & 5 deletions dm_env_rpc/v1/spec_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,9 @@ def test_pack_wrong_shape_raises_error(self):
self._spec_manager.pack({'foo': [1, 2]})

def test_pack_wrong_dtype_raises_error(self):
with self.assertRaisesRegex(TypeError, 'int32'):
with self.assertRaises(ValueError):
self._spec_manager.pack({'foo': 'hello'})

def test_pack_cast_float_to_int_raises_error(self):
with self.assertRaisesRegex(TypeError, 'int32'):
self._spec_manager.pack({'foo': [0.5, 1.0, 1]})

def test_pack_cast_int_to_float_is_ok(self):
packed = self._spec_manager.pack({'fuzz': [1, 2]})
self.assertEqual([1.0, 2.0], packed[54].floats.array)
Expand Down
2 changes: 1 addition & 1 deletion dm_env_rpc/v1/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def pack_tensor(
value = value.astype(
dtype=_DM_ENV_RPC_DTYPE_TO_NUMPY_DTYPE.get(dtype, dtype),
copy=False,
casting='same_kind' if value.size else 'unsafe')
casting='unsafe')

packed.shape[:] = value.shape
packer = get_packer(value.dtype.type)
Expand Down
2 changes: 1 addition & 1 deletion dm_env_rpc/v1/tensor_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_packed_rowmajor(self):
np.testing.assert_array_equal([1, 2, 3, 4, 5, 6], tensor.int32s.array)

def test_mixed_scalar_types_raises_exception(self):
with self.assertRaises(TypeError):
with self.assertRaises(ValueError):
tensor_utils.pack_tensor(['hello!', 75], dtype=np.float32)

def test_jagged_arrays_throw_exceptions(self):
Expand Down

0 comments on commit c6b8435

Please sign in to comment.