mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add specialize
on jax.jit
so that we can delete the duplicate code in jax.make_jaxpr
.
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized` PiperOrigin-RevId: 628748620
This commit is contained in:
parent
06760511b2
commit
1956ff7d7b
@ -2370,43 +2370,25 @@ def make_jaxpr(fun: Callable,
|
||||
g:f32[] = mul f c
|
||||
in (g,) }
|
||||
"""
|
||||
check_callable(fun)
|
||||
static_argnums = _ensure_index_tuple(static_argnums)
|
||||
|
||||
def abstractify(args, kwargs):
|
||||
flat_args, in_tree = tree_flatten((args, kwargs))
|
||||
if abstracted_axes is None:
|
||||
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
||||
else:
|
||||
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
||||
in_avals, keep_inputs = unzip2(in_type)
|
||||
return in_avals, in_tree, keep_inputs
|
||||
try:
|
||||
hash(fun)
|
||||
weakref.ref(fun)
|
||||
except TypeError:
|
||||
fun = partial(fun)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def make_jaxpr_f(*args, **kwargs):
|
||||
f = lu.wrap_init(fun)
|
||||
if static_argnums:
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
||||
f, args = argnums_partial(f, dyn_argnums, args)
|
||||
in_avals, in_tree, keep_inputs = abstractify(args, kwargs)
|
||||
in_type = tuple(zip(in_avals, keep_inputs))
|
||||
f, out_tree = flatten_fun(f, in_tree)
|
||||
f = lu.annotate(f, in_type)
|
||||
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
f, debug_info=debug_info)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
specialized = jit(fun, static_argnums=static_argnums,
|
||||
abstracted_axes=abstracted_axes).specialize(*args, **kwargs)
|
||||
if return_shape:
|
||||
out_avals, _ = unzip2(out_type)
|
||||
out_shapes_flat = [
|
||||
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
||||
return closed_jaxpr, tree_unflatten(out_tree(), out_shapes_flat)
|
||||
return closed_jaxpr
|
||||
out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None))
|
||||
for o in specialized.jaxpr.out_avals]
|
||||
return specialized.jaxpr, tree_unflatten(specialized.out_tree, out)
|
||||
return specialized.jaxpr
|
||||
|
||||
make_jaxpr_f.__module__ = "jax"
|
||||
if hasattr(fun, "__qualname__"):
|
||||
|
@ -462,7 +462,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
'_experimental_lowering_parameters', mlir.LoweringParameters())
|
||||
|
||||
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
|
||||
donated_invars, arg_names, ()) = _infer_params(jit_info, args, kwargs)
|
||||
donated_invars, arg_names, _) = _infer_params(jit_info, args, kwargs)
|
||||
try:
|
||||
lowering = _resolve_and_lower(
|
||||
args_flat, **params, lowering_parameters=lowering_parameters)
|
||||
@ -490,9 +490,15 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
for x, s in zip(params['jaxpr'].out_avals, out_s)]
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
@api_boundary
|
||||
def specialize(*args, **kwargs) -> stages.Specialized:
|
||||
_, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
|
||||
return stages.Specialized(params['jaxpr'], out_tree)
|
||||
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
wrapped.lower = lower
|
||||
wrapped.eval_shape = eval_shape
|
||||
wrapped.specialize = specialize
|
||||
return wrapped
|
||||
|
||||
|
||||
@ -686,6 +692,9 @@ class JitWrapped(stages.Wrapped):
|
||||
"""See ``jax.eval_shape``."""
|
||||
raise NotImplementedError
|
||||
|
||||
def specialize(self, *args, **kwargs) -> stages.Specialized:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# in_shardings and out_shardings can't be None as the default value
|
||||
# because `None` means that the input is fully replicated.
|
||||
|
@ -413,6 +413,15 @@ class CompiledCallParams(NamedTuple):
|
||||
out_tree: tree_util.PyTreeDef
|
||||
|
||||
|
||||
# TODO(yashkatariya): Make Specialized inherit from `Stage`.
|
||||
class Specialized:
|
||||
__slots__ = ["jaxpr", "out_tree"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, out_tree):
|
||||
self.jaxpr = jaxpr
|
||||
self.out_tree = out_tree
|
||||
|
||||
|
||||
class Compiled(Stage):
|
||||
"""Compiled representation of a function specialized to types/values.
|
||||
|
||||
|
@ -30,4 +30,5 @@ from jax._src.stages import (
|
||||
Lowered as Lowered,
|
||||
Wrapped as Wrapped,
|
||||
ArgInfo as ArgInfo,
|
||||
Specialized as Specialized
|
||||
)
|
||||
|
@ -2511,7 +2511,8 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
scan_fun = lambda c, xs: lax.scan(f, c, xs)
|
||||
|
||||
def new_jaxpr():
|
||||
jaxpr = jax.make_jaxpr(scan_fun)(c, xs).jaxpr
|
||||
# partial avoids a cache_hit in make_jaxpr.
|
||||
jaxpr = jax.make_jaxpr(partial(scan_fun))(c, xs).jaxpr
|
||||
scan = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'scan')
|
||||
return jaxpr, scan
|
||||
|
||||
|
@ -4056,6 +4056,14 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash
|
||||
|
||||
def test_jit_specialize(self):
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
specialized = jax.jit(f).specialize(jnp.arange(8))
|
||||
self.assertLen(specialized.jaxpr.eqns, 1)
|
||||
self.assertEqual(specialized.out_tree.num_leaves, 1)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user