Add specialize on jax.jit so that we can delete the duplicate code in jax.make_jaxpr.

You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`

PiperOrigin-RevId: 628748620
This commit is contained in:
Yash Katariya 2024-04-27 18:57:16 -07:00 committed by jax authors
parent 06760511b2
commit 1956ff7d7b
6 changed files with 41 additions and 31 deletions

View File

@ -2370,43 +2370,25 @@ def make_jaxpr(fun: Callable,
g:f32[] = mul f c
in (g,) }
"""
check_callable(fun)
static_argnums = _ensure_index_tuple(static_argnums)
def abstractify(args, kwargs):
flat_args, in_tree = tree_flatten((args, kwargs))
if abstracted_axes is None:
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
else:
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
in_avals, keep_inputs = unzip2(in_type)
return in_avals, in_tree, keep_inputs
try:
hash(fun)
weakref.ref(fun)
except TypeError:
fun = partial(fun)
@wraps(fun)
@api_boundary
def make_jaxpr_f(*args, **kwargs):
f = lu.wrap_init(fun)
if static_argnums:
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f, args = argnums_partial(f, dyn_argnums, args)
in_avals, in_tree, keep_inputs = abstractify(args, kwargs)
in_type = tuple(zip(in_avals, keep_inputs))
f, out_tree = flatten_fun(f, in_tree)
f = lu.annotate(f, in_type)
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
f, debug_info=debug_info)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
specialized = jit(fun, static_argnums=static_argnums,
abstracted_axes=abstracted_axes).specialize(*args, **kwargs)
if return_shape:
out_avals, _ = unzip2(out_type)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
return closed_jaxpr
out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None))
for o in specialized.jaxpr.out_avals]
return specialized.jaxpr, tree_unflatten(specialized.out_tree, out)
return specialized.jaxpr
make_jaxpr_f.__module__ = "jax"
if hasattr(fun, "__qualname__"):

View File

@ -462,7 +462,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
'_experimental_lowering_parameters', mlir.LoweringParameters())
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars, arg_names, ()) = _infer_params(jit_info, args, kwargs)
donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs)
try:
lowering = _resolve_and_lower(
args_flat, **params, lowering_parameters=lowering_parameters)
@ -490,9 +490,15 @@ def _make_jit_wrapper(jit_info: PjitInfo):
for x, s in zip(params['jaxpr'].out_avals, out_s)]
return tree_unflatten(out_tree, out)
@api_boundary
def specialize(*args, **kwargs) -> stages.Specialized:
_, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
return stages.Specialized(params['jaxpr'], out_tree)
wrapped = _cpp_pjit(jit_info)
wrapped.lower = lower
wrapped.eval_shape = eval_shape
wrapped.specialize = specialize
return wrapped
@ -686,6 +692,9 @@ class JitWrapped(stages.Wrapped):
"""See ``jax.eval_shape``."""
raise NotImplementedError
def specialize(self, *args, **kwargs) -> stages.Specialized:
raise NotImplementedError
# in_shardings and out_shardings can't be None as the default value
# because `None` means that the input is fully replicated.

View File

@ -413,6 +413,15 @@ class CompiledCallParams(NamedTuple):
out_tree: tree_util.PyTreeDef
# TODO(yashkatariya): Make Specialized inherit from `Stage`.
class Specialized:
__slots__ = ["jaxpr", "out_tree"]
def __init__(self, jaxpr: core.ClosedJaxpr, out_tree):
self.jaxpr = jaxpr
self.out_tree = out_tree
class Compiled(Stage):
"""Compiled representation of a function specialized to types/values.

View File

@ -30,4 +30,5 @@ from jax._src.stages import (
Lowered as Lowered,
Wrapped as Wrapped,
ArgInfo as ArgInfo,
Specialized as Specialized
)

View File

@ -2511,7 +2511,8 @@ class LaxControlFlowTest(jtu.JaxTestCase):
scan_fun = lambda c, xs: lax.scan(f, c, xs)
def new_jaxpr():
jaxpr = jax.make_jaxpr(scan_fun)(c, xs).jaxpr
# partial avoids a cache_hit in make_jaxpr.
jaxpr = jax.make_jaxpr(partial(scan_fun))(c, xs).jaxpr
scan = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'scan')
return jaxpr, scan

View File

@ -4056,6 +4056,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash
def test_jit_specialize(self):
def f(x):
return x * 2
specialized = jax.jit(f).specialize(jnp.arange(8))
self.assertLen(specialized.jaxpr.eqns, 1)
self.assertEqual(specialized.out_tree.num_leaves, 1)
class TempSharding(Sharding):