Merge pull request #18551 from mattjj:reshape-error-message

PiperOrigin-RevId: 582876150
This commit is contained in:
jax authors 2023-11-15 19:00:00 -08:00
commit aa35e6395f
2 changed files with 25 additions and 7 deletions

View File

@ -23,6 +23,7 @@ __all__ = ['register_jax_array_methods']
import abc
from functools import partial, wraps
import math
from typing import Any, Optional, Union
import numpy as np
@ -115,22 +116,35 @@ def _transpose(a: Array, *args: Any) -> Array:
def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape:
"""Fixes a -1 value in newshape, if present."""
# other errors, like having more than one -1, are caught downstream, in
# reshape_shape_rule.
orig_newshape = newshape # for error messages
try:
iter(newshape) # type: ignore[arg-type]
except:
newshape = [newshape]
newshape = core.canonicalize_shape(newshape) # type: ignore[arg-type]
neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
if len(neg1s) == 1:
if len(neg1s) > 1:
raise ValueError("can only specify one unknown axis size with a `-1` value, "
f"got {orig_newshape}")
if neg1s:
i, = neg1s
sz = core.cancel_divide_tracers(np.shape(a), (*newshape[:i], *newshape[i+1:]))
other_sizes = (*newshape[:i], *newshape[i+1:])
if (all(isinstance(d, int) for d in (*np.shape(a), *other_sizes)) and
np.size(a) % math.prod(other_sizes) != 0):
raise ValueError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
f"into shape {orig_newshape} because the product of "
f"specified axis sizes ({math.prod(other_sizes)}) does "
f"not evenly divide {np.size(a)}")
sz = core.cancel_divide_tracers(np.shape(a), other_sizes)
if sz is not None:
return (*newshape[:i], sz, *newshape[i+1:])
else:
if (all(isinstance(d, int) for d in (*np.shape(a), *newshape)) and
np.size(a) != math.prod(newshape)):
raise ValueError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
f"into shape {orig_newshape} (size {math.prod(newshape)})")
return tuple(-core.divide_shape_sizes(np.shape(a), newshape)
if core.definitely_equal(d, -1) else d
for d in newshape)
if core.definitely_equal(d, -1) else d for d in newshape)
def _reshape(a: Array, *args: Any, order: str = "C") -> Array:
@ -138,6 +152,7 @@ def _reshape(a: Array, *args: Any, order: str = "C") -> Array:
Refer to :func:`jax.numpy.reshape` for full documentation.
"""
__tracebackhide__ = True
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
if order == "C":
return lax.reshape(a, newshape, None)
@ -731,6 +746,7 @@ def _forward_operator_to_aval(name):
def _forward_method_to_aval(name):
def meth(self, *args, **kwargs):
__tracebackhide__ = True
return getattr(self.aval, name).fun(self, *args, **kwargs)
return meth

View File

@ -780,12 +780,14 @@ def isrealobj(x: Any) -> bool:
@util._wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
__tracebackhide__ = True
util.check_arraylike("reshape", a)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
except AttributeError:
return asarray(a).reshape(newshape, order=order)
pass
return asarray(a).reshape(newshape, order=order)
@util._wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)