Add sharding to convert_element_type_p primitive.

There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
This commit is contained in:
Yash Katariya 2024-07-09 07:32:38 -07:00 committed by jax authors
parent 4b260cdc6b
commit 0426388d31
13 changed files with 165 additions and 54 deletions

View File

@ -220,7 +220,7 @@ class SourceInfo(NamedTuple):
eqn_name: str
def jaxpr_shardings(
def get_intermediate_shardings(
jaxpr: core.Jaxpr,
) -> Iterator[tuple[Sharding, SourceInfo]]:
from jax._src import pjit
@ -246,7 +246,7 @@ def jaxpr_shardings(
yield from ((s, source_info) for s in eqn.params['devices']
if isinstance(s, Sharding) and s.memory_kind is not None)
for subjaxpr in core.subjaxprs(jaxpr):
yield from jaxpr_shardings(subjaxpr)
yield from get_intermediate_shardings(subjaxpr)
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:

View File

@ -883,7 +883,7 @@ def _check_lowering(lowering) -> None:
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
"pgle_profiler"]
"pgle_profiler", "intermediate_shardings"]
for compile_arg in lowering.compile_args.keys():
if compile_arg not in allowed_compile_args:
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")

View File

@ -505,7 +505,7 @@ def _make_convert_element_type_harness(name,
"convert_element_type",
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}",
lambda arg: (lax.convert_element_type_p.bind(
arg, new_dtype=np.dtype(new_dtype), weak_type=False)),
arg, new_dtype=np.dtype(new_dtype), weak_type=False, sharding=None)),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,

View File

@ -2237,13 +2237,14 @@ def lower_sharding_computation(
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
unique_intermediate_shardings = list(util.stable_unique(
dispatch.get_intermediate_shardings(jaxpr)))
backend, device_assignment = _get_and_check_device_assignment(
it.chain(
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)),
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)),
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in util.stable_unique(jaxpr_sharding))),
for js, source_info in unique_intermediate_shardings)),
devices_from_context)
platforms = lowering_platforms or (backend.platform,)
@ -2254,14 +2255,15 @@ def lower_sharding_computation(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
any(not is_unspecified(o) for o in out_shardings))
da_object = _create_da_object(tuple(device_assignment))
all_default_mem_kind = are_all_shardings_default_mem_kind(
da_object,
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding]))
it.chain(in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]))
# TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when
# JAX puts memory kinds in the types of jaxpr.
@ -2321,7 +2323,8 @@ def lower_sharding_computation(
shape_poly_state=shape_poly_state,
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler)
pgle_profiler=pgle_profiler,
intermediate_shardings=[s for s, _ in unique_intermediate_shardings])
def _to_logical_sharding(
@ -2704,20 +2707,21 @@ def _get_out_sharding_from_orig_sharding(
return out
def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals):
old_shardings, new_shardings, old_avals, new_avals,
intermediate_shardings=None):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings
orig_in_s = None
orig_aval = None
for oi, aval in safe_zip(old_shardings, old_avals):
if type(oi) in _orig_out_sharding_handlers:
orig_in_s = oi
orig_aval = aval
break
if orig_in_s is not None:
return _get_out_sharding_from_orig_sharding(
new_shardings, new_avals, orig_in_s, orig_aval)
for oi, o_aval in safe_zip(old_shardings, old_avals):
if oi is not None and type(oi) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, new_avals, oi, o_aval)
if intermediate_shardings is not None:
for i in intermediate_shardings:
if i is not None and type(i) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, new_avals, i, None)
return new_shardings
@ -2997,7 +3001,8 @@ class UnloadedMeshExecutable:
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
compiler_options=None,
pgle_profiler: profiler.PGLEProfiler | None = None
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
@ -3054,7 +3059,8 @@ class UnloadedMeshExecutable:
xla_executable, in_layouts, out_layouts, len(ordered_effects))
out_shardings = maybe_recover_user_shardings(
in_shardings, out_shardings, global_in_avals, global_out_avals)
in_shardings, out_shardings, global_in_avals, global_out_avals,
intermediate_shardings)
out_shardings = finalize_out_shardings(out_shardings, da)

View File

@ -517,14 +517,16 @@ def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
return _convert_element_type(operand, new_dtype, weak_type=False)
def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None,
weak_type: bool = False):
weak_type: bool = False,
sharding: Sharding | None = None):
if hasattr(operand, '__jax_array__'):
operand = operand.__jax_array__()
if (dtypes.issubdtype(new_dtype, dtypes.extended) or
dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)):
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
weak_type=bool(weak_type))
return convert_element_type_p.bind(
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
sharding=sharding)
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
@ -553,11 +555,13 @@ def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
isinstance(operand, Array) and
not (isinstance(operand, core.Tracer) and
isinstance(core.get_aval(operand), core.ConcreteArray))):
isinstance(core.get_aval(operand), core.ConcreteArray)) and
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
return type_cast(Array, operand)
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
weak_type=bool(weak_type))
return convert_element_type_p.bind(
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
sharding=sharding)
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise bitcast.
@ -1341,7 +1345,8 @@ def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
broadcasted_iota(np.int32, shape, 1))
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False)
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False,
sharding=None)
def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
"""This utility function exists for creating Kronecker delta arrays."""
@ -1351,8 +1356,9 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
iotas = [broadcasted_iota(np.uint32, base_shape, i)
for i in range(len(base_shape))]
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
result = convert_element_type_p.bind(_reduce(operator.and_, eyes),
new_dtype=dtype, weak_type=False)
result = convert_element_type_p.bind(
_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False,
sharding=None)
return broadcast_in_dim(result, shape, axes)
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
@ -1362,7 +1368,8 @@ def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0),
asarray(core.dimension_as_value(offset)).astype(np.int32)),
broadcasted_iota(np.int32, shape, 1))
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False)
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False,
sharding=None)
def stop_gradient(x: T) -> T:
"""Stops gradient computation.
@ -2469,10 +2476,12 @@ ad.defjvp_zero(lt_to_p)
mlir.register_lowering(lt_to_p, partial(_compare_lower_hlo, "LT", True))
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type,
sharding):
return operand.shape
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
sharding):
if (operand.dtype != new_dtype and
((dtypes.issubdtype(operand.dtype, dtypes.extended) and
not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or
@ -2483,10 +2492,12 @@ def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
f"to {dtype_to_string(new_dtype)}")
return new_dtype
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type):
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type,
sharding):
return weak_type
def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type):
def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type,
sharding):
assert ad.is_undefined_primal(operand)
old_dtype = operand.aval.dtype
old_weak_type = dtypes.is_weakly_typed(operand)
@ -2495,16 +2506,17 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type):
elif core.primal_dtype_to_tangent_dtype(old_dtype) == dtypes.float0:
return [ad_util.Zero(operand.aval.update(dtype=dtypes.float0, weak_type=False))]
else:
return [convert_element_type_p.bind(ct, new_dtype=old_dtype,
weak_type=old_weak_type)]
return [convert_element_type_p.bind(
ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)]
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type,
sharding):
if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0:
tangent_aval = core.raise_to_shaped(core.get_aval(tangent))
return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False))
else:
return convert_element_type_p.bind(tangent, new_dtype=new_dtype,
weak_type=weak_type)
weak_type=weak_type, sharding=sharding)
def _convert_elt_type_folding_rule(consts, eqn):
# We constant-fold convert_element_types applied to constants if those
@ -2544,11 +2556,19 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
# don't print new_dtype because the output binder shows it, don't print
# weak_type when false
params = dict(eqn.params)
del params['new_dtype'] # output binder shows it
if not params['weak_type']: del params['weak_type'] # don't show trivial case
if params['sharding'] is None:
del params['sharding'] # don't show trivial case
return core._pp_eqn(eqn.replace(params=params), context, settings)
convert_element_type_p = Primitive('convert_element_type')
def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding):
operand = core.Primitive.bind(convert_element_type_p, operand,
new_dtype=new_dtype, weak_type=weak_type,
sharding=sharding)
if sharding is not None:
operand = jax.lax.with_sharding_constraint(operand, sharding)
return operand
convert_element_type_p.def_custom_bind(_convert_element_type_bind)
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
@ -2560,12 +2580,12 @@ batching.defvectorized(convert_element_type_p)
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
pe.def_trivial_padding(convert_element_type_p)
# TODO(mattjj): un-comment the next line (see #9456)
# core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
def _real_dtype(dtype): return np.finfo(dtype).dtype
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
sharding):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and

View File

@ -734,7 +734,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
Args:
a: input array
axes: optionally specify the permutation using a length-`a.ndim` sequence of integers
``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e
``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e.
reverses the order of all axes.
Returns:
@ -3227,7 +3227,8 @@ deprecations.register("jax-numpy-array-none")
@util.implements(np.array, lax_description=_ARRAY_DOC)
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0) -> Array:
order: str | None = "K", ndmin: int = 0,
*, device: xc.Device | Sharding | None = None) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
@ -3239,10 +3240,12 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
sharding = canonicalize_device_to_sharding(device)
# Use device_put to avoid a copy for ndarray inputs.
if (not copy and isinstance(object, np.ndarray) and
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim)):
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and
device is None):
# Keep the output uncommitted.
return jax.device_put(object)
@ -3326,12 +3329,19 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
raise TypeError(f"Unexpected input type for array: {type(object)}")
out_array: Array = lax_internal._convert_element_type(
out, dtype, weak_type=weak_type)
out, dtype, weak_type=weak_type, sharding=sharding)
if ndmin > ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
return out_array
def canonicalize_device_to_sharding(device: xc.Device | Sharding | None
) -> Sharding | None:
if isinstance(device, xc.Device):
return SingleDeviceSharding(device)
return device
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
try:
dtypes.dtype(x)

View File

@ -1344,9 +1344,10 @@ def _convert_helper(x, *, to_dtype):
raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")
def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
del weak_type
del sharding
out_aval = ctx.avals_out[0]
old_dtype = ctx.avals_in[0].dtype
out_type = aval_to_ir_type(out_aval)

View File

@ -309,7 +309,7 @@ def _broadcast_in_dim_lowering_rule(
@register_lowering_rule(lax.convert_element_type_p)
def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype))

View File

@ -1526,7 +1526,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
@register_lowering(lax.convert_element_type_p)
def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
[x_aval] = ctx.avals_in
x = _ensure_ir_value(x, x_aval)

View File

@ -636,6 +636,7 @@ def _infer_params_impl(
HashableFunction(res_paths, closure=()),
IgnoreKey(ji.inline))
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),

View File

@ -2007,7 +2007,7 @@ tf_impl[lax.le_to_p] = handle_boolean_args(partial(_total_order_cond, tf.math.le
tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky
def _convert_element_type(operand, *, new_dtype, weak_type=False):
def _convert_element_type(operand, *, new_dtype, weak_type=False, sharding=None):
old_dtype = operand.dtype.as_numpy_dtype
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):

View File

@ -99,7 +99,8 @@ def argwhere(
) -> Array: ...
around = round
def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True,
order: str | None = ..., ndmin: int = ...) -> Array: ...
order: str | None = ..., ndmin: int = ..., *,
device: _Device | _Sharding | None = None) -> Array: ...
def array_equal(
a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ...
) -> Array: ...

View File

@ -37,6 +37,7 @@ from jax import dtypes
from jax import stages
from jax.errors import JAXTypeError
from jax import lax
from jax._src.lax import lax as lax_internal
from jax.lax import with_sharding_constraint
from jax._src import prng
from jax.sharding import PartitionSpec as P, Mesh
@ -4240,6 +4241,77 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, s2)
def test_convert_element_type_sharding(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
inp = np.arange(16).reshape(8, 2)
out = lax_internal._convert_element_type(
inp, new_dtype=np.float32, weak_type=False, sharding=s)
self.assertArraysEqual(out, inp.astype('float32'))
self.assertEqual(out.dtype, np.float32)
self.assertEqual(out.sharding, s)
def test_jnp_array_sharding(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
inp = np.arange(16).reshape(8, 2)
out = jnp.array(inp, device=s)
self.assertArraysEqual(out, inp)
self.assertEqual(out.sharding, s)
def test_jnp_array_inside_jit_sharding(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
inp = np.arange(16).reshape(8, 2)
@jax.jit
def f():
return jnp.array(inp, dtype=np.float32, device=s)
out = f()
print(f.trace().jaxpr)
self.assertArraysEqual(out, inp.astype('float32'))
self.assertEqual(out.sharding, s)
self.assertEqual(out.dtype, np.float32)
@jax.jit
def g(x):
return jnp.array(x, dtype=np.float32, device=s)
out2 = g(inp)
self.assertArraysEqual(out2, inp.astype('float32'))
self.assertEqual(out2.sharding, s)
self.assertEqual(out2.dtype, np.float32)
def test_jnp_array_reshard_error(self):
if jax.device_count() < 2:
self.skipTest('Requires >=2 devices')
arr = jax.device_put(np.arange(8), jax.devices()[0])
with self.assertRaisesRegex(ValueError, "Received incompatible devices.*"):
jnp.array(arr, device=jax.devices()[1])
def test_jnp_array_sharded_array_no_op(self):
inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(inp, jax.devices()[0])
out = lax_internal._convert_element_type(
arr, sharding=SingleDeviceSharding(jax.devices()[0]))
self.assertArraysEqual(out, inp)
self.assertEqual(out.unsafe_buffer_pointer(), arr.unsafe_buffer_pointer())
def test_wsc_named_sharding_nullary(self):
mesh = jtu.create_global_mesh((2,), ('x',))
s = NamedSharding(mesh, P())
@jax.jit
def f():
return jax.lax.with_sharding_constraint(jnp.arange(8), s)
out = f()
self.assertEqual(out.sharding, s)
def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")