mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add sharding
to convert_element_type_p
primitive.
There are 2 reasons for doing this: * Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs. * This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device. Also fixes: https://github.com/google/jax/issues/17422 PiperOrigin-RevId: 650621659
This commit is contained in:
parent
4b260cdc6b
commit
0426388d31
@ -220,7 +220,7 @@ class SourceInfo(NamedTuple):
|
||||
eqn_name: str
|
||||
|
||||
|
||||
def jaxpr_shardings(
|
||||
def get_intermediate_shardings(
|
||||
jaxpr: core.Jaxpr,
|
||||
) -> Iterator[tuple[Sharding, SourceInfo]]:
|
||||
from jax._src import pjit
|
||||
@ -246,7 +246,7 @@ def jaxpr_shardings(
|
||||
yield from ((s, source_info) for s in eqn.params['devices']
|
||||
if isinstance(s, Sharding) and s.memory_kind is not None)
|
||||
for subjaxpr in core.subjaxprs(jaxpr):
|
||||
yield from jaxpr_shardings(subjaxpr)
|
||||
yield from get_intermediate_shardings(subjaxpr)
|
||||
|
||||
|
||||
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
|
||||
|
@ -883,7 +883,7 @@ def _check_lowering(lowering) -> None:
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
||||
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
|
||||
"pgle_profiler"]
|
||||
"pgle_profiler", "intermediate_shardings"]
|
||||
for compile_arg in lowering.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
@ -505,7 +505,7 @@ def _make_convert_element_type_harness(name,
|
||||
"convert_element_type",
|
||||
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}",
|
||||
lambda arg: (lax.convert_element_type_p.bind(
|
||||
arg, new_dtype=np.dtype(new_dtype), weak_type=False)),
|
||||
arg, new_dtype=np.dtype(new_dtype), weak_type=False, sharding=None)),
|
||||
[RandArg(shape, dtype)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
|
@ -2237,13 +2237,14 @@ def lower_sharding_computation(
|
||||
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
|
||||
unique_intermediate_shardings = list(util.stable_unique(
|
||||
dispatch.get_intermediate_shardings(jaxpr)))
|
||||
backend, device_assignment = _get_and_check_device_assignment(
|
||||
it.chain(
|
||||
((i, MismatchType.ARG_SHARDING, None) for i in util.stable_unique(in_shardings)),
|
||||
((o, MismatchType.OUT_SHARDING, None) for o in util.stable_unique(out_shardings)),
|
||||
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||
for js, source_info in util.stable_unique(jaxpr_sharding))),
|
||||
for js, source_info in unique_intermediate_shardings)),
|
||||
devices_from_context)
|
||||
|
||||
platforms = lowering_platforms or (backend.platform,)
|
||||
@ -2254,14 +2255,15 @@ def lower_sharding_computation(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not is_unspecified(i) for i in in_shardings) or
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
da_object,
|
||||
it.chain(in_shardings, out_shardings, [js for js, _ in jaxpr_sharding]))
|
||||
it.chain(in_shardings, out_shardings,
|
||||
[js for js, _ in unique_intermediate_shardings]))
|
||||
|
||||
# TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when
|
||||
# JAX puts memory kinds in the types of jaxpr.
|
||||
@ -2321,7 +2323,8 @@ def lower_sharding_computation(
|
||||
shape_poly_state=shape_poly_state,
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
all_args_info=all_args_info,
|
||||
pgle_profiler=pgle_profiler)
|
||||
pgle_profiler=pgle_profiler,
|
||||
intermediate_shardings=[s for s, _ in unique_intermediate_shardings])
|
||||
|
||||
|
||||
def _to_logical_sharding(
|
||||
@ -2704,20 +2707,21 @@ def _get_out_sharding_from_orig_sharding(
|
||||
return out
|
||||
|
||||
def maybe_recover_user_shardings(
|
||||
old_shardings, new_shardings, old_avals, new_avals):
|
||||
old_shardings, new_shardings, old_avals, new_avals,
|
||||
intermediate_shardings=None):
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
|
||||
return new_shardings
|
||||
|
||||
orig_in_s = None
|
||||
orig_aval = None
|
||||
for oi, aval in safe_zip(old_shardings, old_avals):
|
||||
if type(oi) in _orig_out_sharding_handlers:
|
||||
orig_in_s = oi
|
||||
orig_aval = aval
|
||||
break
|
||||
if orig_in_s is not None:
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
new_shardings, new_avals, orig_in_s, orig_aval)
|
||||
for oi, o_aval in safe_zip(old_shardings, old_avals):
|
||||
if oi is not None and type(oi) in _orig_out_sharding_handlers:
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
new_shardings, new_avals, oi, o_aval)
|
||||
|
||||
if intermediate_shardings is not None:
|
||||
for i in intermediate_shardings:
|
||||
if i is not None and type(i) in _orig_out_sharding_handlers:
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
new_shardings, new_avals, i, None)
|
||||
|
||||
return new_shardings
|
||||
|
||||
@ -2997,7 +3001,8 @@ class UnloadedMeshExecutable:
|
||||
all_default_mem_kind: bool = True,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
compiler_options=None,
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None,
|
||||
intermediate_shardings: Sequence[JSharding] | None = None,
|
||||
) -> MeshExecutable:
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
@ -3054,7 +3059,8 @@ class UnloadedMeshExecutable:
|
||||
xla_executable, in_layouts, out_layouts, len(ordered_effects))
|
||||
|
||||
out_shardings = maybe_recover_user_shardings(
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals)
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals,
|
||||
intermediate_shardings)
|
||||
|
||||
out_shardings = finalize_out_shardings(out_shardings, da)
|
||||
|
||||
|
@ -517,14 +517,16 @@ def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
return _convert_element_type(operand, new_dtype, weak_type=False)
|
||||
|
||||
def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None,
|
||||
weak_type: bool = False):
|
||||
weak_type: bool = False,
|
||||
sharding: Sharding | None = None):
|
||||
if hasattr(operand, '__jax_array__'):
|
||||
operand = operand.__jax_array__()
|
||||
|
||||
if (dtypes.issubdtype(new_dtype, dtypes.extended) or
|
||||
dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)):
|
||||
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
|
||||
weak_type=bool(weak_type))
|
||||
return convert_element_type_p.bind(
|
||||
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
||||
sharding=sharding)
|
||||
|
||||
# Don't canonicalize old_dtype because x64 context might cause
|
||||
# un-canonicalized operands to be passed in.
|
||||
@ -553,11 +555,13 @@ def _convert_element_type(operand: ArrayLike, new_dtype: DTypeLike | None = None
|
||||
if ((old_dtype, old_weak_type) == (new_dtype, weak_type) and
|
||||
isinstance(operand, Array) and
|
||||
not (isinstance(operand, core.Tracer) and
|
||||
isinstance(core.get_aval(operand), core.ConcreteArray))):
|
||||
isinstance(core.get_aval(operand), core.ConcreteArray)) and
|
||||
(sharding is None or getattr(operand, 'sharding', None) == sharding)):
|
||||
return type_cast(Array, operand)
|
||||
else:
|
||||
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
|
||||
weak_type=bool(weak_type))
|
||||
return convert_element_type_p.bind(
|
||||
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
|
||||
sharding=sharding)
|
||||
|
||||
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
"""Elementwise bitcast.
|
||||
@ -1341,7 +1345,8 @@ def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
||||
broadcasted_iota(np.int32, shape, 1))
|
||||
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False)
|
||||
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False,
|
||||
sharding=None)
|
||||
|
||||
def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
"""This utility function exists for creating Kronecker delta arrays."""
|
||||
@ -1351,8 +1356,9 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
iotas = [broadcasted_iota(np.uint32, base_shape, i)
|
||||
for i in range(len(base_shape))]
|
||||
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
|
||||
result = convert_element_type_p.bind(_reduce(operator.and_, eyes),
|
||||
new_dtype=dtype, weak_type=False)
|
||||
result = convert_element_type_p.bind(
|
||||
_reduce(operator.and_, eyes), new_dtype=dtype, weak_type=False,
|
||||
sharding=None)
|
||||
return broadcast_in_dim(result, shape, axes)
|
||||
|
||||
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
|
||||
@ -1362,7 +1368,8 @@ def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
|
||||
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0),
|
||||
asarray(core.dimension_as_value(offset)).astype(np.int32)),
|
||||
broadcasted_iota(np.int32, shape, 1))
|
||||
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False)
|
||||
return convert_element_type_p.bind(bool_tri, new_dtype=dtype, weak_type=False,
|
||||
sharding=None)
|
||||
|
||||
def stop_gradient(x: T) -> T:
|
||||
"""Stops gradient computation.
|
||||
@ -2469,10 +2476,12 @@ ad.defjvp_zero(lt_to_p)
|
||||
mlir.register_lowering(lt_to_p, partial(_compare_lower_hlo, "LT", True))
|
||||
|
||||
|
||||
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
|
||||
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
return operand.shape
|
||||
|
||||
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
|
||||
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
if (operand.dtype != new_dtype and
|
||||
((dtypes.issubdtype(operand.dtype, dtypes.extended) and
|
||||
not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or
|
||||
@ -2483,10 +2492,12 @@ def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
|
||||
f"to {dtype_to_string(new_dtype)}")
|
||||
return new_dtype
|
||||
|
||||
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type):
|
||||
def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
return weak_type
|
||||
|
||||
def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type):
|
||||
def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
old_dtype = operand.aval.dtype
|
||||
old_weak_type = dtypes.is_weakly_typed(operand)
|
||||
@ -2495,16 +2506,17 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type):
|
||||
elif core.primal_dtype_to_tangent_dtype(old_dtype) == dtypes.float0:
|
||||
return [ad_util.Zero(operand.aval.update(dtype=dtypes.float0, weak_type=False))]
|
||||
else:
|
||||
return [convert_element_type_p.bind(ct, new_dtype=old_dtype,
|
||||
weak_type=old_weak_type)]
|
||||
return [convert_element_type_p.bind(
|
||||
ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)]
|
||||
|
||||
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type):
|
||||
def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type,
|
||||
sharding):
|
||||
if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0:
|
||||
tangent_aval = core.raise_to_shaped(core.get_aval(tangent))
|
||||
return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False))
|
||||
else:
|
||||
return convert_element_type_p.bind(tangent, new_dtype=new_dtype,
|
||||
weak_type=weak_type)
|
||||
weak_type=weak_type, sharding=sharding)
|
||||
|
||||
def _convert_elt_type_folding_rule(consts, eqn):
|
||||
# We constant-fold convert_element_types applied to constants if those
|
||||
@ -2544,11 +2556,19 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
|
||||
# don't print new_dtype because the output binder shows it, don't print
|
||||
# weak_type when false
|
||||
params = dict(eqn.params)
|
||||
del params['new_dtype'] # output binder shows it
|
||||
if not params['weak_type']: del params['weak_type'] # don't show trivial case
|
||||
if params['sharding'] is None:
|
||||
del params['sharding'] # don't show trivial case
|
||||
return core._pp_eqn(eqn.replace(params=params), context, settings)
|
||||
|
||||
convert_element_type_p = Primitive('convert_element_type')
|
||||
def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding):
|
||||
operand = core.Primitive.bind(convert_element_type_p, operand,
|
||||
new_dtype=new_dtype, weak_type=weak_type,
|
||||
sharding=sharding)
|
||||
if sharding is not None:
|
||||
operand = jax.lax.with_sharding_constraint(operand, sharding)
|
||||
return operand
|
||||
convert_element_type_p.def_custom_bind(_convert_element_type_bind)
|
||||
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
|
||||
convert_element_type_p.def_abstract_eval(
|
||||
partial(standard_abstract_eval, convert_element_type_p,
|
||||
@ -2560,12 +2580,12 @@ batching.defvectorized(convert_element_type_p)
|
||||
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
|
||||
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
|
||||
pe.def_trivial_padding(convert_element_type_p)
|
||||
# TODO(mattjj): un-comment the next line (see #9456)
|
||||
# core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
|
||||
core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
|
||||
|
||||
def _real_dtype(dtype): return np.finfo(dtype).dtype
|
||||
|
||||
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
|
||||
def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
|
||||
|
@ -734,7 +734,7 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
|
||||
Args:
|
||||
a: input array
|
||||
axes: optionally specify the permutation using a length-`a.ndim` sequence of integers
|
||||
``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e
|
||||
``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e.
|
||||
reverses the order of all axes.
|
||||
|
||||
Returns:
|
||||
@ -3227,7 +3227,8 @@ deprecations.register("jax-numpy-array-none")
|
||||
|
||||
@util.implements(np.array, lax_description=_ARRAY_DOC)
|
||||
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
order: str | None = "K", ndmin: int = 0) -> Array:
|
||||
order: str | None = "K", ndmin: int = 0,
|
||||
*, device: xc.Device | Sharding | None = None) -> Array:
|
||||
if order is not None and order != "K":
|
||||
raise NotImplementedError("Only implemented for order='K'")
|
||||
|
||||
@ -3239,10 +3240,12 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
# whenever x is weak, but avoids introducing weak types with something like
|
||||
# array([1, 2, 3])
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(object)
|
||||
sharding = canonicalize_device_to_sharding(device)
|
||||
|
||||
# Use device_put to avoid a copy for ndarray inputs.
|
||||
if (not copy and isinstance(object, np.ndarray) and
|
||||
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim)):
|
||||
(dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and
|
||||
device is None):
|
||||
# Keep the output uncommitted.
|
||||
return jax.device_put(object)
|
||||
|
||||
@ -3326,12 +3329,19 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
raise TypeError(f"Unexpected input type for array: {type(object)}")
|
||||
|
||||
out_array: Array = lax_internal._convert_element_type(
|
||||
out, dtype, weak_type=weak_type)
|
||||
out, dtype, weak_type=weak_type, sharding=sharding)
|
||||
if ndmin > ndim(out_array):
|
||||
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
|
||||
return out_array
|
||||
|
||||
|
||||
def canonicalize_device_to_sharding(device: xc.Device | Sharding | None
|
||||
) -> Sharding | None:
|
||||
if isinstance(device, xc.Device):
|
||||
return SingleDeviceSharding(device)
|
||||
return device
|
||||
|
||||
|
||||
def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
|
||||
try:
|
||||
dtypes.dtype(x)
|
||||
|
@ -1344,9 +1344,10 @@ def _convert_helper(x, *, to_dtype):
|
||||
raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")
|
||||
|
||||
def _convert_element_type_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
|
||||
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
|
||||
):
|
||||
del weak_type
|
||||
del sharding
|
||||
out_aval = ctx.avals_out[0]
|
||||
old_dtype = ctx.avals_in[0].dtype
|
||||
out_type = aval_to_ir_type(out_aval)
|
||||
|
@ -309,7 +309,7 @@ def _broadcast_in_dim_lowering_rule(
|
||||
|
||||
@register_lowering_rule(lax.convert_element_type_p)
|
||||
def _convert_element_type_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
|
||||
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
|
||||
):
|
||||
return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype))
|
||||
|
||||
|
@ -1526,7 +1526,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
|
||||
|
||||
@register_lowering(lax.convert_element_type_p)
|
||||
def _convert_element_type_lowering_rule(
|
||||
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
|
||||
ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
|
||||
):
|
||||
[x_aval] = ctx.avals_in
|
||||
x = _ensure_ir_value(x, x_aval)
|
||||
|
@ -636,6 +636,7 @@ def _infer_params_impl(
|
||||
HashableFunction(res_paths, closure=()),
|
||||
IgnoreKey(ji.inline))
|
||||
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)
|
||||
|
||||
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
|
||||
ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
|
||||
|
@ -2007,7 +2007,7 @@ tf_impl[lax.le_to_p] = handle_boolean_args(partial(_total_order_cond, tf.math.le
|
||||
tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky
|
||||
|
||||
|
||||
def _convert_element_type(operand, *, new_dtype, weak_type=False):
|
||||
def _convert_element_type(operand, *, new_dtype, weak_type=False, sharding=None):
|
||||
old_dtype = operand.dtype.as_numpy_dtype
|
||||
if (dtypes.issubdtype(old_dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
|
@ -99,7 +99,8 @@ def argwhere(
|
||||
) -> Array: ...
|
||||
around = round
|
||||
def array(object: Any, dtype: DTypeLike | None = ..., copy: builtins.bool = True,
|
||||
order: str | None = ..., ndmin: int = ...) -> Array: ...
|
||||
order: str | None = ..., ndmin: int = ..., *,
|
||||
device: _Device | _Sharding | None = None) -> Array: ...
|
||||
def array_equal(
|
||||
a1: ArrayLike, a2: ArrayLike, equal_nan: builtins.bool = ...
|
||||
) -> Array: ...
|
||||
|
@ -37,6 +37,7 @@ from jax import dtypes
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax.lax import with_sharding_constraint
|
||||
from jax._src import prng
|
||||
from jax.sharding import PartitionSpec as P, Mesh
|
||||
@ -4240,6 +4241,77 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, np_inp)
|
||||
self.assertEqual(out.sharding, s2)
|
||||
|
||||
def test_convert_element_type_sharding(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = np.arange(16).reshape(8, 2)
|
||||
|
||||
out = lax_internal._convert_element_type(
|
||||
inp, new_dtype=np.float32, weak_type=False, sharding=s)
|
||||
self.assertArraysEqual(out, inp.astype('float32'))
|
||||
self.assertEqual(out.dtype, np.float32)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
def test_jnp_array_sharding(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = np.arange(16).reshape(8, 2)
|
||||
|
||||
out = jnp.array(inp, device=s)
|
||||
self.assertArraysEqual(out, inp)
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
def test_jnp_array_inside_jit_sharding(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp = np.arange(16).reshape(8, 2)
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
return jnp.array(inp, dtype=np.float32, device=s)
|
||||
|
||||
out = f()
|
||||
print(f.trace().jaxpr)
|
||||
self.assertArraysEqual(out, inp.astype('float32'))
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertEqual(out.dtype, np.float32)
|
||||
|
||||
@jax.jit
|
||||
def g(x):
|
||||
return jnp.array(x, dtype=np.float32, device=s)
|
||||
|
||||
out2 = g(inp)
|
||||
self.assertArraysEqual(out2, inp.astype('float32'))
|
||||
self.assertEqual(out2.sharding, s)
|
||||
self.assertEqual(out2.dtype, np.float32)
|
||||
|
||||
def test_jnp_array_reshard_error(self):
|
||||
if jax.device_count() < 2:
|
||||
self.skipTest('Requires >=2 devices')
|
||||
arr = jax.device_put(np.arange(8), jax.devices()[0])
|
||||
with self.assertRaisesRegex(ValueError, "Received incompatible devices.*"):
|
||||
jnp.array(arr, device=jax.devices()[1])
|
||||
|
||||
def test_jnp_array_sharded_array_no_op(self):
|
||||
inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(inp, jax.devices()[0])
|
||||
|
||||
out = lax_internal._convert_element_type(
|
||||
arr, sharding=SingleDeviceSharding(jax.devices()[0]))
|
||||
self.assertArraysEqual(out, inp)
|
||||
self.assertEqual(out.unsafe_buffer_pointer(), arr.unsafe_buffer_pointer())
|
||||
|
||||
def test_wsc_named_sharding_nullary(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
s = NamedSharding(mesh, P())
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
return jax.lax.with_sharding_constraint(jnp.arange(8), s)
|
||||
|
||||
out = f()
|
||||
self.assertEqual(out.sharding, s)
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
|
||||
|
Loading…
x
Reference in New Issue
Block a user