[sharding_in_types] Add support for svd_p

PiperOrigin-RevId: 720409750
This commit is contained in:
Yash Katariya 2025-01-27 20:29:25 -08:00 committed by jax authors
parent 24987a90dc
commit ae705fef9c
5 changed files with 156 additions and 36 deletions

View File

@ -1731,14 +1731,20 @@ def _invalid_shape_error(shape: Shape, context: str=""):
return TypeError(msg)
def _make_lengths_same(sharding, ndim):
if ndim > len(sharding.spec):
return sharding.with_spec(sharding.spec._normalized_spec(ndim))
if ndim < len(sharding.spec):
return sharding.with_spec(sharding.spec[:ndim])
assert False, "unreachable"
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def modify_spec_for_auto_manual(spec, mesh) -> P:
if all(s is None for s in spec):
return spec
new_spec = [] # type: ignore
for s in spec:
if s is None:
if not s:
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
@ -1748,22 +1754,29 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
else s)
return P(*new_spec)
def _maybe_modify_sharding(sharding):
def _maybe_modify_sharding(sharding, ndim):
if sharding.mesh._are_all_axes_explicit:
return sharding
new_spec = modify_spec_for_auto_manual(sharding.spec, sharding.mesh)
return sharding.with_spec(new_spec)
out = sharding
elif all(s is None for s in sharding.spec):
out = sharding
else:
out = sharding.with_spec(modify_spec_for_auto_manual(
sharding.spec, sharding.mesh))
if (len(out.spec) != ndim and
(out.mesh._are_all_axes_auto or out.mesh._are_all_axes_manual)):
out = _make_lengths_same(out, ndim)
return out
def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding # type: ignore
if sharding is not None:
if len(sharding.spec) != ndim:
out_s = _maybe_modify_sharding(sharding, ndim)
if len(out_s.spec) != ndim:
raise ValueError(
"Length of sharding.spec must be equal to aval's ndim. Got"
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
out_s = _maybe_modify_sharding(sharding)
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
else:
context_mesh = mesh_lib.get_abstract_mesh()
if not context_mesh:

View File

@ -5165,18 +5165,28 @@ def _rev_shape_rule(operand, *, dimensions):
raise TypeError(msg.format(dimensions, operand.ndim))
return operand.shape
def _rev_sharding_rule(operand, *, dimensions):
# TODO(yashkatariya): Will lead to data movement. Maybe just error out and
# require the operand to be unsharded?
return operand.sharding
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
operand, = batched_args
bdim, = batch_dims
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
return rev(operand, new_dimensions), bdim
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev',
sharding_rule=_rev_sharding_rule)
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
batching.primitive_batchers[rev_p] = _rev_batch_rule
def _rev_lower(ctx, x, *, dimensions):
return [hlo.reverse(x, mlir.dense_int_array(dimensions))]
aval_out, = ctx.avals_out
out = hlo.reverse(x, mlir.dense_int_array(dimensions))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(rev_p, _rev_lower)
@ -5932,7 +5942,10 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
mlir.flatten_ir_values(operands),
dimension=mlir.i64_attr(dimension),
is_stable=ir.BoolAttr.get(is_stable))
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
scalar_s = (lambda a: a.sharding.with_spec(P())
if config.sharding_in_types.value else lambda _: None)
scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval))
for aval in ctx.avals_in]
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
comparator = sort.comparator.blocks.append(
*util.flatten(zip(scalar_types, scalar_types)))

View File

@ -40,6 +40,7 @@ from jax._src.interpreters import mlir
from jax._src.lax import control_flow
from jax._src.lax import eigh as lax_eigh
from jax._src.lax import lax as lax_internal
from jax._src.partition_spec import PartitionSpec as P
from jax._src.lax import svd as lax_svd
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
@ -960,9 +961,20 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
if config.sharding_in_types.value:
batch_s = operand.sharding.spec[:-2]
ns = operand.sharding.spec[-1]
if ns is not None:
raise ValueError(f'n should be unsharded. Got n: {ns}'
' specs. Try marking their specs as None.')
w_s = operand.sharding.with_spec(P(*batch_s + (ns,)))
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ns)))
else:
w_s, v_s = None, None
w = operand.update(shape=batch_dims + (n,),
dtype=lax_internal._complex_basetype(operand.dtype))
v = operand.update(shape=batch_dims + (n, n))
dtype=lax_internal._complex_basetype(operand.dtype),
sharding=w_s)
v = operand.update(shape=batch_dims + (n, n), sharding=v_s)
else:
w, v = operand, operand
return w, v
@ -1029,16 +1041,23 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
d = (
n
if subset_by_index is None
else subset_by_index[1] - subset_by_index[0]
)
v = operand.update(shape=batch_dims + (n, d))
d = (n if subset_by_index is None else
subset_by_index[1] - subset_by_index[0])
if config.sharding_in_types.value:
batch_s = operand.sharding.spec[:-2]
ns, ds = operand.sharding.spec[-1], None
if ns is not None:
raise ValueError(f'n should be unsharded. Got n: {ns} specs. Try '
'marking their specs as None.')
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ds)))
w_s = operand.sharding.with_spec(P(*batch_s + (ds,)))
else:
v_s, w_s = None, None
v = operand.update(shape=batch_dims + (n, d), sharding=v_s)
w = operand.update(
shape=batch_dims + (d,),
dtype=lax_internal._complex_basetype(operand.dtype),
)
sharding=w_s)
else:
v, w = operand, operand
return v, w
@ -1249,6 +1268,24 @@ def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
raise TypeError(msg.format(a.shape, b.shape))
return b.shape
def _triangular_solve_sharding_rule(a, b, *, left_side=False, **unused_kwargs):
a_spec, b_spec = a.sharding.spec, b.sharding.spec
if a_spec[-1] != a_spec[-2]:
raise TypeError(
"triangular_solve requires the last two dimensions of a to be equal "
f"in sharding, got a_spec of {a_spec}.")
if a_spec[:-2] != b_spec[:-2]:
raise TypeError(
"triangular_solve requires both arguments to have the same number "
f"of dimensions and equal batch shardings, got {a_spec} and {b_spec}.")
common_dim = -2 if left_side else -1
if a_spec[-1] != b_spec[common_dim]:
raise TypeError(
"Incompatible shardings for arguments to triangular_solve:"
f" {a_spec} and {b_spec}.")
return b.sharding
def _triangular_solve_jvp_rule_a(
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
@ -1328,7 +1365,7 @@ def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
triangular_solve_p = standard_primitive(
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
'triangular_solve')
'triangular_solve', sharding_rule=_triangular_solve_sharding_rule)
ad.defjvp2(triangular_solve_p,
_triangular_solve_jvp_rule_a,
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
@ -1346,10 +1383,13 @@ def _triangular_solve_lowering(
transpose = "NO_TRANSPOSE"
else:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
return [hlo.triangular_solve(
a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]
out = hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, out_aval)]
return [out]
def _triangular_solve_cpu_lower(
@ -1802,7 +1842,17 @@ def _geqrf_abstract_eval(operand):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)))
if config.sharding_in_types.value:
spec = operand.sharding.spec
batch_s, ms, ns = spec[:-2], spec[-2], spec[-1]
if ms is not None or ns is not None:
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
' specs. Try marking their specs as None.')
taus_s = operand.sharding.with_spec(P(*(*batch_s, None)))
else:
taus_s = None
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)),
sharding=taus_s)
return operand, taus
def _geqrf_batching_rule(batched_args, batch_dims):
@ -2024,13 +2074,23 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
k = m if full_matrices else core.min_dim(m, n)
q = operand.update(shape=(*batch_dims, m, k))
r = operand.update(shape=(*batch_dims, k, n))
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32))
if config.sharding_in_types.value:
*batch_s, ms, ns = operand.sharding.spec
ks = None
if ms is not None or ns is not None:
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
' specs. Try marking their specs as None.')
q_s = operand.sharding.with_spec(P(*(*batch_s, ms, ks)))
r_s = operand.sharding.with_spec(P(*(*batch_s, ks, ns)))
p_s = operand.sharding.with_spec(P(*(*batch_s, ns)))
else:
q_s, r_s, p_s = None, None, None
q = operand.update(shape=(*batch_dims, m, k), sharding=q_s)
r = operand.update(shape=(*batch_dims, k, n), sharding=r_s)
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32),
sharding=p_s)
else:
q = operand
r = operand
p = operand
q, r, p = operand, operand, operand
return (q, r, p) if pivoting else (q, r)
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
@ -2136,13 +2196,32 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
raise ValueError("full_matrices and subset_by_index cannot both be set")
rank = min(rank, subset_by_index[1] - subset_by_index[0])
if config.sharding_in_types.value:
batch_s = operand.sharding.spec[:-2]
ms = operand.sharding.spec[-2]
ns = operand.sharding.spec[-1]
if ms is not None or ns is not None:
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
' specs. Try marking their specs as None.')
rank_s = None
s_sharding = operand.sharding.with_spec(P(*batch_s + (rank_s,)))
u_sharding = operand.sharding.with_spec(
P(*batch_s + (ms, ms if full_matrices else rank_s)))
vt_sharding = operand.sharding.with_spec(
P(*batch_s + (ns if full_matrices else rank_s, ns)))
else:
s_sharding, u_sharding, vt_sharding = None, None, None
s = operand.update(
shape=batch_dims + (rank,),
dtype=lax_internal._complex_basetype(operand.dtype),
sharding=s_sharding
)
if compute_uv:
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank))
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n))
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank),
sharding=u_sharding)
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n),
sharding=vt_sharding)
return s, u, vt
else:
return s,

View File

@ -1886,6 +1886,10 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
cur_mesh = mesh_lib.get_abstract_mesh()
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: # type: ignore
return None
if (cur_mesh._are_all_axes_explicit and # type: ignore
all(s is None for s in operand.sharding.spec) and
all(s is None for s in indices.sharding.spec)):
return None
raise GatherShardingError(
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
" the gather indexing.")

View File

@ -6368,6 +6368,17 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertTupleEqual(out2.sharding._device_assignment,
tuple(mesh2.devices.flat))
@jtu.with_user_mesh((2, 1), ('x', 'y'))
def test_svd(self, mesh):
np_inp = np.zeros([128, 128])
arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, None)))
@jax.jit
def f(x):
return jnp.linalg.norm(x, 2)
f(arr) # doesn't crash
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):