Merge pull request #16386 from axch:ragged-einsum

PiperOrigin-RevId: 542887557
This commit is contained in:
jax authors 2023-06-23 10:00:07 -07:00
commit 63415a9184
6 changed files with 111 additions and 14 deletions

View File

@ -1238,7 +1238,10 @@ def dedup_referents(itr: Iterable[Any]) -> List[Any]:
return list({HashableWrapper(get_referent(x)):x for x in itr}.values())
def definitely_equal(x, y):
return x is y or same_referent(x, y) or symbolic_equal_dim(x, y)
if isinstance(x, Tracer) or isinstance(y, Tracer):
return same_referent(x, y)
else:
return symbolic_equal_dim(x, y)
# -------------------- abstract values --------------------
@ -1968,6 +1971,10 @@ def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
return (len(s1) == len(s2) and
all(unsafe_map(symbolic_equal_dim, s1, s2)))
def definitely_equal_shape(s1: Shape, s2: Shape) -> bool:
return (len(s1) == len(s2) and
all(unsafe_map(definitely_equal, s1, s2)))
def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.symbolic_equal(*ds) or handler.greater_equal(*ds)

View File

@ -134,6 +134,10 @@ class RaggedAxis:
new_ragged_axes = [(move_axis(ax), sizes) for ax, sizes in self.ragged_axes]
return RaggedAxis(dst, new_ragged_axes)
def transpose_ragged_axes(self, perm):
new_ragged_axes = [(perm[ax], size) for ax, size in self.ragged_axes]
return RaggedAxis(self.stacked_axis, new_ragged_axes)
def make_batch_axis(
ndim: int, stacked_axis: int, ragged_axes: List[Tuple[int, Array]]
) -> Union[int, RaggedAxis]:
@ -143,6 +147,24 @@ def make_batch_axis(
else:
return canonicalize_axis(stacked_axis, ndim)
def bdim_as_shape(
bdim: Union[int, RaggedAxis], data_shape: core.Shape) -> core.Shape:
if isinstance(bdim, RaggedAxis):
result = list(data_shape)
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
for ragged_axis, segment_lens in bdim.ragged_axes:
result[ragged_axis] = IndexedAxisSize(binder, segment_lens)
return tuple(result)
else:
return data_shape
def shape_as_bdim(
stacked_axis: int, data_shape: core.Shape) -> Union[int, RaggedAxis]:
# This assumes that there is only one binder in the data_shape.
ragged_axes = [(i, size.lengths) for i, size in enumerate(data_shape)
if isinstance(size, IndexedAxisSize)]
return make_batch_axis(len(data_shape), stacked_axis, ragged_axes)
def _update_annotation(
f: lu.WrappedFun, orig_type: Optional[core.InputType],

View File

@ -114,10 +114,10 @@ def _try_broadcast_shapes(
result_shape.append(ds[0])
else:
# if all dims are equal (or 1), the result is the non-1 size (or 1)
non_1s = [d for d in ds if not core.symbolic_equal_dim(d, 1)]
non_1s = [d for d in ds if not core.definitely_equal(d, 1)]
if not non_1s:
result_shape.append(1)
elif all(core.symbolic_equal_dim(non_1s[0], d) for d in non_1s[1:]):
elif all(core.definitely_equal(non_1s[0], d) for d in non_1s[1:]):
result_shape.append(non_1s[0])
else:
return None
@ -174,7 +174,7 @@ def _broadcast_ranks(s1, s2):
s1, s2 = s2, s1
assert len(s1) <= len(s2)
s1_ = s2[len(s2) - len(s1):]
if core.symbolic_equal_shape(s1_, s1): return s2
if core.definitely_equal_shape(s1_, s1): return s2
else: raise ValueError
def _identity(x): return x
@ -2600,25 +2600,47 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim), dimension_numbers)
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
(lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim),
dimension_numbers)
# TODO Should probably check that any ragged dimensions have corresponding
# sizes, because otherwise the dot product is technically undefined.
#
# This masking is not strictly necessary for non-contraction dimensions;
# we could micro-optimize here by avoiding computing that mask.
if type(lbd) is RaggedAxis:
lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd)
lhs_shape = batching.bdim_as_shape(lbd, lhs.shape)
else:
lhs_shape = lhs.shape
if type(rbd) is RaggedAxis:
rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd)
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = rhs.shape
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
return batched_out, result_batch_dim
def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
# there are three kinds of dimensions in a dot_general:
# There are three kinds of dimensions in a dot_general:
# - contraction dimensions appear in lhs and rhs but not the result
# - batch dimensions appear in lhs, rhs, and result
# - tensor product dimensions appear in the result and one of lhs or rhs
# The dimensions of the result are ordered as
# - Batch dimensions
# - Q: In what order? The order of appearance in lhs, rhs, or
# dimension_numbers?
# - Tensor dimensions from the LHS
# - Tensor dimensions from the RHS
lhs_ndim, rhs_ndim = ndims
# lbd and rbd are "batch" dimensions in the sense of dimensions being
# vmapped, not to be confused with "batch" dimensions in the sense of
# explicitly present dimensions that this dot_general is zipping together.
lbd, rbd = batch_dims
assert lbd is not None or rbd is not None
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
@ -2627,19 +2649,24 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
return tuple(np.add(dims, np.greater_equal(dims, b)))
if type(lbd) is type(rbd) is int:
# adding a batch dimension
# The vmapped dimensions become an additional batch dimension in the
# batched dot_general, which we arbitrarily put first.
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
lhs_contract = bump_dims(lhs_contract, lbd)
rhs_contract = bump_dims(rhs_contract, rbd)
result_batch_dim = 0
elif (type(lbd) is int and rbd is None):
# The left vmapped dimension becomes an additional tensor dimension in the
# batched dot_general.
lhs_tensor = [d for d in range(lhs_ndim)
if d not in lhs_batch and d not in lhs_contract]
result_batch_dim = len(lhs_batch) + int(sum(np.less(lhs_tensor, lbd)))
lhs_batch = bump_dims(lhs_batch, lbd)
lhs_contract = bump_dims(lhs_contract, lbd)
elif (type(rbd) is int and lbd is None):
# The right vmapped dimension becomes an additional tensor dimension in the
# batched dot_general.
rhs_tensor = [d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
result_batch_dim = (lhs_ndim - len(lhs_contract) +
@ -2647,6 +2674,7 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
rhs_batch = bump_dims(rhs_batch, rbd)
rhs_contract = bump_dims(rhs_contract, rbd)
else:
# We wouldn't be here if we didn't have at least one vmapped dimension.
assert False
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
@ -3374,8 +3402,13 @@ def _transpose_shape_rule(operand, *, permutation):
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
operand, = batched_args
bdim, = batch_dims
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
return transpose(operand, perm), 0
stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
perm = (stack_dim,) + tuple(i if i < stack_dim else i+1 for i in permutation)
if isinstance(bdim, RaggedAxis):
result_bdim = bdim.move_stacked_axis(0).transpose_ragged_axes(perm)
else:
result_bdim = 0
return transpose(operand, perm), result_bdim
def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out

View File

@ -3305,7 +3305,7 @@ def _einsum(
return operand, names
def filter_singleton_dims(operand, names, other_shape, other_names):
eq = core.symbolic_equal_dim
eq = core.definitely_equal
keep = [not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1)
for i, j in enumerate(map(other_names.find, names))]
sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim)))

View File

@ -389,7 +389,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
shape = (shape,)
shape = core.canonicalize_shape(shape) # check that shape is concrete
arr_shape = np.shape(arr)
if core.symbolic_equal_shape(arr_shape, shape):
if core.definitely_equal_shape(arr_shape, shape):
return arr
else:
nlead = len(shape) - len(arr_shape)
@ -399,7 +399,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
diff, = np.where(tuple(not core.definitely_equal(arr_d, shape_d)
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))

View File

@ -1526,7 +1526,7 @@ class PileTest(jtu.JaxTestCase):
self.assertIsInstance(y, batching.Pile)
self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32'))
def test_pile_map_matrix_dot(self):
def test_pile_map_matrix_dot_ragged_contract(self):
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis
)(sizes)
@ -1537,6 +1537,17 @@ class PileTest(jtu.JaxTestCase):
self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)),
check_dtypes=False)
def test_pile_map_matrix_dot_ragged_tensor(self):
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
lhs_one_d = jnp.arange(size, dtype='int32') + 1
lhs_two_d = jax.lax.broadcast_in_dim(lhs_one_d, (size, 2), (0,))
rhs = jax.lax.broadcasted_iota('int32', (2, 4), 0) + 1
return jnp.dot(lhs_two_d, rhs)
p = jax.vmap(func, out_axes=batching.pile_axis)(sizes)
self.assertIsInstance(p, batching.Pile)
self.assertEqual(p.data.shape, (3, 5, 4))
def test_broadcast_in_dim_while_ragged(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
@ -1605,6 +1616,30 @@ class PileTest(jtu.JaxTestCase):
data = jax.lax.broadcasted_iota('int32', (3, 5, 5), 2)
self.assertAllClose(p.data, data)
def test_transpose_ragged(self):
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def func(size):
one_d = jnp.arange(size, dtype='int32')
two_d = jnp.broadcast_to(one_d, (7, size))
return jnp.transpose(two_d, [1, 0])
p = jax.vmap(func, out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
self.assertAllClose(p.data, data)
def test_ragged_einsum(self):
x_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def fprop_layer(x_size):
one_d = jnp.arange(x_size, dtype='int32')
x = jax.lax.broadcast_in_dim(one_d, (x_size, 11), [0])
wqkv = jax.lax.broadcasted_iota('int32', (3, 2, 7, 11), 1)
qkv = jnp.einsum('te,ihqe->ithq', x, wqkv)
return qkv
p = jax.vmap(fprop_layer, out_axes=batching.pile_axis)(x_sizes)
self.assertIsInstance(p, batching.Pile)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
self.assertEqual(p.data.shape, (3, 3, 5, 2, 7))
def pile_map(f):
def mapped(*piles):
return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,