mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[sharding_in_types] Add support for svd_p
PiperOrigin-RevId: 720409750
This commit is contained in:
parent
24987a90dc
commit
ae705fef9c
@ -1731,14 +1731,20 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
|
||||
return TypeError(msg)
|
||||
|
||||
def _make_lengths_same(sharding, ndim):
|
||||
if ndim > len(sharding.spec):
|
||||
return sharding.with_spec(sharding.spec._normalized_spec(ndim))
|
||||
if ndim < len(sharding.spec):
|
||||
return sharding.with_spec(sharding.spec[:ndim])
|
||||
assert False, "unreachable"
|
||||
|
||||
|
||||
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
|
||||
# Collective too.
|
||||
def modify_spec_for_auto_manual(spec, mesh) -> P:
|
||||
if all(s is None for s in spec):
|
||||
return spec
|
||||
new_spec = [] # type: ignore
|
||||
for s in spec:
|
||||
if s is None:
|
||||
if not s:
|
||||
new_spec.append(s)
|
||||
else:
|
||||
temp_s = s[0] if isinstance(s, tuple) else s
|
||||
@ -1748,22 +1754,29 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
|
||||
else s)
|
||||
return P(*new_spec)
|
||||
|
||||
def _maybe_modify_sharding(sharding):
|
||||
def _maybe_modify_sharding(sharding, ndim):
|
||||
if sharding.mesh._are_all_axes_explicit:
|
||||
return sharding
|
||||
new_spec = modify_spec_for_auto_manual(sharding.spec, sharding.mesh)
|
||||
return sharding.with_spec(new_spec)
|
||||
out = sharding
|
||||
elif all(s is None for s in sharding.spec):
|
||||
out = sharding
|
||||
else:
|
||||
out = sharding.with_spec(modify_spec_for_auto_manual(
|
||||
sharding.spec, sharding.mesh))
|
||||
if (len(out.spec) != ndim and
|
||||
(out.mesh._are_all_axes_auto or out.mesh._are_all_axes_manual)):
|
||||
out = _make_lengths_same(out, ndim)
|
||||
return out
|
||||
|
||||
|
||||
def get_sharding(sharding, ndim):
|
||||
from jax._src.sharding_impls import NamedSharding # type: ignore
|
||||
|
||||
if sharding is not None:
|
||||
if len(sharding.spec) != ndim:
|
||||
out_s = _maybe_modify_sharding(sharding, ndim)
|
||||
if len(out_s.spec) != ndim:
|
||||
raise ValueError(
|
||||
"Length of sharding.spec must be equal to aval's ndim. Got"
|
||||
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
|
||||
out_s = _maybe_modify_sharding(sharding)
|
||||
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
|
||||
else:
|
||||
context_mesh = mesh_lib.get_abstract_mesh()
|
||||
if not context_mesh:
|
||||
|
@ -5165,18 +5165,28 @@ def _rev_shape_rule(operand, *, dimensions):
|
||||
raise TypeError(msg.format(dimensions, operand.ndim))
|
||||
return operand.shape
|
||||
|
||||
def _rev_sharding_rule(operand, *, dimensions):
|
||||
# TODO(yashkatariya): Will lead to data movement. Maybe just error out and
|
||||
# require the operand to be unsharded?
|
||||
return operand.sharding
|
||||
|
||||
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
|
||||
return rev(operand, new_dimensions), bdim
|
||||
|
||||
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
|
||||
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev',
|
||||
sharding_rule=_rev_sharding_rule)
|
||||
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
||||
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
||||
|
||||
def _rev_lower(ctx, x, *, dimensions):
|
||||
return [hlo.reverse(x, mlir.dense_int_array(dimensions))]
|
||||
aval_out, = ctx.avals_out
|
||||
out = hlo.reverse(x, mlir.dense_int_array(dimensions))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
mlir.register_lowering(rev_p, _rev_lower)
|
||||
|
||||
|
||||
@ -5932,7 +5942,10 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
mlir.flatten_ir_values(operands),
|
||||
dimension=mlir.i64_attr(dimension),
|
||||
is_stable=ir.BoolAttr.get(is_stable))
|
||||
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
|
||||
scalar_s = (lambda a: a.sharding.with_spec(P())
|
||||
if config.sharding_in_types.value else lambda _: None)
|
||||
scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval))
|
||||
for aval in ctx.avals_in]
|
||||
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
|
||||
comparator = sort.comparator.blocks.append(
|
||||
*util.flatten(zip(scalar_types, scalar_types)))
|
||||
|
@ -40,6 +40,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.lax import control_flow
|
||||
from jax._src.lax import eigh as lax_eigh
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.lax import svd as lax_svd
|
||||
from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
@ -960,9 +961,20 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
|
||||
batch_dims = operand.shape[:-2]
|
||||
n = operand.shape[-1]
|
||||
if config.sharding_in_types.value:
|
||||
batch_s = operand.sharding.spec[:-2]
|
||||
ns = operand.sharding.spec[-1]
|
||||
if ns is not None:
|
||||
raise ValueError(f'n should be unsharded. Got n: {ns}'
|
||||
' specs. Try marking their specs as None.')
|
||||
w_s = operand.sharding.with_spec(P(*batch_s + (ns,)))
|
||||
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ns)))
|
||||
else:
|
||||
w_s, v_s = None, None
|
||||
w = operand.update(shape=batch_dims + (n,),
|
||||
dtype=lax_internal._complex_basetype(operand.dtype))
|
||||
v = operand.update(shape=batch_dims + (n, n))
|
||||
dtype=lax_internal._complex_basetype(operand.dtype),
|
||||
sharding=w_s)
|
||||
v = operand.update(shape=batch_dims + (n, n), sharding=v_s)
|
||||
else:
|
||||
w, v = operand, operand
|
||||
return w, v
|
||||
@ -1029,16 +1041,23 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
|
||||
|
||||
batch_dims = operand.shape[:-2]
|
||||
n = operand.shape[-1]
|
||||
d = (
|
||||
n
|
||||
if subset_by_index is None
|
||||
else subset_by_index[1] - subset_by_index[0]
|
||||
)
|
||||
v = operand.update(shape=batch_dims + (n, d))
|
||||
d = (n if subset_by_index is None else
|
||||
subset_by_index[1] - subset_by_index[0])
|
||||
if config.sharding_in_types.value:
|
||||
batch_s = operand.sharding.spec[:-2]
|
||||
ns, ds = operand.sharding.spec[-1], None
|
||||
if ns is not None:
|
||||
raise ValueError(f'n should be unsharded. Got n: {ns} specs. Try '
|
||||
'marking their specs as None.')
|
||||
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ds)))
|
||||
w_s = operand.sharding.with_spec(P(*batch_s + (ds,)))
|
||||
else:
|
||||
v_s, w_s = None, None
|
||||
v = operand.update(shape=batch_dims + (n, d), sharding=v_s)
|
||||
w = operand.update(
|
||||
shape=batch_dims + (d,),
|
||||
dtype=lax_internal._complex_basetype(operand.dtype),
|
||||
)
|
||||
sharding=w_s)
|
||||
else:
|
||||
v, w = operand, operand
|
||||
return v, w
|
||||
@ -1249,6 +1268,24 @@ def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
||||
raise TypeError(msg.format(a.shape, b.shape))
|
||||
return b.shape
|
||||
|
||||
def _triangular_solve_sharding_rule(a, b, *, left_side=False, **unused_kwargs):
|
||||
a_spec, b_spec = a.sharding.spec, b.sharding.spec
|
||||
if a_spec[-1] != a_spec[-2]:
|
||||
raise TypeError(
|
||||
"triangular_solve requires the last two dimensions of a to be equal "
|
||||
f"in sharding, got a_spec of {a_spec}.")
|
||||
if a_spec[:-2] != b_spec[:-2]:
|
||||
raise TypeError(
|
||||
"triangular_solve requires both arguments to have the same number "
|
||||
f"of dimensions and equal batch shardings, got {a_spec} and {b_spec}.")
|
||||
common_dim = -2 if left_side else -1
|
||||
if a_spec[-1] != b_spec[common_dim]:
|
||||
raise TypeError(
|
||||
"Incompatible shardings for arguments to triangular_solve:"
|
||||
f" {a_spec} and {b_spec}.")
|
||||
return b.sharding
|
||||
|
||||
|
||||
def _triangular_solve_jvp_rule_a(
|
||||
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
@ -1328,7 +1365,7 @@ def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
||||
|
||||
triangular_solve_p = standard_primitive(
|
||||
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
|
||||
'triangular_solve')
|
||||
'triangular_solve', sharding_rule=_triangular_solve_sharding_rule)
|
||||
ad.defjvp2(triangular_solve_p,
|
||||
_triangular_solve_jvp_rule_a,
|
||||
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
||||
@ -1346,10 +1383,13 @@ def _triangular_solve_lowering(
|
||||
transpose = "NO_TRANSPOSE"
|
||||
else:
|
||||
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
||||
return [hlo.triangular_solve(
|
||||
a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose))]
|
||||
out = hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower),
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, out_aval)]
|
||||
return [out]
|
||||
|
||||
|
||||
def _triangular_solve_cpu_lower(
|
||||
@ -1802,7 +1842,17 @@ def _geqrf_abstract_eval(operand):
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
*batch_dims, m, n = operand.shape
|
||||
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)))
|
||||
if config.sharding_in_types.value:
|
||||
spec = operand.sharding.spec
|
||||
batch_s, ms, ns = spec[:-2], spec[-2], spec[-1]
|
||||
if ms is not None or ns is not None:
|
||||
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
|
||||
' specs. Try marking their specs as None.')
|
||||
taus_s = operand.sharding.with_spec(P(*(*batch_s, None)))
|
||||
else:
|
||||
taus_s = None
|
||||
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)),
|
||||
sharding=taus_s)
|
||||
return operand, taus
|
||||
|
||||
def _geqrf_batching_rule(batched_args, batch_dims):
|
||||
@ -2024,13 +2074,23 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
*batch_dims, m, n = operand.shape
|
||||
k = m if full_matrices else core.min_dim(m, n)
|
||||
q = operand.update(shape=(*batch_dims, m, k))
|
||||
r = operand.update(shape=(*batch_dims, k, n))
|
||||
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32))
|
||||
if config.sharding_in_types.value:
|
||||
*batch_s, ms, ns = operand.sharding.spec
|
||||
ks = None
|
||||
if ms is not None or ns is not None:
|
||||
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
|
||||
' specs. Try marking their specs as None.')
|
||||
q_s = operand.sharding.with_spec(P(*(*batch_s, ms, ks)))
|
||||
r_s = operand.sharding.with_spec(P(*(*batch_s, ks, ns)))
|
||||
p_s = operand.sharding.with_spec(P(*(*batch_s, ns)))
|
||||
else:
|
||||
q_s, r_s, p_s = None, None, None
|
||||
q = operand.update(shape=(*batch_dims, m, k), sharding=q_s)
|
||||
r = operand.update(shape=(*batch_dims, k, n), sharding=r_s)
|
||||
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32),
|
||||
sharding=p_s)
|
||||
else:
|
||||
q = operand
|
||||
r = operand
|
||||
p = operand
|
||||
q, r, p = operand, operand, operand
|
||||
return (q, r, p) if pivoting else (q, r)
|
||||
|
||||
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
|
||||
@ -2136,13 +2196,32 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
|
||||
raise ValueError("full_matrices and subset_by_index cannot both be set")
|
||||
rank = min(rank, subset_by_index[1] - subset_by_index[0])
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
batch_s = operand.sharding.spec[:-2]
|
||||
ms = operand.sharding.spec[-2]
|
||||
ns = operand.sharding.spec[-1]
|
||||
if ms is not None or ns is not None:
|
||||
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
|
||||
' specs. Try marking their specs as None.')
|
||||
rank_s = None
|
||||
s_sharding = operand.sharding.with_spec(P(*batch_s + (rank_s,)))
|
||||
u_sharding = operand.sharding.with_spec(
|
||||
P(*batch_s + (ms, ms if full_matrices else rank_s)))
|
||||
vt_sharding = operand.sharding.with_spec(
|
||||
P(*batch_s + (ns if full_matrices else rank_s, ns)))
|
||||
else:
|
||||
s_sharding, u_sharding, vt_sharding = None, None, None
|
||||
|
||||
s = operand.update(
|
||||
shape=batch_dims + (rank,),
|
||||
dtype=lax_internal._complex_basetype(operand.dtype),
|
||||
sharding=s_sharding
|
||||
)
|
||||
if compute_uv:
|
||||
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank))
|
||||
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n))
|
||||
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank),
|
||||
sharding=u_sharding)
|
||||
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n),
|
||||
sharding=vt_sharding)
|
||||
return s, u, vt
|
||||
else:
|
||||
return s,
|
||||
|
@ -1886,6 +1886,10 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: # type: ignore
|
||||
return None
|
||||
if (cur_mesh._are_all_axes_explicit and # type: ignore
|
||||
all(s is None for s in operand.sharding.spec) and
|
||||
all(s is None for s in indices.sharding.spec)):
|
||||
return None
|
||||
raise GatherShardingError(
|
||||
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
|
||||
" the gather indexing.")
|
||||
|
@ -6368,6 +6368,17 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertTupleEqual(out2.sharding._device_assignment,
|
||||
tuple(mesh2.devices.flat))
|
||||
|
||||
@jtu.with_user_mesh((2, 1), ('x', 'y'))
|
||||
def test_svd(self, mesh):
|
||||
np_inp = np.zeros([128, 128])
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, None)))
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return jnp.linalg.norm(x, 2)
|
||||
|
||||
f(arr) # doesn't crash
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user