From fff33f90b209bdc930e1164f0fa7eac92243dbdf Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 1 Nov 2024 14:00:10 -0700 Subject: [PATCH] Add `compiler_options` argument to `jax.jit`. This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})` PiperOrigin-RevId: 692283964 --- jax/_src/api.py | 3 +- jax/_src/checkify.py | 4 +- jax/_src/interpreters/pxla.py | 38 ++++++------ jax/_src/pjit.py | 86 ++++++++++++++++++---------- jax/experimental/jax2tf/jax2tf.py | 1 + jax/experimental/sparse/transform.py | 5 +- tests/api_test.py | 28 ++++++++- 7 files changed, 109 insertions(+), 56 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 250743821..cc42a37b0 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -151,6 +151,7 @@ def jit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, + compiler_options: dict[str, Any] | None = None, ) -> pjit.JitWrapped: """Sets up ``fun`` for just-in-time compilation with XLA. @@ -280,7 +281,7 @@ def jit( return pjit.make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env=False) + keep_unused, inline, compiler_options, use_resource_env=False) @contextmanager diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 944bf303b..55db5d13e 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -898,7 +898,8 @@ error_checks[lax.while_p] = while_loop_error_check def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, inline, keep_unused): + resource_env, donated_invars, name, inline, keep_unused, + compiler_options_kvs): # jaxpr to checked_jaxpr err_vals, err_tree = jtu.tree_flatten(error) new_vals_in = [*err_vals, *vals_in] @@ -929,6 +930,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, name=name, inline=inline, keep_unused=keep_unused, + compiler_options_kvs=compiler_options_kvs, ) return tree_unflatten(out_tree, err_and_out) error_checks[pjit.pjit_p] = pjit_error_check diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e59d8c89e..04d479fb7 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2121,6 +2121,7 @@ def lower_sharding_computation( *, keep_unused: bool, context_mesh: mesh_lib.Mesh | None, + compiler_options_kvs: tuple[tuple[str, Any], ...], lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None, @@ -2247,6 +2248,7 @@ def lower_sharding_computation( module, donated_invars, platforms, + compiler_options_kvs, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_shardings=in_shardings, @@ -2298,11 +2300,13 @@ class MeshComputation(stages.XlaLowering): def __init__(self, name: str, hlo: ir.Module, donated_invars: Sequence[bool], platforms: Sequence[str], + compiler_options_kvs: tuple[tuple[str, Any], ...], **compile_args): self._name = name self._hlo = hlo self._donated_invars = donated_invars self._platforms = platforms + self._compiler_options_kvs = compiler_options_kvs self.compile_args = compile_args self._executable = None @@ -2312,11 +2316,14 @@ class MeshComputation(stages.XlaLowering): return self._hlo def compile(self, compiler_options=None) -> MeshExecutable: - if self._executable is None or compiler_options is not None: + t_compiler_options = (() if compiler_options is None else + tuple(compiler_options.items())) + compiler_options_kvs = self._compiler_options_kvs + t_compiler_options + if self._executable is None or compiler_options_kvs: executable = UnloadedMeshExecutable.from_hlo( self._name, self._hlo, **self.compile_args, - compiler_options=compiler_options) - if compiler_options is None: + compiler_options_kvs=compiler_options_kvs) + if not compiler_options_kvs: self._executable = executable return executable return self._executable @@ -2581,8 +2588,7 @@ def create_compile_options( else: xla_device_assignment = np_dev.reshape((num_replicas, num_partitions)) - fdo_profile = (None if compiler_options is None else - compiler_options.pop("fdo_profile", None)) + fdo_profile = compiler_options.pop("fdo_profile", None) compile_options = compiler.get_compile_options( num_replicas=num_replicas, @@ -2614,17 +2620,11 @@ def create_compile_options( def _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, - da, pmap_nreps, compiler_options_keys, - compiler_options_values, - pgle_profiler): + da, pmap_nreps, compiler_options_kvs, pgle_profiler): # One would normally just write: dev = np.array(device_assignment) # The formulation below is substantially faster if there are many devices. dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da))) - - if compiler_options_keys is None: - compiler_options = None - else: - compiler_options = dict(safe_zip(compiler_options_keys, compiler_options_values)) + compiler_options = dict(compiler_options_kvs) compile_options = create_compile_options( computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, @@ -2788,22 +2788,18 @@ class UnloadedMeshExecutable: committed: bool, in_layouts: MaybeLayout, out_layouts: MaybeLayout, + compiler_options_kvs: tuple[tuple[str, Any], ...], pmap_nreps: int = 1, mut: MutationData | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, - compiler_options=None, pgle_profiler: profiler.PGLEProfiler | None = None, intermediate_shardings: Sequence[JSharding] | None = None, context_mesh: mesh_lib.Mesh | None = None ) -> MeshExecutable: if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) - compiler_options_keys = tuple( - compiler_options.keys()) if compiler_options is not None else None - compiler_options_values = tuple( - compiler_options.values()) if compiler_options is not None else None if isinstance(device_assignment, xc.DeviceList): da = device_assignment else: @@ -2826,7 +2822,7 @@ class UnloadedMeshExecutable: hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, - compiler_options_keys, compiler_options_values, pgle_profiler) + compiler_options_kvs, pgle_profiler) if auto_spmd_lowering: assert mesh is not None @@ -2918,6 +2914,7 @@ class JitGlobalCppCacheKeys: out_layouts_treedef: PyTreeDef | None = None out_layouts_leaves: tuple[Any, ...] | None = None use_resource_env: bool = False + compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None @functools.cached_property def contains_explicit_attributes(self): @@ -2928,7 +2925,8 @@ class JitGlobalCppCacheKeys: any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or any(i is not None for i in self.in_layouts_leaves) or - any(o is not None for o in self.out_layouts_leaves)) + any(o is not None for o in self.out_layouts_leaves) or + self.compiler_options_kvs) def reflatten_outputs_for_dispatch(out_tree, out_flat): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 3d8df8664..604acfb39 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -164,6 +164,7 @@ class PjitInfo(NamedTuple): inline: bool abstracted_axes: Any | None use_resource_env: bool # False for jit, True for pjit + compiler_options_kvs: tuple[tuple[str, Any], ...] # Hash and compare PjitInfo by identity when used as a cache key. def __hash__(self): @@ -357,7 +358,8 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo): in_layouts_leaves=jit_info.in_layouts_leaves, out_layouts_treedef=jit_info.out_layouts_treedef, out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env) + use_resource_env=jit_info.use_resource_env, + compiler_options_kvs=jit_info.compiler_options_kvs) cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, cache_key, tree_util.dispatch_registry, @@ -398,7 +400,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, static_argnames: str | Iterable[str] | None, device: xc.Device | None, backend: str | None, abstracted_axes: Any | None, keep_unused: bool, - inline: bool, use_resource_env: bool) -> PjitInfo: + inline: bool, compiler_options: dict[str, Any] | None, + use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. Performs any preprocessing and validation of the arguments that we can do @@ -453,6 +456,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) + compiler_options_kvs = (() if compiler_options is None else + tuple(compiler_options.items())) return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -470,7 +475,8 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - use_resource_env=use_resource_env) + use_resource_env=use_resource_env, + compiler_options_kvs=compiler_options_kvs) def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): @@ -514,12 +520,13 @@ def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any, static_argnames: str | Iterable[str] | None, device: xc.Device | None, backend: str | None, abstracted_axes: Any | None, keep_unused: bool, - inline: bool, use_resource_env: bool) -> Any: + inline: bool, compiler_options: dict[str, Any] | None, + use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" jit_info = _parse_jit_arguments( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env) + keep_unused, inline, compiler_options, use_resource_env) return _make_jit_wrapper(fun, jit_info) @@ -676,6 +683,7 @@ def _infer_params_impl( name=fun_qual_name(flat_fun), keep_unused=ji.keep_unused, inline=ji.inline, + compiler_options_kvs=ji.compiler_options_kvs, ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), donated_invars, dbg.arg_names if dbg else None, len(consts), @@ -815,6 +823,7 @@ def pjit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, + compiler_options: dict[str, Any] | None = None, ) -> JitWrapped: """Makes ``fun`` compiled and automatically partitioned across multiple devices. @@ -987,7 +996,7 @@ def pjit( return make_jit( fun, in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums, static_argnames, device, backend, abstracted_axes, - keep_unused, inline, use_resource_env=True) + keep_unused, inline, compiler_options, use_resource_env=True) def hashable_pytree(pytree): @@ -1594,25 +1603,25 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, - lowering_platforms, lowering_parameters, pgle_profiler): + lowering_platforms, lowering_parameters, pgle_profiler, + compiler_options_kvs): in_shardings = _resolve_in_shardings(args, in_shardings) in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) - lowered = _pjit_lower( + return _pjit_lower( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, name, keep_unused, inline, + donated_invars, name, keep_unused, inline, compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) - return lowered def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): global _most_recent_pjit_call_executable - compile_options = None - pgle_profiler = None + pgle_compile_options, pgle_profiler = {}, None pgle_profiler_dict = _most_recent_pjit_call_executable.weak_pgle_profiler_dict if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: if jaxpr not in pgle_profiler_dict: @@ -1626,8 +1635,9 @@ def _pjit_call_impl_python( # be None. fdo_profile = pgle_profiler.consume_fdo_profile() if fdo_profile is not None: - compile_options = {'fdo_profile': fdo_profile} + pgle_compile_options['fdo_profile'] = fdo_profile + compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items()) # TODO(patrios): Do not pass mutable profile session through cached lowering # chain. Instead we need to move profilers dictionary to pxla module and use # module as key. Right now we can't do that since there is no way to evict _pjit_lower_cached cache for in PGLE mode. @@ -1638,8 +1648,9 @@ def _pjit_call_impl_python( donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline, lowering_platforms=None, lowering_parameters=mlir.LoweringParameters(), - pgle_profiler=pgle_profiler - ).compile(compile_options) + pgle_profiler=pgle_profiler, + compiler_options_kvs=compiler_options_kvs, + ).compile() _most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled # This check is expensive so only do it if enable_checks is on. @@ -1693,7 +1704,7 @@ def _pjit_call_impl_python( @weakref_lru_cache def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): + keep_unused, inline, compiler_options_kvs): # The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so # returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to # the jaxpr defeating the purpose of weakref_lru_cache. So return a function @@ -1706,15 +1717,15 @@ def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, def _pjit_call_impl(*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, - donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): def call_impl_cache_miss(*args_, **kwargs_): out_flat, compiled = _pjit_call_impl_python( *args, jaxpr=jaxpr, in_shardings=in_shardings, out_shardings=out_shardings, in_layouts=in_layouts, out_layouts=out_layouts, resource_env=resource_env, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, compiler_options_kvs=compiler_options_kvs) pgle_profiler = _read_pgle_profiler(jaxpr) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects, @@ -1723,7 +1734,8 @@ def _pjit_call_impl(*args, jaxpr, f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline) + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs) donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) cache_key = pxla.JitGlobalCppCacheKeys( donate_argnums=donated_argnums, donate_argnames=None, @@ -1757,6 +1769,7 @@ def _pjit_lower_cached( name: str, keep_unused: bool, inline: bool, + compiler_options_kvs: tuple[tuple[str, Any], ...], *, lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, @@ -1767,6 +1780,7 @@ def _pjit_lower_cached( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), keep_unused=keep_unused, context_mesh=mesh, + compiler_options_kvs=compiler_options_kvs, lowering_platforms=lowering_platforms, lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) @@ -1911,7 +1925,7 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, - donated_invars, keep_unused, inline): + donated_invars, keep_unused, inline, compiler_options_kvs): effects = list(ctx.tokens_in.effects()) output_types = map(mlir.aval_to_ir_type, ctx.avals_out) output_types = [mlir.token_type()] * len(effects) + output_types @@ -1939,7 +1953,8 @@ mlir.register_lowering(pjit_p, _pjit_lowering) def _pjit_batcher(axis_data, vals_in, dims_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) @@ -1974,7 +1989,8 @@ def _pjit_batcher(axis_data, vals_in, dims_in, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs( vals_in, vals_out, axes_out) @@ -2024,7 +2040,8 @@ def _pjit_batcher_for_sharding( def _pjit_jvp(primals_in, tangents_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): if any(isinstance(c, core.MutableArray) for c in jaxpr.consts): jaxpr, mut_primals = pxla._move_mutable_consts(jaxpr) mut_tangents = map(ad_util.zeros_like_jaxval, mut_primals) @@ -2056,7 +2073,8 @@ def _pjit_jvp(primals_in, tangents_in, donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)]) assert len(primals_out) == len(jaxpr.jaxpr.outvars) @@ -2069,7 +2087,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, - name, keep_unused, inline): + name, keep_unused, inline, compiler_options_kvs): in_pvals = [t.pval for t in in_tracers] known_ins = tuple(pv.is_known() for pv in in_pvals) @@ -2127,7 +2145,8 @@ def _pjit_partial_eval(trace, *in_tracers, in_layouts=keep_where(in_layouts, known_ins), out_layouts=known_out_layouts, resource_env=resource_env, donated_invars=keep_where(donated_invars, known_ins), - name=name, keep_unused=keep_unused, inline=inline) + name=name, keep_unused=keep_unused, inline=inline, + compiler_options_kvs=compiler_options_kvs) assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals) assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals) @@ -2161,7 +2180,8 @@ def _pjit_partial_eval(trace, *in_tracers, (False,) * num_residuals), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()] unknown_out_avals = unknown_jaxpr.out_avals unknown_tracers_out = [ @@ -2241,7 +2261,8 @@ def _pjit_transpose_trace(fun, in_avals): def _pjit_transpose(cts_in, *primals_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, - resource_env, donated_invars, name, keep_unused, inline): + resource_env, donated_invars, name, keep_unused, inline, + compiler_options_kvs): def prune_type(ty, xs, maybe_zeros): return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) @@ -2292,7 +2313,8 @@ def _pjit_transpose(cts_in, *primals_in, donated_invars=(False,) * len(primals_and_nz_cts_in), name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) if attrs_tracked: final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) @@ -2358,6 +2380,8 @@ def _pjit_pp_rule(eqn, context, settings): if (params['resource_env'] is None or params['resource_env'].physical_mesh.empty): del params['resource_env'] + if not params['compiler_options_kvs']: + del params['compiler_options_kvs'] # Move name= to the front to make the resulting equation easier to scan. del params["name"] diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 972d1b3dd..783661e71 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3581,6 +3581,7 @@ def _pjit(*args: TfVal, name: str, keep_unused: bool, inline: bool, + compiler_options_kvs, _in_avals: Sequence[core.ShapedArray], _out_aval: Sequence[core.ShapedArray]) -> TfVal: del donated_invars diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 5348dd62a..7c5a96650 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -762,7 +762,7 @@ sparse_rules_bcoo[lax.while_p] = _while_sparse def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, - keep_unused, inline): + keep_unused, inline, compiler_options_kvs): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") @@ -798,7 +798,8 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, donated_invars=donated_invars, name=name, keep_unused=keep_unused, - inline=inline) + inline=inline, + compiler_options_kvs=compiler_options_kvs) return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat)) sparse_rules_bcoo[pjit.pjit_p] = _pjit_sparse diff --git a/tests/api_test.py b/tests/api_test.py index e98f4299c..bb1d24729 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1373,6 +1373,18 @@ class JitTest(jtu.BufferDonationTestCase): } ) + def test_compile_options_jit(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + f_jit = jit( + f, + compiler_options={ + "xla_embed_ir_in_executable": True, + "xla_dump_max_hlo_modules": 200, + "xla_gpu_auto_spmd_partitioning_memory_budget_ratio": 0.5, + })(1.0) # doesn't crash. + def test_jit_lower_compile_with_compiler_options_invalid(self): def f(x): return jnp.sqrt(x ** 2) + 1. @@ -1390,7 +1402,21 @@ class JitTest(jtu.BufferDonationTestCase): lambda: lowered.compile( compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) - def test_jit_lower_compile_with_compiler_options_multiple(self): + def test_jit_compile_with_compiler_options_multiple(self): + def f(x): + return jnp.sqrt(x ** 2) + 1. + + with jtu.count_jit_compilation_cache_miss() as count: + jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.) + jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.) + self.assertEqual(count[0], 2) + + # We should still error on invalid options after some valid compiles + with self.assertRaisesRegex( + xla_extension.XlaRuntimeError, "No such compile option: 'invalid_key'"): + jit(f, compiler_options={"invalid_key": "invalid_value"})(1.) + + def test_lower_compile_with_compiler_options_multiple(self): def f(x): return jnp.sqrt(x ** 2) + 1.