mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
This commit is contained in:
parent
f461c4ef0c
commit
a9e48af260
@ -1276,20 +1276,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
|
||||
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
|
||||
return out_nodes, tokens_out
|
||||
|
||||
def _xla_call_lower(ctx, *args,
|
||||
backend=None, name, call_jaxpr, donated_invars, inline=None,
|
||||
device=None, keep_unused=None):
|
||||
del device, donated_invars, inline, keep_unused # Ignored.
|
||||
out_nodes, tokens = _call_lowering(
|
||||
name, util.wrap_name(name, "jit"), call_jaxpr, backend,
|
||||
ctx.module_context, ctx.avals_in, ctx.avals_out, ctx.tokens_in,
|
||||
*args, dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out_nodes
|
||||
|
||||
register_lowering(xla.xla_call_p, _xla_call_lower)
|
||||
|
||||
def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
||||
def core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
||||
out_nodes, tokens = _call_lowering(
|
||||
name, name, call_jaxpr, backend, ctx.module_context,
|
||||
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
|
||||
@ -1297,9 +1284,9 @@ def _core_call_lowering(ctx, *args, name, backend=None, call_jaxpr):
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out_nodes
|
||||
|
||||
register_lowering(core.call_p, partial(_core_call_lowering, name="core_call"))
|
||||
register_lowering(core.call_p, partial(core_call_lowering, name="core_call"))
|
||||
register_lowering(core.closed_call_p,
|
||||
partial(_core_call_lowering, name="core_closed_call"))
|
||||
partial(core_call_lowering, name="core_closed_call"))
|
||||
|
||||
def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
|
||||
broadcast_dimensions) -> ir.Value:
|
||||
|
@ -902,12 +902,7 @@ class MapTrace(core.Trace):
|
||||
return MapTracer(self, outvals, out_shard_axes)
|
||||
|
||||
def process_call(self, call_primitive, fun, tracers, params):
|
||||
if call_primitive is not xla.xla_call_p: raise NotImplementedError
|
||||
bind = HashableFunction(
|
||||
lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs),
|
||||
(call_primitive, fun))
|
||||
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
||||
return self.process_primitive(fake_primitive, tracers, params)
|
||||
raise NotImplementedError
|
||||
|
||||
def process_map(self, call_primitive, fun, tracers, params):
|
||||
if params['devices'] is not None:
|
||||
@ -1998,15 +1993,14 @@ def _pmap_dce_rule(used_outputs, eqn):
|
||||
|
||||
|
||||
# Set param update handlers to update `donated_invars` just like xla_call_p
|
||||
pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
|
||||
pe.call_param_updaters[xla_pmap_p] = xla.xla_call_partial_eval_update_params
|
||||
pe.partial_eval_jaxpr_custom_rules[xla_pmap_p] = \
|
||||
partial(pe.call_partial_eval_custom_rule,
|
||||
'call_jaxpr', _pmap_partial_eval_custom_params_updater,
|
||||
res_aval=_pmap_partial_eval_custom_res_maker)
|
||||
pe.dce_rules[xla_pmap_p] = _pmap_dce_rule
|
||||
ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
||||
ad.call_transpose_param_updaters[xla_pmap_p] = \
|
||||
ad.call_transpose_param_updaters[xla.xla_call_p]
|
||||
ad.call_param_updaters[xla_pmap_p] = xla.xla_call_jvp_update_params
|
||||
ad.call_transpose_param_updaters[xla_pmap_p] = xla.xla_call_transpose_update_params
|
||||
|
||||
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
||||
|
||||
|
@ -28,17 +28,14 @@ from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol,
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.util import (safe_zip, safe_map, partition_list)
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
|
||||
from jax._src.typing import Shape
|
||||
|
||||
@ -157,7 +154,6 @@ xla_shape_handlers: Dict[Type[core.AbstractValue],
|
||||
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
||||
|
||||
|
||||
|
||||
# IR constants
|
||||
|
||||
# TODO(mattjj): try to remove this canonicalize_dtype stuff
|
||||
@ -348,11 +344,11 @@ def jaxpr_collectives(jaxpr):
|
||||
|
||||
### xla_call underlying jit
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove after 1 month from March 23, 2023.
|
||||
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
|
||||
xla_call = xla_call_p.bind
|
||||
|
||||
def _xla_call_partial_eval_update_params(
|
||||
|
||||
def xla_call_partial_eval_update_params(
|
||||
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
|
||||
) -> core.ParamDict:
|
||||
donated_invars = params['donated_invars']
|
||||
@ -366,57 +362,18 @@ def _xla_call_partial_eval_update_params(
|
||||
# Any new inputs are prepended to the left, so mark those as not donated.
|
||||
donated_invars = [False] * num_new_inputs + donated_invars
|
||||
return dict(params, donated_invars=tuple(donated_invars))
|
||||
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
|
||||
|
||||
def _xla_call_jvp_update_params(params, nz_tangents):
|
||||
def xla_call_jvp_update_params(params, nz_tangents):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
||||
new_donated_invars = (*donated_invars, *donated_tangents)
|
||||
return dict(params, donated_invars=new_donated_invars)
|
||||
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
|
||||
|
||||
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||
def xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
||||
donated_invars = params['donated_invars']
|
||||
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
||||
donated_cotangents = [False for nz in nonzero_cts if nz]
|
||||
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
||||
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
|
||||
|
||||
|
||||
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
||||
|
||||
|
||||
def _xla_call_partial_eval_custom_params_updater(
|
||||
unks_in: Sequence[bool], inst_in: Sequence[bool],
|
||||
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
||||
num_res: int, params_known: dict, params_staged: dict
|
||||
) -> Tuple[dict, dict]:
|
||||
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
|
||||
donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
||||
new_params_known = dict(params_known, donated_invars=tuple(donated_known))
|
||||
# added num_res new inputs to jaxpr_staged, so extend donated_invars
|
||||
_, donated_staged_ = partition_list(inst_in, params_staged['donated_invars'])
|
||||
donated_staged = [False] * num_res + donated_staged_
|
||||
new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged))
|
||||
return new_params_known, new_params_staged
|
||||
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
|
||||
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
_xla_call_partial_eval_custom_params_updater)
|
||||
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
|
||||
|
||||
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
|
||||
|
||||
|
||||
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
|
||||
settings: core.JaxprPpSettings,
|
||||
) -> pp.Doc:
|
||||
printed_params = {k:v for k, v in eqn.params.items() if
|
||||
k == 'call_jaxpr' or k == 'name' or
|
||||
k == 'backend' and v is not None or
|
||||
k == 'device' and v is not None or
|
||||
k == 'donated_invars' and any(v)}
|
||||
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
||||
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
|
||||
|
||||
|
||||
### translation tables
|
||||
|
@ -846,7 +846,7 @@ core.axis_substitution_rules[xmap_p] = _xmap_axis_subst
|
||||
# NOTE: We don't have to handle spmd_{in|out}_axes here, because
|
||||
# SPMD batching always gets involved as the last transform before XLA translation
|
||||
ad.JVPTrace.process_xmap = ad.JVPTrace.process_call # type: ignore
|
||||
ad.call_param_updaters[xmap_p] = ad.call_param_updaters[xla.xla_call_p]
|
||||
ad.call_param_updaters[xmap_p] = xla.xla_call_jvp_update_params
|
||||
|
||||
def _xmap_transpose(params, call_jaxpr, args, cts_in, cts_in_avals, reduce_axes):
|
||||
all_args, in_tree_def = tree_flatten(((), args, cts_in)) # empty consts
|
||||
|
@ -1655,16 +1655,6 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
jaxpr=new_jaxpr,
|
||||
num_carry=num_carry + 2,
|
||||
linear=linear[0:nr_const_and_carry] + (False, False) + linear[nr_const_and_carry:])))
|
||||
elif eqn.primitive is xla.xla_call_p:
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
eqns.append(
|
||||
eqn.replace(
|
||||
invars=eqn.invars + [input_token_var, input_itoken_var],
|
||||
outvars=eqn.outvars + [output_token_var, output_itoken_var],
|
||||
params=dict(
|
||||
eqn.params,
|
||||
call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True),
|
||||
donated_invars=eqn.params["donated_invars"] + (False, False))))
|
||||
elif eqn.primitive is pxla.xla_pmap_p:
|
||||
# We broadcast the input token into an array of tokens
|
||||
call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"])
|
||||
@ -1762,12 +1752,10 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
eqns.append(
|
||||
core.new_jaxpr_eqn(
|
||||
eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var, input_itoken_var],
|
||||
pred1_and_token1, xla.xla_call_p,
|
||||
pred1_and_token1, core.call_p,
|
||||
dict(
|
||||
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||
name="cond_before",
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
||||
inline=False),
|
||||
name="cond_before"),
|
||||
transformed_cond_jaxpr.jaxpr.effects,
|
||||
eqn.source_info))
|
||||
# Make a new cond "lambda pred, carry, token, itoken: pred"
|
||||
@ -1808,22 +1796,18 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
new_body_invars_body_constvars + new_body_invars_carry +
|
||||
[new_body_invars_token, new_body_invars_itoken],
|
||||
new_body_carry2 + [new_body_token2, new_body_itoken2],
|
||||
xla.xla_call_p,
|
||||
core.call_p,
|
||||
dict(
|
||||
call_jaxpr=transformed_body_jaxpr.jaxpr,
|
||||
name="body",
|
||||
donated_invars=(False,) * len(transformed_body_jaxpr.in_avals),
|
||||
inline=False),
|
||||
name="body"),
|
||||
transformed_body_jaxpr.effects,
|
||||
eqn.source_info),
|
||||
core.new_jaxpr_eqn(
|
||||
new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2, new_body_itoken2],
|
||||
[new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p,
|
||||
[new_body_pred2, new_body_token3, new_body_itoken3], core.call_p,
|
||||
dict(
|
||||
call_jaxpr=transformed_cond_jaxpr.jaxpr,
|
||||
name="cond_body",
|
||||
donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals),
|
||||
inline=False),
|
||||
name="cond_body"),
|
||||
transformed_cond_jaxpr.effects,
|
||||
eqn.source_info)
|
||||
]
|
||||
|
@ -1479,27 +1479,8 @@ class TensorFlowTrace(core.Trace):
|
||||
avals: Sequence[core.ShapedArray] = tuple(t.aval for t in tracers)
|
||||
interpreted_fun = _interpret_subtrace(fun, self.main, avals)
|
||||
extra_name_stack = None
|
||||
if call_primitive == xla.xla_call_p:
|
||||
extra_name_stack = util.wrap_name(params["name"], "jit")
|
||||
with _extended_name_stack(extra_name_stack):
|
||||
with core.new_sublevel():
|
||||
if call_primitive == xla.xla_call_p:
|
||||
if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
|
||||
# Make a nested tf.function(jit_compile=True)
|
||||
store_tf_res_avals: Sequence[core.ShapedArray] = []
|
||||
def f_tf(*tf_args):
|
||||
nonlocal store_tf_res_avals
|
||||
tf_res_out: Sequence[Tuple[TfVal, core.ShapedArray]] = \
|
||||
_call_wrapped_with_new_constant_cache(interpreted_fun, tf_args,
|
||||
fresh_constant_cache=False)
|
||||
tf_res_vals, tf_res_avals = util.unzip2(tf_res_out)
|
||||
store_tf_res_avals = tf_res_avals
|
||||
return tf_res_vals
|
||||
tf_vals_out = tf.function(f_tf, autograph=False, jit_compile=True)(*vals)
|
||||
vals_out = zip(tf_vals_out, store_tf_res_avals)
|
||||
else:
|
||||
vals_out = interpreted_fun.call_wrapped(*vals)
|
||||
else:
|
||||
vals_out = interpreted_fun.call_wrapped(*vals)
|
||||
return [TensorFlowTracer(self, v, a) for v, a in vals_out]
|
||||
|
||||
@ -1572,7 +1553,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
|
||||
|
||||
# Call primitives are inlined
|
||||
for unexpected in [core.call_p, xla.xla_call_p, maps.xmap_p]:
|
||||
for unexpected in [core.call_p, maps.xmap_p]:
|
||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
|
@ -267,13 +267,6 @@ register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series
|
||||
|
||||
call_param_updaters = {}
|
||||
|
||||
def _xla_call_param_updater(params, num_inputs):
|
||||
donated_invars = params['donated_invars']
|
||||
if any(donated_invars):
|
||||
raise NotImplementedError("donated_invars not supported with jet")
|
||||
return dict(params, donated_invars=(False,) * num_inputs)
|
||||
call_param_updaters[xla.xla_call_p] = _xla_call_param_updater
|
||||
|
||||
|
||||
### rule definitions
|
||||
|
||||
|
@ -581,25 +581,8 @@ class ShardMapTrace(core.Trace):
|
||||
return ShardMapTracer(self, out_rep, out_vals)
|
||||
|
||||
def process_call(self, call_primitive, fun, tracers, params):
|
||||
if call_primitive is not xla.xla_call_p: raise NotImplementedError
|
||||
fun, jaxpr = _grab_jaxpr_shadily(fun) # TODO remove with initial-style jit
|
||||
bind = partial(call_primitive.bind, fun) # TODO caching (compat w/ jaxpr())
|
||||
fake_primitive = pxla.FakePrimitive(multiple_results=True, bind=bind)
|
||||
_rep_rules[fake_primitive] = lambda *_, **__: set() # pytype: disable=container-type-mismatch
|
||||
out_tracers_ = self.process_primitive(fake_primitive, tracers, params)
|
||||
out_vals = [t.val for t in out_tracers_]
|
||||
if self.check:
|
||||
out_rep = _output_rep(self.mesh, jaxpr(), [t.rep for t in tracers])
|
||||
else:
|
||||
out_rep = [set()] * len(out_vals)
|
||||
return map(partial(ShardMapTracer, self), out_rep, out_vals)
|
||||
raise NotImplementedError
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _grab_jaxpr_shadily(*args):
|
||||
out = yield args, {}
|
||||
main = core.thread_local_state.trace_state.trace_stack.dynamic # forgive me
|
||||
jaxpr, _ = main.jaxpr_stack[-1].to_jaxpr(out)
|
||||
yield out, jaxpr
|
||||
|
||||
class ShardMapTracer(core.Tracer):
|
||||
rep: Set[AxisName]
|
||||
@ -711,10 +694,6 @@ def _axis_index_rule(mesh, *, axis_name):
|
||||
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||
|
||||
@register_rule(xla.xla_call_p)
|
||||
def _jit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _output_rep(mesh, jaxpr, in_rep)
|
||||
|
||||
@register_rule(debugging.debug_callback_p)
|
||||
def _debug_callback_rule(mesh, *in_rep, **_):
|
||||
return []
|
||||
|
@ -420,13 +420,6 @@ def eval_sparse(
|
||||
if prim not in sparse_rules_bcoo:
|
||||
_raise_unimplemented_primitive(prim)
|
||||
out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
|
||||
else:
|
||||
if prim is xla.xla_call_p:
|
||||
# TODO(vanderplas,frostig): workaround for binding call primitives
|
||||
# within a jaxpr interpreter
|
||||
params = eqn.params.copy()
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ())))
|
||||
out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params)
|
||||
else:
|
||||
out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
|
||||
out_bufs = out_bufs if prim.multiple_results else [out_bufs]
|
||||
@ -759,18 +752,6 @@ def _while_sparse(spenv, *spvalues, cond_jaxpr, cond_nconsts, body_jaxpr, body_n
|
||||
|
||||
sparse_rules_bcoo[lax.while_p] = _while_sparse
|
||||
|
||||
def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params):
|
||||
if any(donated_invars):
|
||||
raise NotImplementedError("sparse xla_call with donated_invars")
|
||||
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues)
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr))
|
||||
args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues))
|
||||
donated_invars = tuple(False for arg in args_flat)
|
||||
out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params)
|
||||
return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
|
||||
|
||||
sparse_rules_bcoo[xla.xla_call_p] = _xla_call_sparse
|
||||
|
||||
|
||||
def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
resource_env, donated_invars, name, keep_unused, inline):
|
||||
|
@ -34,9 +34,9 @@ from jax._src.interpreters.mlir import (
|
||||
_call_lowering as _call_lowering,
|
||||
_lowerings as _lowerings,
|
||||
_platform_specific_lowerings as _platform_specific_lowerings,
|
||||
_xla_call_lower as _xla_call_lower,
|
||||
aval_to_ir_type as aval_to_ir_type,
|
||||
aval_to_ir_types as aval_to_ir_types,
|
||||
core_call_lowering as core_call_lowering,
|
||||
dense_bool_elements as dense_bool_elements,
|
||||
dense_int_elements as dense_int_elements,
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
|
@ -29,8 +29,7 @@ from jax._src.interpreters.xla import (
|
||||
register_translation as register_translation,
|
||||
sharding_to_proto as sharding_to_proto,
|
||||
translations as translations,
|
||||
xla_call as xla_call,
|
||||
xla_call_p as xla_call_p,
|
||||
xla_call_p as _deprecated_xla_call_p,
|
||||
xla_destructure as xla_destructure,
|
||||
xla_shape_handlers as xla_shape_handlers,
|
||||
device_put as _deprecated_device_put,
|
||||
@ -83,6 +82,13 @@ _deprecations = {
|
||||
),
|
||||
_deprecated_device_put,
|
||||
),
|
||||
"xla_call_p": (
|
||||
(
|
||||
"jax.interpreters.xla.xla_call_p is deprecated. Please use"
|
||||
" jax.experimental.pjit.pjit_p instead."
|
||||
),
|
||||
_deprecated_xla_call_p,
|
||||
),
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
@ -98,4 +104,5 @@ if typing.TYPE_CHECKING:
|
||||
from jax._src.interpreters.xla import (
|
||||
device_put as device_put,
|
||||
)
|
||||
from jax._src.interpreters.xla import xla_call_p as xla_call_p
|
||||
del typing
|
||||
|
@ -205,7 +205,7 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
|
||||
self.assertIn(bar_effect, jaxpr.jaxpr.effects)
|
||||
|
||||
def test_xla_call_primitive_inherits_effects(self):
|
||||
def test_jit_primitive_inherits_effects(self):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
|
@ -100,7 +100,7 @@ class NameStackTest(jtu.JaxTestCase):
|
||||
hlo_text = _get_hlo(f)(2)
|
||||
self.assertIn('foo/jit(core_call)/bar', hlo_text)
|
||||
|
||||
def test_xla_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
|
||||
def test_jit_jaxpr_should_not_store_outer_name_stack(self):
|
||||
@jax.named_scope('foo')
|
||||
def f(x):
|
||||
@jax.jit
|
||||
|
Loading…
x
Reference in New Issue
Block a user