Deprecated xla_call_p since it has been replaced with pjit.pjit_p

PiperOrigin-RevId: 518921538
This commit is contained in:
Yash Katariya 2023-03-23 11:43:49 -07:00 committed by jax authors
parent f461c4ef0c
commit a9e48af260
13 changed files with 36 additions and 173 deletions

View File

@ -1276,20 +1276,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
return out_nodes, tokens_out
def _xla_call_lower(ctx, *args,
backend=None, name, call_jaxpr, donated_invars, inline=None,
device=None, keep_unused=None):
del device, donated_invars, inline, keep_unused # Ignored.
out_nodes, tokens = _call_lowering(
name, util.wrap_name(name, "jit"), call_jaxpr, backend,
ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in,
*args, dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out_nodes
register_lowering(xla.xla_call_p, _xla_call_lower)
def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
out_nodes, tokens = _call_lowering(
name, name, call_jaxpr, backend, ctx.module_context,
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
@ -1297,9 +1284,9 @@ def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
ctx.set_tokens_out(tokens)
return out_nodes
register_lowering(core.call_p, partial(_core_call_lowering, name="core_call"))
register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(_core_call_lowering, name="core_closed_call"))
partial(core_call_lowering, name="core_closed_call"))
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
broadcast_dimensions) -> ir.Value:

View File

@ -902,12 +902,7 @@ class MapTrace(core.Trace):
return MapTracer(self, outvals, out_shard_axes)
def process_call(self, call_primitive, fun, tracers, params):
if call_primitive is not xla.xla_call_p: raise NotImplementedError
bind = HashableFunction(
lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs),
(call_primitive, fun))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, params)
raise NotImplementedError
def process_map(self, call_primitive, fun, tracers, params):
if params['devices'] is not None:
@ -1998,15 +1993,14 @@ def _pmap_dce_rule(used_outputs, eqn):
# Set param update handlers to update `donated_invars` just like xla_call_p
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
partial(pe.call_partial_eval_custom_rule,
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
res_aval=_pmap_partial_eval_custom_res_maker)
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
ad.call_transpose_param_updaters[xla_pmap_p] = \
ad.call_transpose_param_updaters[xla.xla_call_p]
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)

View File

@ -28,17 +28,14 @@ from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol,
import numpy as np
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.util import (safe_zip, safe_map, partition_list)
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Shape
@ -157,7 +154,6 @@ xla_shape_handlers: Dict[Type[core.AbstractValue],
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
# IR constants
# TODO(mattjj): try to remove this canonicalize_dtype stuff
@ -348,11 +344,11 @@ def jaxpr_collectives(jaxpr):
### xla_call underlying jit
# TODO(yashkatariya): Remove after 1 month from March 23, 2023.
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
def _xla_call_partial_eval_update_params(
def xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
@ -366,57 +362,18 @@ def _xla_call_partial_eval_update_params(
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
def _xla_call_jvp_update_params(params, nz_tangents):
def xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
def xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
def _xla_call_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
new_params_known = dict(params_known, donated_invars=tuple(donated_known))
# added num_res new inputs to jaxpr_staged, so extend donated_invars
_, donated_staged_ = partition_list(inst_in, params_staged['donated_invars'])
donated_staged = [False] * num_res + donated_staged_
new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged))
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
_xla_call_partial_eval_custom_params_updater)
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
settings: core.JaxprPpSettings,
) -> pp.Doc:
printed_params = {k:v for k, v in eqn.params.items() if
k == 'call_jaxpr' or k == 'name' or
k == 'backend' and v is not None or
k == 'device' and v is not None or
k == 'donated_invars' and any(v)}
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
### translation tables

View File

@ -846,7 +846,7 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
# SPMD batching always gets involved as the last transform before XLA translation
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts

View File

@ -1655,16 +1655,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
jaxpr=new_jaxpr,
num_carry=num_carry + 2,
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:])))
elif eqn.primitive is xla.xla_call_p:
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
eqns.append(
eqn.replace(
invars=eqn.invars + [input_token_var, input_itoken_var],
outvars=eqn.outvars + [output_token_var, output_itoken_var],
params=dict(
eqn.params,
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False))))
elif eqn.primitive is pxla.xla_pmap_p:
# We broadcast the input token into an array of tokens
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
@ -1762,12 +1752,10 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
eqns.append(
core.new_jaxpr_eqn(
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var],
pred1_and_token1, xla.xla_call_p,
pred1_and_token1, core.call_p,
dict(
call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_before",
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
inline=False),
name="cond_before"),
transformed_cond_jaxpr.jaxpr.effects,
eqn.source_info))
# Make a new cond "lambda pred, carry, token, itoken: pred"
@ -1808,22 +1796,18 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
new_body_invars_body_constvars + new_body_invars_carry +
[new_body_invars_token, new_body_invars_itoken],
new_body_carry2 + [new_body_token2, new_body_itoken2],
xla.xla_call_p,
core.call_p,
dict(
call_jaxpr=transformed_body_jaxpr.jaxpr,
name="body",
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals),
inline=False),
name="body"),
transformed_body_jaxpr.effects,
eqn.source_info),
core.new_jaxpr_eqn(
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
[new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p,
[new_body_pred2, new_body_token3, new_body_itoken3], core.call_p,
dict(
call_jaxpr=transformed_cond_jaxpr.jaxpr,
name="cond_body",
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
inline=False),
name="cond_body"),
transformed_cond_jaxpr.effects,
eqn.source_info)
]

View File

@ -1479,27 +1479,8 @@ class TensorFlowTrace(core.Trace):
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
extra_name_stack = None
if call_primitive == xla.xla_call_p:
extra_name_stack = util.wrap_name(params["name"], "jit")
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
if call_primitive == xla.xla_call_p:
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
# Make a nested tf.function(jit_compile=True)
store_tf_res_avals: Sequence[core.ShapedArray] = []
def f_tf(*tf_args):
nonlocal store_tf_res_avals
tf_res_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(interpreted_fun, tf_args,
fresh_constant_cache=False)
tf_res_vals, tf_res_avals = util.unzip2(tf_res_out)
store_tf_res_avals = tf_res_avals
return tf_res_vals
tf_vals_out = tf.function(f_tf, autograph=False, jit_compile=True)(*vals)
vals_out = zip(tf_vals_out, store_tf_res_avals)
else:
vals_out = interpreted_fun.call_wrapped(*vals)
else:
vals_out = interpreted_fun.call_wrapped(*vals)
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
@ -1572,7 +1553,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
# Call primitives are inlined
for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]:
for unexpected in [core.call_p, maps.xmap_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.

View File

@ -267,13 +267,6 @@ register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series
call_param_updaters = {}
def _xla_call_param_updater(params, num_inputs):
donated_invars = params['donated_invars']
if any(donated_invars):
raise NotImplementedError("donated_invars not supported with jet")
return dict(params, donated_invars=(False,) * num_inputs)
call_param_updaters[xla.xla_call_p] = _xla_call_param_updater
### rule definitions

View File

@ -581,25 +581,8 @@ class ShardMapTrace(core.Trace):
return ShardMapTracer(self, out_rep, out_vals)
def process_call(self, call_primitive, fun, tracers, params):
if call_primitive is not xla.xla_call_p: raise NotImplementedError
fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit
bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr())
fake_primitive = pxla.FakePrimitive(multiple_results=True, bind=bind)
_rep_rules[fake_primitive] = lambda *_, **__: set() # pytype: disable=container-type-mismatch
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
out_vals = [t.val for t in out_tracers_]
if self.check:
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
else:
out_rep = [set()] * len(out_vals)
return map(partial(ShardMapTracer, self), out_rep, out_vals)
raise NotImplementedError
@lu.transformation_with_aux
def _grab_jaxpr_shadily(*args):
out = yield args, {}
main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me
jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out)
yield out, jaxpr
class ShardMapTracer(core.Tracer):
rep: Set[AxisName]
@ -711,10 +694,6 @@ def _axis_index_rule(mesh, *, axis_name):
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
@register_rule(xla.xla_call_p)
def _jit_rule(mesh, *in_rep, jaxpr, **kwargs):
return _output_rep(mesh, jaxpr, in_rep)
@register_rule(debugging.debug_callback_p)
def _debug_callback_rule(mesh, *in_rep, **_):
return []

View File

@ -420,13 +420,6 @@ def eval_sparse(
if prim not in sparse_rules_bcoo:
_raise_unimplemented_primitive(prim)
out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
else:
if prim is xla.xla_call_p:
# TODO(vanderplas,frostig): workaround for binding call primitives
# within a jaxpr interpreter
params = eqn.params.copy()
fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params)
else:
out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
out_bufs = out_bufs if prim.multiple_results else [out_bufs]
@ -759,18 +752,6 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n
sparse_rules_bcoo[lax.while_p] = _while_sparse
def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params):
if any(donated_invars):
raise NotImplementedError("sparse xla_call with donated_invars")
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues)
fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr))
args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues))
donated_invars = tuple(False for arg in args_flat)
out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params)
return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
sparse_rules_bcoo[xla.xla_call_p] = _xla_call_sparse
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, keep_unused, inline):

View File

@ -34,9 +34,9 @@ from jax._src.interpreters.mlir import (
_call_lowering as _call_lowering,
_lowerings as _lowerings,
_platform_specific_lowerings as _platform_specific_lowerings,
_xla_call_lower as _xla_call_lower,
aval_to_ir_type as aval_to_ir_type,
aval_to_ir_types as aval_to_ir_types,
core_call_lowering as core_call_lowering,
dense_bool_elements as dense_bool_elements,
dense_int_elements as dense_int_elements,
dtype_to_ir_type as dtype_to_ir_type,

View File

@ -29,8 +29,7 @@ from jax._src.interpreters.xla import (
register_translation as register_translation,
sharding_to_proto as sharding_to_proto,
translations as translations,
xla_call as xla_call,
xla_call_p as xla_call_p,
xla_call_p as _deprecated_xla_call_p,
xla_destructure as xla_destructure,
xla_shape_handlers as xla_shape_handlers,
device_put as _deprecated_device_put,
@ -83,6 +82,13 @@ _deprecations = {
),
_deprecated_device_put,
),
"xla_call_p": (
(
"jax.interpreters.xla.xla_call_p is deprecated. Please use"
" jax.experimental.pjit.pjit_p instead."
),
_deprecated_xla_call_p,
),
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
@ -98,4 +104,5 @@ if typing.TYPE_CHECKING:
from jax._src.interpreters.xla import (
device_put as device_put,
)
from jax._src.interpreters.xla import xla_call_p as xla_call_p
del typing

View File

@ -205,7 +205,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
def test_xla_call_primitive_inherits_effects(self):
def test_jit_primitive_inherits_effects(self):
@jax.jit
def f(x):

View File

@ -100,7 +100,7 @@ class NameStackTest(jtu.JaxTestCase):
hlo_text = _get_hlo(f)(2)
self.assertIn('foo/jit(core_call)/bar', hlo_text)
def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
def test_jit_jaxpr_should_not_store_outer_name_stack(self):
@jax.named_scope('foo')
def f(x):
@jax.jit