Create a helper function for creating compile options

PiperOrigin-RevId: 638374966
This commit is contained in:
Yash Katariya 2024-05-29 12:30:10 -07:00 committed by jax authors
parent fafe740400
commit beaabae8f6

View File

@ -2208,6 +2208,7 @@ def lower_sharding_computation(
str(name_stack),
module,
donated_invars,
platforms,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
@ -2222,7 +2223,6 @@ def lower_sharding_computation(
kept_var_idx=kept_var_idx,
mut=mut,
backend=backend,
platforms=platforms,
device_assignment=da_object,
committed=committed,
in_layouts=in_layouts,
@ -2390,6 +2390,7 @@ def lower_mesh_computation(
str(name_stack),
lowering_result.module,
donated_invars,
platforms,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
@ -2403,7 +2404,6 @@ def lower_mesh_computation(
keepalive=lowering_result.keepalive,
kept_var_idx=set(range(len(global_in_avals))),
backend=backend,
platforms=platforms,
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True,
in_layouts=(None,) * len(global_in_avals),
@ -2417,10 +2417,7 @@ class MeshComputation(stages.XlaLowering):
_executable: MeshExecutable | None
def __init__(self, name: str, hlo: ir.Module,
donated_invars: Sequence[bool],
# TODO(necula): fix this when internal clients stop using this
# constructor directly.
platforms: Sequence[str] | None = None,
donated_invars: Sequence[bool], platforms: Sequence[str],
**compile_args):
self._name = name
self._hlo = hlo
@ -2690,39 +2687,23 @@ def get_logical_mesh_ids(mesh_shape):
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
@weakref_lru_cache
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
compiler_options_values,
pgle_profiler):
# TODO(phawkins): One would normally just write:
# dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
# If we were to optimize __getattr__ on xc.Device we might not need this
# workaround.
dev = np.vectorize(lambda i: da[i], otypes=[object])(
np.arange(len(da))
)
def create_compile_options(
computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,
allow_prop_to_inputs, allow_prop_to_outputs, backend,
np_dev, pmap_nreps, compiler_options):
if pmap_nreps > 1:
num_replicas, num_partitions = pmap_nreps, 1
elif spmd_lowering:
num_replicas, num_partitions = 1, dev.size
num_replicas, num_partitions = 1, np_dev.size
else:
num_replicas, num_partitions = dev.size, 1
num_replicas, num_partitions = np_dev.size, 1
if pmap_nreps > 1:
# In `jit` device_assignment is set to None when num_replicas > 1. Do
# the same thing here too.
xla_device_assignment = None
else:
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
if compiler_options_keys is None:
compiler_options = None
else:
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
xla_device_assignment = np_dev.reshape((num_replicas, num_partitions))
fdo_profile = (None if compiler_options is None else
compiler_options.pop("fdo_profile", None))
@ -2749,6 +2730,29 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
compile_options.parameter_is_tupled_arguments = tuple_args
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
return compile_options
@weakref_lru_cache
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
compiler_options_values,
pgle_profiler):
# One would normally just write: dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da)))
if compiler_options_keys is None:
compiler_options = None
else:
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
compile_options = create_compile_options(
computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,
allow_prop_to_inputs, allow_prop_to_outputs, backend,
dev, pmap_nreps, compiler_options)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",