Set in_positional_semantics should be GLOBAL for fully replicated values to avoid recompilation.

Split _pjit_jaxpr into 2 functions so that passing in `is_gda` as an argument to _pjit_jaxpr can be avoided which was leading to the cache invalidation.

PiperOrigin-RevId: 434825926
This commit is contained in:
Yash Katariya 2022-03-15 12:31:51 -07:00 committed by jax authors
parent 4848c75b88
commit 846c480fa9
4 changed files with 102 additions and 51 deletions

View File

@ -735,7 +735,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
av if ips == _PositionalSemantics.GLOBAL else mesh._local_to_global(ax, av)
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
]
in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics]
in_is_global = [ips == _PositionalSemantics.GLOBAL or not ia
for ips, ia in safe_zip(in_positional_semantics, mesh_in_axes)]
tiling_method: pxla.TilingMethod
if config.experimental_xmap_spmd_lowering_manual:
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
@ -746,7 +747,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
f, 'xmap', name, mesh,
mesh_in_axes, mesh_out_axes, donated_invars,
use_spmd_lowering, global_in_avals,
tiling_method=tiling_method, in_is_gda=in_is_gda)
tiling_method=tiling_method, in_is_global=in_is_global)
else:
return dispatch.lower_xla_callable(
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))

View File

@ -238,19 +238,21 @@ def pjit(fun: Callable,
maps._PositionalSemantics.GLOBAL if isinstance(a, GDA) else maps._positional_semantics.val
for a in args_flat)
out_positional_semantics = maps._positional_semantics.val
jaxpr, in_axis_resources_flat, out_axis_resources_flat = _pjit_jaxpr(
flat_fun, mesh, local_in_avals, in_tree,
hashable_pytree(in_axis_resources),
HashableFunction(out_tree, closure=()),
hashable_pytree(out_axis_resources),
in_positional_semantics, out_positional_semantics,
tuple(isinstance(a, GDA) for a in args_flat))
in_axis_resources_flat = tree_map(_maybe_replace_from_gda_with_pspec,
in_axis_resources_flat, tuple(args_flat))
global_in_avals, canonicalized_in_axis_resources_flat = _process_in_axis_resources(
mesh, local_in_avals, hashable_pytree(in_axis_resources), in_tree,
in_positional_semantics, tuple(isinstance(a, GDA) for a in args_flat))
jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
flat_fun, mesh, global_in_avals, HashableFunction(out_tree, closure=()),
hashable_pytree(out_axis_resources))
canonicalized_in_axis_resources_flat = tree_map(
_maybe_replace_from_gda_with_pspec,
canonicalized_in_axis_resources_flat, tuple(args_flat))
params = dict(
jaxpr=jaxpr,
in_axis_resources=in_axis_resources_flat,
out_axis_resources=out_axis_resources_flat,
in_axis_resources=canonicalized_in_axis_resources_flat,
out_axis_resources=canonicalized_out_axis_resources_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unnamed function>'),
@ -270,11 +272,13 @@ def pjit(fun: Callable,
def lower(*args, **kwargs):
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params(*args, **kwargs)
in_is_global = _calc_is_global_sequence(
params['in_positional_semantics'], params['in_axis_resources'])
lowering = _pjit_lower(
params['jaxpr'], params['in_axis_resources'],
params['out_axis_resources'], params['resource_env'],
params['donated_invars'], params['name'],
params['in_positional_semantics'], params['out_positional_semantics'])
in_is_global)
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
@ -352,12 +356,9 @@ class PytreeLeaf:
def __repr__(self): return "pytree leaf"
@lu.cache
def _pjit_jaxpr(fun, mesh, local_in_avals,
in_tree, in_axis_resources_thunk,
out_tree, out_axis_resources_thunk,
in_positional_semantics, out_positional_semantics, is_gda):
# TODO(yashkatariya): Make this work with FROM_GDA special value.
@cache()
def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk,
in_tree, in_positional_semantics, is_gda):
in_axis_resources_flat = flatten_axis_resources(
"pjit in_axis_resources", in_tree,
in_axis_resources_thunk(), tupled_args=True)
@ -388,7 +389,11 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
global_in_avals = local_to_global(in_positional_semantics, mesh,
local_in_avals, canonicalized_in_axis_resources_flat)
return tuple(global_in_avals), canonicalized_in_axis_resources_flat
@lu.cache
def _pjit_jaxpr(fun, mesh, global_in_avals, out_tree, out_axis_resources_thunk):
prev_positional_val = maps._positional_semantics.val
try:
maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL
@ -407,8 +412,7 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
allow_uneven_sharding=False)
canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat)
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return _ListWithW([jaxpr, canonicalized_in_axis_resources_flat,
canonicalized_out_axis_resources_flat])
return _ListWithW([jaxpr, canonicalized_out_axis_resources_flat])
class SpecSync(IntEnum):
@ -492,9 +496,6 @@ class ParsedPartitionSpec:
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
REPLICATED = ParsedPartitionSpec(None, ())
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
"""ParsedPartitionSpecs that are canonicalized.
@ -524,6 +525,9 @@ class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
f"sync={self.sync})")
REPLICATED = CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ()))
def _prepare_axis_resources(axis_resources,
arg_name,
allow_unconstrained_dims=False):
@ -595,10 +599,10 @@ def _pjit_call_impl(*args, jaxpr,
in_axis_resources, out_axis_resources,
resource_env, donated_invars, name,
in_positional_semantics, out_positional_semantics):
in_is_global = _calc_is_global_sequence(in_positional_semantics, in_axis_resources)
compiled = _pjit_lower(
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics).compile()
resource_env, donated_invars, name, in_is_global).compile()
distributed_debug_log(("Running pjit'd function", name),
("mesh", resource_env.physical_mesh))
return compiled.unsafe_call(*args)
@ -612,7 +616,7 @@ def _pjit_lower(
resource_env,
donated_invars,
name: str,
in_positional_semantics, out_positional_semantics):
in_is_global: Sequence[bool]):
# in_axis_resources and out_axis_resources are canonicalized to avoid
# recompilation (since pjit_lower is cached) if its compiled with `None` but
# in the next call `P(None)` is passed. Those are the same thing so should be
@ -623,12 +627,10 @@ def _pjit_lower(
f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
fun = lu.wrap_init(f)
in_is_gda = [ips == maps._PositionalSemantics.GLOBAL
for ips in in_positional_semantics]
return pxla.lower_mesh_computation(
fun, 'pjit', name, resource_env.physical_mesh,
in_axes, out_axes, donated_invars,
True, jaxpr.in_avals, tiling_method=None, in_is_gda=in_is_gda)
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env,
@ -782,8 +784,14 @@ def _pjit_partial_eval(trace, *in_tracers,
out_positional_semantics=out_positional_semantics)
if num_residuals:
executable = _pjit_lower(**known_params).compile(
_allow_propagation_to_outputs=True, _allow_compile_replicated=False)
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_axis_resources'])
executable = _pjit_lower(
known_params["jaxpr"], known_params["in_axis_resources"],
known_params["out_axis_resources"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global).compile(_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
output_op_sharding = \
executable.xla_executable.hlo_modules()[0].spmd_output_sharding
output_sharding_specs = parse_op_sharding(output_op_sharding, mesh)
@ -1053,6 +1061,13 @@ def local_to_global(positional_semantics, mesh, avals, axes):
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]
def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
return tuple(
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))
def _create_cpspec(x):
return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x)

View File

@ -2140,7 +2140,7 @@ def lower_mesh_computation(
spmd_lowering: bool,
global_in_avals: Sequence[core.ShapedArray],
tiling_method: Optional[TilingMethod],
in_is_gda: Sequence[bool]):
in_is_global: Sequence[bool]):
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = new_name_stack(wrap_name(fun_name, api_name))
@ -2236,7 +2236,7 @@ def lower_mesh_computation(
return MeshComputation(
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes,
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_gda=in_is_gda)
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_global=in_is_global)
class MeshComputation:
@ -2277,13 +2277,13 @@ class MeshComputation:
return self._executable
def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_global):
input_specs, input_indices, input_avals = [], [], []
num_local_devices = len(global_mesh.local_devices)
for gaval, axis, is_gda in safe_zip(global_in_avals, in_axes, in_is_gda):
for gaval, axis, is_global in safe_zip(global_in_avals, in_axes, in_is_global):
# TODO(yashkatariya): Don't calculate input_indices and input_specs for GDA
# as GDA doesn't need it.
if is_gda or not axis:
if is_global:
aval = gaval
mesh = global_mesh
else:
@ -2292,9 +2292,11 @@ def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
spec = (mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, axis)
if aval is not core.abstract_unit else None)
# We special case this logic to support fully replicated non-GDA values
# with non-contiguous submeshes
if not axis and not is_gda:
# We special case this logic to support fully replicated values because
# the mesh is global mesh and the indices returned by `spec_to_indices` will
# represent index for each device in the global mesh. But here we want
# indices for the local devices of the global mesh.
if not axis:
index = tuple((slice(None),) * aval.ndim for _ in range(num_local_devices))
else:
index = spec_to_indices(aval.shape, spec) if spec is not None else None
@ -2323,7 +2325,7 @@ class MeshExecutable:
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
spmd_lowering: bool, tuple_args: bool,
in_is_gda: Sequence[bool],
in_is_global: Sequence[bool],
_allow_propagation_to_outputs: bool,
_allow_compile_replicated: bool) -> 'MeshExecutable':
assert not mesh.empty
@ -2345,7 +2347,7 @@ class MeshExecutable:
_allow_propagation_to_outputs
input_specs, input_indices, input_avals = _get_input_metadata(
global_in_avals, mesh, in_axes, in_is_gda)
global_in_avals, mesh, in_axes, in_is_global)
# Calculate local information here instead of calculating it in
# `avals_to_results_handler` because pmap also uses this function.
handle_outs = global_avals_to_results_handler(global_out_avals, out_axes, mesh)

View File

@ -1078,20 +1078,32 @@ class GDAPjitTest(jtu.JaxTestCase):
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
input_shape, global_mesh, mesh_axes, cb)
trace_counter = [0]
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
def f(x, y):
trace_counter[0] += 1
return x @ y.T
before_lower_cache = pjit_lib._pjit_lower.cache_info()
f(gda_obj, gda_obj)
self.assertListEqual(trace_counter, [1])
after_lower_cache1 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits)
self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses)
f(gda_obj, gda_obj)
self.assertListEqual(trace_counter, [1])
after_lower_cache2 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits)
self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses)
f(input_data, input_data)
self.assertListEqual(trace_counter, [2])
after_lower_cache3 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits)
self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses)
f(gda_obj, input_data)
self.assertListEqual(trace_counter, [3])
after_lower_cache4 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits)
self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses)
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_partition_spec_mismatch_semantically_equivalent(self):
@ -1143,7 +1155,7 @@ class GDAPjitTest(jtu.JaxTestCase):
def f(x):
return x
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
with global_mesh:
out_gda = f(input_gda)
self.assertEqual(out_gda.mesh_axes, ())
@ -1151,7 +1163,28 @@ class GDAPjitTest(jtu.JaxTestCase):
f(out_gda)
after_cache = pjit_lib._pjit_lower.cache_info()
self.assertNotEqual(id(before_cache), id(after_cache))
self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)
def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None)
global_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
with jax._src.config.parallel_functions_output_gda(True):
f = pjit(lambda x: x, in_axis_resources=mesh_axes,
out_axis_resources=mesh_axes)
with global_mesh:
out_gda = f(global_data)
self.assertEqual(out_gda.mesh_axes, ())
before_cache = pjit_lib._pjit_lower.cache_info()
f(out_gda)
after_cache = pjit_lib._pjit_lower.cache_info()
self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)