diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bcbab7767..76a9f6fa5 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1276,20 +1276,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens))) return out_nodes, tokens_out -def _xla_call_lower(ctx, *args, - backend=None, name, call_jaxpr, donated_invars, inline=None, - device=None, keep_unused=None): - del device, donated_invars, inline, keep_unused # Ignored. - out_nodes, tokens = _call_lowering( - name, util.wrap_name(name, "jit"), call_jaxpr, backend, - ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, - *args, dim_var_values=ctx.dim_var_values) - ctx.set_tokens_out(tokens) - return out_nodes - -register_lowering(xla.xla_call_p, _xla_call_lower) - -def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): +def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): out_nodes, tokens = _call_lowering( name, name, call_jaxpr, backend, ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args, @@ -1297,9 +1284,9 @@ def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr): ctx.set_tokens_out(tokens) return out_nodes -register_lowering(core.call_p, partial(_core_call_lowering, name="core_call")) +register_lowering(core.call_p, partial(core_call_lowering, name="core_call")) register_lowering(core.closed_call_p, - partial(_core_call_lowering, name="core_closed_call")) + partial(core_call_lowering, name="core_closed_call")) def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *, broadcast_dimensions) -> ir.Value: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1c8a37c58..2ad6ef83c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -902,12 +902,7 @@ class MapTrace(core.Trace): return MapTracer(self, outvals, out_shard_axes) def process_call(self, call_primitive, fun, tracers, params): - if call_primitive is not xla.xla_call_p: raise NotImplementedError - bind = HashableFunction( - lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs), - (call_primitive, fun)) - fake_primitive = FakePrimitive(multiple_results=True, bind=bind) - return self.process_primitive(fake_primitive, tracers, params) + raise NotImplementedError def process_map(self, call_primitive, fun, tracers, params): if params['devices'] is not None: @@ -1998,15 +1993,14 @@ def _pmap_dce_rule(used_outputs, eqn): # Set param update handlers to update `donated_invars` just like xla_call_p -pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p] +pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \ partial(pe.call_partial_eval_custom_rule, 'call_jaxpr', _pmap_partial_eval_custom_params_updater, res_aval=_pmap_partial_eval_custom_res_maker) pe.dce_rules[xla_pmap_p] = _pmap_dce_rule -ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p] -ad.call_transpose_param_updaters[xla_pmap_p] = \ - ad.call_transpose_param_updaters[xla.xla_call_p] +ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params +ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 826b94cd2..6623069b2 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -28,17 +28,14 @@ from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol, import numpy as np from jax.config import config -from jax.interpreters import partial_eval as pe from jax._src import core from jax._src import device_array from jax._src import dtypes -from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ConcreteArray, ShapedArray -from jax._src.interpreters import ad -from jax._src.util import (safe_zip, safe_map, partition_list) +from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -157,7 +154,6 @@ xla_shape_handlers: Dict[Type[core.AbstractValue], xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),) - # IR constants # TODO(mattjj): try to remove this canonicalize_dtype stuff @@ -348,11 +344,11 @@ def jaxpr_collectives(jaxpr): ### xla_call underlying jit - +# TODO(yashkatariya): Remove after 1 month from March 23, 2023. xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call') -xla_call = xla_call_p.bind -def _xla_call_partial_eval_update_params( + +def xla_call_partial_eval_update_params( params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int ) -> core.ParamDict: donated_invars = params['donated_invars'] @@ -366,57 +362,18 @@ def _xla_call_partial_eval_update_params( # Any new inputs are prepended to the left, so mark those as not donated. donated_invars = [False] * num_new_inputs + donated_invars return dict(params, donated_invars=tuple(donated_invars)) -pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params -def _xla_call_jvp_update_params(params, nz_tangents): +def xla_call_jvp_update_params(params, nz_tangents): donated_invars = params['donated_invars'] donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz] new_donated_invars = (*donated_invars, *donated_tangents) return dict(params, donated_invars=new_donated_invars) -ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params -def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): +def xla_call_transpose_update_params(params, undef_primals, nonzero_cts): donated_invars = params['donated_invars'] donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u] donated_cotangents = [False for nz in nonzero_cts if nz] return dict(params, donated_invars=(*donated_primals, *donated_cotangents)) -ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params - - -ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p) - - -def _xla_call_partial_eval_custom_params_updater( - unks_in: Sequence[bool], inst_in: Sequence[bool], - kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool], - num_res: int, params_known: dict, params_staged: dict - ) -> Tuple[dict, dict]: - # pruned inputs to jaxpr_known according to unks_in, so prune donated_invars - donated_known, _ = partition_list(unks_in, params_known['donated_invars']) - new_params_known = dict(params_known, donated_invars=tuple(donated_known)) - # added num_res new inputs to jaxpr_staged, so extend donated_invars - _, donated_staged_ = partition_list(inst_in, params_staged['donated_invars']) - donated_staged = [False] * num_res + donated_staged_ - new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged)) - return new_params_known, new_params_staged -pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \ - partial(pe.call_partial_eval_custom_rule, 'call_jaxpr', - _xla_call_partial_eval_custom_params_updater) -pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule - -pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p) - - -def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext, - settings: core.JaxprPpSettings, - ) -> pp.Doc: - printed_params = {k:v for k, v in eqn.params.items() if - k == 'call_jaxpr' or k == 'name' or - k == 'backend' and v is not None or - k == 'device' and v is not None or - k == 'donated_invars' and any(v)} - return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -core.pp_eqn_rules[xla_call_p] = _pp_xla_call ### translation tables diff --git a/jax/_src/maps.py b/jax/_src/maps.py index a16d881cd..7e11f950e 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -846,7 +846,7 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst # NOTE: We don't have to handle spmd_{in|out}_axes here, because # SPMD batching always gets involved as the last transform before XLA translation ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore -ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p] +ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 47791ff7c..6e3aaa9be 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1655,16 +1655,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], jaxpr=new_jaxpr, num_carry=num_carry + 2, linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:]))) - elif eqn.primitive is xla.xla_call_p: - call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) - eqns.append( - eqn.replace( - invars=eqn.invars + [input_token_var, input_itoken_var], - outvars=eqn.outvars + [output_token_var, output_itoken_var], - params=dict( - eqn.params, - call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), - donated_invars=eqn.params["donated_invars"] + (False, False)))) elif eqn.primitive is pxla.xla_pmap_p: # We broadcast the input token into an array of tokens call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) @@ -1762,12 +1752,10 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], eqns.append( core.new_jaxpr_eqn( eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var], - pred1_and_token1, xla.xla_call_p, + pred1_and_token1, core.call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_before", - donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals), - inline=False), + name="cond_before"), transformed_cond_jaxpr.jaxpr.effects, eqn.source_info)) # Make a new cond "lambda pred, carry, token, itoken: pred" @@ -1808,22 +1796,18 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], new_body_invars_body_constvars + new_body_invars_carry + [new_body_invars_token, new_body_invars_itoken], new_body_carry2 + [new_body_token2, new_body_itoken2], - xla.xla_call_p, + core.call_p, dict( call_jaxpr=transformed_body_jaxpr.jaxpr, - name="body", - donated_invars=(False,) * len(transformed_body_jaxpr.in_avals), - inline=False), + name="body"), transformed_body_jaxpr.effects, eqn.source_info), core.new_jaxpr_eqn( new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2], - [new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p, + [new_body_pred2, new_body_token3, new_body_itoken3], core.call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, - name="cond_body", - donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals), - inline=False), + name="cond_body"), transformed_cond_jaxpr.effects, eqn.source_info) ] diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 37b65bf85..37a33b400 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1479,28 +1479,9 @@ class TensorFlowTrace(core.Trace): avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers) interpreted_fun = _interpret_subtrace(fun, self.main, avals) extra_name_stack = None - if call_primitive == xla.xla_call_p: - extra_name_stack = util.wrap_name(params["name"], "jit") with _extended_name_stack(extra_name_stack): with core.new_sublevel(): - if call_primitive == xla.xla_call_p: - if _WRAP_JAX_JIT_WITH_TF_FUNCTION: - # Make a nested tf.function(jit_compile=True) - store_tf_res_avals: Sequence[core.ShapedArray] = [] - def f_tf(*tf_args): - nonlocal store_tf_res_avals - tf_res_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \ - _call_wrapped_with_new_constant_cache(interpreted_fun, tf_args, - fresh_constant_cache=False) - tf_res_vals, tf_res_avals = util.unzip2(tf_res_out) - store_tf_res_avals = tf_res_avals - return tf_res_vals - tf_vals_out = tf.function(f_tf, autograph=False, jit_compile=True)(*vals) - vals_out = zip(tf_vals_out, store_tf_res_avals) - else: - vals_out = interpreted_fun.call_wrapped(*vals) - else: - vals_out = interpreted_fun.call_wrapped(*vals) + vals_out = interpreted_fun.call_wrapped(*vals) return [TensorFlowTracer(self, v, a) for v, a in vals_out] def post_process_call(self, call_primitive: core.Primitive, @@ -1572,7 +1553,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): # Call primitives are inlined -for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]: +for unexpected in [core.call_p, maps.xmap_p]: tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) # Primitives that are not yet implemented must be explicitly declared here. diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 54323fe2d..f20682a63 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -267,13 +267,6 @@ register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series call_param_updaters = {} -def _xla_call_param_updater(params, num_inputs): - donated_invars = params['donated_invars'] - if any(donated_invars): - raise NotImplementedError("donated_invars not supported with jet") - return dict(params, donated_invars=(False,) * num_inputs) -call_param_updaters[xla.xla_call_p] = _xla_call_param_updater - ### rule definitions diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index b8fc67fd2..e2eb80854 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -581,25 +581,8 @@ class ShardMapTrace(core.Trace): return ShardMapTracer(self, out_rep, out_vals) def process_call(self, call_primitive, fun, tracers, params): - if call_primitive is not xla.xla_call_p: raise NotImplementedError - fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit - bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr()) - fake_primitive = pxla.FakePrimitive(multiple_results=True, bind=bind) - _rep_rules[fake_primitive] = lambda *_, **__: set() # pytype: disable=container-type-mismatch - out_tracers_ = self.process_primitive(fake_primitive, tracers, params) - out_vals = [t.val for t in out_tracers_] - if self.check: - out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers]) - else: - out_rep = [set()] * len(out_vals) - return map(partial(ShardMapTracer, self), out_rep, out_vals) + raise NotImplementedError -@lu.transformation_with_aux -def _grab_jaxpr_shadily(*args): - out = yield args, {} - main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me - jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out) - yield out, jaxpr class ShardMapTracer(core.Tracer): rep: Set[AxisName] @@ -711,10 +694,6 @@ def _axis_index_rule(mesh, *, axis_name): def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs): return _output_rep(mesh, jaxpr.jaxpr, in_rep) -@register_rule(xla.xla_call_p) -def _jit_rule(mesh, *in_rep, jaxpr, **kwargs): - return _output_rep(mesh, jaxpr, in_rep) - @register_rule(debugging.debug_callback_p) def _debug_callback_rule(mesh, *in_rep, **_): return [] diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 67b3fdcc5..d79455e9a 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -421,14 +421,7 @@ def eval_sparse( _raise_unimplemented_primitive(prim) out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params) else: - if prim is xla.xla_call_p: - # TODO(vanderplas,frostig): workaround for binding call primitives - # within a jaxpr interpreter - params = eqn.params.copy() - fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) - out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params) - else: - out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) + out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf, outvar in safe_zip(out_bufs, eqn.outvars): @@ -759,18 +752,6 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n sparse_rules_bcoo[lax.while_p] = _while_sparse -def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params): - if any(donated_invars): - raise NotImplementedError("sparse xla_call with donated_invars") - sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues) - fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr)) - args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues)) - donated_invars = tuple(False for arg in args_flat) - out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params) - return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat)) - -sparse_rules_bcoo[xla.xla_call_p] = _xla_call_sparse - def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, inline): diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 1499d07d8..2eca53a83 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -34,9 +34,9 @@ from jax._src.interpreters.mlir import ( _call_lowering as _call_lowering, _lowerings as _lowerings, _platform_specific_lowerings as _platform_specific_lowerings, - _xla_call_lower as _xla_call_lower, aval_to_ir_type as aval_to_ir_type, aval_to_ir_types as aval_to_ir_types, + core_call_lowering as core_call_lowering, dense_bool_elements as dense_bool_elements, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index e0da29e7e..b008f09be 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -29,8 +29,7 @@ from jax._src.interpreters.xla import ( register_translation as register_translation, sharding_to_proto as sharding_to_proto, translations as translations, - xla_call as xla_call, - xla_call_p as xla_call_p, + xla_call_p as _deprecated_xla_call_p, xla_destructure as xla_destructure, xla_shape_handlers as xla_shape_handlers, device_put as _deprecated_device_put, @@ -83,6 +82,13 @@ _deprecations = { ), _deprecated_device_put, ), + "xla_call_p": ( + ( + "jax.interpreters.xla.xla_call_p is deprecated. Please use" + " jax.experimental.pjit.pjit_p instead." + ), + _deprecated_xla_call_p, + ), } from jax._src.deprecations import deprecation_getattr as _deprecation_getattr @@ -98,4 +104,5 @@ if typing.TYPE_CHECKING: from jax._src.interpreters.xla import ( device_put as device_put, ) + from jax._src.interpreters.xla import xla_call_p as xla_call_p del typing diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index f3845320e..72626a6ae 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -205,7 +205,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase): self.assertIn(foo_effect, jaxpr.jaxpr.effects) self.assertIn(bar_effect, jaxpr.jaxpr.effects) - def test_xla_call_primitive_inherits_effects(self): + def test_jit_primitive_inherits_effects(self): @jax.jit def f(x): diff --git a/tests/name_stack_test.py b/tests/name_stack_test.py index 182cd8b3a..b3e434180 100644 --- a/tests/name_stack_test.py +++ b/tests/name_stack_test.py @@ -100,7 +100,7 @@ class NameStackTest(jtu.JaxTestCase): hlo_text = _get_hlo(f)(2) self.assertIn('foo/jit(core_call)/bar', hlo_text) - def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self): + def test_jit_jaxpr_should_not_store_outer_name_stack(self): @jax.named_scope('foo') def f(x): @jax.jit