From 0426388d31fba7c79b313ff9c4a3e875db4f8af0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 9 Jul 2024 07:32:38 -0700 Subject: [PATCH] Add `sharding` to `convert_element_type_p` primitive. There are 2 reasons for doing this: * Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs. * This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device. Also fixes: https://github.com/google/jax/issues/17422 PiperOrigin-RevId: 650621659 --- jax/_src/dispatch.py | 4 +- jax/_src/export/_export.py | 2 +- jax/_src/internal_test_util/test_harnesses.py | 2 +- jax/_src/interpreters/pxla.py | 42 ++++++----- jax/_src/lax/lax.py | 66 +++++++++++------ jax/_src/numpy/lax_numpy.py | 18 +++-- jax/_src/pallas/mosaic/lowering.py | 3 +- jax/_src/pallas/mosaic_gpu/lowering.py | 2 +- jax/_src/pallas/triton/lowering.py | 2 +- jax/_src/pjit.py | 1 + jax/experimental/jax2tf/jax2tf.py | 2 +- jax/numpy/__init__.pyi | 3 +- tests/pjit_test.py | 72 +++++++++++++++++++ 13 files changed, 165 insertions(+), 54 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index ec7bb81af..10c1947f3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -220,7 +220,7 @@ class SourceInfo(NamedTuple): eqn_name: str -def jaxpr_shardings( +def get_intermediate_shardings( jaxpr: core.Jaxpr, ) -> Iterator[tuple[Sharding, SourceInfo]]: from jax._src import pjit @@ -246,7 +246,7 @@ def jaxpr_shardings( yield from ((s, source_info) for s in eqn.params['devices'] if isinstance(s, Sharding) and s.memory_kind is not None) for subjaxpr in core.subjaxprs(jaxpr): - yield from jaxpr_shardings(subjaxpr) + yield from get_intermediate_shardings(subjaxpr) def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 799d764ea..c54a9f4b6 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -883,7 +883,7 @@ def _check_lowering(lowering) -> None: "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", "all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info", - "pgle_profiler"] + "pgle_profiler", "intermediate_shardings"] for compile_arg in lowering.compile_args.keys(): if compile_arg not in allowed_compile_args: raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]") diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 2b22944c1..4bf6d1ceb 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -505,7 +505,7 @@ def _make_convert_element_type_harness(name, "convert_element_type", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}", lambda arg: (lax.convert_element_type_p.bind( - arg, new_dtype=np.dtype(new_dtype), weak_type=False)), + arg, new_dtype=np.dtype(new_dtype), weak_type=False, sharding=None)), [RandArg(shape, dtype)], shape=shape, dtype=dtype, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8385d801e..62119859c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2237,13 +2237,14 @@ def lower_sharding_computation( # Device assignment across all inputs, outputs and shardings inside jaxpr # should be the same. - jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr)) + unique_intermediate_shardings = list(util.stable_unique( + dispatch.get_intermediate_shardings(jaxpr))) backend, device_assignment = _get_and_check_device_assignment( it.chain( ((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)), ((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)), ((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) - for js, source_info in util.stable_unique(jaxpr_sharding))), + for js, source_info in unique_intermediate_shardings)), devices_from_context) platforms = lowering_platforms or (backend.platform,) @@ -2254,14 +2255,15 @@ def lower_sharding_computation( devices_from_context or len(device_assignment) > 1 or any(not is_unspecified(i) for i in in_shardings) or - any(not is_unspecified(js) for js, _ in jaxpr_sharding) or + any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or any(not is_unspecified(o) for o in out_shardings)) da_object = _create_da_object(tuple(device_assignment)) all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, - it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding])) + it.chain(in_shardings, out_shardings, + [js for js, _ in unique_intermediate_shardings])) # TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when # JAX puts memory kinds in the types of jaxpr. @@ -2321,7 +2323,8 @@ def lower_sharding_computation( shape_poly_state=shape_poly_state, all_default_mem_kind=all_default_mem_kind, all_args_info=all_args_info, - pgle_profiler=pgle_profiler) + pgle_profiler=pgle_profiler, + intermediate_shardings=[s for s, _ in unique_intermediate_shardings]) def _to_logical_sharding( @@ -2704,20 +2707,21 @@ def _get_out_sharding_from_orig_sharding( return out def maybe_recover_user_shardings( - old_shardings, new_shardings, old_avals, new_avals): + old_shardings, new_shardings, old_avals, new_avals, + intermediate_shardings=None): if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings): return new_shardings - orig_in_s = None - orig_aval = None - for oi, aval in safe_zip(old_shardings, old_avals): - if type(oi) in _orig_out_sharding_handlers: - orig_in_s = oi - orig_aval = aval - break - if orig_in_s is not None: - return _get_out_sharding_from_orig_sharding( - new_shardings, new_avals, orig_in_s, orig_aval) + for oi, o_aval in safe_zip(old_shardings, old_avals): + if oi is not None and type(oi) in _orig_out_sharding_handlers: + return _get_out_sharding_from_orig_sharding( + new_shardings, new_avals, oi, o_aval) + + if intermediate_shardings is not None: + for i in intermediate_shardings: + if i is not None and type(i) in _orig_out_sharding_handlers: + return _get_out_sharding_from_orig_sharding( + new_shardings, new_avals, i, None) return new_shardings @@ -2997,7 +3001,8 @@ class UnloadedMeshExecutable: all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, compiler_options=None, - pgle_profiler: profiler.PGLEProfiler | None = None + pgle_profiler: profiler.PGLEProfiler | None = None, + intermediate_shardings: Sequence[JSharding] | None = None, ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) @@ -3054,7 +3059,8 @@ class UnloadedMeshExecutable: xla_executable, in_layouts, out_layouts, len(ordered_effects)) out_shardings = maybe_recover_user_shardings( - in_shardings, out_shardings, global_in_avals, global_out_avals) + in_shardings, out_shardings, global_in_avals, global_out_avals, + intermediate_shardings) out_shardings = finalize_out_shardings(out_shardings, da) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 62071350d..53c8f92ec 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -517,14 +517,16 @@ def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: return _convert_element_type(operand, new_dtype, weak_type=False) def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None, - weak_type: bool = False): + weak_type: bool = False, + sharding: Sharding | None = None): if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() if (dtypes.issubdtype(new_dtype, dtypes.extended) or dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)): - return convert_element_type_p.bind(operand, new_dtype=new_dtype, - weak_type=bool(weak_type)) + return convert_element_type_p.bind( + operand, new_dtype=new_dtype, weak_type=bool(weak_type), + sharding=sharding) # Don't canonicalize old_dtype because x64 context might cause # un-canonicalized operands to be passed in. @@ -553,11 +555,13 @@ def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array) and not (isinstance(operand, core.Tracer) and - isinstance(core.get_aval(operand), core.ConcreteArray))): + isinstance(core.get_aval(operand), core.ConcreteArray)) and + (sharding is None or getattr(operand, 'sharding', None) == sharding)): return type_cast(Array, operand) else: - return convert_element_type_p.bind(operand, new_dtype=new_dtype, - weak_type=bool(weak_type)) + return convert_element_type_p.bind( + operand, new_dtype=new_dtype, weak_type=bool(weak_type), + sharding=sharding) def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array: """Elementwise bitcast. @@ -1341,7 +1345,8 @@ def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: dtype = dtypes.canonicalize_dtype(dtype) bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)), broadcasted_iota(np.int32, shape, 1)) - return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False) + return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False, + sharding=None) def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: """This utility function exists for creating Kronecker delta arrays.""" @@ -1351,8 +1356,9 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: iotas = [broadcasted_iota(np.uint32, base_shape, i) for i in range(len(base_shape))] eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] - result = convert_element_type_p.bind(_reduce(operator.and_, eyes), - new_dtype=dtype, weak_type=False) + result = convert_element_type_p.bind( + _reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False, + sharding=None) return broadcast_in_dim(result, shape, axes) def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: @@ -1362,7 +1368,8 @@ def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), asarray(core.dimension_as_value(offset)).astype(np.int32)), broadcasted_iota(np.int32, shape, 1)) - return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False) + return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False, + sharding=None) def stop_gradient(x: T) -> T: """Stops gradient computation. @@ -2469,10 +2476,12 @@ ad.defjvp_zero(lt_to_p) mlir.register_lowering(lt_to_p, partial(_compare_lower_hlo, "LT", True)) -def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): +def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type, + sharding): return operand.shape -def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type): +def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type, + sharding): if (operand.dtype != new_dtype and ((dtypes.issubdtype(operand.dtype, dtypes.extended) and not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or @@ -2483,10 +2492,12 @@ def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type): f"to {dtype_to_string(new_dtype)}") return new_dtype -def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type): +def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type, + sharding): return weak_type -def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type): +def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type, + sharding): assert ad.is_undefined_primal(operand) old_dtype = operand.aval.dtype old_weak_type = dtypes.is_weakly_typed(operand) @@ -2495,16 +2506,17 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type): elif core.primal_dtype_to_tangent_dtype(old_dtype) == dtypes.float0: return [ad_util.Zero(operand.aval.update(dtype=dtypes.float0, weak_type=False))] else: - return [convert_element_type_p.bind(ct, new_dtype=old_dtype, - weak_type=old_weak_type)] + return [convert_element_type_p.bind( + ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)] -def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type): +def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type, + sharding): if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0: tangent_aval = core.raise_to_shaped(core.get_aval(tangent)) return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False)) else: return convert_element_type_p.bind(tangent, new_dtype=new_dtype, - weak_type=weak_type) + weak_type=weak_type, sharding=sharding) def _convert_elt_type_folding_rule(consts, eqn): # We constant-fold convert_element_types applied to constants if those @@ -2544,11 +2556,19 @@ def _convert_elt_type_pp_rule(eqn, context, settings): # don't print new_dtype because the output binder shows it, don't print # weak_type when false params = dict(eqn.params) - del params['new_dtype'] # output binder shows it - if not params['weak_type']: del params['weak_type'] # don't show trivial case + if params['sharding'] is None: + del params['sharding'] # don't show trivial case return core._pp_eqn(eqn.replace(params=params), context, settings) convert_element_type_p = Primitive('convert_element_type') +def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): + operand = core.Primitive.bind(convert_element_type_p, operand, + new_dtype=new_dtype, weak_type=weak_type, + sharding=sharding) + if sharding is not None: + operand = jax.lax.with_sharding_constraint(operand, sharding) + return operand +convert_element_type_p.def_custom_bind(_convert_element_type_bind) convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, @@ -2560,12 +2580,12 @@ batching.defvectorized(convert_element_type_p) pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule pe.def_trivial_padding(convert_element_type_p) -# TODO(mattjj): un-comment the next line (see #9456) -# core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule +core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule def _real_dtype(dtype): return np.finfo(dtype).dtype -def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type): +def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, + sharding): aval_in, = ctx.avals_in aval_out, = ctx.avals_out if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c80354bf8..b1f493324 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -734,7 +734,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: Args: a: input array axes: optionally specify the permutation using a length-`a.ndim` sequence of integers - ``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e + ``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e. reverses the order of all axes. Returns: @@ -3227,7 +3227,8 @@ deprecations.register("jax-numpy-array-none") @util.implements(np.array, lax_description=_ARRAY_DOC) def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, - order: str | None = "K", ndmin: int = 0) -> Array: + order: str | None = "K", ndmin: int = 0, + *, device: xc.Device | Sharding | None = None) -> Array: if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'") @@ -3239,10 +3240,12 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # whenever x is weak, but avoids introducing weak types with something like # array([1, 2, 3]) weak_type = dtype is None and dtypes.is_weakly_typed(object) + sharding = canonicalize_device_to_sharding(device) # Use device_put to avoid a copy for ndarray inputs. if (not copy and isinstance(object, np.ndarray) and - (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim)): + (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and + device is None): # Keep the output uncommitted. return jax.device_put(object) @@ -3326,12 +3329,19 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, raise TypeError(f"Unexpected input type for array: {type(object)}") out_array: Array = lax_internal._convert_element_type( - out, dtype, weak_type=weak_type) + out, dtype, weak_type=weak_type, sharding=sharding) if ndmin > ndim(out_array): out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array))) return out_array +def canonicalize_device_to_sharding(device: xc.Device | Sharding | None + ) -> Sharding | None: + if isinstance(device, xc.Device): + return SingleDeviceSharding(device) + return device + + def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: try: dtypes.dtype(x) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 7fec8285d..416a88d4b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1344,9 +1344,10 @@ def _convert_helper(x, *, to_dtype): raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}") def _convert_element_type_lowering_rule( - ctx: LoweringRuleContext, x, *, new_dtype, weak_type + ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type + del sharding out_aval = ctx.avals_out[0] old_dtype = ctx.avals_in[0].dtype out_type = aval_to_ir_type(out_aval) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 5b4db68f2..47597dbcf 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -309,7 +309,7 @@ def _broadcast_in_dim_lowering_rule( @register_lowering_rule(lax.convert_element_type_p) def _convert_element_type_lowering_rule( - ctx: LoweringRuleContext, x, *, new_dtype, weak_type + ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype)) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index b11aaa026..d8eaa8d28 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1526,7 +1526,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value: @register_lowering(lax.convert_element_type_p) def _convert_element_type_lowering_rule( - ctx: LoweringRuleContext, x, *, new_dtype, weak_type + ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): [x_aval] = ctx.avals_in x = _ensure_ir_value(x, x_aval) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 992e8d57e..f9b3925c3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -636,6 +636,7 @@ def _infer_params_impl( HashableFunction(res_paths, closure=()), IgnoreKey(ji.inline)) _attr_update(flat_fun, in_type, attr_token, attrs_tracked) + out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, ji.out_layouts_leaves, HashableFunction(out_tree, closure=()), diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index b14883b42..cfed7d7d0 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2007,7 +2007,7 @@ tf_impl[lax.le_to_p] = handle_boolean_args(partial(_total_order_cond, tf.math.le tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky -def _convert_element_type(operand, *, new_dtype, weak_type=False): +def _convert_element_type(operand, *, new_dtype, weak_type=False, sharding=None): old_dtype = operand.dtype.as_numpy_dtype if (dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 5e4735f98..656fe3827 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -99,7 +99,8 @@ def argwhere( ) -> Array: ... around = round def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True, - order: str | None = ..., ndmin: int = ...) -> Array: ... + order: str | None = ..., ndmin: int = ..., *, + device: _Device | _Sharding | None = None) -> Array: ... def array_equal( a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ... ) -> Array: ... diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9e1ca442a..3b56a4f70 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -37,6 +37,7 @@ from jax import dtypes from jax import stages from jax.errors import JAXTypeError from jax import lax +from jax._src.lax import lax as lax_internal from jax.lax import with_sharding_constraint from jax._src import prng from jax.sharding import PartitionSpec as P, Mesh @@ -4240,6 +4241,77 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, s2) + def test_convert_element_type_sharding(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + inp = np.arange(16).reshape(8, 2) + + out = lax_internal._convert_element_type( + inp, new_dtype=np.float32, weak_type=False, sharding=s) + self.assertArraysEqual(out, inp.astype('float32')) + self.assertEqual(out.dtype, np.float32) + self.assertEqual(out.sharding, s) + + def test_jnp_array_sharding(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + inp = np.arange(16).reshape(8, 2) + + out = jnp.array(inp, device=s) + self.assertArraysEqual(out, inp) + self.assertEqual(out.sharding, s) + + def test_jnp_array_inside_jit_sharding(self): + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + inp = np.arange(16).reshape(8, 2) + + @jax.jit + def f(): + return jnp.array(inp, dtype=np.float32, device=s) + + out = f() + print(f.trace().jaxpr) + self.assertArraysEqual(out, inp.astype('float32')) + self.assertEqual(out.sharding, s) + self.assertEqual(out.dtype, np.float32) + + @jax.jit + def g(x): + return jnp.array(x, dtype=np.float32, device=s) + + out2 = g(inp) + self.assertArraysEqual(out2, inp.astype('float32')) + self.assertEqual(out2.sharding, s) + self.assertEqual(out2.dtype, np.float32) + + def test_jnp_array_reshard_error(self): + if jax.device_count() < 2: + self.skipTest('Requires >=2 devices') + arr = jax.device_put(np.arange(8), jax.devices()[0]) + with self.assertRaisesRegex(ValueError, "Received incompatible devices.*"): + jnp.array(arr, device=jax.devices()[1]) + + def test_jnp_array_sharded_array_no_op(self): + inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(inp, jax.devices()[0]) + + out = lax_internal._convert_element_type( + arr, sharding=SingleDeviceSharding(jax.devices()[0])) + self.assertArraysEqual(out, inp) + self.assertEqual(out.unsafe_buffer_pointer(), arr.unsafe_buffer_pointer()) + + def test_wsc_named_sharding_nullary(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + s = NamedSharding(mesh, P()) + + @jax.jit + def f(): + return jax.lax.with_sharding_constraint(jnp.arange(8), s) + + out = f() + self.assertEqual(out.sharding, s) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")