Merge pull request #6275 from google:omnistaging-forever

PiperOrigin-RevId: 365681256
This commit is contained in:
jax authors 2021-03-29 15:43:09 -07:00
commit 6d0b8327c7
39 changed files with 174 additions and 1516 deletions

View File

@ -322,6 +322,7 @@ class TracerArrayConversionError(JAXTypeError):
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
"""
def __init__(self, tracer: "core.Tracer"):
# TODO(mattjj, jakevdp): use tracer._origin_msg() here
super().__init__(
"The numpy.ndarray conversion method __array__() was called on "
f"the JAX Tracer object {tracer}")

View File

@ -45,7 +45,7 @@ from jax.interpreters import masking
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax._src.traceback_util import api_boundary
from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip,
from jax._src.util import (partial, unzip2, unzip3, safe_map, safe_zip,
split_list, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
treedef_children, treedef_tuple, tree_multimap,
@ -481,17 +481,12 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
if config.omnistaging_enabled:
partial_eval_jaxpr = pe.partial_eval_jaxpr
else:
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
# Fixpoint computation of unknown carry. Each iteration promotes
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
carry_uk = carry_init_uk
for _ in range(1 + len(carry_uk)):
body_jaxpr_known, _, carry_out_uk = partial_eval_jaxpr( # type: ignore
body_jaxpr_known, _, carry_out_uk = pe.partial_eval_jaxpr( # type: ignore
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
if carry_out_uk == carry_uk:
break
@ -500,7 +495,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
else:
assert False, "Fixpoint not reached"
cond_jaxpr_known, _, cond_uk = partial_eval_jaxpr( # type: ignore
cond_jaxpr_known, _, cond_uk = pe.partial_eval_jaxpr( # type: ignore
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
@ -850,11 +845,6 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
unknowns = [t.pval[0] is not None for t in tracers]
index_uk, *ops_uk = unknowns
if config.omnistaging_enabled:
partial_eval_jaxpr = pe.partial_eval_jaxpr
else:
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
if index_uk:
# When the branch index is unknown, we stage out the whole cond.
params = dict(branches=branches, linear=linear)
@ -862,13 +852,13 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
branches_out_uks = []
for branch_jaxpr in branches:
_, _, out_uks = partial_eval_jaxpr(branch_jaxpr, ops_uk, instantiate=False)
_, _, out_uks = pe.partial_eval_jaxpr(branch_jaxpr, ops_uk, instantiate=False)
branches_out_uks.append(out_uks)
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
branches_1, branches_2, branch_res_avals = [], [], []
for branch_jaxpr in branches:
branch_jaxpr_1, branch_jaxpr_2, _ = partial_eval_jaxpr(
branch_jaxpr_1, branch_jaxpr_2, _ = pe.partial_eval_jaxpr(
branch_jaxpr, ops_uk, instantiate=out_uks)
branch_num_res = len(branch_jaxpr_1.out_avals) - len(out_uks)
@ -1552,22 +1542,11 @@ def _prune_zeros(ts):
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: ignore
params = dict(reverse=reverse, length=length, num_consts=num_consts,
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
unroll=unroll)
return trace.default_process_primitive(scan_p, tracers, params)
num_ys = len(jaxpr.out_avals) - num_carry
unknowns = [t.pval[0] is not None for t in tracers]
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
if config.omnistaging_enabled:
partial_eval_jaxpr = pe.partial_eval_jaxpr
else:
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
# Fixpoint computation of which carry are unknown (not a constant): either
# unknown from init, or the carry out is unknown. Each iteration promotes
# at least one carry to unknown. We need at most len(carry) iterations,
@ -1576,7 +1555,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
carry_uk = init_uk
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_1, jaxpr_2, out_uk = partial_eval_jaxpr(
jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
carry_uk_out = out_uk[:num_carry]
if carry_uk_out == carry_uk:
@ -1744,12 +1723,7 @@ def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
if config.omnistaging_enabled:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
else:
pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
out_avals, _ = unzip2(pvals_out)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
@ -2623,63 +2597,3 @@ def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_open_jaxpr, _initial_style_jaxpr, \
_initial_style_jaxprs_with_common_consts
@cache()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
with core.initial_style_staging(): # type: ignore
jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
return jaxpr, out_pvals, consts, out_tree
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
jaxpr, out_pvals, consts, out_tree = _initial_style_open_jaxpr(
fun, in_tree, in_avals)
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
return closed_jaxpr, consts, out_tree()
@cache()
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
in_tree, in_avals):
# When staging the branches of a conditional into jaxprs, constants are
# extracted from each branch and converted to jaxpr arguments. To use the
# staged jaxprs as the branches to a conditional *primitive*, we need for
# their (input) signatures to match. This function "joins" the staged jaxprs:
# for each one, it makes another that accepts *all* constants, but only uses
# those that it needs (dropping the rest).
jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
_initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs])
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = tuple(
tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
for consts in all_consts)
unused_const_vars = tuple(
tuple(newvar(aval) for aval in const_avals)
for const_avals in all_const_avals)
def pad_jaxpr_constvars(i, jaxpr):
prefix = util.concatenate(unused_const_vars[:i])
suffix = util.concatenate(unused_const_vars[i+1:])
constvars = prefix + jaxpr.constvars + suffix
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
def type_and_const_convert_jaxpr(jaxpr, out_pvals):
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
closed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
return (tuple(closed_jaxprs),
tuple(util.concatenate(all_consts)),
tuple(out_tree() for out_tree in all_out_trees))

View File

@ -1767,8 +1767,6 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
weak_type = dtype is None and dtypes.is_weakly_typed(x)
dtype = dtype or _dtype(x)
if not config.omnistaging_enabled:
fill_value = tie_in(x, fill_value)
return full(fill_shape, _convert_element_type(fill_value, dtype, weak_type))
@ -4017,10 +4015,7 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
if type(t) is ad_util.Zero:
return [ad_util.Zero(operand.aval)] + [None] * len(start_indices)
else:
if config.omnistaging_enabled:
zeros = full(operand_shape, 0, operand_dtype)
else:
zeros = full(operand_shape, tie_in(t, _zero(t)))
zeros = full(operand_shape, 0, operand_dtype)
return ([dynamic_update_slice(zeros, t, start_indices)] +
[None] * len(start_indices))
@ -4319,10 +4314,7 @@ def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
if type(t) is ad_util.Zero:
out = ad_util.Zero(operand.aval)
else:
if config.omnistaging_enabled:
zeros = full(operand_shape, _zero(t))
else:
zeros = full(operand_shape, tie_in(t, _zero(t)))
zeros = full(operand_shape, _zero(t))
scatter_dnums = ScatterDimensionNumbers(
update_window_dims=dimension_numbers.offset_dims,
inserted_window_dims=dimension_numbers.collapsed_slice_dims,
@ -5924,8 +5916,6 @@ def _top_k_jvp(primals, tangents, *, k):
gather_indices = []
for i in range(rank-1):
_iota = iota(k_idxs.dtype, idx_shape[i])
if not config.omnistaging_enabled:
_iota = tie_in(operand, _iota)
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
gather_indices.append(_iota)
gather_indices.append(reshape(k_idxs, gather_index_shape))
@ -5979,14 +5969,7 @@ def create_token(_=None):
The argument is ignored. It exists for backward compatibility.
"""
if config.omnistaging_enabled:
return create_token_p.bind()
else:
x = _
if x is None:
raise ValueError(
'create_token needs a tie-in operand unless omnistaging is enabled.')
return create_token_p.bind(stop_gradient(x))
return create_token_p.bind()
create_token_p = Primitive("create_token")
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
@ -6601,71 +6584,3 @@ def _canonicalize_axis(axis, num_dims):
if axis < 0:
axis = axis + num_dims
return axis
tie_in_p = Primitive('tie_in')
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global tie_in
def tie_in(x: Array, y: Array) -> Array:
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
When staging to XLA (e.g. running under jit or pmap), values that don't depend
on computation inputs are computed op-by-op, and folded into the XLA
computation as constants.
``tie_in`` provides a way to explicitly stage values into the computation.
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
is ``y``, but staged to XLA. Downstream use of the result will also be staged
to XLA.
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
XLA as long as ``x`` is staged to XLA.
"""
if config.omnistaging_enabled:
return y
else:
return tie_in_p.bind(x, y)
# If lax has already been imported, we need to monkey-patch the
# lax/__init__.py import of tie_in. If not (i.e. if this is running at lax
# module creation time) then we'll get an import error.
try:
jax.lax.tie_in = tie_in
except AttributeError:
pass
def _tie_in_transpose_rule(t, x, y):
if ad.is_undefined_primal(x):
return [ad_util.Zero(x.aval), t]
else:
return [ad_util.Zero.from_value(x), t]
def _tie_in_batch_rule(batched_args, batch_dims):
y = tie_in(*batched_args)
_, bdim_y = batch_dims
return y, bdim_y
def _tie_in_impl(x, y):
core.check_valid_jaxtype(x)
core.check_valid_jaxtype(y)
return y
def _tie_in_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
if type(y_dot) is ad_util.Zero or core.get_aval(y_dot).dtype is dtypes.float0:
return y, y_dot # skip tying in in this case
else:
return ad.linear_jvp(tie_in_p, primals, tangents)
tie_in_p.def_impl(_tie_in_impl)
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
xla.translations[tie_in_p] = lambda c, x, y: y
ad.primitive_jvps[tie_in_p] = _tie_in_jvp
ad.primitive_transposes[tie_in_p] = partial(ad.linear_transpose2, _tie_in_transpose_rule)
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]

View File

@ -25,18 +25,15 @@ import numpy as np
from jax import core
from jax import dtypes
from jax import tree_util
from jax._src import source_info_util
from . import lax
from jax.core import ShapedArray, AxisName, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax._src.util import partial, unzip2, prod, canonicalize_axis, safe_map, moveaxis
from jax.lib import xla_client as xc
from jax.lib import xla_bridge as xb
from jax.config import config
from jax._src.numpy import lax_numpy
xops = xc.ops
@ -1044,15 +1041,9 @@ def all_gather(x, axis_name, *, axis_index_groups=None):
[ 4. 5. 6. 7.]]
"""
axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
# The all_gather primitive doesn't work when omni-staging is disabled.
if not config.omnistaging_enabled:
return _all_gather_via_psum(x, all_gather_dimension=0, axis_name=axis_name,
axis_index_groups=axis_index_groups, axis_size=axis_size)
def bind(x):
return all_gather_p.bind(x, all_gather_dimension=0, axis_name=axis_name,
axis_index_groups=axis_index_groups, axis_size=axis_size)
bind = partial(all_gather_p.bind, all_gather_dimension=0,
axis_name=axis_name, axis_index_groups=axis_index_groups,
axis_size=axis_size)
return tree_util.tree_map(bind, x)
def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
@ -1337,42 +1328,3 @@ xla.parallel_translations[pgather_p] = _pgather_parallel_translation
batching.primitive_batchers[pgather_p] = _pgather_batcher
batching.collective_rules[pgather_p] = _pgather_collective_batcher
core.axis_substitution_rules[pgather_p] = partial(_subst_all_names_in_param, 'axes')
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global axis_index
psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore
psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore
def _axis_index_bind(*, axis_name):
dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
frame = dynamic_axis_env[axis_name]
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
nreps = dynamic_axis_env.nreps
trace = frame.pmap_trace
out_aval = _axis_index_abstract_eval(
nreps=nreps, sizes=sizes, axis_name=axis_name)
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
source_info_util.current())
out_tracer.recipe = eqn
return out_tracer
def _axis_index_translation_rule(c, nreps, sizes, axis_name):
div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
def _axis_index_abstract_eval(*, nreps, sizes, axis_name):
return ShapedArray((), np.int32, named_shape={axis_name: sizes[-1]})
axis_index_p.def_custom_bind(_axis_index_bind)
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
xla.translations[axis_index_p] = _axis_index_translation_rule

View File

@ -2306,12 +2306,7 @@ def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims)
normalizer = normalizer - ddof
if config.omnistaging_enabled:
normalizer_mask = lax.le(normalizer, 0)
else:
zero = lax.full_like(normalizer, 0, shape=())
normalizer_mask = lax.le(normalizer, zero)
normalizer_mask = lax.le(normalizer, 0)
result = nansum(centered, axis, keepdims=keepdims)
result = where(normalizer_mask, nan, result)
divisor = where(normalizer_mask, 1, normalizer)
@ -4353,8 +4348,6 @@ def _take_along_axis(arr, indices, axis):
j += 1
elif idx_shape[i] != 1:
iota = lax.iota(_dtype(indices), out_shape[i])
if not config.omnistaging_enabled:
iota = lax.tie_in(arr, iota)
iota = lax.broadcast_in_dim(iota, gather_index_shape, (j,))
gather_indices.append(iota)
slice_sizes.append(1)

View File

@ -57,18 +57,6 @@ def unzip3(xyzs):
zs.append(z)
return tuple(xs), tuple(ys), tuple(zs)
def unzip4(wxyzs):
ws = []
xs = []
ys = []
zs = []
for w, x, y, z in wxyzs:
ws.append(w)
xs.append(x)
ys.append(y)
zs.append(z)
return tuple(ws), tuple(xs), tuple(ys), tuple(zs)
def subvals(lst, replace):
lst = list(lst)
for i, v in replace:

View File

@ -206,7 +206,7 @@ def jit(
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ]
"""
if FLAGS.experimental_cpp_jit and config.omnistaging_enabled:
if FLAGS.experimental_cpp_jit:
return _cpp_jit(fun, static_argnums, device, backend, donate_argnums)
else:
return _python_jit(fun, static_argnums, device, backend, donate_argnums)
@ -644,16 +644,10 @@ def xla_computation(fun: Callable,
"xla_computation in_parts", in_tree.children()[0], in_parts))
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
avals = map(shaped_abstractify, args_flat)
if config.omnistaging_enabled:
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
else:
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, pvals, instantiate=True, stage_out=True) # type: ignore
out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals]
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
if out_parts is None:
@ -1549,10 +1543,6 @@ def pmap(
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap", kws=True)
for arg in args: _check_arg(arg)
flat_fun, out_tree = flatten_fun(f, in_tree)
if not config.omnistaging_enabled and out_axes != 0:
raise ValueError("out_axes supported only with omnistaging enabled")
if not config.omnistaging_enabled and any(in_axis not in {None, 0} for in_axis in in_axes_flat):
raise ValueError("in_axes other than 0 and None only supported with omnistaging enabled")
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")
# NOTE: We don't put out_tree() in the closure, because it's (1) non-hashable,
@ -2048,19 +2038,10 @@ def make_jaxpr(fun: Callable,
jax_args, in_tree = tree_flatten((args, kwargs))
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
in_avals = map(shaped_abstractify, jax_args)
if config.omnistaging_enabled:
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
else:
if axis_env:
raise NotImplementedError(
"axis_env argument to make_jaxpr only supported with omnistaging.")
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
if return_shape:
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
@ -2488,11 +2469,7 @@ class CustomTransformsFunction(object):
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_pvals = [pe.PartialVal.unknown(raise_to_shaped(core.get_aval(x)))
for x in args_flat]
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
else:
with core.initial_style_staging(): # type: ignore
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
in_tree=in_tree, out_tree=out_tree(),
num_consts=len(consts))

View File

@ -59,11 +59,10 @@ class Config:
self.FLAGS = NameSpace(self.read, self.update)
self.use_absl = False
self._contextmanager_flags = set()
self._update_hooks = {}
# TODO(mattjj): delete these when only omnistaging is available
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
self._omnistaging_disablers = []
self._update_hooks = {}
def update(self, name, val):
if self.use_absl:
@ -160,13 +159,8 @@ class Config:
"see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.\n"
"To remove this warning, unset the JAX_OMNISTAGING environment variable.")
def register_omnistaging_disabler(self, disabler):
if self.omnistaging_enabled:
self._omnistaging_disablers.append(disabler)
def enable_omnistaging(self):
if not self.omnistaging_enabled:
raise Exception("can't re-enable omnistaging after it's been disabled")
return # TODO(mattjj): remove all callers
def disable_omnistaging(self):
warnings.warn(
@ -341,6 +335,7 @@ FLAGS = flags.FLAGS
already_configured_with_absl = False
# TODO(mattjj): remove all uses of this flag
flags.DEFINE_bool(
'jax_omnistaging',
bool_env('JAX_OMNISTAGING', True),

View File

@ -1819,166 +1819,6 @@ def pp_kv_pairs(kv_pairs):
else:
return pp('')
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env, \
eval_context, extra_jit_context
class TraceStack:
upward: List[MainTrace]
downward: List[MainTrace]
def __init__(self):
self.upward = []
self.downward = []
def next_level(self, bottom: bool) -> int:
if bottom:
return - (len(self.downward) + 1)
else:
return len(self.upward)
def push(self, main_trace: MainTrace, bottom: bool) -> None:
if bottom:
self.downward.append(main_trace)
else:
self.upward.append(main_trace)
def pop(self, bottom: bool) -> None:
if bottom:
self.downward.pop()
else:
self.upward.pop()
def __repr__(self) -> str:
return 'Trace stack\n{} ---\n{}'.format(
map(' {}\n'.format, self.upward[::-1]),
map(' {}\n'.format, self.downward))
def copy(self):
new = TraceStack()
new.upward = self.upward[:]
new.downward = self.downward[:]
return new
class TraceState:
trace_stack: TraceStack
substack: List[Sublevel]
initial_style: bool
def __init__(self) -> None:
self.trace_stack = TraceStack() # type: ignore
self.substack = [Sublevel(0)]
self.initial_style = False
def copy(self):
new = TraceState()
new.trace_stack = self.trace_stack.copy()
new.substack = self.substack[:]
new.initial_style = self.initial_style
return new
def extra_jit_context(trace_stack):
return None
thread_local_state = ThreadLocalState()
def reset_trace_state() -> bool:
"Reset the global trace state and return True if it was already clean."
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
thread_local_state.trace_state.trace_stack.downward or
thread_local_state.trace_state.trace_stack.upward):
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
return True
@contextmanager
def new_main(trace_type: Type[Trace], bottom=False, **payload) -> Generator[MainTrace, None, None]:
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
main = MainTrace(level, trace_type, **payload)
thread_local_state.trace_state.trace_stack.push(main, bottom)
try:
yield main
finally:
thread_local_state.trace_state.trace_stack.pop(bottom)
if config.jax_check_tracer_leaks:
t = ref(main)
del main
if t() is not None:
print(thread_local_state.trace_state.trace_stack)
raise Exception('Leaked trace {}'.format(t()))
def find_top_trace(xs) -> Optional[Trace]:
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
key=attrgetter('level'), default=None)
return top_trace and top_trace.main.with_cur_sublevel()
@contextmanager
def eval_context():
yield # dummy implementation for forward compatibility
def bind(self, *args, **kwargs):
assert not config.jax_enable_checks or all(isinstance(arg, Tracer)
or valid_jaxtype(arg) for arg in args), args
top_trace = find_top_trace(args)
if top_trace is None:
return self.impl(*args, **kwargs)
tracers = map(top_trace.full_raise, args)
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
if self.multiple_results:
return map(full_lower, out_tracer)
else:
return full_lower(out_tracer)
Primitive.bind = bind # type: ignore
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
fun: lu.WrappedFun, *args, **params):
out_axes_transforms = _IgnoreElemList()
if primitive.map_primitive:
out_axes_thunk = params['out_axes_thunk']
# The new thunk depends deterministically on the old thunk and the wrapped function.
# Any caching already has to include the wrapped function as part of the key, so we
# only use the previous thunk for equality checks.
@as_hashable_function(closure=out_axes_thunk)
def new_out_axes_thunk():
out_axes = out_axes_thunk()
for t in out_axes_transforms:
out_axes = t(out_axes)
return out_axes
params = dict(params, out_axes_thunk=new_out_axes_thunk)
params_tuple = tuple(params.items())
top_trace = find_top_trace(args)
level = (thread_local_state.trace_state.trace_stack.next_level(True)
if top_trace is None else top_trace.level)
params_tuple = tuple(params.items())
fun, env_trace_todo = process_env_traces(
fun, primitive, level, params_tuple, out_axes_transforms)
if top_trace is None:
with new_sublevel():
outs = primitive.impl(fun, *args, **params)
else:
tracers = map(top_trace.full_raise, args)
outs = primitive.process(top_trace, fun, tracers, params)
return apply_todos(env_trace_todo(), map(full_lower, outs))
@contextmanager
def extend_axis_env(axis_name, size: int, tag: Any):
yield
@contextmanager
def initial_style_staging():
trace_state = thread_local_state.trace_state
prev, trace_state.initial_style = trace_state.initial_style, True
try:
yield
finally:
trace_state.initial_style = prev
# Casting float0 array to a float-valued zero array.
def zeros_like_float0(array, dtype=None):
if not dtype:

View File

@ -213,16 +213,8 @@ class custom_jvp:
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
if config.omnistaging_enabled:
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
else:
if _initial_style_staging():
out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat) # type: ignore
out_tree = out_tree1()
else:
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
return tree_unflatten(out_tree, out_flat)
def _add_args(f, extra_args):
@ -489,21 +481,10 @@ class custom_vjp:
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
if config.omnistaging_enabled:
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
else:
if _initial_style_staging():
out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd, # type: ignore
*args_flat, out_trees=out_trees)
out_tree = out_tree()
else:
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
return tree_unflatten(out_tree, out_flat)
@partial(partial, tree_map)
@ -662,8 +643,6 @@ def _custom_vjp_call_jaxpr_vmap(
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
out_trees=out_trees, num_consts=num_consts)
out_dims = out_dims2[0] if out_dims2 else out_dims1
if not config.omnistaging_enabled:
out_dims = out_dims[:len(batched_outs)]
return batched_outs, out_dims
batching.initial_style_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
@ -674,76 +653,6 @@ batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_jaxpr, custom_vjp_call_jaxpr, custom_jvp_call_jaxpr
def _initial_style_jaxpr(fun, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
bottom=True, stage_out=False) # type: ignore
assert not any(isinstance(c, core.Tracer) for c in consts)
return jaxpr, consts
def jvp_bind(self, fun, jvp, *args):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
fun, env_trace_todo1 = core.process_env_traces(
fun, self, top_trace and top_trace.level, (), None)
jvp, env_trace_todo2 = core.process_env_traces(
jvp, self, top_trace and top_trace.level, (), None)
if top_trace is None:
with core.new_sublevel():
outs = self.impl(fun, jvp, *args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
if env_trace_todo:
raise core.UnexpectedTracerError
return map(core.full_lower, outs)
CustomJVPCallPrimitive.bind = jvp_bind # type: ignore
def jvp_post_process(self, trace, out_tracers, params):
raise core.UnexpectedTracerError
CustomJVPCallPrimitive.post_process = jvp_post_process # type: ignore
def vjp_bind(self, fun, fwd, bwd, *args, out_trees):
args = map(core.full_lower, args)
top_trace = core.find_top_trace(args)
if top_trace is None:
outs = fun.call_wrapped(*args)
else:
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
out_trees=out_trees)
return map(core.full_lower, outs)
CustomVJPCallPrimitive.bind = vjp_bind # type: ignore
def vjp_post_process(self, trace, out_tracers, params):
raise core.UnexpectedTracerError
CustomVJPCallPrimitive.post_process = vjp_post_process # type: ignore
def custom_jvp_call_jaxpr(fun: Callable, jvp: Callable, *args):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
jvp_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
return custom_jvp_call_jaxpr_p.bind(
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
fwd_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(fwd, in_avals))
return custom_vjp_call_jaxpr_p.bind(
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
num_consts=len(consts))
def custom_gradient(fun):
"""Convenience function for defining custom VJP rules (aka custom gradients).
@ -813,11 +722,7 @@ def custom_gradient(fun):
ans_flat, out_tree = tree_flatten((ans,))
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
else:
ans_pvals = [pe.PartialVal.unknown(a) for a in ans_avals]
jaxpr, _, consts = pe.trace_to_jaxpr(rule, ans_pvals, instantiate=True)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
def bwd(res, cts):
@ -921,15 +826,8 @@ def closure_convert(fun, *example_args):
@cache()
def _closure_convert_for_avals(fun, in_tree, in_avals):
if config.omnistaging_enabled:
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
else:
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
with core.initial_style_staging(): # type: ignore
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()
# We only want to closure convert for constants with respect to which we're

View File

@ -27,7 +27,6 @@ import numpy as np
from jax.tree_util import tree_flatten, tree_unflatten
from jax.api_util import flatten_fun_nokwargs
from jax import ad_util, core, lax, xla
from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, wrap_name
import jax.numpy as jnp
import jax.linear_util as lu
@ -274,10 +273,6 @@ def _def_passthrough(prim, argnums=(0,)):
_def_passthrough(lax.select_p, (0, 1, 2))
_def_passthrough(lax.broadcast_in_dim_p)
_def_passthrough(xla.device_put_p)
try:
_def_passthrough(lax_internal.tie_in_p, (0, 1))
except AttributeError:
pass
class _DoubleDouble:

View File

@ -152,16 +152,9 @@ def convert(fun: Callable, *,
def converted_fun(*args: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if config.omnistaging_enabled:
if not core.trace_state_clean():
raise ValueError("convert must be used outside all JAX transformations."
+ f"Trace state: {core.thread_local_state.trace_state}")
else:
if (core.thread_local_state.trace_state.trace_stack.downward or
core.thread_local_state.trace_state.trace_stack.upward or
core.thread_local_state.trace_state.substack != [core.Sublevel(0)]):
raise ValueError("convert must be used outside all JAX transformations."
+ f"Trace state: {core.thread_local_state.trace_state}")
if not core.trace_state_clean():
raise ValueError("convert must be used outside all JAX transformations."
+ f"Trace state: {core.thread_local_state.trace_state}")
def check_arg(a):
if not _is_tfval(a):
@ -267,8 +260,7 @@ def _interpret_fun(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],
in_avals: Sequence[core.AbstractValue]
) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
new_main = core.new_base_main if config.omnistaging_enabled else core.new_main
with new_main(TensorFlowTrace) as main: # type: ignore
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
fun = _interpret_subtrace(fun, main, in_avals)
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = fun.call_wrapped(*in_vals)
del main
@ -813,10 +805,6 @@ tf_not_yet_impl = [
"call_tf",
]
try:
tf_impl[lax.tie_in_p] = lambda x, y: y
except AttributeError:
pass
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
tf_impl[ad_util.zeros_like_p] = tf.zeros_like

View File

@ -379,7 +379,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "grad")
@jtu.skip_on_flag('jax_omnistaging', False)
def test_convert_nullary_func(self):
# Even nullary functions are converted to TF (as opposed to constant-folded
# in JAX prior to conversion).
@ -389,7 +388,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def()
self.assertIn('op: "Sin"', str(f_tf_graph))
@jtu.skip_on_flag('jax_omnistaging', False)
def test_convert_of_nested_independent_jit(self):
def func(x):
def inner1(y):
@ -449,7 +447,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
@jtu.skip_on_flag('jax_omnistaging', False)
def test_name_scope(self):
log = []

View File

@ -682,28 +682,6 @@ def _make_add_any_harness(name, *, shapes=((2,), (2,)), dtype=np.float32):
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
_make_add_any_harness("dtypes", dtype=dtype)
for rhs_dtype in jtu.dtypes.all:
lhs_dtype = np.float32
lhs_shape = (2, 3)
rhs_shape = (4, 5)
define(
lax.tie_in_p,
f"lhs={jtu.format_shape_dtype_string(lhs_shape, lhs_dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}",
lax.tie_in_p.bind,
[RandArg(lhs_shape, lhs_dtype),
RandArg(rhs_shape, rhs_dtype)],
jax_unimplemented=[
Limitation(
"requires omnistaging to be disabled",
enabled=config.omnistaging_enabled)
],
dtype=rhs_dtype,
lhs_shape=lhs_shape,
lhs_dtype=lhs_dtype,
rhs_shape=rhs_shape,
rhs_dtype=rhs_dtype,
primitive=lax.tie_in_p)
for dtype in jtu.dtypes.all:
shape: Tuple[int, ...] = (20, 20)
define(

View File

@ -129,8 +129,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
all_primitives = tuple(sorted(all_primitives, key=str))
for p in all_primitives:
# TODO: remove tie_in once omnistaging is on by default
if p.name == "axis_index" or p.name == "tie_in":
if p.name == "axis_index":
continue
if p.name in tf_not_yet_impl:
self.assertNotIn(

View File

@ -589,6 +589,3 @@ def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr,
del jvp_jaxpr_thunk
return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in)
jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule
deflinear(lax.tie_in_p)

View File

@ -118,7 +118,6 @@ from jax import tree_util
from jax import numpy as jnp
from jax.interpreters import partial_eval as pe
from jax._src.util import safe_map
from jax.config import config
class Scope(object):
@ -291,25 +290,15 @@ class Scope(object):
def start_subtrace(self):
"""Starts a nested trace, returns the Trace object."""
# TODO: This follows the __enter__ part of core.new_main.
if config.omnistaging_enabled:
level = core.thread_local_state.trace_state.trace_stack.next_level()
main = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(main)
self._count_subtraces += 1
return pe.JaxprTrace(main, core.cur_sublevel())
else:
level = core.thread_local_state.trace_state.trace_stack.next_level(False)
main = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(main, False)
self._count_subtraces += 1
return pe.JaxprTrace(main, core.cur_sublevel())
level = core.thread_local_state.trace_state.trace_stack.next_level()
main = core.MainTrace(level, pe.JaxprTrace)
core.thread_local_state.trace_state.trace_stack.push(main)
self._count_subtraces += 1
return pe.JaxprTrace(main, core.cur_sublevel())
def end_subtrace(self):
# TODO: This follows the __exit__ part of core.new_main
if config.omnistaging_enabled:
core.thread_local_state.trace_state.trace_stack.pop()
else:
core.thread_local_state.trace_state.trace_stack.pop(False)
core.thread_local_state.trace_state.trace_stack.pop()
self._count_subtraces -= 1

View File

@ -550,13 +550,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
# (in this case via partial_eval) before we call into backward_pass again.
typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
unknowns = map(is_undefined_primal, primals_in)
if config.omnistaging_enabled:
primal_jaxpr, tangent_jaxpr, out_unknowns = \
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore
else:
primal_jaxpr, tangent_jaxpr, out_unknowns = \
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
trace_type=None) # type: ignore
primal_jaxpr, tangent_jaxpr, out_unknowns = \
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore
def do_transpose(primals_in, cotangents_in):
# NOTE: This is passing in undefined primals in place of tangent arguments, but it
@ -706,12 +701,7 @@ def defvjp_all(prim, custom_vjp):
primals_out = [primals_out]
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
if config.omnistaging_enabled:
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
else:
with core.initial_style_staging(): # type: ignore
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
instantiate=True)
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
num_res=len(res), out_avals=out_avals)
return primals_out + tangents_out
@ -766,17 +756,3 @@ class CustomVJPException(Exception):
"closed-over value into the custom_vjp function as an argument, and "
"adapting the custom_vjp fwd and bwd rules.")
super().__init__(msg)
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global jvp_jaxpr
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
assert len(jaxpr.in_avals) == len(nonzeros)
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
return core.ClosedJaxpr(jaxpr_out, consts), out_nonzeros()

View File

@ -48,7 +48,7 @@ def batchfun(axis_name, axis_size, in_dims, *in_vals):
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
in_dims = in_dims() if callable(in_dims) else in_dims
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
and not isinstance(core.get_aval(x), core.AbstractUnit) # non-omnistaging
and not isinstance(core.get_aval(x), core.AbstractUnit)
else ax for x, ax in zip(in_vals, in_dims)]
with core.new_main(BatchTrace, axis_name=axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main):
@ -477,20 +477,4 @@ def _merge_bdims(x, y):
return x # arbitrary
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global batch_jaxpr
def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
for aval, b in zip(jaxpr.in_avals, in_batched)]
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
return core.ClosedJaxpr(jaxpr_out, consts_out), out_batched()
collective_rules: Dict[core.Primitive, Callable] = {}

View File

@ -26,7 +26,6 @@ from ..api_util import flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten, register_pytree_node
from .._src.util import safe_map, safe_zip, split_list
from .. import custom_derivatives
from ..config import config
map = safe_map
zip = safe_zip
@ -214,14 +213,9 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)
in_avals = map(abstract, primals_in + primals_out + primals_out)
if config.omnistaging_enabled:
# TODO: Actually we do know some of the inputs, because they might be literals!
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
else:
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals),
instantiate=True, stage_out=False) # type: ignore
# TODO: Actually we do know some of the inputs, because they might be literals!
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then
ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])
@ -231,12 +225,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
unknowns = (map(ad.is_undefined_primal, primals_in) +
map(ad.is_undefined_primal, primals_out) +
[False] * len(cts_in))
if config.omnistaging_enabled:
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
ivjp_jaxpr, unknowns, instantiate=False) # type:ignore
else:
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) # type: ignore
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
ivjp_jaxpr, unknowns, instantiate=False) # type:ignore
unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
# Make sure we're able to compute all cotangents. We don't really care if we
# can reconstruct or primals or not, although failure to do so might result in
@ -312,15 +302,3 @@ def get_primitive_inverse(p):
def definverse(primitive, inverse_rule):
primitive_inverses[primitive] = inverse_rule
return inverse_rule
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _initial_style_jaxpr, custom_jvp_call
def _initial_style_jaxpr(fun, in_avals):
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
bottom=True, stage_out=False) # type: ignore
assert not any(isinstance(c, core.Tracer) for c in consts)
return core.ClosedJaxpr(jaxpr, consts)

View File

@ -17,7 +17,7 @@ from collections import namedtuple
import contextlib
import functools
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
List, Union, cast, Type, no_type_check)
List, Union, cast)
from weakref import ref
import numpy as np
@ -163,11 +163,6 @@ class JaxprTrace(Trace):
# We use process_call to handle both call and map primitives.
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
if not config.omnistaging_enabled:
if (self.main.trace_type is StagingJaxprTrace # type: ignore
and primitive in staged_out_calls): # type: ignore
tracers = map(self.instantiate_const_abstracted, tracers)
if primitive in call_partial_eval_rules:
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
@ -406,15 +401,8 @@ call_param_updaters: Dict[core.Primitive, Callable] = {}
def abstract_eval_fun(fun, *avals, **params):
if config.omnistaging_enabled:
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
else:
pvals_in = [PartialVal.unknown(a) for a in avals]
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
instantiate=True, stage_out=True) # type: ignore
avals_out, _ = unzip2(pvals_out)
for aval_out in avals_out:
assert isinstance(aval_out, AbstractValue) # instantiate=True
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
return avals_out
@ -775,13 +763,8 @@ def _remat_partial_eval(trace, _, f, tracers, params):
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
in_pvals = [t.pval for t in instantiated_tracers]
if config.omnistaging_enabled:
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
else:
with core.initial_style_staging(): # type: ignore
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
# Convert consts to inputs, since they may contain Tracer instances.
jaxpr = convert_constvars_jaxpr(jaxpr)
@ -792,12 +775,8 @@ def _remat_partial_eval(trace, _, f, tracers, params):
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
in_unknowns = ([False] * len(consts) +
[not t.is_known() for t in it.chain(env_tracers, tracers)])
if config.omnistaging_enabled:
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
else:
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
closed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
out_knowns = [not b for b in out_unknowns]
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
@ -1191,7 +1170,6 @@ def _memoize(thunk):
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
assert config.omnistaging_enabled
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f) # type: ignore
main.jaxpr_stack = () # type: ignore
@ -1221,7 +1199,6 @@ def extend_jaxpr_stack(main, frame):
main.jaxpr_stack = main.jaxpr_stack[:-1]
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
assert config.omnistaging_enabled
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f) # type: ignore
main.jaxpr_stack = () # type: ignore
@ -1234,7 +1211,6 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
# use trace_to_jaxpr directly because of an interaction with the curent
# custom_derivatives.py, which we work around by adding the EvalTrace.
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
assert config.omnistaging_enabled
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
return trace_to_jaxpr(fun, in_pvals)
@ -1247,124 +1223,3 @@ def fun_sourceinfo(fun):
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return "<unknown>"
@config.register_omnistaging_disabler
@no_type_check
def omnistaging_disabler() -> None:
global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: Union[bool, Sequence[bool]] = False,
stage_out=False, bottom=False,
trace_type: Optional[Type[Trace]] = None,
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
"""Traces a function into a Jaxpr, given PartialVals for inputs.
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
for the outputs. The intermediate values that depend only on known inputs and
are needed to compute the output of `jaxpr` are in `consts` and are passed in
as the constvars of the `jaxpr`. The handling of the known outputs depends on
`instantiate`.
For example, given `fun` defined as follows::
def fun(ki, ui): # ki will be a known input in this example
ka = ki + 2
kb = ka + 3
return (kb, ui + ka)
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
computation that depends on unknown inputs is `ui + ka` and will be the only
computation in the body of the `jaxpr`. This computation depends on the known
intermediate value `ka`, which will be computed statically. Currently, such
constants are either embedded in the Jaxpr if they are scalars, or passed as a
constvar to `jaxpr`, and then the value of the actual constant will be in
`consts`:
When `instantiate=False` we get::
jaxpr =
{ lambda ka ; ki ui.
let c = add ui ka
in (*, c) } # known outputs are `*`
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
consts = [3] # the constant for `ka`
When `instantiate=True` we get::
jaxpr =
{ lambda ka kb ; ki ui.
let c = add ui ka
in (kb, c) } # known output are explicit
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
consts = [3, 6] # values for `ka` and `kb` constvars
"""
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
with core.new_main(trace_type, bottom=bottom) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
del main
return jaxpr, out_pvals, consts
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate: Union[bool, Sequence[bool]],
trace_type: Optional[Type[core.Trace]]
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
cell = []
def fun(*vals):
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
trace_type=trace_type)
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
return out_consts_2 + consts_2
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
# the avals, and it will never actually reach any primitives, because the `fun` above will
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
for aval, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
(out_pvs_2, jaxpr_2, num_res), = cell
assert len(jaxpr_2.constvars) == num_res
# jaxpr :: a -> b
# jaxpr_1 :: a1 -> [b1, res]
# jaxpr_2 :: res | a2 -> b2
# jaxpr_2 :: [a2, res] -> b2
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
if not unknown:
var.aval = abstract_unit
uk_out = [pv is not None for pv in out_pvs_2]
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
# See comment at top of `JaxprTrace`. This method should be reachable
# only when we stage out, and in that case we drop the custom differentiation
# rules, because we do not need them.
if not config.omnistaging_enabled:
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
# See comment in the above process_custom_jvp_call method.
if not config.omnistaging_enabled:
assert self.main.trace_type is StagingJaxprTrace
return fun.call_wrapped(*tracers)
JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
staged_out_calls = set()
class StagingJaxprTrace(JaxprTrace): pass

View File

@ -35,8 +35,7 @@ import itertools as it
import operator as op
import threading
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
Type, Union, Iterable, no_type_check, NamedTuple,
TYPE_CHECKING)
Type, Union, Iterable, NamedTuple, TYPE_CHECKING)
from absl import logging
import numpy as np
@ -46,7 +45,7 @@ from .. import core
from .. import linear_util as lu
from ..abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray
from .._src.util import (partial, unzip2, unzip3, prod, safe_map, safe_zip,
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, curry)
from ..lib import xla_bridge as xb
@ -672,48 +671,25 @@ def parallel_callable(fun: lu.WrappedFun,
else:
local_devices = None # type: ignore
if config.omnistaging_enabled:
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(in_axes, avals))
if any(s is not None for s in global_arg_shapes):
# TODO(skye): we could take this branch unconditionally if we handled
# grad of global_arg_shapes correctly.
global_sharded_avals = [
aval.update(shape=shape) if shape is not None else aval
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
global_sharded_avals = sharded_avals # type: ignore
logging.vlog(2, "sharded_avals: %s", sharded_avals)
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(in_axes, avals))
if any(s is not None for s in global_arg_shapes):
# TODO(skye): we could take this branch unconditionally if we handled
# grad of global_arg_shapes correctly.
global_sharded_avals = [
aval.update(shape=shape) if shape is not None else aval
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
else:
@lu.wrap_init
def dynamic_fun(dummy, *args):
with extend_dynamic_axis_env(axis_name, dummy._trace, global_axis_size): # type: ignore
return fun.call_wrapped(*args)
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(in_axes, avals))
pvals = [pe.PartialVal.unknown(aval) for aval in sharded_avals]
# We add a dummy first invar, to carry the trace details to `dynamic_fun`
pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env
jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
out_pvs, out_consts = unzip2(out_pvals)
global_sharded_avals = sharded_avals # type: ignore
logging.vlog(2, "sharded_avals: %s", sharded_avals)
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
out_axes = out_axes_thunk()
if config.omnistaging_enabled:
assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
else:
assert len(out_pvals) == len(out_axes), (len(out_pvals), len(out_axes))
assert all(out_axis == 0 for out_axis in out_axes)
assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
# for now raise an error
@ -724,21 +700,6 @@ def parallel_callable(fun: lu.WrappedFun,
if is_multi_host_pmap:
check_multihost_collective_allowlist(jaxpr)
if not config.omnistaging_enabled:
if all(pv is None for pv in out_pvs):
# When the output doesn't depend on the input we don't need to compile an
# XLA computation at all; we handle this as a special case so we can stage
# out multi-replica XLA computations regardless of the hardware available.
# The 'None' values here are just dummies we know will be ignored.
handlers = [
_pval_to_result_handler( # type: ignore
axis_size, None, None, None, pval, local_devices, backend_name) # type: ignore
for pval in out_pvals # type: ignore
]
results = [handler(None) for handler in handlers]
return lambda *_: results
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
num_local_replicas = axis_size * jaxpr_replicas
@ -893,31 +854,25 @@ def parallel_callable(fun: lu.WrappedFun,
if spec is not None else None
for aval, spec in safe_zip(avals, input_sharding_specs)]
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
if config.omnistaging_enabled:
nouts = len(out_sharded_avals)
if out_parts is None:
out_parts = (None,) * nouts
if local_out_parts is None:
local_out_parts = (None,) * nouts
nouts = len(out_sharded_avals)
if out_parts is None:
out_parts = (None,) * nouts
if local_out_parts is None:
local_out_parts = (None,) * nouts
local_out_avals = [get_local_aval(aval, parts, lparts)
for aval, parts, lparts
in safe_zip(out_sharded_avals, out_parts, local_out_parts)]
local_unmapped_avals = [core.unmapped_aval(axis_size, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in safe_zip(local_out_avals, out_axes)]
local_out_avals = [get_local_aval(aval, parts, lparts)
for aval, parts, lparts
in safe_zip(out_sharded_avals, out_parts, local_out_parts)]
local_unmapped_avals = [core.unmapped_aval(axis_size, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in safe_zip(local_out_avals, out_axes)]
out_specs = [_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
parts, aval, out_axis)
if aval is not core.abstract_unit else None
for parts, aval, out_axis in safe_zip(local_out_parts, local_out_avals, out_axes)]
handle_outs = avals_to_results_handler(
num_local_replicas, local_num_partitions, out_specs, local_unmapped_avals)
else:
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore
local_num_partitions,
local_out_parts, out_pvals,
compiled.local_devices(), backend)
out_specs = [_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
parts, aval, out_axis)
if aval is not core.abstract_unit else None
for parts, aval, out_axis in safe_zip(local_out_parts, local_out_avals, out_axes)]
handle_outs = avals_to_results_handler(
num_local_replicas, local_num_partitions, out_specs, local_unmapped_avals)
if hasattr(backend, "wrap_execute_replicated"):
return backend.wrap_execute_replicated(compiled, compiled.local_devices(),
@ -1390,7 +1345,6 @@ def mesh_callable(fun: lu.WrappedFun,
spmd_lowering: bool,
*local_in_untiled_avals,
tile_by_mesh_axes: bool):
assert config.omnistaging_enabled
local_mesh = mesh.local_mesh
global_axis_sizes = mesh.shape
local_axis_sizes = local_mesh.shape
@ -1585,102 +1539,6 @@ def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):
yield
@config.register_omnistaging_disabler
@no_type_check
def omnistaging_disabler() -> None:
global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
apply_parallel_primitive, parallel_pure_rules, \
_pvals_to_results_handler, _pval_to_result_handler, replicate, \
axis_index, maybe_extend_axis_env
@contextmanager
def maybe_extend_axis_env(*args, **kwargs):
yield
def _pvals_to_results_handler(
size, nrep, npart,
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],
out_pvals, devices, backend):
if out_parts is None:
out_parts = (None,) * len(out_pvals)
handlers = [
_pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)
for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore
]
def handler(out_bufs):
return [h(bufs) for h, bufs in safe_zip(handlers, out_bufs)]
return handler
def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend):
if devices:
assert all(d.host_id == xb.host_id(backend) for d in devices)
aval, const = pval
if aval is None:
if nrep is None:
nrep = axis_size
# If 'const' is a ShardedDeviceArray, it must have come from a pmap nested
# inside the one we're currently evaluating, and we should replicate
# 'const' across the total number of devices needed. We don't necessarily
# know the nested pmap's axis_size (e.g. the jaxpr for
# pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the
# axis size of the output 'const'.
# TODO: we might be doing unnecessary device transfers in the inner pmap.
if isinstance(const, ShardedDeviceArray):
nrep *= len(const)
bcast_const = (core.unit if const is core.unit
else replicate(const, axis_size, nrep, devices, backend)) # type: ignore
return lambda _: bcast_const # type: ignore
else:
if aval is not core.abstract_unit:
unsharded_aval = aval.update(shape=(axis_size,) + aval.shape)
sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, aval, 0)
indices = spec_to_indices(unsharded_aval.shape, sharding_spec)
else:
sharding_spec = indices = None
unsharded_aval = aval
return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
@contextmanager
def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size))
try:
yield
finally:
dynamic_axis_env.pop()
def unmapped_device_count(backend=None):
dynamic_axis_env = _thread_local_state.dynamic_axis_env
mapped = prod(frame.hard_size for frame in dynamic_axis_env)
unmapped, ragged = divmod(xb.device_count(backend), mapped)
assert not ragged and unmapped > 0
return unmapped
def apply_parallel_primitive(prim, *args, **params):
# This is the op-by-op version of applying a collective primitive, like a psum
# that doesn't have a data dependence on the argument of a pmap function. In
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
# look up information in the dynamic axis env.
dynamic_axis_env = _thread_local_state.dynamic_axis_env
axis_name = params.pop('axes')
axis_index_groups = params.pop('axis_index_groups')
if axis_index_groups is not None:
shape = (len(axis_index_groups[0]),)
else:
logical_size = lambda frame: frame.hard_size
if isinstance(axis_name, (list, tuple)):
shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name)
else:
shape = (logical_size(dynamic_axis_env[axis_name]),)
return parallel_pure_rules[prim](*args, shape=shape, **params)
pe.staged_out_calls.add(xla_pmap_p) # type: ignore
parallel_pure_rules = {} # type: ignore
class DynamicAxisEnvFrame(object):
__slots__ = ["name", "pmap_trace", "hard_size"]

View File

@ -86,20 +86,7 @@ def _sharded_callable(
logging.vlog(2, "in_parts: %s", in_parts)
logging.vlog(2, "local_in_parts: %s", local_in_parts)
if config.omnistaging_enabled:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun, global_abstract_args)
else:
in_pvals = [pe.PartialVal.unknown(aval) for aval in global_abstract_args]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore
instantiate=False, bottom=True) # type: ignore
# TODO(skye): add tests for equationless jaxpr cases
if not jaxpr.eqns and all(outvar.aval is core.abstract_unit
for outvar in jaxpr.outvars):
return lambda *_: [
const if pv is None else core.unit for pv, const in out_pvals
]
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args)
if xb.get_backend().platform not in ["tpu", "gpu"]:
# TODO(skye): fall back to regular jit?
@ -182,7 +169,6 @@ def _sharded_callable(
handle_args = partial(pxla.shard_args, compiled.local_devices(),
input_indices)
assert config.omnistaging_enabled
handle_outs = _avals_to_results_handler(nrep, local_nparts, # type: ignore
local_out_parts, local_out_avals)
return partial(_execute_spatially_partitioned, compiled, handle_args,
@ -467,29 +453,3 @@ def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
A new version of ``x`` with the specified sharding applied.
"""
return sharding_constraint_p.bind(x, partitions=partitions)
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _pvals_to_results_handler, _pval_to_result_handler
def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
handlers = [_pval_to_result_handler(npart, parts, out_pval)
for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore
def handler(out_bufs):
return [h(bufs) for h, bufs in zip(handlers, out_bufs)]
return handler
def _pval_to_result_handler(npart, parts, pval):
pv, const = pval
if pv is None:
raise NotImplementedError # TODO(skye): handle constant outputs
else:
if pv is not core.abstract_unit:
spec = pxla.partitioned_sharding_spec(npart, parts, pv)
indices = pxla.spec_to_indices(pv.shape, spec)
else:
spec = indices = None
return pxla.aval_to_result_handler(spec, indices, pv)

View File

@ -649,25 +649,16 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = unzip2(arg_specs)
if config.omnistaging_enabled:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
if any(isinstance(c, core.Tracer) for c in consts):
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
else:
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore
fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
if any(isinstance(c, core.Tracer) for c in consts):
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
jaxpr = apply_outfeed_rewriter(jaxpr)
nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = device.platform if device else backend
if config.omnistaging_enabled:
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
else:
out_avals = [pval.get_aval() for pval in pvals]
result_handlers = map(partial(_pval_to_result_handler, device), pvals) # type: ignore
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
@ -953,16 +944,9 @@ def lower_fun(fun, multiple_results, parallel=False, with_avals=False):
wrapped_fun = lu.wrap_init(fun, params)
if not multiple_results:
wrapped_fun = _tuple_output(wrapped_fun)
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '',
*xla_args)
else:
pvals = [pe.PartialVal.unknown(a) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
stage_out=True) # type: ignore
xla_consts = _xla_consts(c, consts)
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '',
*xla_args)
if multiple_results or any(v.aval._num_buffers > 1 for v in jaxpr.outvars):
return xops.Tuple(c, outs)
else:
@ -980,17 +964,9 @@ def _array_aval_from_xla_shape(xla_shape):
def lower_fun_initial_style(fun):
def f(c, axis_env, name_stack, avals, backend, *xla_args, **params):
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
name_stack, *xla_args)
else:
pvals = [pe.PartialVal.unknown(a) for a in avals]
jaxpr, _, consts = pe.trace_to_jaxpr(
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore
xla_consts = _xla_consts(c, consts)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack,
*xla_args)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
name_stack, *xla_args)
return xops.Tuple(c, outs)
return f
@ -1438,18 +1414,3 @@ def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend,
c, axis_env, in_nodes, name_stack, name="core_call",
backend=backend, call_jaxpr=call_jaxpr)
call_translations[core.call_p] = _call_translation_rule
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global _pval_to_result_handler
def _pval_to_result_handler(device, pval):
pv, const = pval
if pv is None:
const = _device_put_impl(const, device) if device else const
return lambda _: const
else:
return aval_to_result_handler(device, pv)
pe.staged_out_calls.add(xla_call_p) # type: ignore

View File

@ -283,7 +283,6 @@ from jax._src.lax.lax import (
tanh,
tanh_p,
tie_in,
tie_in_p,
top_k,
top_k_p,
transpose,

View File

@ -485,9 +485,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_omnistaging(self):
# See https://github.com/google/jax/issues/5206
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
key_list = [None]
def init():
@ -1551,10 +1548,7 @@ class APITest(jtu.JaxTestCase):
def f():
return jnp.zeros((3, 4))
if config.omnistaging_enabled:
xla_comp = api.xla_computation(f)()
else:
xla_comp = api.xla_computation(f, instantiate_const_outputs=True)()
xla_comp = api.xla_computation(f)()
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
self.assertEqual(out_shape.dimensions(), (3, 4))
@ -1608,8 +1602,6 @@ class APITest(jtu.JaxTestCase):
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)
def test_xla_computation_psum_constant(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test requires omnistaging")
f = lambda: jax.lax.psum(1, "i")
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
@ -1856,11 +1848,10 @@ class APITest(jtu.JaxTestCase):
api.pmap(f, 'i')(x, x)
# With in_axes and out_axes
if config.omnistaging_enabled:
for x_in, y_in, x_out, y_out in it.product(*((0, 1, 2) for _ in range(4))):
with jtu.assert_num_jit_and_pmap_compilations(1):
for _ in range(2):
api.pmap(f, 'i', in_axes=(x_in, y_in), out_axes=(x_out, y_out))(x, x)
for x_in, y_in, x_out, y_out in it.product(*((0, 1, 2) for _ in range(4))):
with jtu.assert_num_jit_and_pmap_compilations(1):
for _ in range(2):
api.pmap(f, 'i', in_axes=(x_in, y_in), out_axes=(x_out, y_out))(x, x)
# Forward-mode AD on the outside
with jtu.assert_num_jit_and_pmap_compilations(1):
@ -1962,7 +1953,7 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
self.assertLen(inner_jaxpr.eqns, 2 if config.omnistaging_enabled else 3)
self.assertLen(inner_jaxpr.eqns, 2)
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
@ -2043,9 +2034,6 @@ class APITest(jtu.JaxTestCase):
api.jit(func1)(2.)
def test_escaped_tracer_omnistaging(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
count = 1
@jit
@ -2067,9 +2055,6 @@ class APITest(jtu.JaxTestCase):
g()
def test_escaped_tracer_omnistaging_top_trace(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
count = 1
def f(_, __):
@ -2117,16 +2102,6 @@ class APITest(jtu.JaxTestCase):
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
def test_omnistaging_flag(self):
if FLAGS.jax_omnistaging:
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
self.assertLen(jaxpr.jaxpr.eqns, 1)
else:
# omnistaging can be enabled programmatically without setting the flag,
# but that shouldn't happen in tests
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
self.assertLen(jaxpr.jaxpr.eqns, 0)
def test_eval_context(self):
@jit
def f():
@ -2136,9 +2111,6 @@ class APITest(jtu.JaxTestCase):
f() # doesn't crash
def test_concrete_error_because_arg(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
@jax.jit
def f(x, y):
if x > y:
@ -2151,9 +2123,6 @@ class APITest(jtu.JaxTestCase):
f(1, 2)
def test_concrete_error_because_const(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
@jax.jit
def f():
assert jnp.add(1, 1) > 0
@ -2165,18 +2134,12 @@ class APITest(jtu.JaxTestCase):
def test_xla_computation_zeros_doesnt_device_put(self):
raise unittest.SkipTest("broken test") # TODO(mattjj): fix
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
with jtu.count_device_put() as count:
api.xla_computation(lambda: jnp.zeros(3))()
self.assertEqual(count[0], 0)
def test_join_concrete_arrays_with_omnistaging(self):
# https://github.com/google/jax/issues/4622
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
x = jnp.array([1., 2., 3.])
y = jnp.array([1., 2., 4.])
@ -2223,9 +2186,6 @@ class APITest(jtu.JaxTestCase):
self.assertIsInstance(x, jax.interpreters.xla.Token)
def test_leak_checker_catches_a_jit_leak(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
lst = []
@ -2238,9 +2198,6 @@ class APITest(jtu.JaxTestCase):
f(3)
def test_leak_checker_catches_a_pmap_leak(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
lst = []
@ -2253,9 +2210,6 @@ class APITest(jtu.JaxTestCase):
f(np.ones(1))
def test_leak_checker_catches_a_grad_leak(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
lst = []
@ -2267,9 +2221,6 @@ class APITest(jtu.JaxTestCase):
api.grad(f)(3.)
def test_leak_checker_avoids_false_positives(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
@jit
def f(x):
@ -2285,9 +2236,6 @@ class APITest(jtu.JaxTestCase):
api.vmap(f)(np.ones((1, 1))) # doesn't crash
def test_leak_checker_catches_a_scan_leak(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
lst = []
@ -2297,17 +2245,11 @@ class APITest(jtu.JaxTestCase):
lax.scan(to_scan, 1., np.arange(3.))
def test_leak_checker_avoids_false_positives_scan(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
to_scan = lambda c, x: (jnp.sin(c), None)
lax.scan(to_scan, 1., np.arange(3.)) # doesn't crash
def test_leak_checker_avoids_false_positives_scan_jvp(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
to_scan = lambda c, x: (c, None)
@ -2316,9 +2258,6 @@ class APITest(jtu.JaxTestCase):
api.jvp(f, (3.,), (1.,)) # doesn't crash
def test_leak_checker_avoids_false_positives_scan_vmap(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
to_scan = lambda c, _: (1., None)
@ -2328,9 +2267,6 @@ class APITest(jtu.JaxTestCase):
f(np.arange(5.)) # doesn't crash
def test_leak_checker_avoids_false_positives_scan_vmap_2(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
to_scan = lambda c, _: (c, None)
@ -2340,9 +2276,6 @@ class APITest(jtu.JaxTestCase):
f(np.arange(5.)) # doesn't crash
def test_leak_checker_catches_a_sublevel_leak(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
with jax.checking_leaks():
@jit
def f(x):
@ -2804,25 +2737,8 @@ class RematTest(jtu.JaxTestCase):
jax.grad(scan_bug)(1.0) # doesn't crash
def test_remat_jit_static_argnum(self):
# https://github.com/google/jax/issues/2833
if config.omnistaging_enabled:
raise unittest.SkipTest("test only works without omnistaging") # see next test
def f(a_bool, y):
if a_bool:
return y + 1
else:
return y
api.jit(api.remat(f, concrete=True), static_argnums=0)(True, 1) # no crash
def test_remat_jit_static_argnum_omnistaging(self):
# https://github.com/google/jax/issues/2833
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging") # see previous test
def named_call(f):
def named_f(*args):
f_ = lu.wrap_init(lambda: (f(*args),))
@ -2894,9 +2810,6 @@ class RematTest(jtu.JaxTestCase):
def test_escaped_tracer_remat(self):
# b/169779185
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f():
seq = [jnp.zeros([])]
def g():
@ -2926,18 +2839,11 @@ class JaxprTest(jtu.JaxTestCase):
def fun(x):
return (x, 1., np.zeros(1))
if config.omnistaging_enabled:
expected = """
{ lambda a ; b.
let
in (b, 1.0, a) }
"""
else:
expected = """
{ lambda b ; a.
let
in (a, 1.0, b) }
"""
expected = """
{ lambda a ; b.
let
in (b, 1.0, a) }
"""
jaxpr = api.make_jaxpr(fun)(0.)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
@ -2949,43 +2855,21 @@ class JaxprTest(jtu.JaxTestCase):
lambda xt: xt + x,
x + 2.,
lambda xf: xf - x)
if config.omnistaging_enabled:
expected = """
{ lambda ; a.
let b = ge a 0.0
c = add a 1.0
d = add a 2.0
e = convert_element_type[ new_dtype=int32
weak_type=False ] b
f = cond[ branches=( { lambda ; e_ a b c.
let d = sub c a
in (d,) }
{ lambda ; a f_ b c.
let d = add b a
in (d,) } )
linear=(False, False, False, False) ] e a a c d
in (f,) }
"""
else:
expected = """
{ lambda ; a.
let b = ge a 0.0
c = convert_element_type[ new_dtype=int32
weak_type=False ] b
d = convert_element_type[ new_dtype=float32
weak_type=False ] a
e = convert_element_type[ new_dtype=float32
weak_type=False ] a
f = add a 1.0
g = add a 2.0
h = cond[ branches=( { lambda ; e_ c a b.
let d = sub b c
in (d,) }
{ lambda ; c f_ a b.
let d = add a c
in (d,) } )
linear=(False, False, False, False) ] c d e f g
in (h,) }
expected = """
{ lambda ; a.
let b = ge a 0.0
c = add a 1.0
d = add a 2.0
e = convert_element_type[ new_dtype=int32
weak_type=False ] b
f = cond[ branches=( { lambda ; e_ a b c.
let d = sub c a
in (d,) }
{ lambda ; a f_ b c.
let d = add b a
in (d,) } )
linear=(False, False, False, False) ] e a a c d
in (f,) }
"""
jaxpr = api.make_jaxpr(f)(3.)
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
@ -3005,18 +2889,12 @@ class JaxprTest(jtu.JaxTestCase):
self.assertEqual(shape_tree, expected)
def test_make_jaxpr_axis_env(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
return x - lax.psum(x, 'i')
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
self.assertIn('psum', str(jaxpr))
def test_make_jaxpr_named(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
return x - lax.psum(x, 'i')
@ -3276,9 +3154,6 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_closed_over_tracers_error_message(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
@api.custom_jvp
def g(y):
@ -3326,9 +3201,6 @@ class CustomJVPTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def test_nondiff_arg_hiding_jvp_tracer(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def f(x):
@partial(api.custom_jvp, nondiff_argnums=(0,))
def g(h, x):
@ -3600,9 +3472,6 @@ class CustomJVPTest(jtu.JaxTestCase):
def test_nondiff_argnums_vmap_tracer(self):
# https://github.com/google/jax/issues/3964
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
@partial(jax.custom_jvp, nondiff_argnums=(0, 2))
def sample(shape, param, seed):
return jax.random.uniform(key=seed, shape=shape, minval=param)
@ -3622,9 +3491,6 @@ class CustomJVPTest(jtu.JaxTestCase):
(1.,), (1.,))
def test_fun_with_nested_calls_2(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
def call(f, *args):
f = api.custom_jvp(f)
f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents)))
@ -3647,9 +3513,6 @@ class CustomJVPTest(jtu.JaxTestCase):
api.vmap(fun_with_nested_calls_2)(jnp.arange(3.))
def test_closure_with_vmap(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test only works with omnistaging")
# https://github.com/google/jax/issues/3822
alpha = np.float32(2.)
@ -4788,9 +4651,6 @@ class InvertibleADTest(jtu.JaxTestCase):
@jtu.ignore_warning(message="Values that an @invertible function closes")
def test_invertible_basic(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("Test requires omnistaging")
def f(x):
return (jnp.exp(x) * 4) * x

View File

@ -15,7 +15,6 @@
import itertools as it
import numpy as np
from unittest import skipIf
from absl.testing import absltest
from absl.testing import parameterized
@ -976,8 +975,6 @@ class BatchingTest(jtu.JaxTestCase):
(lax.pmin, jnp.min)]
for vmap_names in [('i',), ('i', 'j'), ('j', 'i')]
for collective_names in it.permutations(vmap_names))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testCommAssocCollective(self, collective, bulk_op, vmap_names, collective_names):
x = jnp.arange(3 * 4 * 5, dtype=jnp.float32).reshape((3, 4, 5))
@ -993,8 +990,6 @@ class BatchingTest(jtu.JaxTestCase):
if collective is lax.psum:
jtu.check_grads(f, (x,), 2, eps=1)
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testPPermute(self):
nelem = 10
ntests = 10
@ -1013,8 +1008,6 @@ class BatchingTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAllToAll(self, vmap_axis, split_axis, concat_axis):
shape = (4, 4, 4, 4)
x = np.arange(np.prod(shape)).reshape(shape)
@ -1029,8 +1022,6 @@ class BatchingTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3)))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis):
shape = (4, 4, 4)
x = np.arange(np.prod(shape)).reshape(shape)
@ -1084,16 +1075,12 @@ class BatchingTest(jtu.JaxTestCase):
jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))(
np.arange(5), 7)
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAxisIndex(self):
x = np.arange(10)
self.assertAllClose(
vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
x - np.arange(x.shape[0]))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testCollectivePdot(self):
def f(x, y):
return lax.pdot(x, y, 'i')
@ -1110,8 +1097,6 @@ class BatchingTest(jtu.JaxTestCase):
z = vmap(f, axis_name='i', in_axes=(0, 0), out_axes=None)(x, y)
self.assertAllClose(z, jnp.dot(x.T, y))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testCollectivePdotBatching(self):
def f(x, y):
return lax.pdot(x, y, 'i')
@ -1122,8 +1107,6 @@ class BatchingTest(jtu.JaxTestCase):
zs = vmap(vmap(f, axis_name='i', in_axes=(1, 0), out_axes=None))(xs, ys)
self.assertAllClose(zs, jnp.einsum('nij,njk->nik', xs, ys))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testPdotJvp(self):
def f(x, y):
return lax.pdot(x, y, 'i')
@ -1139,8 +1122,6 @@ class BatchingTest(jtu.JaxTestCase):
self.assertAllClose(z, jnp.dot(x, y))
self.assertAllClose(z_dot, jnp.dot(x_dot, y) + jnp.dot(x, y_dot))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testPdotVjp(self):
def f(x, y):
return lax.pdot(x, y, 'i')
@ -1165,8 +1146,6 @@ class BatchingTest(jtu.JaxTestCase):
y = vmap(f)(a=jnp.array([1]), b=jnp.array([2])) # doesn't work
self.assertAllClose(x, y)
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAllGatherToUnmapped(self):
def f(x):
return lax.all_gather(x, axis_name='i')
@ -1175,8 +1154,6 @@ class BatchingTest(jtu.JaxTestCase):
# Original mapped axis becomes first axis of unmapped return value.
self.assertAllClose(vmap(f, axis_name='i', in_axes=1, out_axes=None)(x), x.T)
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testBatchedAllGather(self):
def f(x):
return lax.all_gather(x, axis_name='i')
@ -1188,8 +1165,6 @@ class BatchingTest(jtu.JaxTestCase):
res = vmap(vmap(f, axis_name='j'), axis_name='i', out_axes=None)(x)
self.assertAllClose(res, x.T)
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAllGatherVjp(self):
def f(x):
return lax.all_gather(x, axis_name='i')
@ -1201,8 +1176,6 @@ class BatchingTest(jtu.JaxTestCase):
x_bar, = vmap(lambda x, y_bar: vjp(f, x)[1](y_bar), axis_name='i')(x, y_bar)
self.assertAllClose(x_bar, np.sum(y_bar, axis=0))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAllGatherOfConst(self):
def f(x):
a = lax.all_gather(jnp.ones_like(x), axis_name='i')
@ -1226,8 +1199,6 @@ class BatchingTest(jtu.JaxTestCase):
for shape in [(7,), (5, 8)]
for axis in range(len(shape))
)
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)

View File

@ -112,8 +112,6 @@ class DebugNaNsTest(jtu.JaxTestCase):
@jtu.ignore_warning(message=".*is an experimental.*")
def testXmap(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
f = jax.experimental.maps.xmap(
lambda x: 0. / x,

View File

@ -32,7 +32,6 @@ zip, unsafe_zip = safe_zip, zip
jax.config.parse_flags_with_absl()
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
class DJaxTests(jtu.JaxTestCase):
def test_identity_typechecks(self):
@ -79,7 +78,6 @@ class DJaxTests(jtu.JaxTestCase):
djax.typecheck_jaxpr(jaxpr)
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
@skipIf(jax.config.x64_enabled, "only 32bit for now")
class DJaxXLATests(jtu.JaxTestCase):
@ -123,7 +121,6 @@ class DJaxXLATests(jtu.JaxTestCase):
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
@skipIf(jax.config.x64_enabled, "only 32bit for now")
class DJaxADTests(jtu.JaxTestCase):
@ -160,7 +157,6 @@ class DJaxADTests(jtu.JaxTestCase):
self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
@skipIf(jax.config.x64_enabled, "only 32bit for now")
class DJaxBatchingTests(jtu.JaxTestCase):

View File

@ -20,7 +20,7 @@ import re
import threading
import time
from typing import Callable, Optional, Sequence
from unittest import SkipTest, skipIf
from unittest import SkipTest
from absl.testing import absltest
from absl.testing import parameterized
@ -262,8 +262,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
9.00 )""", testing_stream.output)
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_with_result_no_arg(self):
def tap_func(arg, transforms):
testing_stream.write(f"called tap_func with {arg}")
@ -278,8 +276,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
testing_stream.output)
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_result_unused(self):
def tap_func(arg, transforms):
testing_stream.write(f"called tap_func with {arg}")
@ -415,8 +411,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
11""", testing_stream.output)
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_jit_result_unused(self):
"""We can id_print even if we don't use the result."""
@ -593,8 +587,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_cond(self, with_jit=False):
"""A conditional"""
@ -630,8 +622,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
dict(testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_while_cond(self, with_jit=False):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
@ -713,8 +703,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
testcase_name=f"_with_jit_{with_jit}",
with_jit=with_jit)
for with_jit in [True, False]))
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_scan_cond(self, with_jit=True):
def func(x):
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
@ -864,8 +852,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
4""", testing_stream.output)
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_jvp(self):
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
@ -882,9 +868,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
testing_stream.reset()
def test_tap_grad_primal_unused(self):
if not config.omnistaging_enabled:
raise SkipTest("Test requires omnistaging")
# The output of id_print is not needed for backwards pass
def func(x):
return 2. * hcb.id_print(x * 3., what="x * 3",
@ -948,9 +931,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
testing_stream.reset()
def test_tap_grad_grad(self):
if not config.omnistaging_enabled:
raise SkipTest("Test requires omnistaging")
def func(x):
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
return x * (y * 3.)
@ -976,8 +956,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
2.00""", testing_stream.output)
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_grad_pytree(self):
def func(x):
x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair",
@ -1000,8 +978,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
0.00 )""", testing_stream.output)
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_jvp_float0(self):
def f(x, yint):
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint))
@ -1010,8 +986,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
res = api.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0)))
self.assertAllClose((6., 0.6), res)
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_grad_float0(self):
def func(x, yint):
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream)
@ -1147,8 +1121,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
[2 2 2 3 4]""", testing_stream.output)
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_transforms(self):
def power(x, n):
x, n = hcb.id_print((x, n), output_stream=testing_stream)
@ -1315,8 +1287,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
testing_stream.reset()
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_jvp_pmap_vmap(self):
# A matrix M[ijk] = i * 100 + j * 10 * k
nr_devices = len(devices())
@ -1537,8 +1507,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
what: ct_b
1.""", testing_stream.output)
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_tap_mask(self):
@partial(api.mask, in_shapes=['n'], out_shape='')
@ -1714,9 +1682,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_tap_named_call(self):
if not config.omnistaging_enabled:
raise SkipTest("Test requires omnistaging")
def tap_scalar(init, do_print=False):
@partial(api.named_call, name="step")
def step(acc, step_nr):
@ -1775,8 +1740,6 @@ class HostCallbackCallTest(jtu.JaxTestCase):
res_inside = fun(2, use_outside=False)
self.assertAllClose(res_inside, fun(2, use_outside=True))
@skipIf(not config.omnistaging_enabled,
"test works only with omnistaging enabled")
def test_call_no_result(self):
def f_outside(arg):
self.call_log_testing_stream(lambda x: None, arg,

View File

@ -76,7 +76,7 @@ class InfeedTest(jtu.JaxTestCase):
y, token = lax.infeed(
token, shape=jax.ShapedArray((3, 4), jnp.float32))
token = lax.outfeed(token, y + np.float32(1))
return x - 1 if config.omnistaging_enabled else lax.tie_in(token, x - 1)
return x - 1
x = np.float32(7.5)
y = np.random.randn(3, 4).astype(np.float32)
@ -101,7 +101,7 @@ class InfeedTest(jtu.JaxTestCase):
def f(n):
token = lax.create_token(n)
token = lax.fori_loop(0, n, doubler, token)
return n if config.omnistaging_enabled else lax.tie_in(token, n)
return n
device = jax.local_devices()[0]
n = 10

View File

@ -18,7 +18,7 @@ from functools import partial
import itertools
import operator
import re
from unittest import SkipTest, skipIf
from unittest import SkipTest
import textwrap
from absl.testing import absltest
@ -335,7 +335,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = np.array([4, 3, 4, 3])
self.assertAllClose(ans, expected, check_dtypes=False)
@skipIf(not config.omnistaging_enabled, "test only works with omnistaging")
def testWhileLoopAxisIndexBatched(self):
def fun(x):
return lax.while_loop(lambda x: x < lax.axis_index('i'), lambda x: x + 2, x)
@ -996,7 +995,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
@skipIf(not config.omnistaging_enabled, "test only works with omnistaging")
def testCondGradVmapNan(self):
eps = 1e-3
@ -2564,8 +2562,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
*_, ext_res = vjp_fun.args[0].args[0]
self.assertIsInstance(ext_res, xla.DeviceArray)
@skipIf(not config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def test_scan_vmap_collectives(self):
def scan_f(state, x):
s = lax.psum(state, 'i') * x

View File

@ -2291,8 +2291,6 @@ class LaxTest(jtu.JaxTestCase):
lambda x, y: dict(x=x['x'] + y['x']), [0])
def test_select_jvp_complexity(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
jaxpr = jax.make_jaxpr(lambda x: jax.jvp(lambda x: lax.select(True, x, x),
(x,), (1.,)))(1.)
self.assertLen(jaxpr.jaxpr.eqns, 2)
@ -2510,8 +2508,6 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
self.assertEqual(out, expected)
def test_abstract_eval_collective(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
with core.extend_axis_env('i', 10, None):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})

View File

@ -66,11 +66,10 @@ class MetadataTest(jtu.JaxTestCase):
self.assertRegex(hlo, 'op_type="cos"')
self.assertRegex(hlo, 'op_type="mul"')
# TODO(mattjj,jekbradbury): update these tests post-omnistaging
if not config.omnistaging_enabled:
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\('
'jvp\\(foo\\)\\)\\)/mul"')
# self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
# self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
# self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\('
# 'jvp\\(foo\\)\\)\\)/mul"')
def test_cond_metadata(self):
def true_fun(x):

View File

@ -120,10 +120,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
x2_uncommitted = jnp.array([2, 3])
z1, z2, z3 = jax.jit(lambda x, y: (y, 1, x))(x_uncommitted, x2_uncommitted)
self.assert_uncommitted_to_device(z1, devices[0])
if config.omnistaging_enabled:
self.assert_uncommitted_to_device(z2, devices[0])
else:
self.assertIs(z2, 1)
self.assert_uncommitted_to_device(z2, devices[0])
self.assert_uncommitted_to_device(z3, devices[0])

View File

@ -20,7 +20,7 @@ import gc
import os
from random import shuffle
from typing import Optional, cast
from unittest import SkipTest, skipIf
from unittest import SkipTest
import warnings
import weakref
@ -51,17 +51,11 @@ prev_xla_flags = None
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
def all_bdims(*shapes, pmap):
if pmap and not config.omnistaging_enabled:
bdims = ((None, 0) for shape in shapes)
else:
bdims = (it.chain([cast(Optional[int], None)],
range(len(shape) + 1))
for shape in shapes)
bdims = (it.chain([cast(Optional[int], None)], range(len(shape) + 1))
for shape in shapes)
return (t for t in it.product(*bdims) if not all(e is None for e in t))
def out_bdims(shape, pmap):
if pmap and not config.omnistaging_enabled:
return (0,)
return (d[0] for d in all_bdims(shape, pmap=pmap) if d[0] is not None)
@ -157,9 +151,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testTrees(self):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
def protate(x, axis_name):
n = lax.psum(1, axis_name)
@ -217,9 +208,6 @@ class PmapTest(jtu.JaxTestCase):
for split_axis, concat_axis in it.product(range(2), range(2)))
@ignore_slow_all_to_all_warning()
def testAllToAll(self, split_axis, concat_axis):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
pmap_in_axis = 0
shape = (xla_bridge.device_count(),) * 3
x = np.arange(np.prod(shape)).reshape(shape)
@ -240,9 +228,6 @@ class PmapTest(jtu.JaxTestCase):
for split_axis, concat_axis in it.product(range(2), range(2)))
@ignore_slow_all_to_all_warning()
def testAllToAllSplitAxis(self, split_axis, concat_axis):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
if xla_bridge.device_count() < 4:
raise SkipTest("test requires at least four devices")
pmap_in_axis = 0
@ -620,8 +605,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testGradOfGather(self):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
@partial(pmap, axis_name='i')
def f(x):
return lax.all_gather(x, axis_name='i')
@ -893,29 +876,18 @@ class PmapTest(jtu.JaxTestCase):
device_count = xla_bridge.device_count()
f = pmap(lambda x: 3)
x = jnp.arange(device_count + 1)
if config.omnistaging_enabled:
self.assertRaisesRegex(
ValueError,
(r"compiling computation that requires \d+ logical devices, "
r"but only \d+ XLA devices are available .*"),
lambda: f(x))
self.assertRaisesRegex(
ValueError,
(r"compiling computation that requires \d+ logical devices, "
r"but only \d+ XLA devices are available .*"),
lambda: f(x))
# TODO(mattjj): test error message with explicit devices
# f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
# x = jnp.arange(2)
# self.assertRaisesRegex(
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
# r"local devices are available.", lambda: f(x))
else:
self.assertRaisesRegex(
ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
r"local devices are available.", lambda: f(x))
f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
x = jnp.arange(2)
self.assertRaisesRegex(
ValueError, "Cannot replicate across 2 replicas because only 1 "
"local devices are available.", lambda: f(x))
# TODO(mattjj): test error message with explicit devices
# f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
# x = jnp.arange(2)
# self.assertRaisesRegex(
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
# r"local devices are available.", lambda: f(x))
def testNestedPmapConstant(self):
if xla_bridge.device_count() == 1:
@ -967,35 +939,22 @@ class PmapTest(jtu.JaxTestCase):
f = pmap(pmap(lambda x: 3))
shape = (2, xla_bridge.device_count() // 2 + 1, 3)
x = jnp.arange(prod(shape)).reshape(shape)
if config.omnistaging_enabled:
self.assertRaisesRegex(
ValueError,
(r"compiling computation that requires \d+ logical devices, "
r"but only \d+ XLA devices are available .*"),
lambda: f(x))
self.assertRaisesRegex(
ValueError,
(r"compiling computation that requires \d+ logical devices, "
r"but only \d+ XLA devices are available .*"),
lambda: f(x))
# TODO(mattjj): check error message with explicit devices
# if xla_bridge.device_count() > 1:
# f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
# shape = (2, xla_bridge.device_count() // 2, 3)
# x = jnp.arange(prod(shape)).reshape(shape)
# self.assertRaisesRegex(
# ValueError,
# (r"compiling computation that requires \d+ replicas, "
# r"but only \d+ XLA devices are available"),
# lambda: f(x))
else:
self.assertRaisesRegex(
ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
r"local devices are available.", lambda: f(x))
if xla_bridge.device_count() > 1:
f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
shape = (2, xla_bridge.device_count() // 2, 3)
x = jnp.arange(prod(shape)).reshape(shape)
self.assertRaisesRegex(
ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
r"local devices are available.", lambda: f(x))
# TODO(mattjj): check error message with explicit devices
# if xla_bridge.device_count() > 1:
# f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
# shape = (2, xla_bridge.device_count() // 2, 3)
# x = jnp.arange(prod(shape)).reshape(shape)
# self.assertRaisesRegex(
# ValueError,
# (r"compiling computation that requires \d+ replicas, "
# r"but only \d+ XLA devices are available"),
# lambda: f(x))
def testCollectiveConstant(self):
device_count = xla_bridge.device_count()
@ -1049,8 +1008,6 @@ class PmapTest(jtu.JaxTestCase):
self.assertAllClose(f('i')(x), expected_j.T, check_dtypes=False)
def testAxisIndexNd(self):
if not config.omnistaging_enabled:
self.skipTest("axis_index doesn't work without omnistaging")
device_count = xla_bridge.device_count()
if device_count < 4:
raise SkipTest("test requires at least four devices")
@ -1149,8 +1106,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testPswapaxes(self):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
device_count = xla_bridge.device_count()
shape = (device_count, 3, device_count, 5)
x = np.arange(prod(shape)).reshape(shape)
@ -1161,8 +1116,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testGradOfPswapaxes(self):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
device_count = xla_bridge.device_count()
shape = (device_count, 1, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
@ -1179,8 +1132,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testAllToAllReplicaGroups(self):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
# If num_devices = 4, these would be the inputs/outputs:
# input = [[0, 1], [2, 3], [4, 5], [6, 7]]
# axis_index_groups = [[0, 1], [2, 3]]
@ -1210,8 +1161,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testGradOfAllToAllReplicaGroups(self):
if not config.omnistaging_enabled:
self.skipTest("all_to_all doesn't work without omnistaging")
device_count = xla_bridge.device_count()
if device_count % 2 != 0:
raise SkipTest('test requires an even number of devices')
@ -1254,7 +1203,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapBatchMatmul(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
@ -1264,7 +1212,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapBatchMatmulJit(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
@ -1274,7 +1221,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapPsumConstant(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
def f(_):
return lax.psum(1, 'i')
@ -1284,7 +1230,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapPsum(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
def f(x):
return x / lax.psum(x, 'i')
@ -1294,7 +1239,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapAxisIndex(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
def f(x):
return x * lax.axis_index('i')
@ -1304,7 +1248,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapOfJit(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
n = 4 * xla_bridge.device_count()
def f(x):
return 3 * x
@ -1342,7 +1285,6 @@ class PmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapDevicePersistence(self):
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
device_count = xla_bridge.device_count()
shape = (2 * 2 * device_count, 2, 3)
@ -1651,8 +1593,6 @@ class PmapTest(jtu.JaxTestCase):
mapped_fn(indices) # doesn't crash
@ignore_xmap_warning()
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testPdotBasic(self):
num_devices = jax.device_count()
@ -1679,8 +1619,6 @@ class PmapTest(jtu.JaxTestCase):
for axis in range(len(shape))
)
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
if not config.omnistaging_enabled:
self.skipTest("test requires omnistaging")
if xla_bridge.device_count() < shape[axis]:
raise SkipTest(f"test requires at least {shape[axis]} devices")
if (jtu.device_under_test() == 'cpu' and
@ -1757,8 +1695,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
{"testcase_name": "_collective={}".format(collective.__name__).replace(" ", ""),
"collective": collective}
for collective in [lax.psum, lax.pmean, lax.pmax, lax.pmin])
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testCollectivesWithVmap(self, collective):
def f(map1, map2):
@partial(map1, axis_name='i')
@ -1775,8 +1711,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y)
self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y)
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testPPermuteWithVmap(self):
perm = [(0, 1), (1, 0)]
@ -1796,8 +1730,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
@ignore_slow_all_to_all_warning()
def testAllToAllInVmap(self, split_axis, concat_axis, vmap_axis):
def f(x):
@ -1867,8 +1799,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(3), range(3)))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
@ignore_slow_all_to_all_warning()
def testAllToAllVsVmap(self, split_axis, concat_axis):
def f(x):
@ -1884,8 +1814,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
"axes": axes, "split_axis": split_axis, "concat_axis": concat_axis}
for axes, split_axis, concat_axis
in it.product([('i', 'j'), ('j', 'i')], range(3), range(3)))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
@ignore_slow_all_to_all_warning()
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
raise SkipTest("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
@ -1900,8 +1828,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x),
vmap(vmap(f, axis_name='j'), axis_name='i')(x))
@skipIf(not jax.config.omnistaging_enabled,
"vmap collectives only supported when omnistaging is enabled")
def testAllGatherWithVmap(self):
def f(map2):
@partial(jax.pmap, axis_name='i')
@ -2062,7 +1988,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
expected = np.sin(x + 3.)
self.assertAllClose(ans, expected, check_dtypes=False)
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
def testPmapInAxesBasic(self):
@partial(pmap, in_axes=(1, 2))
def f(x, y):
@ -2075,7 +2000,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
self.assertAllClose(f(x, y),
jnp.sin(x.transpose((1, 0, 2)) + y.transpose((2, 0, 1))))
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
def testPmapInAxesGrad(self):
def f(x, y, z):
return jnp.sin(x + y + z)
@ -2096,7 +2020,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
self.assertAllClose(jax.grad(lambda args: fp(*args).sum())((x, y, z)),
jax.grad(lambda args: fv(*args).sum())((x, y, z)))
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
def testPmapOutAxesBasic(self):
@partial(pmap, in_axes=(1, None), out_axes=(2, None))
def f(x, y):
@ -2109,7 +2032,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
self.assertAllClose(f(x, y),
(jnp.sin(x.transpose((1, 0, 2)) + y).transpose((1, 2, 0)), y * 2))
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{in_axes}_{out_axes}",
"in_axes": in_axes, "out_axes": out_axes}
@ -2128,7 +2050,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
jtu.check_grads(pmap(f, in_axes=in_axes, out_axes=out_axes), args,
order=2, atol=2e-2, rtol=2e-2, eps=1e-3)
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
def testPmapPostProcess(self):
def mk_case(map_fun):
def f(x, y):

View File

@ -752,8 +752,6 @@ class LaxRandomTest(jtu.JaxTestCase):
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
def testNoOpByOpUnderHash(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
def fail(*args, **kwargs): assert False
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
try:
@ -916,8 +914,6 @@ class LaxRandomTest(jtu.JaxTestCase):
random.choice(key, 5, 2, replace=True)
def test_eval_shape_big_random_array(self):
if not config.omnistaging_enabled:
raise SkipTest("after deleting lazy constants, requires omnistaging")
def f(x):
return random.normal(random.PRNGKey(x), (int(1e12),))
with jax.enable_checks(False): # check_jaxpr will materialize array
@ -977,8 +973,6 @@ class LaxRandomTest(jtu.JaxTestCase):
api.jit(random.PRNGKey)(seed)
def test_random_split_doesnt_device_put_during_tracing(self):
if not config.omnistaging_enabled:
raise SkipTest("test requires omnistaging")
key = random.PRNGKey(1).block_until_ready()
with jtu.count_device_put() as count:
api.jit(random.split)(key)

View File

@ -47,8 +47,6 @@ class X64ContextTests(jtu.JaxTestCase):
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in ["python", "cpp", None]))
def test_make_array(self, jit):
if jit == "cpp" and not config.omnistaging_enabled:
self.skipTest("cpp_jit requires omnistaging")
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
dtype_start = func().dtype
with enable_x64():
@ -64,9 +62,6 @@ class X64ContextTests(jtu.JaxTestCase):
"enable_or_disable": f
} for jit in ["python", "cpp", None] for f in [enable_x64, disable_x64]))
def test_correctly_capture_default(self, jit, enable_or_disable):
if jit == "cpp" and not config.omnistaging_enabled:
self.skipTest("cpp_jit requires omnistaging")
# The fact we defined a jitted function with a block with a different value
# of `config.enable_x64` has no impact on the output.
with enable_or_disable():
@ -87,8 +82,6 @@ class X64ContextTests(jtu.JaxTestCase):
def test_near_singular_inverse(self, jit):
if jtu.device_under_test() == "tpu":
self.skipTest("64-bit inverse not available on TPU")
if jit == "cpp" and not config.omnistaging_enabled:
self.skipTest("cpp_jit requires omnistaging")
@partial(_maybe_jit, jit, static_argnums=1)
def near_singular_inverse(key, N, eps):
X = random.uniform(key, (N, N))
@ -111,8 +104,6 @@ class X64ContextTests(jtu.JaxTestCase):
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in ["python", "cpp", None]))
def test_while_loop(self, jit):
if jit == "cpp" and not config.omnistaging_enabled:
self.skipTest("cpp_jit requires omnistaging")
@partial(_maybe_jit, jit)
def count_to(N):
return lax.while_loop(lambda x: x < N, lambda x: x + 1.0, 0.0)

View File

@ -235,10 +235,7 @@ def schedules(sizes: Dict[str, int]
class XMapTestCase(jtu.BufferDonationTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
pass
# A mixin that enables SPMD lowering tests
@ -689,9 +686,6 @@ class NamedNNTest(XMapTestCase):
class NewPrimitiveTest(XMapTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
def testGatherPositional(self):
x = jnp.arange(27).reshape((9, 3))
@ -1031,11 +1025,6 @@ class PDotTests(XMapTestCase):
class XMapErrorTest(jtu.JaxTestCase):
def setUp(self):
if not config.omnistaging_enabled:
raise SkipTest("xmap requires omnistaging")
super().setUp()
@ignore_xmap_warning()
@with_mesh([('x', 2)])
def testRepeatedAxisResource(self):