mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18551 from mattjj:reshape-error-message
PiperOrigin-RevId: 582876150
This commit is contained in:
commit
aa35e6395f
@ -23,6 +23,7 @@ __all__ = ['register_jax_array_methods']
|
||||
|
||||
import abc
|
||||
from functools import partial, wraps
|
||||
import math
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -115,22 +116,35 @@ def _transpose(a: Array, *args: Any) -> Array:
|
||||
|
||||
def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape:
|
||||
"""Fixes a -1 value in newshape, if present."""
|
||||
# other errors, like having more than one -1, are caught downstream, in
|
||||
# reshape_shape_rule.
|
||||
orig_newshape = newshape # for error messages
|
||||
try:
|
||||
iter(newshape) # type: ignore[arg-type]
|
||||
except:
|
||||
newshape = [newshape]
|
||||
newshape = core.canonicalize_shape(newshape) # type: ignore[arg-type]
|
||||
neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
|
||||
if len(neg1s) == 1:
|
||||
if len(neg1s) > 1:
|
||||
raise ValueError("can only specify one unknown axis size with a `-1` value, "
|
||||
f"got {orig_newshape}")
|
||||
if neg1s:
|
||||
i, = neg1s
|
||||
sz = core.cancel_divide_tracers(np.shape(a), (*newshape[:i], *newshape[i+1:]))
|
||||
other_sizes = (*newshape[:i], *newshape[i+1:])
|
||||
if (all(isinstance(d, int) for d in (*np.shape(a), *other_sizes)) and
|
||||
np.size(a) % math.prod(other_sizes) != 0):
|
||||
raise ValueError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
|
||||
f"into shape {orig_newshape} because the product of "
|
||||
f"specified axis sizes ({math.prod(other_sizes)}) does "
|
||||
f"not evenly divide {np.size(a)}")
|
||||
sz = core.cancel_divide_tracers(np.shape(a), other_sizes)
|
||||
if sz is not None:
|
||||
return (*newshape[:i], sz, *newshape[i+1:])
|
||||
else:
|
||||
if (all(isinstance(d, int) for d in (*np.shape(a), *newshape)) and
|
||||
np.size(a) != math.prod(newshape)):
|
||||
raise ValueError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
|
||||
f"into shape {orig_newshape} (size {math.prod(newshape)})")
|
||||
return tuple(-core.divide_shape_sizes(np.shape(a), newshape)
|
||||
if core.definitely_equal(d, -1) else d
|
||||
for d in newshape)
|
||||
if core.definitely_equal(d, -1) else d for d in newshape)
|
||||
|
||||
|
||||
def _reshape(a: Array, *args: Any, order: str = "C") -> Array:
|
||||
@ -138,6 +152,7 @@ def _reshape(a: Array, *args: Any, order: str = "C") -> Array:
|
||||
|
||||
Refer to :func:`jax.numpy.reshape` for full documentation.
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
|
||||
if order == "C":
|
||||
return lax.reshape(a, newshape, None)
|
||||
@ -731,6 +746,7 @@ def _forward_operator_to_aval(name):
|
||||
|
||||
def _forward_method_to_aval(name):
|
||||
def meth(self, *args, **kwargs):
|
||||
__tracebackhide__ = True
|
||||
return getattr(self.aval, name).fun(self, *args, **kwargs)
|
||||
return meth
|
||||
|
||||
|
@ -780,12 +780,14 @@ def isrealobj(x: Any) -> bool:
|
||||
|
||||
@util._wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
|
||||
def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
|
||||
__tracebackhide__ = True
|
||||
util.check_arraylike("reshape", a)
|
||||
try:
|
||||
# forward to method for ndarrays
|
||||
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
|
||||
except AttributeError:
|
||||
return asarray(a).reshape(newshape, order=order)
|
||||
pass
|
||||
return asarray(a).reshape(newshape, order=order)
|
||||
|
||||
|
||||
@util._wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
|
||||
|
Loading…
x
Reference in New Issue
Block a user