mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
This commit is contained in:
parent
f461c4ef0c
commit
a9e48af260
@ -1276,20 +1276,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
|||||||
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
|
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
|
||||||
return out_nodes, tokens_out
|
return out_nodes, tokens_out
|
||||||
|
|
||||||
def _xla_call_lower(ctx, *args,
|
def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
||||||
backend=None, name, call_jaxpr, donated_invars, inline=None,
|
|
||||||
device=None, keep_unused=None):
|
|
||||||
del device, donated_invars, inline, keep_unused # Ignored.
|
|
||||||
out_nodes, tokens = _call_lowering(
|
|
||||||
name, util.wrap_name(name, "jit"), call_jaxpr, backend,
|
|
||||||
ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in,
|
|
||||||
*args, dim_var_values=ctx.dim_var_values)
|
|
||||||
ctx.set_tokens_out(tokens)
|
|
||||||
return out_nodes
|
|
||||||
|
|
||||||
register_lowering(xla.xla_call_p, _xla_call_lower)
|
|
||||||
|
|
||||||
def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
|
||||||
out_nodes, tokens = _call_lowering(
|
out_nodes, tokens = _call_lowering(
|
||||||
name, name, call_jaxpr, backend, ctx.module_context,
|
name, name, call_jaxpr, backend, ctx.module_context,
|
||||||
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
|
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
|
||||||
@ -1297,9 +1284,9 @@ def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
|||||||
ctx.set_tokens_out(tokens)
|
ctx.set_tokens_out(tokens)
|
||||||
return out_nodes
|
return out_nodes
|
||||||
|
|
||||||
register_lowering(core.call_p, partial(_core_call_lowering, name="core_call"))
|
register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
|
||||||
register_lowering(core.closed_call_p,
|
register_lowering(core.closed_call_p,
|
||||||
partial(_core_call_lowering, name="core_closed_call"))
|
partial(core_call_lowering, name="core_closed_call"))
|
||||||
|
|
||||||
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
|
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
|
||||||
broadcast_dimensions) -> ir.Value:
|
broadcast_dimensions) -> ir.Value:
|
||||||
|
@ -902,12 +902,7 @@ class MapTrace(core.Trace):
|
|||||||
return MapTracer(self, outvals, out_shard_axes)
|
return MapTracer(self, outvals, out_shard_axes)
|
||||||
|
|
||||||
def process_call(self, call_primitive, fun, tracers, params):
|
def process_call(self, call_primitive, fun, tracers, params):
|
||||||
if call_primitive is not xla.xla_call_p: raise NotImplementedError
|
raise NotImplementedError
|
||||||
bind = HashableFunction(
|
|
||||||
lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs),
|
|
||||||
(call_primitive, fun))
|
|
||||||
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
|
||||||
return self.process_primitive(fake_primitive, tracers, params)
|
|
||||||
|
|
||||||
def process_map(self, call_primitive, fun, tracers, params):
|
def process_map(self, call_primitive, fun, tracers, params):
|
||||||
if params['devices'] is not None:
|
if params['devices'] is not None:
|
||||||
@ -1998,15 +1993,14 @@ def _pmap_dce_rule(used_outputs, eqn):
|
|||||||
|
|
||||||
|
|
||||||
# Set param update handlers to update `donated_invars` just like xla_call_p
|
# Set param update handlers to update `donated_invars` just like xla_call_p
|
||||||
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
|
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
|
||||||
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
||||||
partial(pe.call_partial_eval_custom_rule,
|
partial(pe.call_partial_eval_custom_rule,
|
||||||
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
|
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
|
||||||
res_aval=_pmap_partial_eval_custom_res_maker)
|
res_aval=_pmap_partial_eval_custom_res_maker)
|
||||||
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
||||||
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
|
||||||
ad.call_transpose_param_updaters[xla_pmap_p] = \
|
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
|
||||||
ad.call_transpose_param_updaters[xla.xla_call_p]
|
|
||||||
|
|
||||||
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
||||||
|
|
||||||
|
@ -28,17 +28,14 @@ from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol,
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
from jax.interpreters import partial_eval as pe
|
|
||||||
|
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import device_array
|
from jax._src import device_array
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import pretty_printer as pp
|
|
||||||
from jax._src import source_info_util
|
from jax._src import source_info_util
|
||||||
from jax._src.abstract_arrays import numpy_scalar_types
|
from jax._src.abstract_arrays import numpy_scalar_types
|
||||||
from jax._src.core import ConcreteArray, ShapedArray
|
from jax._src.core import ConcreteArray, ShapedArray
|
||||||
from jax._src.interpreters import ad
|
from jax._src.util import safe_zip, safe_map
|
||||||
from jax._src.util import (safe_zip, safe_map, partition_list)
|
|
||||||
|
|
||||||
from jax._src.typing import Shape
|
from jax._src.typing import Shape
|
||||||
|
|
||||||
@ -157,7 +154,6 @@ xla_shape_handlers: Dict[Type[core.AbstractValue],
|
|||||||
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# IR constants
|
# IR constants
|
||||||
|
|
||||||
# TODO(mattjj): try to remove this canonicalize_dtype stuff
|
# TODO(mattjj): try to remove this canonicalize_dtype stuff
|
||||||
@ -348,11 +344,11 @@ def jaxpr_collectives(jaxpr):
|
|||||||
|
|
||||||
### xla_call underlying jit
|
### xla_call underlying jit
|
||||||
|
|
||||||
|
# TODO(yashkatariya): Remove after 1 month from March 23, 2023.
|
||||||
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
|
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
|
||||||
xla_call = xla_call_p.bind
|
|
||||||
|
|
||||||
def _xla_call_partial_eval_update_params(
|
|
||||||
|
def xla_call_partial_eval_update_params(
|
||||||
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
|
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
|
||||||
) -> core.ParamDict:
|
) -> core.ParamDict:
|
||||||
donated_invars = params['donated_invars']
|
donated_invars = params['donated_invars']
|
||||||
@ -366,57 +362,18 @@ def _xla_call_partial_eval_update_params(
|
|||||||
# Any new inputs are prepended to the left, so mark those as not donated.
|
# Any new inputs are prepended to the left, so mark those as not donated.
|
||||||
donated_invars = [False] * num_new_inputs + donated_invars
|
donated_invars = [False] * num_new_inputs + donated_invars
|
||||||
return dict(params, donated_invars=tuple(donated_invars))
|
return dict(params, donated_invars=tuple(donated_invars))
|
||||||
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
|
|
||||||
|
|
||||||
def _xla_call_jvp_update_params(params, nz_tangents):
|
def xla_call_jvp_update_params(params, nz_tangents):
|
||||||
donated_invars = params['donated_invars']
|
donated_invars = params['donated_invars']
|
||||||
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
||||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||||
return dict(params, donated_invars=new_donated_invars)
|
return dict(params, donated_invars=new_donated_invars)
|
||||||
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
|
|
||||||
|
|
||||||
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
def xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||||
donated_invars = params['donated_invars']
|
donated_invars = params['donated_invars']
|
||||||
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
||||||
donated_cotangents = [False for nz in nonzero_cts if nz]
|
donated_cotangents = [False for nz in nonzero_cts if nz]
|
||||||
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
||||||
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
|
|
||||||
|
|
||||||
|
|
||||||
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
|
||||||
|
|
||||||
|
|
||||||
def _xla_call_partial_eval_custom_params_updater(
|
|
||||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
|
||||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
|
||||||
num_res: int, params_known: dict, params_staged: dict
|
|
||||||
) -> Tuple[dict, dict]:
|
|
||||||
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
|
|
||||||
donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
|
||||||
new_params_known = dict(params_known, donated_invars=tuple(donated_known))
|
|
||||||
# added num_res new inputs to jaxpr_staged, so extend donated_invars
|
|
||||||
_, donated_staged_ = partition_list(inst_in, params_staged['donated_invars'])
|
|
||||||
donated_staged = [False] * num_res + donated_staged_
|
|
||||||
new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged))
|
|
||||||
return new_params_known, new_params_staged
|
|
||||||
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
|
|
||||||
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
|
|
||||||
_xla_call_partial_eval_custom_params_updater)
|
|
||||||
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
|
|
||||||
|
|
||||||
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
|
|
||||||
|
|
||||||
|
|
||||||
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
|
|
||||||
settings: core.JaxprPpSettings,
|
|
||||||
) -> pp.Doc:
|
|
||||||
printed_params = {k:v for k, v in eqn.params.items() if
|
|
||||||
k == 'call_jaxpr' or k == 'name' or
|
|
||||||
k == 'backend' and v is not None or
|
|
||||||
k == 'device' and v is not None or
|
|
||||||
k == 'donated_invars' and any(v)}
|
|
||||||
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
|
||||||
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
|
|
||||||
|
|
||||||
|
|
||||||
### translation tables
|
### translation tables
|
||||||
|
@ -846,7 +846,7 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
|
|||||||
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
|
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
|
||||||
# SPMD batching always gets involved as the last transform before XLA translation
|
# SPMD batching always gets involved as the last transform before XLA translation
|
||||||
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
|
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
|
||||||
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params
|
||||||
|
|
||||||
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
|
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
|
||||||
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
|
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
|
||||||
|
@ -1655,16 +1655,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
|||||||
jaxpr=new_jaxpr,
|
jaxpr=new_jaxpr,
|
||||||
num_carry=num_carry + 2,
|
num_carry=num_carry + 2,
|
||||||
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:])))
|
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:])))
|
||||||
elif eqn.primitive is xla.xla_call_p:
|
|
||||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
|
||||||
eqns.append(
|
|
||||||
eqn.replace(
|
|
||||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
|
||||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
|
||||||
params=dict(
|
|
||||||
eqn.params,
|
|
||||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
|
||||||
donated_invars=eqn.params["donated_invars"] + (False, False))))
|
|
||||||
elif eqn.primitive is pxla.xla_pmap_p:
|
elif eqn.primitive is pxla.xla_pmap_p:
|
||||||
# We broadcast the input token into an array of tokens
|
# We broadcast the input token into an array of tokens
|
||||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||||
@ -1762,12 +1752,10 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
|||||||
eqns.append(
|
eqns.append(
|
||||||
core.new_jaxpr_eqn(
|
core.new_jaxpr_eqn(
|
||||||
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var],
|
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var],
|
||||||
pred1_and_token1, xla.xla_call_p,
|
pred1_and_token1, core.call_p,
|
||||||
dict(
|
dict(
|
||||||
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||||
name="cond_before",
|
name="cond_before"),
|
||||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
|
||||||
inline=False),
|
|
||||||
transformed_cond_jaxpr.jaxpr.effects,
|
transformed_cond_jaxpr.jaxpr.effects,
|
||||||
eqn.source_info))
|
eqn.source_info))
|
||||||
# Make a new cond "lambda pred, carry, token, itoken: pred"
|
# Make a new cond "lambda pred, carry, token, itoken: pred"
|
||||||
@ -1808,22 +1796,18 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
|||||||
new_body_invars_body_constvars + new_body_invars_carry +
|
new_body_invars_body_constvars + new_body_invars_carry +
|
||||||
[new_body_invars_token, new_body_invars_itoken],
|
[new_body_invars_token, new_body_invars_itoken],
|
||||||
new_body_carry2 + [new_body_token2, new_body_itoken2],
|
new_body_carry2 + [new_body_token2, new_body_itoken2],
|
||||||
xla.xla_call_p,
|
core.call_p,
|
||||||
dict(
|
dict(
|
||||||
call_jaxpr=transformed_body_jaxpr.jaxpr,
|
call_jaxpr=transformed_body_jaxpr.jaxpr,
|
||||||
name="body",
|
name="body"),
|
||||||
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals),
|
|
||||||
inline=False),
|
|
||||||
transformed_body_jaxpr.effects,
|
transformed_body_jaxpr.effects,
|
||||||
eqn.source_info),
|
eqn.source_info),
|
||||||
core.new_jaxpr_eqn(
|
core.new_jaxpr_eqn(
|
||||||
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
|
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
|
||||||
[new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p,
|
[new_body_pred2, new_body_token3, new_body_itoken3], core.call_p,
|
||||||
dict(
|
dict(
|
||||||
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||||
name="cond_body",
|
name="cond_body"),
|
||||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
|
||||||
inline=False),
|
|
||||||
transformed_cond_jaxpr.effects,
|
transformed_cond_jaxpr.effects,
|
||||||
eqn.source_info)
|
eqn.source_info)
|
||||||
]
|
]
|
||||||
|
@ -1479,28 +1479,9 @@ class TensorFlowTrace(core.Trace):
|
|||||||
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
||||||
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
|
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
|
||||||
extra_name_stack = None
|
extra_name_stack = None
|
||||||
if call_primitive == xla.xla_call_p:
|
|
||||||
extra_name_stack = util.wrap_name(params["name"], "jit")
|
|
||||||
with _extended_name_stack(extra_name_stack):
|
with _extended_name_stack(extra_name_stack):
|
||||||
with core.new_sublevel():
|
with core.new_sublevel():
|
||||||
if call_primitive == xla.xla_call_p:
|
vals_out = interpreted_fun.call_wrapped(*vals)
|
||||||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
|
||||||
# Make a nested tf.function(jit_compile=True)
|
|
||||||
store_tf_res_avals: Sequence[core.ShapedArray] = []
|
|
||||||
def f_tf(*tf_args):
|
|
||||||
nonlocal store_tf_res_avals
|
|
||||||
tf_res_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
|
||||||
_call_wrapped_with_new_constant_cache(interpreted_fun, tf_args,
|
|
||||||
fresh_constant_cache=False)
|
|
||||||
tf_res_vals, tf_res_avals = util.unzip2(tf_res_out)
|
|
||||||
store_tf_res_avals = tf_res_avals
|
|
||||||
return tf_res_vals
|
|
||||||
tf_vals_out = tf.function(f_tf, autograph=False, jit_compile=True)(*vals)
|
|
||||||
vals_out = zip(tf_vals_out, store_tf_res_avals)
|
|
||||||
else:
|
|
||||||
vals_out = interpreted_fun.call_wrapped(*vals)
|
|
||||||
else:
|
|
||||||
vals_out = interpreted_fun.call_wrapped(*vals)
|
|
||||||
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
|
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
|
||||||
|
|
||||||
def post_process_call(self, call_primitive: core.Primitive,
|
def post_process_call(self, call_primitive: core.Primitive,
|
||||||
@ -1572,7 +1553,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
# Call primitives are inlined
|
# Call primitives are inlined
|
||||||
for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]:
|
for unexpected in [core.call_p, maps.xmap_p]:
|
||||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||||
|
|
||||||
# Primitives that are not yet implemented must be explicitly declared here.
|
# Primitives that are not yet implemented must be explicitly declared here.
|
||||||
|
@ -267,13 +267,6 @@ register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series
|
|||||||
|
|
||||||
call_param_updaters = {}
|
call_param_updaters = {}
|
||||||
|
|
||||||
def _xla_call_param_updater(params, num_inputs):
|
|
||||||
donated_invars = params['donated_invars']
|
|
||||||
if any(donated_invars):
|
|
||||||
raise NotImplementedError("donated_invars not supported with jet")
|
|
||||||
return dict(params, donated_invars=(False,) * num_inputs)
|
|
||||||
call_param_updaters[xla.xla_call_p] = _xla_call_param_updater
|
|
||||||
|
|
||||||
|
|
||||||
### rule definitions
|
### rule definitions
|
||||||
|
|
||||||
|
@ -581,25 +581,8 @@ class ShardMapTrace(core.Trace):
|
|||||||
return ShardMapTracer(self, out_rep, out_vals)
|
return ShardMapTracer(self, out_rep, out_vals)
|
||||||
|
|
||||||
def process_call(self, call_primitive, fun, tracers, params):
|
def process_call(self, call_primitive, fun, tracers, params):
|
||||||
if call_primitive is not xla.xla_call_p: raise NotImplementedError
|
raise NotImplementedError
|
||||||
fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit
|
|
||||||
bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr())
|
|
||||||
fake_primitive = pxla.FakePrimitive(multiple_results=True, bind=bind)
|
|
||||||
_rep_rules[fake_primitive] = lambda *_, **__: set() # pytype: disable=container-type-mismatch
|
|
||||||
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
|
|
||||||
out_vals = [t.val for t in out_tracers_]
|
|
||||||
if self.check:
|
|
||||||
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
|
|
||||||
else:
|
|
||||||
out_rep = [set()] * len(out_vals)
|
|
||||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
|
||||||
|
|
||||||
@lu.transformation_with_aux
|
|
||||||
def _grab_jaxpr_shadily(*args):
|
|
||||||
out = yield args, {}
|
|
||||||
main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me
|
|
||||||
jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out)
|
|
||||||
yield out, jaxpr
|
|
||||||
|
|
||||||
class ShardMapTracer(core.Tracer):
|
class ShardMapTracer(core.Tracer):
|
||||||
rep: Set[AxisName]
|
rep: Set[AxisName]
|
||||||
@ -711,10 +694,6 @@ def _axis_index_rule(mesh, *, axis_name):
|
|||||||
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||||
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
|
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||||
|
|
||||||
@register_rule(xla.xla_call_p)
|
|
||||||
def _jit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
|
||||||
return _output_rep(mesh, jaxpr, in_rep)
|
|
||||||
|
|
||||||
@register_rule(debugging.debug_callback_p)
|
@register_rule(debugging.debug_callback_p)
|
||||||
def _debug_callback_rule(mesh, *in_rep, **_):
|
def _debug_callback_rule(mesh, *in_rep, **_):
|
||||||
return []
|
return []
|
||||||
|
@ -421,14 +421,7 @@ def eval_sparse(
|
|||||||
_raise_unimplemented_primitive(prim)
|
_raise_unimplemented_primitive(prim)
|
||||||
out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
|
out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
|
||||||
else:
|
else:
|
||||||
if prim is xla.xla_call_p:
|
out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
|
||||||
# TODO(vanderplas,frostig): workaround for binding call primitives
|
|
||||||
# within a jaxpr interpreter
|
|
||||||
params = eqn.params.copy()
|
|
||||||
fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
|
|
||||||
out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params)
|
|
||||||
else:
|
|
||||||
out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
|
|
||||||
out_bufs = out_bufs if prim.multiple_results else [out_bufs]
|
out_bufs = out_bufs if prim.multiple_results else [out_bufs]
|
||||||
out = []
|
out = []
|
||||||
for buf, outvar in safe_zip(out_bufs, eqn.outvars):
|
for buf, outvar in safe_zip(out_bufs, eqn.outvars):
|
||||||
@ -759,18 +752,6 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n
|
|||||||
|
|
||||||
sparse_rules_bcoo[lax.while_p] = _while_sparse
|
sparse_rules_bcoo[lax.while_p] = _while_sparse
|
||||||
|
|
||||||
def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params):
|
|
||||||
if any(donated_invars):
|
|
||||||
raise NotImplementedError("sparse xla_call with donated_invars")
|
|
||||||
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues)
|
|
||||||
fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr))
|
|
||||||
args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues))
|
|
||||||
donated_invars = tuple(False for arg in args_flat)
|
|
||||||
out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params)
|
|
||||||
return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
|
|
||||||
|
|
||||||
sparse_rules_bcoo[xla.xla_call_p] = _xla_call_sparse
|
|
||||||
|
|
||||||
|
|
||||||
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||||
resource_env, donated_invars, name, keep_unused, inline):
|
resource_env, donated_invars, name, keep_unused, inline):
|
||||||
|
@ -34,9 +34,9 @@ from jax._src.interpreters.mlir import (
|
|||||||
_call_lowering as _call_lowering,
|
_call_lowering as _call_lowering,
|
||||||
_lowerings as _lowerings,
|
_lowerings as _lowerings,
|
||||||
_platform_specific_lowerings as _platform_specific_lowerings,
|
_platform_specific_lowerings as _platform_specific_lowerings,
|
||||||
_xla_call_lower as _xla_call_lower,
|
|
||||||
aval_to_ir_type as aval_to_ir_type,
|
aval_to_ir_type as aval_to_ir_type,
|
||||||
aval_to_ir_types as aval_to_ir_types,
|
aval_to_ir_types as aval_to_ir_types,
|
||||||
|
core_call_lowering as core_call_lowering,
|
||||||
dense_bool_elements as dense_bool_elements,
|
dense_bool_elements as dense_bool_elements,
|
||||||
dense_int_elements as dense_int_elements,
|
dense_int_elements as dense_int_elements,
|
||||||
dtype_to_ir_type as dtype_to_ir_type,
|
dtype_to_ir_type as dtype_to_ir_type,
|
||||||
|
@ -29,8 +29,7 @@ from jax._src.interpreters.xla import (
|
|||||||
register_translation as register_translation,
|
register_translation as register_translation,
|
||||||
sharding_to_proto as sharding_to_proto,
|
sharding_to_proto as sharding_to_proto,
|
||||||
translations as translations,
|
translations as translations,
|
||||||
xla_call as xla_call,
|
xla_call_p as _deprecated_xla_call_p,
|
||||||
xla_call_p as xla_call_p,
|
|
||||||
xla_destructure as xla_destructure,
|
xla_destructure as xla_destructure,
|
||||||
xla_shape_handlers as xla_shape_handlers,
|
xla_shape_handlers as xla_shape_handlers,
|
||||||
device_put as _deprecated_device_put,
|
device_put as _deprecated_device_put,
|
||||||
@ -83,6 +82,13 @@ _deprecations = {
|
|||||||
),
|
),
|
||||||
_deprecated_device_put,
|
_deprecated_device_put,
|
||||||
),
|
),
|
||||||
|
"xla_call_p": (
|
||||||
|
(
|
||||||
|
"jax.interpreters.xla.xla_call_p is deprecated. Please use"
|
||||||
|
" jax.experimental.pjit.pjit_p instead."
|
||||||
|
),
|
||||||
|
_deprecated_xla_call_p,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
@ -98,4 +104,5 @@ if typing.TYPE_CHECKING:
|
|||||||
from jax._src.interpreters.xla import (
|
from jax._src.interpreters.xla import (
|
||||||
device_put as device_put,
|
device_put as device_put,
|
||||||
)
|
)
|
||||||
|
from jax._src.interpreters.xla import xla_call_p as xla_call_p
|
||||||
del typing
|
del typing
|
||||||
|
@ -205,7 +205,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
|||||||
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
|
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
|
||||||
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
|
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
|
||||||
|
|
||||||
def test_xla_call_primitive_inherits_effects(self):
|
def test_jit_primitive_inherits_effects(self):
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def f(x):
|
def f(x):
|
||||||
|
@ -100,7 +100,7 @@ class NameStackTest(jtu.JaxTestCase):
|
|||||||
hlo_text = _get_hlo(f)(2)
|
hlo_text = _get_hlo(f)(2)
|
||||||
self.assertIn('foo/jit(core_call)/bar', hlo_text)
|
self.assertIn('foo/jit(core_call)/bar', hlo_text)
|
||||||
|
|
||||||
def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
def test_jit_jaxpr_should_not_store_outer_name_stack(self):
|
||||||
@jax.named_scope('foo')
|
@jax.named_scope('foo')
|
||||||
def f(x):
|
def f(x):
|
||||||
@jax.jit
|
@jax.jit
|
||||||
|
Loading…
x
Reference in New Issue
Block a user