mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Adding support for a special value for in_axis_resources (pjit.FROM_GSDA
) when GSDA is an input.
PiperOrigin-RevId: 411148899
This commit is contained in:
parent
34a2ffcfb6
commit
d1de309410
@ -2524,7 +2524,8 @@ def _pjit(*args: TfVal,
|
||||
resource_env: maps.ResourceEnv,
|
||||
donated_invars,
|
||||
name: str,
|
||||
positional_semantics,
|
||||
in_positional_semantics,
|
||||
out_positional_semantics,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray) -> TfVal:
|
||||
del donated_invars
|
||||
|
@ -45,6 +45,11 @@ from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
||||
split_list, cache, tuple_insert)
|
||||
xops = xc._xla.ops
|
||||
|
||||
class _FromGsdaSingleton:
|
||||
pass
|
||||
FROM_GSDA = _FromGsdaSingleton()
|
||||
|
||||
# TODO(yashkatariya): Add pjit microbenchmarks.
|
||||
def pjit(fun: Callable,
|
||||
in_axis_resources,
|
||||
out_axis_resources,
|
||||
@ -211,16 +216,22 @@ def pjit(fun: Callable,
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
|
||||
# TODO(yashkatariya): Remove `is_gsda` check when special value for in_axis_resources
|
||||
# is added for GSDA.
|
||||
is_gsda = all(isinstance(a, GSDA) for a in args_flat)
|
||||
# TODO(yashkatariya): This is a hack. This should go away when avals have
|
||||
# is_global attribute.
|
||||
in_positional_semantics = tuple(
|
||||
maps._PositionalSemantics.GLOBAL
|
||||
if type(a) is GSDA else maps._positional_semantics
|
||||
for a in args_flat)
|
||||
out_positional_semantics = maps._positional_semantics
|
||||
jaxpr, in_axis_resources_flat, out_axis_resources_flat = _pjit_jaxpr(
|
||||
flat_fun, mesh, local_in_avals, in_tree,
|
||||
hashable_pytree(in_axis_resources),
|
||||
HashableFunction(out_tree, closure=()),
|
||||
hashable_pytree(out_axis_resources),
|
||||
(maps._PositionalSemantics.GLOBAL
|
||||
if is_gsda else maps._positional_semantics))
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
tuple(isinstance(a, GSDA) for a in args_flat))
|
||||
in_axis_resources_flat = tree_map(_canonicalize_spec, in_axis_resources_flat,
|
||||
tuple(args_flat))
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_axis_resources=in_axis_resources_flat,
|
||||
@ -228,7 +239,8 @@ def pjit(fun: Callable,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=flat_fun.__name__,
|
||||
positional_semantics=maps._positional_semantics)
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
return args_flat, params, in_tree, out_tree(), donate_argnums
|
||||
|
||||
@wraps(fun)
|
||||
@ -245,7 +257,8 @@ def pjit(fun: Callable,
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], params['in_axis_resources'],
|
||||
params['out_axis_resources'], params['resource_env'],
|
||||
params['donated_invars'], params['name'], maps._positional_semantics)
|
||||
params['donated_invars'], params['name'],
|
||||
params['in_positional_semantics'], params['out_positional_semantics'])
|
||||
return Lowered(lowering, in_tree, out_tree, donate_argnums, no_kwargs=True)
|
||||
|
||||
wrapped.lower = lower
|
||||
@ -274,13 +287,19 @@ def flatten_axis_resources(what, tree, axis_resources, tupled_args):
|
||||
def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
in_tree, in_axis_resources_thunk,
|
||||
out_tree, out_axis_resources_thunk,
|
||||
positional_semantics):
|
||||
in_positional_semantics, out_positional_semantics, is_gsda):
|
||||
# TODO(yashkatariya): Make this work with FROM_GSDA special value.
|
||||
in_axis_resources_flat = flatten_axis_resources(
|
||||
"pjit in_axis_resources", in_tree,
|
||||
in_axis_resources_thunk(), tupled_args=True)
|
||||
# This check should be above local_to_global call below otherwise if
|
||||
# `FROM_GSDA` is passed to any input other than GSDA, a ugly error message
|
||||
# will be raised because get_array_mapping (in local_to_global) of a
|
||||
# FROM_GSDA cannot happen.
|
||||
tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gsda)
|
||||
_check_shapes_against_resources("pjit arguments", False, mesh.local_mesh.shape,
|
||||
local_in_avals, in_axis_resources_flat)
|
||||
global_in_avals = local_to_global(positional_semantics, mesh,
|
||||
global_in_avals = local_to_global(in_positional_semantics, mesh,
|
||||
local_in_avals, in_axis_resources_flat)
|
||||
|
||||
prev_positional = maps._positional_semantics
|
||||
@ -383,13 +402,22 @@ def _prepare_axis_resources(axis_resources, arg_name):
|
||||
# to explicitly declare them as such
|
||||
entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None)
|
||||
what = f"{arg_name} leaf specifications"
|
||||
entries = [ParsedPartitionSpec.from_user_input(entry, what) for entry in entries]
|
||||
entries = [
|
||||
entry if entry is FROM_GSDA else ParsedPartitionSpec.from_user_input(
|
||||
entry, what) for entry in entries
|
||||
]
|
||||
_check_unique_resources(entries, arg_name)
|
||||
return tree_unflatten(treedef, entries), entries, treedef
|
||||
|
||||
def _check_resources_mismatch(in_axis_resources_flat, is_gsda):
|
||||
if not is_gsda and in_axis_resources_flat is FROM_GSDA:
|
||||
raise ValueError('For a non-GSDA input, the corresponding resource in '
|
||||
'in_axis_resources cannot be `pjit.FROM_GSDA`.')
|
||||
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
if arg_axis_resources is FROM_GSDA: continue
|
||||
resource_counts = Counter(it.chain.from_iterable(arg_axis_resources))
|
||||
if not resource_counts: continue
|
||||
if resource_counts.most_common(1)[0][1] > 1:
|
||||
@ -402,6 +430,8 @@ def _check_unique_resources(axis_resources, arg_name):
|
||||
def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape, flat_avals, flat_axis_resources):
|
||||
global_str = " global" if is_global_shape else ""
|
||||
for aval, aval_axis_resources in zip(flat_avals, flat_axis_resources):
|
||||
if aval_axis_resources is FROM_GSDA:
|
||||
continue
|
||||
shape = aval.shape
|
||||
if len(shape) < len(aval_axis_resources):
|
||||
raise ValueError(f"One of {what} was given the resource assignment "
|
||||
@ -430,10 +460,11 @@ pjit_p.multiple_results = True
|
||||
def _pjit_call_impl(*args, jaxpr,
|
||||
in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name,
|
||||
positional_semantics):
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name, positional_semantics).compile()
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics).compile()
|
||||
distributed_debug_log(("Running pjit'd function", name),
|
||||
("mesh", resource_env.physical_mesh))
|
||||
return compiled.unsafe_call(*args)
|
||||
@ -447,30 +478,36 @@ def _pjit_lower(
|
||||
resource_env,
|
||||
donated_invars,
|
||||
name: str,
|
||||
positional_semantics):
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
in_axes = [get_array_mapping(axes) for axes in in_axis_resources]
|
||||
out_axes = [get_array_mapping(axes) for axes in out_axis_resources]
|
||||
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
|
||||
f = core.jaxpr_as_fun(jaxpr)
|
||||
f.__name__ = name
|
||||
fun = lu.wrap_init(f)
|
||||
local_in_avals = global_to_local(positional_semantics, resource_env.physical_mesh,
|
||||
# TODO(yashkatariya): Passing in out_positional_semantics is a hack.
|
||||
# This logic should get replaced with `is_global` attribute exists on aval.
|
||||
local_in_avals = global_to_local(out_positional_semantics, resource_env.physical_mesh,
|
||||
jaxpr.in_avals, in_axis_resources)
|
||||
# TODO(yashkatariya): Passing positional_semantics is a hack. This should go
|
||||
# away when avals have is_global attribute.
|
||||
return pxla.lower_mesh_computation(
|
||||
fun, name, resource_env.physical_mesh,
|
||||
in_axes, out_axes, donated_invars,
|
||||
True, local_in_avals, tile_by_mesh_axes=False)
|
||||
|
||||
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env, positional_semantics, **_):
|
||||
return global_to_local(positional_semantics, resource_env.physical_mesh,
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env,
|
||||
out_positional_semantics, **_):
|
||||
return global_to_local(out_positional_semantics, resource_env.physical_mesh,
|
||||
jaxpr.out_avals, out_axis_resources)
|
||||
pjit_p.def_abstract_eval(_pjit_abstract_eval)
|
||||
|
||||
|
||||
def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, positional_semantics):
|
||||
resource_env, donated_invars, in_positional_semantics,
|
||||
out_positional_semantics):
|
||||
mesh = resource_env.physical_mesh
|
||||
subc = xc.XlaBuilder(f"pjit_{name}")
|
||||
|
||||
@ -503,7 +540,8 @@ def _pjit_batcher(insert_axis,
|
||||
axis_size, axis_name, main_type,
|
||||
vals_in, dims_in,
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name, positional_semantics):
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics):
|
||||
# batch_jaxpr expects all batching dimensions to be equal to 0
|
||||
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(vals_in, dims_in)]
|
||||
@ -527,7 +565,8 @@ def _pjit_batcher(insert_axis,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=name,
|
||||
positional_semantics=positional_semantics)
|
||||
in_positional_semantics=in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
|
||||
return vals_out, dims_out
|
||||
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
|
||||
@ -536,7 +575,8 @@ pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True)
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name, positional_semantics):
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics):
|
||||
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
|
||||
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
|
||||
jaxpr, is_nz_tangents_in, instantiate=False)
|
||||
@ -553,7 +593,8 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
resource_env=resource_env,
|
||||
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
||||
name=wrap_name(name, 'jvp'),
|
||||
positional_semantics=positional_semantics)
|
||||
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
|
||||
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
|
||||
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
|
||||
@ -565,7 +606,8 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name, positional_semantics):
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics):
|
||||
# XXX: At the moment all residuals get fully replicated, which is extremely
|
||||
# wasteful and might quickly lead to OOM errors.
|
||||
mesh = resource_env.physical_mesh
|
||||
@ -597,7 +639,8 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
resource_env=resource_env,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
name=name,
|
||||
positional_semantics=positional_semantics)
|
||||
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
|
||||
if num_residuals:
|
||||
executable = _pjit_lower(**known_params).compile(
|
||||
@ -636,11 +679,16 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
donated_invars=(keep_where(donated_invars, unknown_ins) +
|
||||
(False,) * num_residuals),
|
||||
name=name,
|
||||
positional_semantics=positional_semantics)
|
||||
in_positional_semantics=(keep_where(
|
||||
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
unknown_tracers_out = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
for aval in global_to_local(positional_semantics, mesh, unknown_jaxpr.out_avals,
|
||||
unknown_params['out_axis_resources'])]
|
||||
unknown_tracers_out = [
|
||||
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
|
||||
for aval in global_to_local(unknown_params["out_positional_semantics"],
|
||||
mesh, unknown_jaxpr.out_avals,
|
||||
unknown_params["out_axis_resources"])
|
||||
]
|
||||
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
|
||||
unknown_tracers_out,
|
||||
pjit_p,
|
||||
@ -653,7 +701,8 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
|
||||
|
||||
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name, positional_semantics):
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics):
|
||||
mesh = resource_env.physical_mesh
|
||||
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
@ -668,8 +717,12 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
*prune_type(ad.UndefinedPrimal, in_axis_resources, primals_in),
|
||||
*prune_type(ad.Zero, out_axis_resources, cts_in)
|
||||
)
|
||||
transpose_in_positional_semantics = (
|
||||
*prune_type(ad.UndefinedPrimal, in_positional_semantics, primals_in),
|
||||
*prune_type(ad.Zero, (out_positional_semantics,) * len(cts_in), cts_in)
|
||||
)
|
||||
global_cts_in_avals = local_to_global(
|
||||
positional_semantics,
|
||||
transpose_in_positional_semantics,
|
||||
mesh,
|
||||
[core.raise_to_shaped(core.get_aval(ct)) for ct in primals_and_nz_cts_in],
|
||||
transpose_in_axis_resources)
|
||||
@ -692,7 +745,8 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
|
||||
resource_env=resource_env,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
name=name,
|
||||
positional_semantics=positional_semantics)
|
||||
in_positional_semantics=transpose_in_positional_semantics,
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
return tree_unflatten(cts_out_treedef, nz_cts_out)
|
||||
ad.reducing_transposes[pjit_p] = _pjit_transpose
|
||||
|
||||
@ -815,16 +869,39 @@ def get_aval_sharding_proto(aval: core.AbstractValue,
|
||||
return sharding_spec.sharding_proto()
|
||||
|
||||
def global_to_local(positional_semantics, mesh, avals, axes):
|
||||
if positional_semantics == maps._PositionalSemantics.GLOBAL:
|
||||
return avals
|
||||
return [mesh.global_to_local(get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes in zip(avals, axes)]
|
||||
if isinstance(positional_semantics, maps._PositionalSemantics):
|
||||
positional_semantics = [positional_semantics] * len(axes)
|
||||
return [
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.global_to_local(
|
||||
get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
|
||||
]
|
||||
|
||||
def local_to_global(positional_semantics, mesh, avals, axes):
|
||||
if positional_semantics == maps._PositionalSemantics.GLOBAL:
|
||||
return avals
|
||||
return [mesh.local_to_global(get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes in zip(avals, axes)]
|
||||
return [
|
||||
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.local_to_global(
|
||||
get_array_mapping(aval_axes), aval)
|
||||
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
|
||||
]
|
||||
|
||||
def _canonicalize_spec(in_axis_resources_flat: ParsedPartitionSpec, arg):
|
||||
if isinstance(arg, GSDA):
|
||||
gsda_ppspec = gsda_mesh_axes_to_parsed_pspec(arg._mesh_axes)
|
||||
if in_axis_resources_flat is not FROM_GSDA and in_axis_resources_flat != gsda_ppspec:
|
||||
raise ValueError(
|
||||
'Got an input GSDA to pjit with different partitioning than specified in '
|
||||
'the in_axis_resources argument to pjit. The paritioning must match, or '
|
||||
'use `jax.experimental.pjit.FROM_GSDA` in `in_axis_resources`. '
|
||||
f'Got GSDA spec: {gsda_ppspec}, pjit spec: {in_axis_resources_flat}')
|
||||
return gsda_ppspec
|
||||
return in_axis_resources_flat
|
||||
|
||||
def gsda_mesh_axes_to_parsed_pspec(mesh_axes) -> ParsedPartitionSpec:
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
else:
|
||||
pspec = mesh_axes
|
||||
return ParsedPartitionSpec.from_user_input(pspec, arg_name='GSDA mesh_axes')
|
||||
|
||||
# -------------------- XLA OpSharding to PartitionSpec --------------------
|
||||
# Note that OpSharding is more expressive than PartitionSpecs, so it's not
|
||||
|
@ -33,7 +33,8 @@ from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap, mesh, Mesh
|
||||
from jax.experimental import gsda
|
||||
import jax.experimental.pjit as pjit_lib
|
||||
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
|
||||
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
|
||||
SpecSync, FROM_GSDA)
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lib import xla_client
|
||||
@ -609,7 +610,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with jax._src.config.gsda_out(True):
|
||||
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
|
||||
@partial(pjit, in_axis_resources=FROM_GSDA, out_axis_resources=P('x', 'y'))
|
||||
def f(x):
|
||||
return x @ x.T
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
@ -622,12 +623,11 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
for s in out.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
out1 = f(input_data)
|
||||
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
|
||||
for s in out.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
('For a non-GSDA input, the corresponding resource in '
|
||||
'in_axis_resources cannot be `pjit.FROM_GSDA`.')):
|
||||
f(input_data)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_multi_input_multi_output(self):
|
||||
@ -654,7 +654,8 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
with jax._src.config.gsda_out(True):
|
||||
@partial(
|
||||
pjit,
|
||||
in_axis_resources=(mesh_axes1, mesh_axes2, mesh_axes3, mesh_axes4),
|
||||
# `FROM_GSDA` will be replicated for all the inputs.
|
||||
in_axis_resources=FROM_GSDA,
|
||||
out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
|
||||
def f(x, y, z, a):
|
||||
return x @ x.T, y, z, a
|
||||
@ -701,6 +702,41 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
for s in out4.local_shards:
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_mixed_inputs(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
with jax._src.config.gsda_out(True):
|
||||
@partial(pjit,
|
||||
in_axis_resources=(FROM_GSDA, P('x', 'y')),
|
||||
out_axis_resources=(P('x', 'y'), P(('x', 'y'))))
|
||||
def f(x, y):
|
||||
return x @ x.T, y @ y.T
|
||||
expected_matrix_mul = input_data @ input_data.T
|
||||
|
||||
out1, out2 = f(gsda_obj, input_data)
|
||||
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
|
||||
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, gsda.GlobalShardedDeviceArray)
|
||||
self.assertEqual(out2.shape, (8, 8))
|
||||
self.assertEqual(out2.local_shards[0].data.shape, (1, 8))
|
||||
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out2.local_shards:
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def test_pjit_gsda_mesh_mismatch(self):
|
||||
@ -718,11 +754,34 @@ class GSDAPjitTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Pjit's mesh and GSDA's mesh should be equal."):
|
||||
@partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y'))
|
||||
@partial(pjit, in_axis_resources=FROM_GSDA, out_axis_resources=P('x', 'y'))
|
||||
def f(x):
|
||||
return x
|
||||
f(gsda_obj)
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gsda_wrong_resource_for_gsda_input(self):
|
||||
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = ['x']
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, cb)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
("Got an input GSDA to pjit with different partitioning than specified "
|
||||
"in the in_axis_resources argument to pjit. The paritioning must "
|
||||
"match, or use `jax.experimental.pjit.FROM_GSDA` in `in_axis_resources`. "
|
||||
"Got GSDA spec: <partitions=(('x',),) sync=2>, "
|
||||
"pjit spec: <partitions=(('x',), ('y',)) sync=2>")):
|
||||
@partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y'))
|
||||
def f(x):
|
||||
return x
|
||||
f(gsda_obj)
|
||||
|
||||
|
||||
def spec_regex(s):
|
||||
|
Loading…
x
Reference in New Issue
Block a user