mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Create a helper function for creating compile options
PiperOrigin-RevId: 638374966
This commit is contained in:
parent
fafe740400
commit
beaabae8f6
@ -2208,6 +2208,7 @@ def lower_sharding_computation(
|
||||
str(name_stack),
|
||||
module,
|
||||
donated_invars,
|
||||
platforms,
|
||||
global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals,
|
||||
in_shardings=in_shardings,
|
||||
@ -2222,7 +2223,6 @@ def lower_sharding_computation(
|
||||
kept_var_idx=kept_var_idx,
|
||||
mut=mut,
|
||||
backend=backend,
|
||||
platforms=platforms,
|
||||
device_assignment=da_object,
|
||||
committed=committed,
|
||||
in_layouts=in_layouts,
|
||||
@ -2390,6 +2390,7 @@ def lower_mesh_computation(
|
||||
str(name_stack),
|
||||
lowering_result.module,
|
||||
donated_invars,
|
||||
platforms,
|
||||
global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals,
|
||||
in_shardings=in_shardings,
|
||||
@ -2403,7 +2404,6 @@ def lower_mesh_computation(
|
||||
keepalive=lowering_result.keepalive,
|
||||
kept_var_idx=set(range(len(global_in_avals))),
|
||||
backend=backend,
|
||||
platforms=platforms,
|
||||
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
||||
committed=True,
|
||||
in_layouts=(None,) * len(global_in_avals),
|
||||
@ -2417,10 +2417,7 @@ class MeshComputation(stages.XlaLowering):
|
||||
_executable: MeshExecutable | None
|
||||
|
||||
def __init__(self, name: str, hlo: ir.Module,
|
||||
donated_invars: Sequence[bool],
|
||||
# TODO(necula): fix this when internal clients stop using this
|
||||
# constructor directly.
|
||||
platforms: Sequence[str] | None = None,
|
||||
donated_invars: Sequence[bool], platforms: Sequence[str],
|
||||
**compile_args):
|
||||
self._name = name
|
||||
self._hlo = hlo
|
||||
@ -2690,39 +2687,23 @@ def get_logical_mesh_ids(mesh_shape):
|
||||
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, host_callbacks, backend,
|
||||
da, pmap_nreps, compiler_options_keys,
|
||||
compiler_options_values,
|
||||
pgle_profiler):
|
||||
# TODO(phawkins): One would normally just write:
|
||||
# dev = np.array(device_assignment)
|
||||
# The formulation below is substantially faster if there are many devices.
|
||||
# If we were to optimize __getattr__ on xc.Device we might not need this
|
||||
# workaround.
|
||||
dev = np.vectorize(lambda i: da[i], otypes=[object])(
|
||||
np.arange(len(da))
|
||||
)
|
||||
def create_compile_options(
|
||||
computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,
|
||||
allow_prop_to_inputs, allow_prop_to_outputs, backend,
|
||||
np_dev, pmap_nreps, compiler_options):
|
||||
if pmap_nreps > 1:
|
||||
num_replicas, num_partitions = pmap_nreps, 1
|
||||
elif spmd_lowering:
|
||||
num_replicas, num_partitions = 1, dev.size
|
||||
num_replicas, num_partitions = 1, np_dev.size
|
||||
else:
|
||||
num_replicas, num_partitions = dev.size, 1
|
||||
num_replicas, num_partitions = np_dev.size, 1
|
||||
|
||||
if pmap_nreps > 1:
|
||||
# In `jit` device_assignment is set to None when num_replicas > 1. Do
|
||||
# the same thing here too.
|
||||
xla_device_assignment = None
|
||||
else:
|
||||
xla_device_assignment = dev.reshape((num_replicas, num_partitions))
|
||||
|
||||
if compiler_options_keys is None:
|
||||
compiler_options = None
|
||||
else:
|
||||
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
|
||||
xla_device_assignment = np_dev.reshape((num_replicas, num_partitions))
|
||||
|
||||
fdo_profile = (None if compiler_options is None else
|
||||
compiler_options.pop("fdo_profile", None))
|
||||
@ -2749,6 +2730,29 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
|
||||
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
|
||||
return compile_options
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, host_callbacks, backend,
|
||||
da, pmap_nreps, compiler_options_keys,
|
||||
compiler_options_values,
|
||||
pgle_profiler):
|
||||
# One would normally just write: dev = np.array(device_assignment)
|
||||
# The formulation below is substantially faster if there are many devices.
|
||||
dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da)))
|
||||
|
||||
if compiler_options_keys is None:
|
||||
compiler_options = None
|
||||
else:
|
||||
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
|
||||
|
||||
compile_options = create_compile_options(
|
||||
computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,
|
||||
allow_prop_to_inputs, allow_prop_to_outputs, backend,
|
||||
dev, pmap_nreps, compiler_options)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
|
Loading…
x
Reference in New Issue
Block a user