Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matmul raises IndexError exception with input shapes (2, 1, 2) and (2,) #2264

Open
antonwolfy opened this issue Jan 15, 2025 · 0 comments
Open
Assignees

Comments

@antonwolfy
Copy link
Contributor

The below example causes an issue:

import dpnp, numpy

dpnp.__version__
# Out: '0.17.0dev4+3.g498e705d848.dirty'

a = dpnp.ones((2, 1, 2))
b = dpnp.ones((2,))

a @ b
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[10], line 1
----> 1 a @ b

File ~/code/dpnp/dpnp/dpnp_array.py:489, in dpnp_array.__matmul__(self, other)
    487 def __matmul__(self, other):
    488     """Return ``self@value``."""
--> 489     return dpnp.matmul(self, other)

File ~/code/dpnp/dpnp/dpnp_iface_linearalgebra.py:851, in matmul(x1, x2, out, casting, order, dtype, subok, signature, axes, axis)
    846 if axis is not None:
    847     raise NotImplementedError(
    848         "axis keyword argument is only supported by its default value."
    849     )
--> 851 return dpnp_matmul(
    852     x1,
    853     x2,
    854     out=out,
    855     casting=casting,
    856     order=order,
    857     dtype=dtype,
    858     axes=axes,
    859 )

File ~/code/dpnp/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py:951, in dpnp_matmul(x1, x2, out, casting, order, dtype, axes)
    949         else:  # call_flag == "gemm_batch"
    950             assert call_flag == "gemm_batch"
--> 951             result = _gemm_batch_matmul(
    952                 exec_q,
    953                 x1,
    954                 x2,
    955                 result,
    956             )
    958 if NumPy_special_behavior:
    959     result = dpnp.tile(result, out.shape)

File ~/code/dpnp/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py:373, in _gemm_batch_matmul(exec_q, x1, x2, res)
    371 x2_shape = x2.shape
    372 x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1]))
--> 373 x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1]))
    374 orig_shape = res.shape
    375 res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1]))

IndexError: tuple index out of range

# works with numpy:
na, nb = a.asnumpy(), b.asnumpy()
na @ nb
# Out:
# array([[2.],
#        [2.]])
@antonwolfy antonwolfy changed the title matmul raises unexpected IndexError exception with input shapes (2, 1, 2) and (2,) matmul raises IndexError exception with input shapes (2, 1, 2) and (2,) Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants