Merge pull request #15694 from mattjj:djax-reshape

PiperOrigin-RevId: 526194423
This commit is contained in:
jax authors 2023-04-21 19:42:27 -07:00
commit 13fe3810d2
3 changed files with 98 additions and 12 deletions

View File

@ -48,7 +48,8 @@ from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, as_hashable_function,
HashableFunction, HashableWrapper, weakref_lru_cache)
HashableFunction, HashableWrapper, weakref_lru_cache,
partition_list)
import jax._src.pretty_printer as pp
from jax._src.lib import jax_jit
from jax._src import traceback_util
@ -1211,6 +1212,9 @@ def same_referent(x: Any, y: Any) -> bool:
def dedup_referents(itr: Iterable[Any]) -> List[Any]:
return list({HashableWrapper(get_referent(x)):x for x in itr}.values())
def definitely_equal(x, y):
return x is y or same_referent(x, y) or symbolic_equal_dim(x, y)
# -------------------- abstract values --------------------
@ -1479,7 +1483,9 @@ class ShapedArray(UnshapedArray):
return ShapedArray(shape, dtype, weak_type, named_shape)
ndim = property(lambda self: len(self.shape))
size = property(lambda self: math.prod(self.shape))
size = property(lambda self:
0 if any(type(d) is int and d == 0 for d in self.shape)
else math.prod(self.shape))
broadcast: ClassVar[Optional[aval_method]] = None
transpose: ClassVar[Optional[aval_method]] = None
@ -1626,7 +1632,9 @@ class DShapedArray(UnshapedArray):
self.weak_type = weak_type
ndim = property(lambda self: len(self.shape))
size = property(lambda self: math.prod(self.shape))
size = property(lambda self:
0 if any(type(d) is int and d == 0 for d in self.shape)
else math.prod(self.shape))
def str_short(self, short_dtypes=False) -> str:
del short_dtypes # ignored
@ -1903,7 +1911,7 @@ def is_constant_shape(s: Shape) -> bool:
return all(is_constant_dim(d) for d in s)
def symbolic_equal_dim(d1: DimSize, d2: DimSize) -> bool:
if d1 is d2 or get_referent(d1) is get_referent(d2): return True
if d1 is d2 or same_referent(d1, d2): return True
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.symbolic_equal(*ds)
@ -1946,8 +1954,32 @@ def divide_shape_sizes(s1: Shape, s2: Shape) -> DimSize:
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
def same_shape_sizes(s1: Shape, s2: Shape) -> bool:
maybe_result = cancel_divide_tracers(s1, s2)
if maybe_result is not None: return maybe_result == 1
return 1 == divide_shape_sizes(s1, s2)
def cancel_divide_tracers(num, denom):
partition = lambda l: partition_list([isinstance(d, Tracer) for d in l], l)
num, num_tracers = partition(num)
denom, denom_tracers = partition(denom)
if num_tracers or denom_tracers:
factor = _cancel_divide(num_tracers, denom_tracers)
if factor is not None:
size1 = math.prod(num)
size2 = math.prod(denom)
if size1 == size2 or size2 != 0:
return factor * (size1 // size2 if size1 != size2 else 1)
def _cancel_divide(num, denom):
num = list(num)
for a in denom:
i = next((i for i, b in enumerate(num) if definitely_equal(a, b)), None)
if i is None:
break # couldn't cancel
del num[i]
else:
return math.prod(num)
def is_empty_shape(s: Shape) -> bool:
return any(symbolic_equal_dim(d, 0) for d in s)

View File

@ -41,7 +41,10 @@ from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape
from jax._src.util import safe_zip
from jax._src.util import safe_zip, safe_map
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
### add method and operator overloads to arraylike classes
@ -119,11 +122,15 @@ def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape:
try:
iter(newshape) # type: ignore[arg-type]
except:
iterable = False
else:
iterable = True
newshape = core.canonicalize_shape(newshape if iterable else [newshape]) # type: ignore[arg-type]
return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
newshape = [newshape]
newshape = core.canonicalize_shape(newshape) # type: ignore[arg-type]
neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
if len(neg1s) == 1:
i, = neg1s
sz = core.cancel_divide_tracers(np.shape(a), (*newshape[:i], *newshape[i+1:]))
if sz is not None:
return (*newshape[:i], sz, *newshape[i+1:])
return tuple(-core.divide_shape_sizes(np.shape(a), newshape)
if core.symbolic_equal_dim(d, -1) else d
for d in newshape)
@ -337,7 +344,7 @@ def _multi_slice(arr: ArrayLike,
DeviceArray method here to avoid circular imports.
"""
results: List[Array] = []
for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
for starts, limits, removed in zip(start_indices, limit_indices, removed_dims):
sliced = lax.slice(arr, starts, limits)
if removed:
sliced = lax.squeeze(sliced, removed)
@ -472,7 +479,7 @@ def allow_pass_by_position_with_warning(f):
warnings.warn(
f"jnp.ndarray.at[...].{f.__name__}: Passing '{keywords[0]}' by position is deprecated. "
f"Pass by keyword instead", category=FutureWarning, stacklevel=2)
converted_kwargs = dict(zip(keywords, args[n_positional:]))
converted_kwargs = dict(unsafe_zip(keywords, args[n_positional:]))
return f(*args[:n_positional], **converted_kwargs, **kwargs)
else:
return f(*args, **kwargs)

View File

@ -582,6 +582,53 @@ class DynamicShapeStagingTest(jtu.JaxTestCase):
self.assertIsInstance(three_, int)
self.assertEqual(three_, 3)
def test_zero_size_checking(self):
def f(x):
if core.definitely_equal(x.size, 0):
return x
else:
return -x
x = jnp.zeros(1)
jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x) # doesn't crash
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 1)
y = jnp.zeros((2, 0))
jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(y) # doesn't crash
self.assertLen(jaxpr.jaxpr.eqns, 0)
def test_flattening_basic(self):
x = jnp.zeros((2, 3, 4, 5))
# don't need to divide or multiply any dynamic axis sizes
jaxpr = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], -1),
abstracted_axes={0: 'n'})(x)
self.assertLen(jaxpr.jaxpr.eqns, 1)
jaxpr = jax.make_jaxpr(lambda x: x.reshape(3, x.shape[0], -1),
abstracted_axes={0: 'n'})(x)
self.assertLen(jaxpr.jaxpr.eqns, 1)
jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, x.shape[0]),
abstracted_axes={0: 'n'})(x)
self.assertLen(jaxpr.jaxpr.eqns, 1)
# don't need to divide but do need a dynamic axis size in multiplication
# (so to typecheck we'd need nontrivial reductions)
jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1),
abstracted_axes={0: 'n'})(x)
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) # may have mul with 1
self.assertEqual(str(jaxpr.jaxpr.eqns[-2].primitive), 'mul')
self.assertEqual(str(jaxpr.jaxpr.eqns[-1].primitive), 'reshape')
jaxpr = jax.make_jaxpr(lambda x: x.reshape(2, -1),
abstracted_axes={0: 'n'})(x)
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3)
jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, 12), abstracted_axes={0: 'n'})(x)
self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3)
# do need divide, also shouldn't typecheck
_ = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], x.shape[0], -1),
abstracted_axes={0: 'n'})(x) # don't crash
@unittest.skip("Test does not work with jax.Array")
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
class DynamicShapeAutodiffTest(jtu.JaxTestCase):