mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #16826 from mattjj:issue16805
PiperOrigin-RevId: 551263673
This commit is contained in:
commit
416814df2a
@ -2608,7 +2608,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
|
||||
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
|
||||
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
|
||||
(lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim),
|
||||
(np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim),
|
||||
dimension_numbers)
|
||||
# TODO Should probably check that any ragged dimensions have corresponding
|
||||
# sizes, because otherwise the dot product is technically undefined.
|
||||
@ -2619,12 +2619,12 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd)
|
||||
lhs_shape = batching.bdim_as_shape(lbd, lhs.shape)
|
||||
else:
|
||||
lhs_shape = lhs.shape
|
||||
lhs_shape = np.shape(lhs)
|
||||
if type(rbd) is RaggedAxis:
|
||||
rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd)
|
||||
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
|
||||
else:
|
||||
rhs_shape = rhs.shape
|
||||
rhs_shape = np.shape(rhs)
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
@ -2673,6 +2673,14 @@ class LaxTest(jtu.JaxTestCase):
|
||||
np.testing.assert_equal(
|
||||
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])
|
||||
|
||||
def test_dot_general_batching_python_builtin_arg(self):
|
||||
# https://github.com/google/jax/issues/16805
|
||||
@jax.remat
|
||||
def f(x):
|
||||
return jax.lax.dot_general(x, x, (([], []), ([], [])))
|
||||
|
||||
jax.hessian(f)(1.0) # don't crash
|
||||
|
||||
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
|
Loading…
x
Reference in New Issue
Block a user