mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add compiler_options
argument to jax.jit
.
This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})` PiperOrigin-RevId: 692283964
This commit is contained in:
parent
07858fa98d
commit
fff33f90b2
@ -151,6 +151,7 @@ def jit(
|
||||
backend: str | None = None,
|
||||
inline: bool = False,
|
||||
abstracted_axes: Any | None = None,
|
||||
compiler_options: dict[str, Any] | None = None,
|
||||
) -> pjit.JitWrapped:
|
||||
"""Sets up ``fun`` for just-in-time compilation with XLA.
|
||||
|
||||
@ -280,7 +281,7 @@ def jit(
|
||||
return pjit.make_jit(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes,
|
||||
keep_unused, inline, use_resource_env=False)
|
||||
keep_unused, inline, compiler_options, use_resource_env=False)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -898,7 +898,8 @@ error_checks[lax.while_p] = while_loop_error_check
|
||||
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
in_shardings, out_shardings,
|
||||
in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, inline, keep_unused):
|
||||
resource_env, donated_invars, name, inline, keep_unused,
|
||||
compiler_options_kvs):
|
||||
# jaxpr to checked_jaxpr
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
@ -929,6 +930,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
name=name,
|
||||
inline=inline,
|
||||
keep_unused=keep_unused,
|
||||
compiler_options_kvs=compiler_options_kvs,
|
||||
)
|
||||
return tree_unflatten(out_tree, err_and_out)
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
@ -2121,6 +2121,7 @@ def lower_sharding_computation(
|
||||
*,
|
||||
keep_unused: bool,
|
||||
context_mesh: mesh_lib.Mesh | None,
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...],
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None,
|
||||
@ -2247,6 +2248,7 @@ def lower_sharding_computation(
|
||||
module,
|
||||
donated_invars,
|
||||
platforms,
|
||||
compiler_options_kvs,
|
||||
global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals,
|
||||
in_shardings=in_shardings,
|
||||
@ -2298,11 +2300,13 @@ class MeshComputation(stages.XlaLowering):
|
||||
|
||||
def __init__(self, name: str, hlo: ir.Module,
|
||||
donated_invars: Sequence[bool], platforms: Sequence[str],
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...],
|
||||
**compile_args):
|
||||
self._name = name
|
||||
self._hlo = hlo
|
||||
self._donated_invars = donated_invars
|
||||
self._platforms = platforms
|
||||
self._compiler_options_kvs = compiler_options_kvs
|
||||
self.compile_args = compile_args
|
||||
self._executable = None
|
||||
|
||||
@ -2312,11 +2316,14 @@ class MeshComputation(stages.XlaLowering):
|
||||
return self._hlo
|
||||
|
||||
def compile(self, compiler_options=None) -> MeshExecutable:
|
||||
if self._executable is None or compiler_options is not None:
|
||||
t_compiler_options = (() if compiler_options is None else
|
||||
tuple(compiler_options.items()))
|
||||
compiler_options_kvs = self._compiler_options_kvs + t_compiler_options
|
||||
if self._executable is None or compiler_options_kvs:
|
||||
executable = UnloadedMeshExecutable.from_hlo(
|
||||
self._name, self._hlo, **self.compile_args,
|
||||
compiler_options=compiler_options)
|
||||
if compiler_options is None:
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
if not compiler_options_kvs:
|
||||
self._executable = executable
|
||||
return executable
|
||||
return self._executable
|
||||
@ -2581,8 +2588,7 @@ def create_compile_options(
|
||||
else:
|
||||
xla_device_assignment = np_dev.reshape((num_replicas, num_partitions))
|
||||
|
||||
fdo_profile = (None if compiler_options is None else
|
||||
compiler_options.pop("fdo_profile", None))
|
||||
fdo_profile = compiler_options.pop("fdo_profile", None)
|
||||
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=num_replicas,
|
||||
@ -2614,17 +2620,11 @@ def create_compile_options(
|
||||
def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, host_callbacks, backend,
|
||||
da, pmap_nreps, compiler_options_keys,
|
||||
compiler_options_values,
|
||||
pgle_profiler):
|
||||
da, pmap_nreps, compiler_options_kvs, pgle_profiler):
|
||||
# One would normally just write: dev = np.array(device_assignment)
|
||||
# The formulation below is substantially faster if there are many devices.
|
||||
dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da)))
|
||||
|
||||
if compiler_options_keys is None:
|
||||
compiler_options = None
|
||||
else:
|
||||
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
|
||||
compiler_options = dict(compiler_options_kvs)
|
||||
|
||||
compile_options = create_compile_options(
|
||||
computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,
|
||||
@ -2788,22 +2788,18 @@ class UnloadedMeshExecutable:
|
||||
committed: bool,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: MaybeLayout,
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...],
|
||||
pmap_nreps: int = 1,
|
||||
mut: MutationData | None = None,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
all_default_mem_kind: bool = True,
|
||||
all_args_info: AllArgsInfo | None = None,
|
||||
compiler_options=None,
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None,
|
||||
intermediate_shardings: Sequence[JSharding] | None = None,
|
||||
context_mesh: mesh_lib.Mesh | None = None
|
||||
) -> MeshExecutable:
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
compiler_options_keys = tuple(
|
||||
compiler_options.keys()) if compiler_options is not None else None
|
||||
compiler_options_values = tuple(
|
||||
compiler_options.values()) if compiler_options is not None else None
|
||||
if isinstance(device_assignment, xc.DeviceList):
|
||||
da = device_assignment
|
||||
else:
|
||||
@ -2826,7 +2822,7 @@ class UnloadedMeshExecutable:
|
||||
hlo, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
|
||||
compiler_options_keys, compiler_options_values, pgle_profiler)
|
||||
compiler_options_kvs, pgle_profiler)
|
||||
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
@ -2918,6 +2914,7 @@ class JitGlobalCppCacheKeys:
|
||||
out_layouts_treedef: PyTreeDef | None = None
|
||||
out_layouts_leaves: tuple[Any, ...] | None = None
|
||||
use_resource_env: bool = False
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
|
||||
|
||||
@functools.cached_property
|
||||
def contains_explicit_attributes(self):
|
||||
@ -2928,7 +2925,8 @@ class JitGlobalCppCacheKeys:
|
||||
any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or
|
||||
any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or
|
||||
any(i is not None for i in self.in_layouts_leaves) or
|
||||
any(o is not None for o in self.out_layouts_leaves))
|
||||
any(o is not None for o in self.out_layouts_leaves) or
|
||||
self.compiler_options_kvs)
|
||||
|
||||
|
||||
def reflatten_outputs_for_dispatch(out_tree, out_flat):
|
||||
|
@ -164,6 +164,7 @@ class PjitInfo(NamedTuple):
|
||||
inline: bool
|
||||
abstracted_axes: Any | None
|
||||
use_resource_env: bool # False for jit, True for pjit
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...]
|
||||
|
||||
# Hash and compare PjitInfo by identity when used as a cache key.
|
||||
def __hash__(self):
|
||||
@ -357,7 +358,8 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
in_layouts_leaves=jit_info.in_layouts_leaves,
|
||||
out_layouts_treedef=jit_info.out_layouts_treedef,
|
||||
out_layouts_leaves=jit_info.out_layouts_leaves,
|
||||
use_resource_env=jit_info.use_resource_env)
|
||||
use_resource_env=jit_info.use_resource_env,
|
||||
compiler_options_kvs=jit_info.compiler_options_kvs)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
jit_info.static_argnames, cache_key, tree_util.dispatch_registry,
|
||||
@ -398,7 +400,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
static_argnames: str | Iterable[str] | None,
|
||||
device: xc.Device | None, backend: str | None,
|
||||
abstracted_axes: Any | None, keep_unused: bool,
|
||||
inline: bool, use_resource_env: bool) -> PjitInfo:
|
||||
inline: bool, compiler_options: dict[str, Any] | None,
|
||||
use_resource_env: bool) -> PjitInfo:
|
||||
"""Parses the arguments to jit/pjit.
|
||||
|
||||
Performs any preprocessing and validation of the arguments that we can do
|
||||
@ -453,6 +456,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
|
||||
static_argnames)
|
||||
|
||||
compiler_options_kvs = (() if compiler_options is None else
|
||||
tuple(compiler_options.items()))
|
||||
return PjitInfo(
|
||||
fun_sourceinfo=fun_sourceinfo,
|
||||
fun_signature=fun_signature,
|
||||
@ -470,7 +475,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
donate_argnames=donate_argnames, device=device, backend=backend,
|
||||
keep_unused=keep_unused, inline=inline,
|
||||
abstracted_axes=abstracted_axes,
|
||||
use_resource_env=use_resource_env)
|
||||
use_resource_env=use_resource_env,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
|
||||
|
||||
def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
@ -514,12 +520,13 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
static_argnames: str | Iterable[str] | None,
|
||||
device: xc.Device | None, backend: str | None,
|
||||
abstracted_axes: Any | None, keep_unused: bool,
|
||||
inline: bool, use_resource_env: bool) -> Any:
|
||||
inline: bool, compiler_options: dict[str, Any] | None,
|
||||
use_resource_env: bool) -> Any:
|
||||
"""jit() and pjit() are thin wrappers around this function."""
|
||||
jit_info = _parse_jit_arguments(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes,
|
||||
keep_unused, inline, use_resource_env)
|
||||
keep_unused, inline, compiler_options, use_resource_env)
|
||||
return _make_jit_wrapper(fun, jit_info)
|
||||
|
||||
|
||||
@ -676,6 +683,7 @@ def _infer_params_impl(
|
||||
name=fun_qual_name(flat_fun),
|
||||
keep_unused=ji.keep_unused,
|
||||
inline=ji.inline,
|
||||
compiler_options_kvs=ji.compiler_options_kvs,
|
||||
)
|
||||
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
|
||||
donated_invars, dbg.arg_names if dbg else None, len(consts),
|
||||
@ -815,6 +823,7 @@ def pjit(
|
||||
backend: str | None = None,
|
||||
inline: bool = False,
|
||||
abstracted_axes: Any | None = None,
|
||||
compiler_options: dict[str, Any] | None = None,
|
||||
) -> JitWrapped:
|
||||
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
|
||||
|
||||
@ -987,7 +996,7 @@ def pjit(
|
||||
return make_jit(
|
||||
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
|
||||
static_argnums, static_argnames, device, backend, abstracted_axes,
|
||||
keep_unused, inline, use_resource_env=True)
|
||||
keep_unused, inline, compiler_options, use_resource_env=True)
|
||||
|
||||
|
||||
def hashable_pytree(pytree):
|
||||
@ -1594,25 +1603,25 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
def _resolve_and_lower(
|
||||
args, jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
|
||||
lowering_platforms, lowering_parameters, pgle_profiler):
|
||||
lowering_platforms, lowering_parameters, pgle_profiler,
|
||||
compiler_options_kvs):
|
||||
in_shardings = _resolve_in_shardings(args, in_shardings)
|
||||
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
|
||||
jaxpr.in_avals)
|
||||
lowered = _pjit_lower(
|
||||
return _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, name, keep_unused, inline,
|
||||
donated_invars, name, keep_unused, inline, compiler_options_kvs,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
return lowered
|
||||
|
||||
def _pjit_call_impl_python(
|
||||
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
global _most_recent_pjit_call_executable
|
||||
|
||||
compile_options = None
|
||||
pgle_profiler = None
|
||||
pgle_compile_options, pgle_profiler = {}, None
|
||||
pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict
|
||||
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
|
||||
if jaxpr not in pgle_profiler_dict:
|
||||
@ -1626,8 +1635,9 @@ def _pjit_call_impl_python(
|
||||
# be None.
|
||||
fdo_profile = pgle_profiler.consume_fdo_profile()
|
||||
if fdo_profile is not None:
|
||||
compile_options = {'fdo_profile': fdo_profile}
|
||||
pgle_compile_options['fdo_profile'] = fdo_profile
|
||||
|
||||
compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
|
||||
# TODO(patrios): Do not pass mutable profile session through cached lowering
|
||||
# chain. Instead we need to move profilers dictionary to pxla module and use
|
||||
# module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode.
|
||||
@ -1638,8 +1648,9 @@ def _pjit_call_impl_python(
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline, lowering_platforms=None,
|
||||
lowering_parameters=mlir.LoweringParameters(),
|
||||
pgle_profiler=pgle_profiler
|
||||
).compile(compile_options)
|
||||
pgle_profiler=pgle_profiler,
|
||||
compiler_options_kvs=compiler_options_kvs,
|
||||
).compile()
|
||||
|
||||
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
|
||||
# This check is expensive so only do it if enable_checks is on.
|
||||
@ -1693,7 +1704,7 @@ def _pjit_call_impl_python(
|
||||
@weakref_lru_cache
|
||||
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
out_layouts, resource_env, donated_invars, name,
|
||||
keep_unused, inline):
|
||||
keep_unused, inline, compiler_options_kvs):
|
||||
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
|
||||
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
|
||||
# the jaxpr defeating the purpose of weakref_lru_cache. So return a function
|
||||
@ -1706,15 +1717,15 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
|
||||
def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env,
|
||||
donated_invars, name, keep_unused, inline):
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def call_impl_cache_miss(*args_, **kwargs_):
|
||||
out_flat, compiled = _pjit_call_impl_python(
|
||||
*args, jaxpr=jaxpr, in_shardings=in_shardings,
|
||||
out_shardings=out_shardings, in_layouts=in_layouts,
|
||||
out_layouts=out_layouts, resource_env=resource_env,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
inline=inline, compiler_options_kvs=compiler_options_kvs)
|
||||
pgle_profiler = _read_pgle_profiler(jaxpr)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
||||
@ -1723,7 +1734,8 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline)
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs)
|
||||
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
donate_argnums=donated_argnums, donate_argnames=None,
|
||||
@ -1757,6 +1769,7 @@ def _pjit_lower_cached(
|
||||
name: str,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
compiler_options_kvs: tuple[tuple[str, Any], ...],
|
||||
*,
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
@ -1767,6 +1780,7 @@ def _pjit_lower_cached(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
keep_unused=keep_unused, context_mesh=mesh,
|
||||
compiler_options_kvs=compiler_options_kvs,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
@ -1911,7 +1925,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
|
||||
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, in_layouts, out_layouts, resource_env,
|
||||
donated_invars, keep_unused, inline):
|
||||
donated_invars, keep_unused, inline, compiler_options_kvs):
|
||||
effects = list(ctx.tokens_in.effects())
|
||||
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
|
||||
output_types = [mlir.token_type()] * len(effects) + output_types
|
||||
@ -1939,7 +1953,8 @@ mlir.register_lowering(pjit_p, _pjit_lowering)
|
||||
|
||||
def _pjit_batcher(axis_data, vals_in, dims_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
|
||||
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
|
||||
|
||||
@ -1974,7 +1989,8 @@ def _pjit_batcher(axis_data, vals_in, dims_in,
|
||||
donated_invars=donated_invars,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
|
||||
resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
|
||||
vals_in, vals_out, axes_out)
|
||||
@ -2024,7 +2040,8 @@ def _pjit_batcher_for_sharding(
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
|
||||
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
|
||||
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
|
||||
@ -2056,7 +2073,8 @@ def _pjit_jvp(primals_in, tangents_in,
|
||||
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
|
||||
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
|
||||
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
|
||||
@ -2069,7 +2087,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
def _pjit_partial_eval(trace, *in_tracers,
|
||||
jaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars,
|
||||
name, keep_unused, inline):
|
||||
name, keep_unused, inline, compiler_options_kvs):
|
||||
in_pvals = [t.pval for t in in_tracers]
|
||||
|
||||
known_ins = tuple(pv.is_known() for pv in in_pvals)
|
||||
@ -2127,7 +2145,8 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
in_layouts=keep_where(in_layouts, known_ins),
|
||||
out_layouts=known_out_layouts, resource_env=resource_env,
|
||||
donated_invars=keep_where(donated_invars, known_ins),
|
||||
name=name, keep_unused=keep_unused, inline=inline)
|
||||
name=name, keep_unused=keep_unused, inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
|
||||
assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals)
|
||||
|
||||
@ -2161,7 +2180,8 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
(False,) * num_residuals),
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
|
||||
unknown_out_avals = unknown_jaxpr.out_avals
|
||||
unknown_tracers_out = [
|
||||
@ -2241,7 +2261,8 @@ def _pjit_transpose_trace(fun, in_avals):
|
||||
|
||||
def _pjit_transpose(cts_in, *primals_in,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
resource_env, donated_invars, name, keep_unused, inline,
|
||||
compiler_options_kvs):
|
||||
def prune_type(ty, xs, maybe_zeros):
|
||||
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
|
||||
|
||||
@ -2292,7 +2313,8 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
|
||||
if attrs_tracked:
|
||||
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
|
||||
@ -2358,6 +2380,8 @@ def _pjit_pp_rule(eqn, context, settings):
|
||||
if (params['resource_env'] is None or
|
||||
params['resource_env'].physical_mesh.empty):
|
||||
del params['resource_env']
|
||||
if not params['compiler_options_kvs']:
|
||||
del params['compiler_options_kvs']
|
||||
|
||||
# Move name= to the front to make the resulting equation easier to scan.
|
||||
del params["name"]
|
||||
|
@ -3581,6 +3581,7 @@ def _pjit(*args: TfVal,
|
||||
name: str,
|
||||
keep_unused: bool,
|
||||
inline: bool,
|
||||
compiler_options_kvs,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
|
||||
del donated_invars
|
||||
|
@ -762,7 +762,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
|
||||
|
||||
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, resource_env, donated_invars, name,
|
||||
keep_unused, inline):
|
||||
keep_unused, inline, compiler_options_kvs):
|
||||
if any(donated_invars):
|
||||
raise NotImplementedError("sparse xla_call with donated_invars")
|
||||
|
||||
@ -798,7 +798,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
donated_invars=donated_invars,
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
|
||||
|
||||
sparse_rules_bcoo[pjit.pjit_p] = _pjit_sparse
|
||||
|
@ -1373,6 +1373,18 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
}
|
||||
)
|
||||
|
||||
def test_compile_options_jit(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
||||
f_jit = jit(
|
||||
f,
|
||||
compiler_options={
|
||||
"xla_embed_ir_in_executable": True,
|
||||
"xla_dump_max_hlo_modules": 200,
|
||||
"xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5,
|
||||
})(1.0) # doesn't crash.
|
||||
|
||||
def test_jit_lower_compile_with_compiler_options_invalid(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
@ -1390,7 +1402,21 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
|
||||
|
||||
def test_jit_lower_compile_with_compiler_options_multiple(self):
|
||||
def test_jit_compile_with_compiler_options_multiple(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
||||
with jtu.count_jit_compilation_cache_miss() as count:
|
||||
jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.)
|
||||
jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.)
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
# We should still error on invalid options after some valid compiles
|
||||
with self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"):
|
||||
jit(f, compiler_options={"invalid_key": "invalid_value"})(1.)
|
||||
|
||||
def test_lower_compile_with_compiler_options_multiple(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user