Adding support for a special value for in_axis_resources (pjit.FROM_GSDA) when GSDA is an input.

PiperOrigin-RevId: 411148899
This commit is contained in:
Yash Katariya 2021-11-19 14:47:32 -08:00 committed by jax authors
parent 34a2ffcfb6
commit d1de309410
3 changed files with 186 additions and 49 deletions

View File

@ -2524,7 +2524,8 @@ def _pjit(*args: TfVal,
resource_env: maps.ResourceEnv,
donated_invars,
name: str,
positional_semantics,
in_positional_semantics,
out_positional_semantics,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray) -> TfVal:
del donated_invars

View File

@ -45,6 +45,11 @@ from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
split_list, cache, tuple_insert)
xops = xc._xla.ops
class _FromGsdaSingleton:
pass
FROM_GSDA = _FromGsdaSingleton()
# TODO(yashkatariya): Add pjit microbenchmarks.
def pjit(fun: Callable,
in_axis_resources,
out_axis_resources,
@ -211,16 +216,22 @@ def pjit(fun: Callable,
donated_invars = (False,) * len(args_flat)
local_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
# TODO(yashkatariya): Remove `is_gsda` check when special value for in_axis_resources
# is added for GSDA.
is_gsda = all(isinstance(a, GSDA) for a in args_flat)
# TODO(yashkatariya): This is a hack. This should go away when avals have
# is_global attribute.
in_positional_semantics = tuple(
maps._PositionalSemantics.GLOBAL
if type(a) is GSDA else maps._positional_semantics
for a in args_flat)
out_positional_semantics = maps._positional_semantics
jaxpr, in_axis_resources_flat, out_axis_resources_flat = _pjit_jaxpr(
flat_fun, mesh, local_in_avals, in_tree,
hashable_pytree(in_axis_resources),
HashableFunction(out_tree, closure=()),
hashable_pytree(out_axis_resources),
(maps._PositionalSemantics.GLOBAL
if is_gsda else maps._positional_semantics))
in_positional_semantics, out_positional_semantics,
tuple(isinstance(a, GSDA) for a in args_flat))
in_axis_resources_flat = tree_map(_canonicalize_spec, in_axis_resources_flat,
tuple(args_flat))
params = dict(
jaxpr=jaxpr,
in_axis_resources=in_axis_resources_flat,
@ -228,7 +239,8 @@ def pjit(fun: Callable,
resource_env=resource_env,
donated_invars=donated_invars,
name=flat_fun.__name__,
positional_semantics=maps._positional_semantics)
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics)
return args_flat, params, in_tree, out_tree(), donate_argnums
@wraps(fun)
@ -245,7 +257,8 @@ def pjit(fun: Callable,
lowering = _pjit_lower(
params['jaxpr'], params['in_axis_resources'],
params['out_axis_resources'], params['resource_env'],
params['donated_invars'], params['name'], maps._positional_semantics)
params['donated_invars'], params['name'],
params['in_positional_semantics'], params['out_positional_semantics'])
return Lowered(lowering, in_tree, out_tree, donate_argnums, no_kwargs=True)
wrapped.lower = lower
@ -274,13 +287,19 @@ def flatten_axis_resources(what, tree, axis_resources, tupled_args):
def _pjit_jaxpr(fun, mesh, local_in_avals,
in_tree, in_axis_resources_thunk,
out_tree, out_axis_resources_thunk,
positional_semantics):
in_positional_semantics, out_positional_semantics, is_gsda):
# TODO(yashkatariya): Make this work with FROM_GSDA special value.
in_axis_resources_flat = flatten_axis_resources(
"pjit in_axis_resources", in_tree,
in_axis_resources_thunk(), tupled_args=True)
# This check should be above local_to_global call below otherwise if
# `FROM_GSDA` is passed to any input other than GSDA, a ugly error message
# will be raised because get_array_mapping (in local_to_global) of a
# FROM_GSDA cannot happen.
tree_map(_check_resources_mismatch, in_axis_resources_flat, is_gsda)
_check_shapes_against_resources("pjit arguments", False, mesh.local_mesh.shape,
local_in_avals, in_axis_resources_flat)
global_in_avals = local_to_global(positional_semantics, mesh,
global_in_avals = local_to_global(in_positional_semantics, mesh,
local_in_avals, in_axis_resources_flat)
prev_positional = maps._positional_semantics
@ -383,13 +402,22 @@ def _prepare_axis_resources(axis_resources, arg_name):
# to explicitly declare them as such
entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None)
what = f"{arg_name} leaf specifications"
entries = [ParsedPartitionSpec.from_user_input(entry, what) for entry in entries]
entries = [
entry if entry is FROM_GSDA else ParsedPartitionSpec.from_user_input(
entry, what) for entry in entries
]
_check_unique_resources(entries, arg_name)
return tree_unflatten(treedef, entries), entries, treedef
def _check_resources_mismatch(in_axis_resources_flat, is_gsda):
if not is_gsda and in_axis_resources_flat is FROM_GSDA:
raise ValueError('For a non-GSDA input, the corresponding resource in '
'in_axis_resources cannot be `pjit.FROM_GSDA`.')
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if arg_axis_resources is FROM_GSDA: continue
resource_counts = Counter(it.chain.from_iterable(arg_axis_resources))
if not resource_counts: continue
if resource_counts.most_common(1)[0][1] > 1:
@ -402,6 +430,8 @@ def _check_unique_resources(axis_resources, arg_name):
def _check_shapes_against_resources(what: str, is_global_shape: bool, mesh_shape, flat_avals, flat_axis_resources):
global_str = " global" if is_global_shape else ""
for aval, aval_axis_resources in zip(flat_avals, flat_axis_resources):
if aval_axis_resources is FROM_GSDA:
continue
shape = aval.shape
if len(shape) < len(aval_axis_resources):
raise ValueError(f"One of {what} was given the resource assignment "
@ -430,10 +460,11 @@ pjit_p.multiple_results = True
def _pjit_call_impl(*args, jaxpr,
in_axis_resources, out_axis_resources,
resource_env, donated_invars, name,
positional_semantics):
in_positional_semantics, out_positional_semantics):
compiled = _pjit_lower(
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, positional_semantics).compile()
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics).compile()
distributed_debug_log(("Running pjit'd function", name),
("mesh", resource_env.physical_mesh))
return compiled.unsafe_call(*args)
@ -447,30 +478,36 @@ def _pjit_lower(
resource_env,
donated_invars,
name: str,
positional_semantics):
in_positional_semantics, out_positional_semantics):
in_axes = [get_array_mapping(axes) for axes in in_axis_resources]
out_axes = [get_array_mapping(axes) for axes in out_axis_resources]
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
fun = lu.wrap_init(f)
local_in_avals = global_to_local(positional_semantics, resource_env.physical_mesh,
# TODO(yashkatariya): Passing in out_positional_semantics is a hack.
# This logic should get replaced with `is_global` attribute exists on aval.
local_in_avals = global_to_local(out_positional_semantics, resource_env.physical_mesh,
jaxpr.in_avals, in_axis_resources)
# TODO(yashkatariya): Passing positional_semantics is a hack. This should go
# away when avals have is_global attribute.
return pxla.lower_mesh_computation(
fun, name, resource_env.physical_mesh,
in_axes, out_axes, donated_invars,
True, local_in_avals, tile_by_mesh_axes=False)
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env, positional_semantics, **_):
return global_to_local(positional_semantics, resource_env.physical_mesh,
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env,
out_positional_semantics, **_):
return global_to_local(out_positional_semantics, resource_env.physical_mesh,
jaxpr.out_avals, out_axis_resources)
pjit_p.def_abstract_eval(_pjit_abstract_eval)
def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, positional_semantics):
resource_env, donated_invars, in_positional_semantics,
out_positional_semantics):
mesh = resource_env.physical_mesh
subc = xc.XlaBuilder(f"pjit_{name}")
@ -503,7 +540,8 @@ def _pjit_batcher(insert_axis,
axis_size, axis_name, main_type,
vals_in, dims_in,
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, positional_semantics):
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics):
# batch_jaxpr expects all batching dimensions to be equal to 0
vals_in = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(vals_in, dims_in)]
@ -527,7 +565,8 @@ def _pjit_batcher(insert_axis,
resource_env=resource_env,
donated_invars=donated_invars,
name=name,
positional_semantics=positional_semantics)
in_positional_semantics=in_positional_semantics,
out_positional_semantics=out_positional_semantics)
dims_out = [0 if batched else batching.not_mapped for batched in is_mapped_out]
return vals_out, dims_out
batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False)
@ -536,7 +575,8 @@ pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True)
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, positional_semantics):
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics):
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
jaxpr, is_nz_tangents_in, instantiate=False)
@ -553,7 +593,8 @@ def _pjit_jvp(primals_in, tangents_in,
resource_env=resource_env,
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
name=wrap_name(name, 'jvp'),
positional_semantics=positional_semantics)
in_positional_semantics=(*in_positional_semantics, *_filter_zeros_in(in_positional_semantics)),
out_positional_semantics=out_positional_semantics)
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
@ -565,7 +606,8 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, positional_semantics):
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics):
# XXX: At the moment all residuals get fully replicated, which is extremely
# wasteful and might quickly lead to OOM errors.
mesh = resource_env.physical_mesh
@ -597,7 +639,8 @@ def _pjit_partial_eval(trace, *in_tracers,
resource_env=resource_env,
donated_invars=keep_where(donated_invars, known_ins),
name=name,
positional_semantics=positional_semantics)
in_positional_semantics=keep_where(in_positional_semantics, known_ins),
out_positional_semantics=out_positional_semantics)
if num_residuals:
executable = _pjit_lower(**known_params).compile(
@ -636,11 +679,16 @@ def _pjit_partial_eval(trace, *in_tracers,
donated_invars=(keep_where(donated_invars, unknown_ins) +
(False,) * num_residuals),
name=name,
positional_semantics=positional_semantics)
in_positional_semantics=(keep_where(
in_positional_semantics, unknown_ins) + (out_positional_semantics,) * num_residuals),
out_positional_semantics=out_positional_semantics)
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_tracers_out = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in global_to_local(positional_semantics, mesh, unknown_jaxpr.out_avals,
unknown_params['out_axis_resources'])]
unknown_tracers_out = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in global_to_local(unknown_params["out_positional_semantics"],
mesh, unknown_jaxpr.out_avals,
unknown_params["out_axis_resources"])
]
eqn = pe.new_eqn_recipe((*unknown_tracers_in, *residual_tracers),
unknown_tracers_out,
pjit_p,
@ -653,7 +701,8 @@ pe.custom_partial_eval_rules[pjit_p] = _pjit_partial_eval
def _pjit_transpose(reduce_axes, cts_in, *primals_in,
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, positional_semantics):
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics):
mesh = resource_env.physical_mesh
def prune_type(ty, xs, maybe_zeros):
@ -668,8 +717,12 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
*prune_type(ad.UndefinedPrimal, in_axis_resources, primals_in),
*prune_type(ad.Zero, out_axis_resources, cts_in)
)
transpose_in_positional_semantics = (
*prune_type(ad.UndefinedPrimal, in_positional_semantics, primals_in),
*prune_type(ad.Zero, (out_positional_semantics,) * len(cts_in), cts_in)
)
global_cts_in_avals = local_to_global(
positional_semantics,
transpose_in_positional_semantics,
mesh,
[core.raise_to_shaped(core.get_aval(ct)) for ct in primals_and_nz_cts_in],
transpose_in_axis_resources)
@ -692,7 +745,8 @@ def _pjit_transpose(reduce_axes, cts_in, *primals_in,
resource_env=resource_env,
donated_invars=(False,) * len(primals_and_nz_cts_in),
name=name,
positional_semantics=positional_semantics)
in_positional_semantics=transpose_in_positional_semantics,
out_positional_semantics=out_positional_semantics)
return tree_unflatten(cts_out_treedef, nz_cts_out)
ad.reducing_transposes[pjit_p] = _pjit_transpose
@ -815,16 +869,39 @@ def get_aval_sharding_proto(aval: core.AbstractValue,
return sharding_spec.sharding_proto()
def global_to_local(positional_semantics, mesh, avals, axes):
if positional_semantics == maps._PositionalSemantics.GLOBAL:
return avals
return [mesh.global_to_local(get_array_mapping(aval_axes), aval)
for aval, aval_axes in zip(avals, axes)]
if isinstance(positional_semantics, maps._PositionalSemantics):
positional_semantics = [positional_semantics] * len(axes)
return [
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.global_to_local(
get_array_mapping(aval_axes), aval)
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]
def local_to_global(positional_semantics, mesh, avals, axes):
if positional_semantics == maps._PositionalSemantics.GLOBAL:
return avals
return [mesh.local_to_global(get_array_mapping(aval_axes), aval)
for aval, aval_axes in zip(avals, axes)]
return [
aval if ps == maps._PositionalSemantics.GLOBAL else mesh.local_to_global(
get_array_mapping(aval_axes), aval)
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]
def _canonicalize_spec(in_axis_resources_flat: ParsedPartitionSpec, arg):
if isinstance(arg, GSDA):
gsda_ppspec = gsda_mesh_axes_to_parsed_pspec(arg._mesh_axes)
if in_axis_resources_flat is not FROM_GSDA and in_axis_resources_flat != gsda_ppspec:
raise ValueError(
'Got an input GSDA to pjit with different partitioning than specified in '
'the in_axis_resources argument to pjit. The paritioning must match, or '
'use `jax.experimental.pjit.FROM_GSDA` in `in_axis_resources`. '
f'Got GSDA spec: {gsda_ppspec}, pjit spec: {in_axis_resources_flat}')
return gsda_ppspec
return in_axis_resources_flat
def gsda_mesh_axes_to_parsed_pspec(mesh_axes) -> ParsedPartitionSpec:
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)
else:
pspec = mesh_axes
return ParsedPartitionSpec.from_user_input(pspec, arg_name='GSDA mesh_axes')
# -------------------- XLA OpSharding to PartitionSpec --------------------
# Note that OpSharding is more expressive than PartitionSpecs, so it's not

View File

@ -33,7 +33,8 @@ from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap, mesh, Mesh
from jax.experimental import gsda
import jax.experimental.pjit as pjit_lib
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
SpecSync, FROM_GSDA)
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src.lib import xla_client
@ -609,7 +610,7 @@ class GSDAPjitTest(jtu.JaxTestCase):
global_input_shape, global_mesh, mesh_axes, cb)
with jax._src.config.gsda_out(True):
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
@partial(pjit, in_axis_resources=FROM_GSDA, out_axis_resources=P('x', 'y'))
def f(x):
return x @ x.T
expected_matrix_mul = input_data @ input_data.T
@ -622,12 +623,11 @@ class GSDAPjitTest(jtu.JaxTestCase):
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
out1 = f(input_data)
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
for s in out.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
with self.assertRaisesRegex(
ValueError,
('For a non-GSDA input, the corresponding resource in '
'in_axis_resources cannot be `pjit.FROM_GSDA`.')):
f(input_data)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_multi_input_multi_output(self):
@ -654,7 +654,8 @@ class GSDAPjitTest(jtu.JaxTestCase):
with jax._src.config.gsda_out(True):
@partial(
pjit,
in_axis_resources=(mesh_axes1, mesh_axes2, mesh_axes3, mesh_axes4),
# `FROM_GSDA` will be replicated for all the inputs.
in_axis_resources=FROM_GSDA,
out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
def f(x, y, z, a):
return x @ x.T, y, z, a
@ -701,6 +702,41 @@ class GSDAPjitTest(jtu.JaxTestCase):
for s in out4.local_shards:
self.assertArraysEqual(s.data, input_data[s.index])
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_mixed_inputs(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with jax._src.config.gsda_out(True):
@partial(pjit,
in_axis_resources=(FROM_GSDA, P('x', 'y')),
out_axis_resources=(P('x', 'y'), P(('x', 'y'))))
def f(x, y):
return x @ x.T, y @ y.T
expected_matrix_mul = input_data @ input_data.T
out1, out2 = f(gsda_obj, input_data)
self.assertIsInstance(out1, gsda.GlobalShardedDeviceArray)
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
self.assertDictEqual(out1._global_mesh.shape, {'x': 4, 'y': 2})
for s in out1.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, gsda.GlobalShardedDeviceArray)
self.assertEqual(out2.shape, (8, 8))
self.assertEqual(out2.local_shards[0].data.shape, (1, 8))
self.assertDictEqual(out2._global_mesh.shape, {'x': 4, 'y': 2})
for s in out2.local_shards:
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
@jtu.with_mesh([('x', 2), ('y', 2)])
def test_pjit_gsda_mesh_mismatch(self):
@ -718,11 +754,34 @@ class GSDAPjitTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Pjit's mesh and GSDA's mesh should be equal."):
@partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y'))
@partial(pjit, in_axis_resources=FROM_GSDA, out_axis_resources=P('x', 'y'))
def f(x):
return x
f(gsda_obj)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gsda_wrong_resource_for_gsda_input(self):
global_mesh = create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = ['x']
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gsda_obj = gsda.GlobalShardedDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, cb)
with self.assertRaisesWithLiteralMatch(
ValueError,
("Got an input GSDA to pjit with different partitioning than specified "
"in the in_axis_resources argument to pjit. The paritioning must "
"match, or use `jax.experimental.pjit.FROM_GSDA` in `in_axis_resources`. "
"Got GSDA spec: <partitions=(('x',),) sync=2>, "
"pjit spec: <partitions=(('x',), ('y',)) sync=2>")):
@partial(pjit, in_axis_resources=P('x', 'y'), out_axis_resources=P('x', 'y'))
def f(x):
return x
f(gsda_obj)
def spec_regex(s):