mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[shard-map] fix eager shmap+prngs, revise phys aval/sharding logic
Co-authored-by: Yash Katariya <yashkatariya@google.com>
This commit is contained in:
parent
f6da71c807
commit
728a5ed96a
@ -545,11 +545,15 @@ def flatten_lowering_ir_args(
|
||||
|
||||
_module_name_regex = re.compile(r"[^\w.-]")
|
||||
|
||||
def sharded_aval(aval: core.ShapedArray,
|
||||
sharding: Optional[xc.OpSharding]) -> core.ShapedArray:
|
||||
def sharded_aval(aval: core.AbstractValue,
|
||||
sharding: Optional[xc.OpSharding]) -> core.AbstractValue:
|
||||
"""Returns the new aval sharded based on sharding proto."""
|
||||
if sharding is None:
|
||||
return aval
|
||||
if isinstance(aval, core.AbstractToken):
|
||||
return aval
|
||||
if not isinstance(aval, core.ShapedArray):
|
||||
raise NotImplementedError
|
||||
|
||||
if (sharding.type == xc.OpSharding.Type.REPLICATED or
|
||||
sharding.type == xc.OpSharding.Type.MANUAL):
|
||||
@ -620,22 +624,10 @@ def lower_jaxpr_to_module(
|
||||
if not xb.is_known_platform(platform):
|
||||
raise ValueError(f"Unknown platform {platform}")
|
||||
input_output_aliases = None
|
||||
in_avals = jaxpr.in_avals
|
||||
if arg_shardings is not None:
|
||||
in_avals = [
|
||||
sharded_aval(in_aval, in_sharding)
|
||||
for in_aval, in_sharding in zip(in_avals, arg_shardings)
|
||||
]
|
||||
out_avals = jaxpr.out_avals
|
||||
if result_shardings is not None:
|
||||
out_avals = []
|
||||
for out_aval, out_sharding in zip(jaxpr.out_avals, result_shardings):
|
||||
if (out_aval is not core.abstract_token and
|
||||
core.is_opaque_dtype(out_aval.dtype)):
|
||||
# TODO(frostig,mattjj,necula): asserts a single physical aval
|
||||
out_aval, = out_aval.dtype._rules.physical_avals(out_aval)
|
||||
out_avals.append(sharded_aval(out_aval, out_sharding))
|
||||
|
||||
in_avals = (jaxpr.in_avals if arg_shardings is None else
|
||||
map(sharded_aval, jaxpr.in_avals, arg_shardings))
|
||||
out_avals = (jaxpr.out_avals if result_shardings is None else
|
||||
map(sharded_aval, jaxpr.out_avals, result_shardings))
|
||||
if platform in _platforms_with_donation:
|
||||
input_output_aliases, donated_args = _set_up_aliases(
|
||||
in_avals, out_avals, donated_args)
|
||||
@ -643,9 +635,7 @@ def lower_jaxpr_to_module(
|
||||
if unlowerable_effects:
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
|
||||
if any(donated_args):
|
||||
# TODO(tomhennigan): At call time we should mark these buffers as deleted.
|
||||
unused_donations = [str(a) for a, d in zip(in_avals, donated_args)
|
||||
if d]
|
||||
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
|
||||
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
||||
if platform not in _platforms_with_donation:
|
||||
msg = f"Donation is not implemented for {platform}.\n{msg}"
|
||||
@ -660,8 +650,7 @@ def lower_jaxpr_to_module(
|
||||
dim_vars: Sequence[str]
|
||||
if not config.jax_dynamic_shapes:
|
||||
# Find the dimension variables
|
||||
all_dim_poly = [d
|
||||
for aval in jaxpr.in_avals if hasattr(aval, "shape")
|
||||
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
|
||||
for d in aval.shape if not core.is_constant_dim(d)]
|
||||
dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new.get_vars()),
|
||||
all_dim_poly, set())))
|
||||
@ -886,14 +875,18 @@ def lower_jaxpr_to_fun(
|
||||
ctx.symbol_table.insert(func_op)
|
||||
ir_arg_shardings = None
|
||||
if arg_shardings is not None:
|
||||
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
|
||||
ir_arg_shardings = util.flatten(
|
||||
[[sharding] * len(types) for sharding, types
|
||||
in zip(arg_shardings, input_types)])
|
||||
[[_to_physical_op_sharding(a, s)] * len(types)
|
||||
for a, s, types in zip(in_avals, arg_shardings, input_types)])
|
||||
del in_avals
|
||||
ir_result_shardings = None
|
||||
if result_shardings is not None:
|
||||
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
|
||||
ir_result_shardings = util.flatten(
|
||||
[[sharding] * len(types)
|
||||
for sharding, types in zip(result_shardings, output_types)])
|
||||
[[_to_physical_op_sharding(a, s)] * len(types)
|
||||
for a, s, types in zip(out_avals, result_shardings, output_types)])
|
||||
del out_avals
|
||||
|
||||
if (replicated_args is not None or ir_arg_shardings is not None
|
||||
or input_output_aliases is not None):
|
||||
@ -1001,6 +994,15 @@ def lower_jaxpr_to_fun(
|
||||
|
||||
return func_op
|
||||
|
||||
def _to_physical_op_sharding(
|
||||
aval: Optional[core.AbstractValue], sharding: Optional[xc.OpSharding]
|
||||
) -> Optional[xc.OpSharding]:
|
||||
if (isinstance(aval, core.ShapedArray) and core.is_opaque_dtype(aval.dtype)
|
||||
and sharding is not None):
|
||||
return aval.dtype._rules.physical_op_sharding(aval, sharding)
|
||||
return sharding
|
||||
|
||||
|
||||
def _emit_lowering_rule_as_fun(lowering_rule,
|
||||
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
|
||||
"""Emits the contents of a lowering rule as a private function."""
|
||||
|
@ -2327,13 +2327,15 @@ def _get_and_check_device_assignment(
|
||||
return xb.get_device_backend(final_device_assignment[0]), final_device_assignment
|
||||
|
||||
|
||||
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]],
|
||||
out_shardings: Union[Sequence[Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]], UnspecifiedValue],
|
||||
in_shardings: Sequence[MaybeSharding],
|
||||
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue],
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
@ -2384,17 +2386,17 @@ def lower_sharding_computation(
|
||||
|
||||
if _is_unspecified(out_shardings):
|
||||
out_shardings = (_UNSPECIFIED,) * len(global_out_avals)
|
||||
# mypy doesn't understand that out_sharding here is always a sequence.
|
||||
assert len(out_shardings) == len(global_out_avals), ( # type: ignore
|
||||
len(out_shardings), len(global_out_avals)) # type: ignore
|
||||
assert isinstance(out_shardings, tuple)
|
||||
assert len(out_shardings) == len(global_out_avals), (
|
||||
len(out_shardings), len(global_out_avals))
|
||||
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
|
||||
backend, device_assignment = _get_and_check_device_assignment(
|
||||
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
|
||||
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], # type: ignore
|
||||
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) # type: ignore
|
||||
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
|
||||
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
|
||||
for js, source_info in jaxpr_sharding]),
|
||||
devices_from_context)
|
||||
|
||||
@ -2402,8 +2404,8 @@ def lower_sharding_computation(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not _is_unspecified(i) for i in in_shardings) or
|
||||
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore
|
||||
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
|
||||
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not _is_unspecified(o) for o in out_shardings))
|
||||
|
||||
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
if _is_unspecified(i) else i for i in in_shardings)
|
||||
@ -2445,7 +2447,7 @@ def lower_sharding_computation(
|
||||
# and don't need to evaluate their arguments.
|
||||
if (not always_lower and not (jaxpr.effects or has_outfeed) and
|
||||
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
|
||||
all(_is_unspecified(o) for o in out_shardings)): # type: ignore
|
||||
all(_is_unspecified(o) for o in out_shardings)):
|
||||
return MeshComputation(
|
||||
str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=consts,
|
||||
global_in_avals=global_in_avals, global_out_avals=global_out_avals,
|
||||
@ -2469,25 +2471,8 @@ def lower_sharding_computation(
|
||||
axis_ctx: mlir.AxisContext
|
||||
|
||||
if nreps == 1:
|
||||
in_op_shardings = []
|
||||
for aval, i in safe_zip(global_in_avals, in_shardings):
|
||||
if aval is core.abstract_token:
|
||||
in_op_shardings.append(None)
|
||||
elif core.is_opaque_dtype(aval.dtype):
|
||||
in_op_shardings.append(aval.dtype._rules.physical_op_sharding(aval, i))
|
||||
else:
|
||||
in_op_shardings.append(i._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
|
||||
|
||||
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
|
||||
# [None, OpShardingProto] has the sharding annotations.
|
||||
out_op_shardings = []
|
||||
for aval, o in safe_zip(global_out_avals, out_shardings): # type: ignore[arg-type]
|
||||
if _is_unspecified(o) or aval is core.abstract_token:
|
||||
out_op_shardings.append(None)
|
||||
elif core.is_opaque_dtype(aval.dtype):
|
||||
out_op_shardings.append(aval.dtype._rules.physical_op_sharding(aval, o))
|
||||
else:
|
||||
out_op_shardings.append(o._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
|
||||
in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings)
|
||||
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = mlir.ShardingContext(device_assignment)
|
||||
else:
|
||||
@ -2556,6 +2541,19 @@ def lower_sharding_computation(
|
||||
committed=committed,
|
||||
pmap_nreps=nreps)
|
||||
|
||||
def _to_logical_op_sharding(
|
||||
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
|
||||
) -> Optional[xc.OpSharding]:
|
||||
if _is_unspecified(sharding) or is_auto(sharding):
|
||||
return None
|
||||
elif isinstance(aval, ShapedArray):
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
return sharding._to_xla_op_sharding(aval.ndim)
|
||||
elif isinstance(aval, core.AbstractToken):
|
||||
return None
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_mesh_computation(
|
||||
@ -2575,7 +2573,7 @@ def lower_mesh_computation(
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
auto_spmd_lowering = check_if_any_auto(in_shardings + out_shardings) # type: ignore
|
||||
auto_spmd_lowering = check_if_any_auto((*in_shardings, *out_shardings))
|
||||
|
||||
if auto_spmd_lowering and not spmd_lowering:
|
||||
raise ValueError('Enable spmd_lowering to use auto spmd lowering.')
|
||||
@ -2653,25 +2651,8 @@ def lower_mesh_computation(
|
||||
out_partitions: Optional[List[Optional[xc.OpSharding]]]
|
||||
axis_ctx: mlir.AxisContext
|
||||
if spmd_lowering:
|
||||
in_partitions = []
|
||||
for aval, i in safe_zip(global_in_avals, in_shardings):
|
||||
if is_auto(i):
|
||||
in_partitions.append(None)
|
||||
elif core.is_opaque_dtype(aval.dtype):
|
||||
in_partitions.append(aval.dtype._rules.physical_op_sharding(aval, i))
|
||||
else:
|
||||
in_partitions.append(i._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
|
||||
|
||||
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
|
||||
# [None, OpShardingProto] has the sharding annotations.
|
||||
out_partitions = []
|
||||
for aval, o in safe_zip(global_out_avals, out_shardings):
|
||||
if is_auto(o) or _is_unspecified(o):
|
||||
out_partitions.append(None)
|
||||
elif core.is_opaque_dtype(aval.dtype):
|
||||
out_partitions.append(aval.dtype._rules.physical_op_sharding(aval, o))
|
||||
else:
|
||||
out_partitions.append(o._to_xla_op_sharding(aval.ndim)) # type: ignore[union-attr]
|
||||
in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings)
|
||||
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(in_jaxpr_avals)
|
||||
axis_ctx = mlir.SPMDAxisContext(mesh, manual_axes)
|
||||
else:
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from functools import partial, reduce
|
||||
@ -176,7 +176,12 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
|
||||
_device = property(op.attrgetter('_base_array._device'))
|
||||
_committed = property(op.attrgetter('_base_array._committed'))
|
||||
sharding = property(op.attrgetter('_base_array.sharding'))
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
aval = keys_shaped_array(self.impl, self.shape)
|
||||
phys_sharding = self._base_array.sharding
|
||||
return KeyTyRules.logical_op_sharding(aval, phys_sharding)
|
||||
|
||||
def _is_scalar(self):
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
@ -250,14 +255,14 @@ class PRNGKeyArray(metaclass=PRNGKeyArrayMeta):
|
||||
# `typing.type_check_only`.
|
||||
|
||||
@property
|
||||
def T(self) -> 'PRNGKeyArray': assert False
|
||||
def __getitem__(self, _) -> 'PRNGKeyArray': assert False
|
||||
def ravel(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def squeeze(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def swapaxes(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def take(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def transpose(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def flatten(self, *_, **__) -> 'PRNGKeyArray': assert False
|
||||
def T(self) -> PRNGKeyArray: assert False
|
||||
def __getitem__(self, _) -> PRNGKeyArray: assert False
|
||||
def ravel(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def squeeze(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def swapaxes(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def take(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def transpose(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
def flatten(self, *_, **__) -> PRNGKeyArray: assert False
|
||||
|
||||
|
||||
_set_device_array_base_attributes(PRNGKeyArray, include=[
|
||||
@ -303,9 +308,10 @@ def make_key_array_phys_sharding(aval, sharding, is_sharding_from_xla):
|
||||
elif is_sharding_from_xla:
|
||||
return sharding
|
||||
else:
|
||||
sharding_proto = sharding._to_xla_op_sharding(aval.ndim)
|
||||
return GSPMDSharding(
|
||||
sharding._device_assignment,
|
||||
KeyTyRules.physical_op_sharding(aval, sharding))
|
||||
KeyTyRules.physical_op_sharding(aval, sharding_proto))
|
||||
|
||||
class KeyTyRules:
|
||||
|
||||
@ -316,16 +322,25 @@ class KeyTyRules:
|
||||
jnp.dtype('uint32'))]
|
||||
|
||||
@staticmethod
|
||||
def physical_op_sharding(aval, sharding):
|
||||
op_sharding = sharding._to_xla_op_sharding(aval.ndim)
|
||||
def physical_op_sharding(aval, op_sharding_proto):
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
|
||||
new_op_sharding = op_sharding.clone()
|
||||
new_op_sharding = op_sharding_proto.clone()
|
||||
tad = list(new_op_sharding.tile_assignment_dimensions)
|
||||
tad.extend([1] * len(key_shape))
|
||||
new_op_sharding.tile_assignment_dimensions = tad
|
||||
return new_op_sharding
|
||||
|
||||
@staticmethod
|
||||
def logical_op_sharding(aval, phys_sharding) -> GSPMDSharding:
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
phys_op_sharding = phys_sharding._to_xla_op_sharding(
|
||||
aval.ndim + len(key_shape))
|
||||
logical_op_sharding = phys_op_sharding.clone()
|
||||
tad = list(logical_op_sharding.tile_assignment_dimensions)
|
||||
tad = tad[:-len(key_shape)]
|
||||
logical_op_sharding.tile_assignment_dimensions = tad
|
||||
return GSPMDSharding(phys_sharding._device_assignment, logical_op_sharding)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
def handler(_, buf):
|
||||
@ -367,7 +382,6 @@ class KeyTyRules:
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
aval, out_sharding, is_out_sharding_from_xla)
|
||||
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
|
||||
is_out_sharding_from_xla)
|
||||
def handler(bufs):
|
||||
|
@ -33,6 +33,7 @@ from jax._src import debugging
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import ops
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -473,9 +474,11 @@ def _xla_shard(mesh, names, aval_in, aval_out, x):
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
sharding_proto = pxla.new_mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
||||
aval_in.ndim, axes).sharding_proto()
|
||||
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=set())
|
||||
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_in.ndim)
|
||||
if core.is_opaque_dtype(aval_in.dtype):
|
||||
shard_proto = aval_in.dtype._rules.physical_op_sharding(aval_in, shard_proto)
|
||||
sx = mlir.wrap_with_sharding_op(x, shard_proto, unspecified_dims=set())
|
||||
return [mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, set())]
|
||||
|
||||
def _xla_unshard(mesh, names, aval_in, aval_out, xs):
|
||||
@ -484,9 +487,11 @@ def _xla_unshard(mesh, names, aval_in, aval_out, xs):
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
sharding_proto = pxla.new_mesh_sharding_specs(mesh.shape, mesh.axis_names)(
|
||||
aval_out.ndim, axes).sharding_proto()
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, set())
|
||||
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_out.ndim)
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
shard_proto = aval_out.dtype._rules.physical_op_sharding(aval_out, shard_proto)
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, shard_proto, set())
|
||||
|
||||
# Eager evaluation
|
||||
|
||||
@ -652,7 +657,7 @@ def _standard_rep_rule(_, *in_rep, **__):
|
||||
for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
|
||||
windowed_reductions.__dict__.values(), fft.__dict__.values(),
|
||||
linalg.__dict__.values(), ops.__dict__.values(),
|
||||
ad_util.__dict__.values(),
|
||||
ad_util.__dict__.values(), prng.__dict__.values(),
|
||||
custom_derivatives.__dict__.values()):
|
||||
if isinstance(o, core.Primitive): register_standard(o)
|
||||
|
||||
|
@ -2839,8 +2839,11 @@ class FooTyRules:
|
||||
return [core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))]
|
||||
|
||||
@staticmethod
|
||||
def physical_op_sharding(aval, sharding):
|
||||
return sharding._to_xla_op_sharding(aval.ndim)
|
||||
def physical_op_sharding(aval, op_sharding_proto):
|
||||
new_op_sharding = op_sharding_proto.clone()
|
||||
tad = list(new_op_sharding.tile_assignment_dimensions)
|
||||
new_op_sharding.tile_assignment_dimensions = [*tad, 1]
|
||||
return new_op_sharding
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
|
@ -36,7 +36,7 @@ from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
from jax.lax import with_sharding_constraint
|
||||
from jax import prng
|
||||
from jax._src import prng
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental import multihost_utils
|
||||
|
@ -553,6 +553,22 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jax.eval_shape(jax.grad(lambda x: jax.remat(f)(x).sum().astype('float32')),
|
||||
xs)
|
||||
|
||||
def test_prngkeyarray_eager(self):
|
||||
# https://github.com/google/jax/issues/15398
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
sharding = jax.sharding.NamedSharding(mesh, P('x'))
|
||||
|
||||
rng = jax.random.PRNGKey(0)
|
||||
sharded_rng = jax.random.split(rng, num=4)
|
||||
sharded_rng = jax.device_put(sharded_rng, sharding)
|
||||
|
||||
def f(key):
|
||||
return jax.random.randint(key[0], shape=(1, 16), minval=0, maxval=16,
|
||||
dtype=jnp.int32)
|
||||
|
||||
g = shard_map(f, mesh, in_specs=(P('x', None),), out_specs=P('x', None))
|
||||
_ = g(sharded_rng) # don't crash!
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user