Merge pull request #16826 from mattjj:issue16805

PiperOrigin-RevId: 551263673
This commit is contained in:
jax authors 2023-07-26 11:20:31 -07:00
commit 416814df2a
2 changed files with 11 additions and 3 deletions

View File

@ -2608,7 +2608,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
(lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim),
(np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim),
dimension_numbers)
# TODO Should probably check that any ragged dimensions have corresponding
# sizes, because otherwise the dot product is technically undefined.
@ -2619,12 +2619,12 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd)
lhs_shape = batching.bdim_as_shape(lbd, lhs.shape)
else:
lhs_shape = lhs.shape
lhs_shape = np.shape(lhs)
if type(rbd) is RaggedAxis:
rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd)
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = rhs.shape
rhs_shape = np.shape(rhs)
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type)

View File

@ -2673,6 +2673,14 @@ class LaxTest(jtu.JaxTestCase):
np.testing.assert_equal(
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
def test_dot_general_batching_python_builtin_arg(self):
# https://github.com/google/jax/issues/16805
@jax.remat
def f(x):
return jax.lax.dot_general(x, x, (([], []), ([], [])))
jax.hessian(f)(1.0) # don't crash
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):