mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Set in_positional_semantics should be GLOBAL for fully replicated values to avoid recompilation.
Split _pjit_jaxpr into 2 functions so that passing in `is_gda` as an argument to _pjit_jaxpr can be avoided which was leading to the cache invalidation. PiperOrigin-RevId: 434825926
This commit is contained in:
parent
4848c75b88
commit
846c480fa9
@ -735,7 +735,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
av if ips == _PositionalSemantics.GLOBAL else mesh._local_to_global(ax, av)
|
||||
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
|
||||
]
|
||||
in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics]
|
||||
in_is_global = [ips == _PositionalSemantics.GLOBAL or not ia
|
||||
for ips, ia in safe_zip(in_positional_semantics, mesh_in_axes)]
|
||||
tiling_method: pxla.TilingMethod
|
||||
if config.experimental_xmap_spmd_lowering_manual:
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
@ -746,7 +747,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
f, 'xmap', name, mesh,
|
||||
mesh_in_axes, mesh_out_axes, donated_invars,
|
||||
use_spmd_lowering, global_in_avals,
|
||||
tiling_method=tiling_method, in_is_gda=in_is_gda)
|
||||
tiling_method=tiling_method, in_is_global=in_is_global)
|
||||
else:
|
||||
return dispatch.lower_xla_callable(
|
||||
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))
|
||||
|
@ -238,19 +238,21 @@ def pjit(fun: Callable,
|
||||
maps._PositionalSemantics.GLOBAL if isinstance(a, GDA) else maps._positional_semantics.val
|
||||
for a in args_flat)
|
||||
out_positional_semantics = maps._positional_semantics.val
|
||||
jaxpr, in_axis_resources_flat, out_axis_resources_flat = _pjit_jaxpr(
|
||||
flat_fun, mesh, local_in_avals, in_tree,
|
||||
hashable_pytree(in_axis_resources),
|
||||
HashableFunction(out_tree, closure=()),
|
||||
hashable_pytree(out_axis_resources),
|
||||
in_positional_semantics, out_positional_semantics,
|
||||
tuple(isinstance(a, GDA) for a in args_flat))
|
||||
in_axis_resources_flat = tree_map(_maybe_replace_from_gda_with_pspec,
|
||||
in_axis_resources_flat, tuple(args_flat))
|
||||
|
||||
global_in_avals, canonicalized_in_axis_resources_flat = _process_in_axis_resources(
|
||||
mesh, local_in_avals, hashable_pytree(in_axis_resources), in_tree,
|
||||
in_positional_semantics, tuple(isinstance(a, GDA) for a in args_flat))
|
||||
jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
|
||||
flat_fun, mesh, global_in_avals, HashableFunction(out_tree, closure=()),
|
||||
hashable_pytree(out_axis_resources))
|
||||
canonicalized_in_axis_resources_flat = tree_map(
|
||||
_maybe_replace_from_gda_with_pspec,
|
||||
canonicalized_in_axis_resources_flat, tuple(args_flat))
|
||||
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_axis_resources=in_axis_resources_flat,
|
||||
out_axis_resources=out_axis_resources_flat,
|
||||
in_axis_resources=canonicalized_in_axis_resources_flat,
|
||||
out_axis_resources=canonicalized_out_axis_resources_flat,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=getattr(flat_fun, '__name__', '<unnamed function>'),
|
||||
@ -270,11 +272,13 @@ def pjit(fun: Callable,
|
||||
def lower(*args, **kwargs):
|
||||
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
|
||||
donate_argnums) = infer_params(*args, **kwargs)
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
params['in_positional_semantics'], params['in_axis_resources'])
|
||||
lowering = _pjit_lower(
|
||||
params['jaxpr'], params['in_axis_resources'],
|
||||
params['out_axis_resources'], params['resource_env'],
|
||||
params['donated_invars'], params['name'],
|
||||
params['in_positional_semantics'], params['out_positional_semantics'])
|
||||
in_is_global)
|
||||
|
||||
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
|
||||
@ -352,12 +356,9 @@ class PytreeLeaf:
|
||||
def __repr__(self): return "pytree leaf"
|
||||
|
||||
|
||||
@lu.cache
|
||||
def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
in_tree, in_axis_resources_thunk,
|
||||
out_tree, out_axis_resources_thunk,
|
||||
in_positional_semantics, out_positional_semantics, is_gda):
|
||||
# TODO(yashkatariya): Make this work with FROM_GDA special value.
|
||||
@cache()
|
||||
def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk,
|
||||
in_tree, in_positional_semantics, is_gda):
|
||||
in_axis_resources_flat = flatten_axis_resources(
|
||||
"pjit in_axis_resources", in_tree,
|
||||
in_axis_resources_thunk(), tupled_args=True)
|
||||
@ -388,7 +389,11 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
|
||||
global_in_avals = local_to_global(in_positional_semantics, mesh,
|
||||
local_in_avals, canonicalized_in_axis_resources_flat)
|
||||
return tuple(global_in_avals), canonicalized_in_axis_resources_flat
|
||||
|
||||
|
||||
@lu.cache
|
||||
def _pjit_jaxpr(fun, mesh, global_in_avals, out_tree, out_axis_resources_thunk):
|
||||
prev_positional_val = maps._positional_semantics.val
|
||||
try:
|
||||
maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL
|
||||
@ -407,8 +412,7 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
|
||||
allow_uneven_sharding=False)
|
||||
canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat)
|
||||
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
|
||||
return _ListWithW([jaxpr, canonicalized_in_axis_resources_flat,
|
||||
canonicalized_out_axis_resources_flat])
|
||||
return _ListWithW([jaxpr, canonicalized_out_axis_resources_flat])
|
||||
|
||||
|
||||
class SpecSync(IntEnum):
|
||||
@ -492,9 +496,6 @@ class ParsedPartitionSpec:
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
|
||||
REPLICATED = ParsedPartitionSpec(None, ())
|
||||
|
||||
|
||||
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
||||
"""ParsedPartitionSpecs that are canonicalized.
|
||||
|
||||
@ -524,6 +525,9 @@ class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
||||
f"sync={self.sync})")
|
||||
|
||||
|
||||
REPLICATED = CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ()))
|
||||
|
||||
|
||||
def _prepare_axis_resources(axis_resources,
|
||||
arg_name,
|
||||
allow_unconstrained_dims=False):
|
||||
@ -595,10 +599,10 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name,
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
in_is_global = _calc_is_global_sequence(in_positional_semantics, in_axis_resources)
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_axis_resources, out_axis_resources,
|
||||
resource_env, donated_invars, name, in_positional_semantics,
|
||||
out_positional_semantics).compile()
|
||||
resource_env, donated_invars, name, in_is_global).compile()
|
||||
distributed_debug_log(("Running pjit'd function", name),
|
||||
("mesh", resource_env.physical_mesh))
|
||||
return compiled.unsafe_call(*args)
|
||||
@ -612,7 +616,7 @@ def _pjit_lower(
|
||||
resource_env,
|
||||
donated_invars,
|
||||
name: str,
|
||||
in_positional_semantics, out_positional_semantics):
|
||||
in_is_global: Sequence[bool]):
|
||||
# in_axis_resources and out_axis_resources are canonicalized to avoid
|
||||
# recompilation (since pjit_lower is cached) if its compiled with `None` but
|
||||
# in the next call `P(None)` is passed. Those are the same thing so should be
|
||||
@ -623,12 +627,10 @@ def _pjit_lower(
|
||||
f = core.jaxpr_as_fun(jaxpr)
|
||||
f.__name__ = name
|
||||
fun = lu.wrap_init(f)
|
||||
in_is_gda = [ips == maps._PositionalSemantics.GLOBAL
|
||||
for ips in in_positional_semantics]
|
||||
return pxla.lower_mesh_computation(
|
||||
fun, 'pjit', name, resource_env.physical_mesh,
|
||||
in_axes, out_axes, donated_invars,
|
||||
True, jaxpr.in_avals, tiling_method=None, in_is_gda=in_is_gda)
|
||||
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)
|
||||
|
||||
|
||||
def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env,
|
||||
@ -782,8 +784,14 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
out_positional_semantics=out_positional_semantics)
|
||||
|
||||
if num_residuals:
|
||||
executable = _pjit_lower(**known_params).compile(
|
||||
_allow_propagation_to_outputs=True, _allow_compile_replicated=False)
|
||||
in_is_global = _calc_is_global_sequence(
|
||||
known_params['in_positional_semantics'], known_params['in_axis_resources'])
|
||||
executable = _pjit_lower(
|
||||
known_params["jaxpr"], known_params["in_axis_resources"],
|
||||
known_params["out_axis_resources"], known_params["resource_env"],
|
||||
known_params["donated_invars"], known_params["name"],
|
||||
in_is_global).compile(_allow_propagation_to_outputs=True,
|
||||
_allow_compile_replicated=False)
|
||||
output_op_sharding = \
|
||||
executable.xla_executable.hlo_modules()[0].spmd_output_sharding
|
||||
output_sharding_specs = parse_op_sharding(output_op_sharding, mesh)
|
||||
@ -1053,6 +1061,13 @@ def local_to_global(positional_semantics, mesh, avals, axes):
|
||||
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
|
||||
]
|
||||
|
||||
|
||||
def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
|
||||
return tuple(
|
||||
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
|
||||
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))
|
||||
|
||||
|
||||
def _create_cpspec(x):
|
||||
return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x)
|
||||
|
||||
|
@ -2140,7 +2140,7 @@ def lower_mesh_computation(
|
||||
spmd_lowering: bool,
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
tiling_method: Optional[TilingMethod],
|
||||
in_is_gda: Sequence[bool]):
|
||||
in_is_global: Sequence[bool]):
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
||||
@ -2236,7 +2236,7 @@ def lower_mesh_computation(
|
||||
return MeshComputation(
|
||||
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes,
|
||||
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_gda=in_is_gda)
|
||||
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_global=in_is_global)
|
||||
|
||||
|
||||
class MeshComputation:
|
||||
@ -2277,13 +2277,13 @@ class MeshComputation:
|
||||
return self._executable
|
||||
|
||||
|
||||
def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
|
||||
def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_global):
|
||||
input_specs, input_indices, input_avals = [], [], []
|
||||
num_local_devices = len(global_mesh.local_devices)
|
||||
for gaval, axis, is_gda in safe_zip(global_in_avals, in_axes, in_is_gda):
|
||||
for gaval, axis, is_global in safe_zip(global_in_avals, in_axes, in_is_global):
|
||||
# TODO(yashkatariya): Don't calculate input_indices and input_specs for GDA
|
||||
# as GDA doesn't need it.
|
||||
if is_gda or not axis:
|
||||
if is_global:
|
||||
aval = gaval
|
||||
mesh = global_mesh
|
||||
else:
|
||||
@ -2292,9 +2292,11 @@ def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
|
||||
|
||||
spec = (mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, axis)
|
||||
if aval is not core.abstract_unit else None)
|
||||
# We special case this logic to support fully replicated non-GDA values
|
||||
# with non-contiguous submeshes
|
||||
if not axis and not is_gda:
|
||||
# We special case this logic to support fully replicated values because
|
||||
# the mesh is global mesh and the indices returned by `spec_to_indices` will
|
||||
# represent index for each device in the global mesh. But here we want
|
||||
# indices for the local devices of the global mesh.
|
||||
if not axis:
|
||||
index = tuple((slice(None),) * aval.ndim for _ in range(num_local_devices))
|
||||
else:
|
||||
index = spec_to_indices(aval.shape, spec) if spec is not None else None
|
||||
@ -2323,7 +2325,7 @@ class MeshExecutable:
|
||||
in_axes: Sequence[ArrayMapping],
|
||||
out_axes: Sequence[ArrayMapping],
|
||||
spmd_lowering: bool, tuple_args: bool,
|
||||
in_is_gda: Sequence[bool],
|
||||
in_is_global: Sequence[bool],
|
||||
_allow_propagation_to_outputs: bool,
|
||||
_allow_compile_replicated: bool) -> 'MeshExecutable':
|
||||
assert not mesh.empty
|
||||
@ -2345,7 +2347,7 @@ class MeshExecutable:
|
||||
_allow_propagation_to_outputs
|
||||
|
||||
input_specs, input_indices, input_avals = _get_input_metadata(
|
||||
global_in_avals, mesh, in_axes, in_is_gda)
|
||||
global_in_avals, mesh, in_axes, in_is_global)
|
||||
# Calculate local information here instead of calculating it in
|
||||
# `avals_to_results_handler` because pmap also uses this function.
|
||||
handle_outs = global_avals_to_results_handler(global_out_avals, out_axes, mesh)
|
||||
|
@ -1078,20 +1078,32 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
input_shape, global_mesh, mesh_axes, cb)
|
||||
|
||||
trace_counter = [0]
|
||||
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
|
||||
def f(x, y):
|
||||
trace_counter[0] += 1
|
||||
return x @ y.T
|
||||
|
||||
before_lower_cache = pjit_lib._pjit_lower.cache_info()
|
||||
|
||||
f(gda_obj, gda_obj)
|
||||
self.assertListEqual(trace_counter, [1])
|
||||
after_lower_cache1 = pjit_lib._pjit_lower.cache_info()
|
||||
self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits)
|
||||
self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses)
|
||||
|
||||
f(gda_obj, gda_obj)
|
||||
self.assertListEqual(trace_counter, [1])
|
||||
after_lower_cache2 = pjit_lib._pjit_lower.cache_info()
|
||||
self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits)
|
||||
self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses)
|
||||
|
||||
f(input_data, input_data)
|
||||
self.assertListEqual(trace_counter, [2])
|
||||
after_lower_cache3 = pjit_lib._pjit_lower.cache_info()
|
||||
self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits)
|
||||
self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses)
|
||||
|
||||
f(gda_obj, input_data)
|
||||
self.assertListEqual(trace_counter, [3])
|
||||
after_lower_cache4 = pjit_lib._pjit_lower.cache_info()
|
||||
self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits)
|
||||
self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses)
|
||||
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_partition_spec_mismatch_semantically_equivalent(self):
|
||||
@ -1143,7 +1155,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
|
||||
with global_mesh:
|
||||
out_gda = f(input_gda)
|
||||
self.assertEqual(out_gda.mesh_axes, ())
|
||||
|
||||
@ -1151,7 +1163,28 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
f(out_gda)
|
||||
after_cache = pjit_lib._pjit_lower.cache_info()
|
||||
|
||||
self.assertNotEqual(id(before_cache), id(after_cache))
|
||||
self.assertEqual(before_cache.hits + 1, after_cache.hits)
|
||||
self.assertEqual(before_cache.misses, after_cache.misses)
|
||||
|
||||
def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self):
|
||||
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None)
|
||||
global_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
with jax._src.config.parallel_functions_output_gda(True):
|
||||
f = pjit(lambda x: x, in_axis_resources=mesh_axes,
|
||||
out_axis_resources=mesh_axes)
|
||||
|
||||
with global_mesh:
|
||||
out_gda = f(global_data)
|
||||
self.assertEqual(out_gda.mesh_axes, ())
|
||||
|
||||
before_cache = pjit_lib._pjit_lower.cache_info()
|
||||
f(out_gda)
|
||||
after_cache = pjit_lib._pjit_lower.cache_info()
|
||||
|
||||
self.assertEqual(before_cache.hits + 1, after_cache.hits)
|
||||
self.assertEqual(before_cache.misses, after_cache.misses)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user