Add compiler_options argument to jax.jit.

This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})`

PiperOrigin-RevId: 692283964
This commit is contained in:
Yash Katariya 2024-11-01 14:00:10 -07:00 committed by jax authors
parent 07858fa98d
commit fff33f90b2
7 changed files with 109 additions and 56 deletions

View File

@ -151,6 +151,7 @@ def jit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
compiler_options: dict[str, Any] | None = None,
) -> pjit.JitWrapped:
"""Sets up ``fun`` for just-in-time compilation with XLA.
@ -280,7 +281,7 @@ def jit(
return pjit.make_jit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env=False)
keep_unused, inline, compiler_options, use_resource_env=False)
@contextmanager

View File

@ -898,7 +898,8 @@ error_checks[lax.while_p] = while_loop_error_check
def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
in_shardings, out_shardings,
in_layouts, out_layouts,
resource_env, donated_invars, name, inline, keep_unused):
resource_env, donated_invars, name, inline, keep_unused,
compiler_options_kvs):
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
@ -929,6 +930,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
name=name,
inline=inline,
keep_unused=keep_unused,
compiler_options_kvs=compiler_options_kvs,
)
return tree_unflatten(out_tree, err_and_out)
error_checks[pjit.pjit_p] = pjit_error_check

View File

@ -2121,6 +2121,7 @@ def lower_sharding_computation(
*,
keep_unused: bool,
context_mesh: mesh_lib.Mesh | None,
compiler_options_kvs: tuple[tuple[str, Any], ...],
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None,
@ -2247,6 +2248,7 @@ def lower_sharding_computation(
module,
donated_invars,
platforms,
compiler_options_kvs,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
@ -2298,11 +2300,13 @@ class MeshComputation(stages.XlaLowering):
def __init__(self, name: str, hlo: ir.Module,
donated_invars: Sequence[bool], platforms: Sequence[str],
compiler_options_kvs: tuple[tuple[str, Any], ...],
**compile_args):
self._name = name
self._hlo = hlo
self._donated_invars = donated_invars
self._platforms = platforms
self._compiler_options_kvs = compiler_options_kvs
self.compile_args = compile_args
self._executable = None
@ -2312,11 +2316,14 @@ class MeshComputation(stages.XlaLowering):
return self._hlo
def compile(self, compiler_options=None) -> MeshExecutable:
if self._executable is None or compiler_options is not None:
t_compiler_options = (() if compiler_options is None else
tuple(compiler_options.items()))
compiler_options_kvs = self._compiler_options_kvs + t_compiler_options
if self._executable is None or compiler_options_kvs:
executable = UnloadedMeshExecutable.from_hlo(
self._name, self._hlo, **self.compile_args,
compiler_options=compiler_options)
if compiler_options is None:
compiler_options_kvs=compiler_options_kvs)
if not compiler_options_kvs:
self._executable = executable
return executable
return self._executable
@ -2581,8 +2588,7 @@ def create_compile_options(
else:
xla_device_assignment = np_dev.reshape((num_replicas, num_partitions))
fdo_profile = (None if compiler_options is None else
compiler_options.pop("fdo_profile", None))
fdo_profile = compiler_options.pop("fdo_profile", None)
compile_options = compiler.get_compile_options(
num_replicas=num_replicas,
@ -2614,17 +2620,11 @@ def create_compile_options(
def _cached_compilation(computation, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, host_callbacks, backend,
da, pmap_nreps, compiler_options_keys,
compiler_options_values,
pgle_profiler):
da, pmap_nreps, compiler_options_kvs, pgle_profiler):
# One would normally just write: dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da)))
if compiler_options_keys is None:
compiler_options = None
else:
compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values))
compiler_options = dict(compiler_options_kvs)
compile_options = create_compile_options(
computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,
@ -2788,22 +2788,18 @@ class UnloadedMeshExecutable:
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
compiler_options_kvs: tuple[tuple[str, Any], ...],
pmap_nreps: int = 1,
mut: MutationData | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
compiler_options=None,
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
context_mesh: mesh_lib.Mesh | None = None
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
compiler_options_keys = tuple(
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(
compiler_options.values()) if compiler_options is not None else None
if isinstance(device_assignment, xc.DeviceList):
da = device_assignment
else:
@ -2826,7 +2822,7 @@ class UnloadedMeshExecutable:
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_keys, compiler_options_values, pgle_profiler)
compiler_options_kvs, pgle_profiler)
if auto_spmd_lowering:
assert mesh is not None
@ -2918,6 +2914,7 @@ class JitGlobalCppCacheKeys:
out_layouts_treedef: PyTreeDef | None = None
out_layouts_leaves: tuple[Any, ...] | None = None
use_resource_env: bool = False
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
@functools.cached_property
def contains_explicit_attributes(self):
@ -2928,7 +2925,8 @@ class JitGlobalCppCacheKeys:
any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or
any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or
any(i is not None for i in self.in_layouts_leaves) or
any(o is not None for o in self.out_layouts_leaves))
any(o is not None for o in self.out_layouts_leaves) or
self.compiler_options_kvs)
def reflatten_outputs_for_dispatch(out_tree, out_flat):

View File

@ -164,6 +164,7 @@ class PjitInfo(NamedTuple):
inline: bool
abstracted_axes: Any | None
use_resource_env: bool # False for jit, True for pjit
compiler_options_kvs: tuple[tuple[str, Any], ...]
# Hash and compare PjitInfo by identity when used as a cache key.
def __hash__(self):
@ -357,7 +358,8 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env)
use_resource_env=jit_info.use_resource_env,
compiler_options_kvs=jit_info.compiler_options_kvs)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry,
@ -398,7 +400,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, use_resource_env: bool) -> PjitInfo:
inline: bool, compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> PjitInfo:
"""Parses the arguments to jit/pjit.
Performs any preprocessing and validation of the arguments that we can do
@ -453,6 +456,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)
compiler_options_kvs = (() if compiler_options is None else
tuple(compiler_options.items()))
return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
@ -470,7 +475,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
use_resource_env=use_resource_env)
use_resource_env=use_resource_env,
compiler_options_kvs=compiler_options_kvs)
def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
@ -514,12 +520,13 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, use_resource_env: bool) -> Any:
inline: bool, compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> Any:
"""jit() and pjit() are thin wrappers around this function."""
jit_info = _parse_jit_arguments(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env)
keep_unused, inline, compiler_options, use_resource_env)
return _make_jit_wrapper(fun, jit_info)
@ -676,6 +683,7 @@ def _infer_params_impl(
name=fun_qual_name(flat_fun),
keep_unused=ji.keep_unused,
inline=ji.inline,
compiler_options_kvs=ji.compiler_options_kvs,
)
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
donated_invars, dbg.arg_names if dbg else None, len(consts),
@ -815,6 +823,7 @@ def pjit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
compiler_options: dict[str, Any] | None = None,
) -> JitWrapped:
"""Makes ``fun`` compiled and automatically partitioned across multiple devices.
@ -987,7 +996,7 @@ def pjit(
return make_jit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env=True)
keep_unused, inline, compiler_options, use_resource_env=True)
def hashable_pytree(pytree):
@ -1594,25 +1603,25 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
def _resolve_and_lower(
args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name, keep_unused, inline,
lowering_platforms, lowering_parameters, pgle_profiler):
lowering_platforms, lowering_parameters, pgle_profiler,
compiler_options_kvs):
in_shardings = _resolve_in_shardings(args, in_shardings)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
jaxpr.in_avals)
lowered = _pjit_lower(
return _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, name, keep_unused, inline,
donated_invars, name, keep_unused, inline, compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
return lowered
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
global _most_recent_pjit_call_executable
compile_options = None
pgle_profiler = None
pgle_compile_options, pgle_profiler = {}, None
pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
if jaxpr not in pgle_profiler_dict:
@ -1626,8 +1635,9 @@ def _pjit_call_impl_python(
# be None.
fdo_profile = pgle_profiler.consume_fdo_profile()
if fdo_profile is not None:
compile_options = {'fdo_profile': fdo_profile}
pgle_compile_options['fdo_profile'] = fdo_profile
compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
# TODO(patrios): Do not pass mutable profile session through cached lowering
# chain. Instead we need to move profilers dictionary to pxla module and use
# module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode.
@ -1638,8 +1648,9 @@ def _pjit_call_impl_python(
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline, lowering_platforms=None,
lowering_parameters=mlir.LoweringParameters(),
pgle_profiler=pgle_profiler
).compile(compile_options)
pgle_profiler=pgle_profiler,
compiler_options_kvs=compiler_options_kvs,
).compile()
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
# This check is expensive so only do it if enable_checks is on.
@ -1693,7 +1704,7 @@ def _pjit_call_impl_python(
@weakref_lru_cache
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, resource_env, donated_invars, name,
keep_unused, inline):
keep_unused, inline, compiler_options_kvs):
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
# the jaxpr defeating the purpose of weakref_lru_cache. So return a function
@ -1706,15 +1717,15 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
resource_env,
donated_invars, name, keep_unused, inline):
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
def call_impl_cache_miss(*args_, **kwargs_):
out_flat, compiled = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
inline=inline, compiler_options_kvs=compiler_options_kvs)
pgle_profiler = _read_pgle_profiler(jaxpr)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
@ -1723,7 +1734,8 @@ def _pjit_call_impl(*args, jaxpr,
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=donated_argnums, donate_argnames=None,
@ -1757,6 +1769,7 @@ def _pjit_lower_cached(
name: str,
keep_unused: bool,
inline: bool,
compiler_options_kvs: tuple[tuple[str, Any], ...],
*,
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
@ -1767,6 +1780,7 @@ def _pjit_lower_cached(
jaxpr, api_name, name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
keep_unused=keep_unused, context_mesh=mesh,
compiler_options_kvs=compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
@ -1911,7 +1925,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, in_layouts, out_layouts, resource_env,
donated_invars, keep_unused, inline):
donated_invars, keep_unused, inline, compiler_options_kvs):
effects = list(ctx.tokens_in.effects())
output_types = map(mlir.aval_to_ir_type, ctx.avals_out)
output_types = [mlir.token_type()] * len(effects) + output_types
@ -1939,7 +1953,8 @@ mlir.register_lowering(pjit_p, _pjit_lowering)
def _pjit_batcher(axis_data, vals_in, dims_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
@ -1974,7 +1989,8 @@ def _pjit_batcher(axis_data, vals_in, dims_in,
donated_invars=donated_invars,
name=name,
keep_unused=keep_unused,
inline=inline)
inline=inline,
compiler_options_kvs=compiler_options_kvs)
resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
vals_in, vals_out, axes_out)
@ -2024,7 +2040,8 @@ def _pjit_batcher_for_sharding(
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
if any(isinstance(c, core.MutableArray) for c in jaxpr.consts):
jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr)
mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals)
@ -2056,7 +2073,8 @@ def _pjit_jvp(primals_in, tangents_in,
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
name=name,
keep_unused=keep_unused,
inline=inline)
inline=inline,
compiler_options_kvs=compiler_options_kvs)
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
@ -2069,7 +2087,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars,
name, keep_unused, inline):
name, keep_unused, inline, compiler_options_kvs):
in_pvals = [t.pval for t in in_tracers]
known_ins = tuple(pv.is_known() for pv in in_pvals)
@ -2127,7 +2145,8 @@ def _pjit_partial_eval(trace, *in_tracers,
in_layouts=keep_where(in_layouts, known_ins),
out_layouts=known_out_layouts, resource_env=resource_env,
donated_invars=keep_where(donated_invars, known_ins),
name=name, keep_unused=keep_unused, inline=inline)
name=name, keep_unused=keep_unused, inline=inline,
compiler_options_kvs=compiler_options_kvs)
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals)
@ -2161,7 +2180,8 @@ def _pjit_partial_eval(trace, *in_tracers,
(False,) * num_residuals),
name=name,
keep_unused=keep_unused,
inline=inline)
inline=inline,
compiler_options_kvs=compiler_options_kvs)
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_out_avals = unknown_jaxpr.out_avals
unknown_tracers_out = [
@ -2241,7 +2261,8 @@ def _pjit_transpose_trace(fun, in_avals):
def _pjit_transpose(cts_in, *primals_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline):
resource_env, donated_invars, name, keep_unused, inline,
compiler_options_kvs):
def prune_type(ty, xs, maybe_zeros):
return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
@ -2292,7 +2313,8 @@ def _pjit_transpose(cts_in, *primals_in,
donated_invars=(False,) * len(primals_and_nz_cts_in),
name=name,
keep_unused=keep_unused,
inline=inline)
inline=inline,
compiler_options_kvs=compiler_options_kvs)
if attrs_tracked:
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
@ -2358,6 +2380,8 @@ def _pjit_pp_rule(eqn, context, settings):
if (params['resource_env'] is None or
params['resource_env'].physical_mesh.empty):
del params['resource_env']
if not params['compiler_options_kvs']:
del params['compiler_options_kvs']
# Move name= to the front to make the resulting equation easier to scan.
del params["name"]

View File

@ -3581,6 +3581,7 @@ def _pjit(*args: TfVal,
name: str,
keep_unused: bool,
inline: bool,
compiler_options_kvs,
_in_avals: Sequence[core.ShapedArray],
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
del donated_invars

View File

@ -762,7 +762,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, resource_env, donated_invars, name,
keep_unused, inline):
keep_unused, inline, compiler_options_kvs):
if any(donated_invars):
raise NotImplementedError("sparse xla_call with donated_invars")
@ -798,7 +798,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
donated_invars=donated_invars,
name=name,
keep_unused=keep_unused,
inline=inline)
inline=inline,
compiler_options_kvs=compiler_options_kvs)
return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
sparse_rules_bcoo[pjit.pjit_p] = _pjit_sparse

View File

@ -1373,6 +1373,18 @@ class JitTest(jtu.BufferDonationTestCase):
}
)
def test_compile_options_jit(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
f_jit = jit(
f,
compiler_options={
"xla_embed_ir_in_executable": True,
"xla_dump_max_hlo_modules": 200,
"xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5,
})(1.0) # doesn't crash.
def test_jit_lower_compile_with_compiler_options_invalid(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
@ -1390,7 +1402,21 @@ class JitTest(jtu.BufferDonationTestCase):
lambda: lowered.compile(
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
def test_jit_lower_compile_with_compiler_options_multiple(self):
def test_jit_compile_with_compiler_options_multiple(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
with jtu.count_jit_compilation_cache_miss() as count:
jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.)
jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.)
self.assertEqual(count[0], 2)
# We should still error on invalid options after some valid compiles
with self.assertRaisesRegex(
xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"):
jit(f, compiler_options={"invalid_key": "invalid_value"})(1.)
def test_lower_compile_with_compiler_options_multiple(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.