mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6275 from google:omnistaging-forever
PiperOrigin-RevId: 365681256
This commit is contained in:
commit
6d0b8327c7
@ -322,6 +322,7 @@ class TracerArrayConversionError(JAXTypeError):
|
||||
abstract values, you may want to read :ref:`faq-different-kinds-of-jax-values`.
|
||||
"""
|
||||
def __init__(self, tracer: "core.Tracer"):
|
||||
# TODO(mattjj, jakevdp): use tracer._origin_msg() here
|
||||
super().__init__(
|
||||
"The numpy.ndarray conversion method __array__() was called on "
|
||||
f"the JAX Tracer object {tracer}")
|
||||
|
@ -45,7 +45,7 @@ from jax.interpreters import masking
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip,
|
||||
from jax._src.util import (partial, unzip2, unzip3, safe_map, safe_zip,
|
||||
split_list, cache, extend_name_stack)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
treedef_children, treedef_tuple, tree_multimap,
|
||||
@ -481,17 +481,12 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr,
|
||||
body_nconsts=body_nconsts, body_jaxpr=body_jaxpr)
|
||||
|
||||
if config.omnistaging_enabled:
|
||||
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
||||
else:
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
||||
|
||||
cond_consts_uk, body_consts_uk, carry_init_uk = split_list(unknowns, [cond_nconsts, body_nconsts])
|
||||
# Fixpoint computation of unknown carry. Each iteration promotes
|
||||
# at least one carry to unknown. We need one last iteration to prepare the jaxpr.
|
||||
carry_uk = carry_init_uk
|
||||
for _ in range(1 + len(carry_uk)):
|
||||
body_jaxpr_known, _, carry_out_uk = partial_eval_jaxpr( # type: ignore
|
||||
body_jaxpr_known, _, carry_out_uk = pe.partial_eval_jaxpr( # type: ignore
|
||||
body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk)
|
||||
if carry_out_uk == carry_uk:
|
||||
break
|
||||
@ -500,7 +495,7 @@ def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts:
|
||||
else:
|
||||
assert False, "Fixpoint not reached"
|
||||
|
||||
cond_jaxpr_known, _, cond_uk = partial_eval_jaxpr( # type: ignore
|
||||
cond_jaxpr_known, _, cond_uk = pe.partial_eval_jaxpr( # type: ignore
|
||||
cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False)
|
||||
|
||||
if cond_uk[0] or all([not uk for uk in unknowns]) or all(unknowns):
|
||||
@ -850,11 +845,6 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
unknowns = [t.pval[0] is not None for t in tracers]
|
||||
index_uk, *ops_uk = unknowns
|
||||
|
||||
if config.omnistaging_enabled:
|
||||
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
||||
else:
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
||||
|
||||
if index_uk:
|
||||
# When the branch index is unknown, we stage out the whole cond.
|
||||
params = dict(branches=branches, linear=linear)
|
||||
@ -862,13 +852,13 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
|
||||
|
||||
branches_out_uks = []
|
||||
for branch_jaxpr in branches:
|
||||
_, _, out_uks = partial_eval_jaxpr(branch_jaxpr, ops_uk, instantiate=False)
|
||||
_, _, out_uks = pe.partial_eval_jaxpr(branch_jaxpr, ops_uk, instantiate=False)
|
||||
branches_out_uks.append(out_uks)
|
||||
out_uks = [any(uks) for uks in zip(*branches_out_uks)]
|
||||
|
||||
branches_1, branches_2, branch_res_avals = [], [], []
|
||||
for branch_jaxpr in branches:
|
||||
branch_jaxpr_1, branch_jaxpr_2, _ = partial_eval_jaxpr(
|
||||
branch_jaxpr_1, branch_jaxpr_2, _ = pe.partial_eval_jaxpr(
|
||||
branch_jaxpr, ops_uk, instantiate=out_uks)
|
||||
branch_num_res = len(branch_jaxpr_1.out_avals) - len(out_uks)
|
||||
|
||||
@ -1552,22 +1542,11 @@ def _prune_zeros(ts):
|
||||
|
||||
def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
jaxpr, linear, unroll):
|
||||
if not config.omnistaging_enabled and trace.main.trace_type is pe.StagingJaxprTrace: # type: ignore
|
||||
params = dict(reverse=reverse, length=length, num_consts=num_consts,
|
||||
num_carry=num_carry, jaxpr=jaxpr, linear=linear,
|
||||
unroll=unroll)
|
||||
return trace.default_process_primitive(scan_p, tracers, params)
|
||||
|
||||
num_ys = len(jaxpr.out_avals) - num_carry
|
||||
|
||||
unknowns = [t.pval[0] is not None for t in tracers]
|
||||
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
|
||||
|
||||
if config.omnistaging_enabled:
|
||||
partial_eval_jaxpr = pe.partial_eval_jaxpr
|
||||
else:
|
||||
partial_eval_jaxpr = partial(pe.partial_eval_jaxpr, trace_type=trace.main.trace_type)
|
||||
|
||||
# Fixpoint computation of which carry are unknown (not a constant): either
|
||||
# unknown from init, or the carry out is unknown. Each iteration promotes
|
||||
# at least one carry to unknown. We need at most len(carry) iterations,
|
||||
@ -1576,7 +1555,7 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
|
||||
carry_uk = init_uk
|
||||
for _ in range(1 + len(carry_uk)):
|
||||
unknowns = const_uk + carry_uk + xs_uk
|
||||
jaxpr_1, jaxpr_2, out_uk = partial_eval_jaxpr(
|
||||
jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
|
||||
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
|
||||
carry_uk_out = out_uk[:num_carry]
|
||||
if carry_uk_out == carry_uk:
|
||||
@ -1744,12 +1723,7 @@ def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr):
|
||||
return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_avals + res2_avals)
|
||||
|
||||
def _make_closed_jaxpr(traceable: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]):
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
||||
else:
|
||||
pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True)
|
||||
out_avals, _ = unzip2(pvals_out)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(traceable, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
|
||||
@ -2623,63 +2597,3 @@ def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
|
||||
ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul)
|
||||
ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min)
|
||||
ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_open_jaxpr, _initial_style_jaxpr, \
|
||||
_initial_style_jaxprs_with_common_consts
|
||||
|
||||
@cache()
|
||||
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
||||
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
|
||||
return jaxpr, out_pvals, consts, out_tree
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
|
||||
jaxpr, out_pvals, consts, out_tree = _initial_style_open_jaxpr(
|
||||
fun, in_tree, in_avals)
|
||||
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
return closed_jaxpr, consts, out_tree()
|
||||
|
||||
@cache()
|
||||
def _initial_style_jaxprs_with_common_consts(funs: Sequence[Callable],
|
||||
in_tree, in_avals):
|
||||
# When staging the branches of a conditional into jaxprs, constants are
|
||||
# extracted from each branch and converted to jaxpr arguments. To use the
|
||||
# staged jaxprs as the branches to a conditional *primitive*, we need for
|
||||
# their (input) signatures to match. This function "joins" the staged jaxprs:
|
||||
# for each one, it makes another that accepts *all* constants, but only uses
|
||||
# those that it needs (dropping the rest).
|
||||
|
||||
jaxprs, all_out_pvals, all_consts, all_out_trees = unzip4([
|
||||
_initial_style_open_jaxpr(fun, in_tree, in_avals) for fun in funs])
|
||||
|
||||
newvar = core.gensym(jaxprs, suffix='_')
|
||||
all_const_avals = tuple(
|
||||
tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
|
||||
for consts in all_consts)
|
||||
unused_const_vars = tuple(
|
||||
tuple(newvar(aval) for aval in const_avals)
|
||||
for const_avals in all_const_avals)
|
||||
|
||||
def pad_jaxpr_constvars(i, jaxpr):
|
||||
prefix = util.concatenate(unused_const_vars[:i])
|
||||
suffix = util.concatenate(unused_const_vars[i+1:])
|
||||
constvars = prefix + jaxpr.constvars + suffix
|
||||
return core.Jaxpr(constvars=constvars, invars=jaxpr.invars,
|
||||
outvars=jaxpr.outvars, eqns=jaxpr.eqns)
|
||||
|
||||
def type_and_const_convert_jaxpr(jaxpr, out_pvals):
|
||||
return core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
||||
jaxprs = [pad_jaxpr_constvars(i, jaxpr) for i, jaxpr in enumerate(jaxprs)]
|
||||
closed_jaxprs = _map(type_and_const_convert_jaxpr, jaxprs, all_out_pvals)
|
||||
|
||||
return (tuple(closed_jaxprs),
|
||||
tuple(util.concatenate(all_consts)),
|
||||
tuple(out_tree() for out_tree in all_out_trees))
|
||||
|
@ -1767,8 +1767,6 @@ def full_like(x: Array, fill_value: Array, dtype: Optional[DType] = None,
|
||||
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(x)
|
||||
dtype = dtype or _dtype(x)
|
||||
if not config.omnistaging_enabled:
|
||||
fill_value = tie_in(x, fill_value)
|
||||
return full(fill_shape, _convert_element_type(fill_value, dtype, weak_type))
|
||||
|
||||
|
||||
@ -4017,10 +4015,7 @@ def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes):
|
||||
if type(t) is ad_util.Zero:
|
||||
return [ad_util.Zero(operand.aval)] + [None] * len(start_indices)
|
||||
else:
|
||||
if config.omnistaging_enabled:
|
||||
zeros = full(operand_shape, 0, operand_dtype)
|
||||
else:
|
||||
zeros = full(operand_shape, tie_in(t, _zero(t)))
|
||||
zeros = full(operand_shape, 0, operand_dtype)
|
||||
return ([dynamic_update_slice(zeros, t, start_indices)] +
|
||||
[None] * len(start_indices))
|
||||
|
||||
@ -4319,10 +4314,7 @@ def _gather_transpose_rule(t, operand, start_indices, *, dimension_numbers,
|
||||
if type(t) is ad_util.Zero:
|
||||
out = ad_util.Zero(operand.aval)
|
||||
else:
|
||||
if config.omnistaging_enabled:
|
||||
zeros = full(operand_shape, _zero(t))
|
||||
else:
|
||||
zeros = full(operand_shape, tie_in(t, _zero(t)))
|
||||
zeros = full(operand_shape, _zero(t))
|
||||
scatter_dnums = ScatterDimensionNumbers(
|
||||
update_window_dims=dimension_numbers.offset_dims,
|
||||
inserted_window_dims=dimension_numbers.collapsed_slice_dims,
|
||||
@ -5924,8 +5916,6 @@ def _top_k_jvp(primals, tangents, *, k):
|
||||
gather_indices = []
|
||||
for i in range(rank-1):
|
||||
_iota = iota(k_idxs.dtype, idx_shape[i])
|
||||
if not config.omnistaging_enabled:
|
||||
_iota = tie_in(operand, _iota)
|
||||
_iota = broadcast_in_dim(_iota, gather_index_shape, (i,))
|
||||
gather_indices.append(_iota)
|
||||
gather_indices.append(reshape(k_idxs, gather_index_shape))
|
||||
@ -5979,14 +5969,7 @@ def create_token(_=None):
|
||||
|
||||
The argument is ignored. It exists for backward compatibility.
|
||||
"""
|
||||
if config.omnistaging_enabled:
|
||||
return create_token_p.bind()
|
||||
else:
|
||||
x = _
|
||||
if x is None:
|
||||
raise ValueError(
|
||||
'create_token needs a tie-in operand unless omnistaging is enabled.')
|
||||
return create_token_p.bind(stop_gradient(x))
|
||||
return create_token_p.bind()
|
||||
|
||||
create_token_p = Primitive("create_token")
|
||||
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
|
||||
@ -6601,71 +6584,3 @@ def _canonicalize_axis(axis, num_dims):
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
return axis
|
||||
|
||||
|
||||
tie_in_p = Primitive('tie_in')
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global tie_in
|
||||
|
||||
def tie_in(x: Array, y: Array) -> Array:
|
||||
"""Returns the value of ``y`` but with a fake data dependence on ``x``.
|
||||
|
||||
When staging to XLA (e.g. running under jit or pmap), values that don't depend
|
||||
on computation inputs are computed op-by-op, and folded into the XLA
|
||||
computation as constants.
|
||||
|
||||
``tie_in`` provides a way to explicitly stage values into the computation.
|
||||
When staging to XLA and ``x`` is already staged, then the result of ``tie_in``
|
||||
is ``y``, but staged to XLA. Downstream use of the result will also be staged
|
||||
to XLA.
|
||||
|
||||
For example, ``lax.sin(const)`` would be constant-folded if ``const`` is
|
||||
a constant array, but ``lax.sin(lax.tie_in(x, const))``, will be staged to
|
||||
XLA as long as ``x`` is staged to XLA.
|
||||
"""
|
||||
if config.omnistaging_enabled:
|
||||
return y
|
||||
else:
|
||||
return tie_in_p.bind(x, y)
|
||||
|
||||
# If lax has already been imported, we need to monkey-patch the
|
||||
# lax/__init__.py import of tie_in. If not (i.e. if this is running at lax
|
||||
# module creation time) then we'll get an import error.
|
||||
try:
|
||||
jax.lax.tie_in = tie_in
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _tie_in_transpose_rule(t, x, y):
|
||||
if ad.is_undefined_primal(x):
|
||||
return [ad_util.Zero(x.aval), t]
|
||||
else:
|
||||
return [ad_util.Zero.from_value(x), t]
|
||||
|
||||
def _tie_in_batch_rule(batched_args, batch_dims):
|
||||
y = tie_in(*batched_args)
|
||||
_, bdim_y = batch_dims
|
||||
return y, bdim_y
|
||||
|
||||
def _tie_in_impl(x, y):
|
||||
core.check_valid_jaxtype(x)
|
||||
core.check_valid_jaxtype(y)
|
||||
return y
|
||||
|
||||
def _tie_in_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
x_dot, y_dot = tangents
|
||||
if type(y_dot) is ad_util.Zero or core.get_aval(y_dot).dtype is dtypes.float0:
|
||||
return y, y_dot # skip tying in in this case
|
||||
else:
|
||||
return ad.linear_jvp(tie_in_p, primals, tangents)
|
||||
|
||||
tie_in_p.def_impl(_tie_in_impl)
|
||||
tie_in_p.def_abstract_eval(lambda x, y: raise_to_shaped(y))
|
||||
xla.translations[tie_in_p] = lambda c, x, y: y
|
||||
ad.primitive_jvps[tie_in_p] = _tie_in_jvp
|
||||
ad.primitive_transposes[tie_in_p] = partial(ad.linear_transpose2, _tie_in_transpose_rule)
|
||||
batching.primitive_batchers[tie_in_p] = _tie_in_batch_rule
|
||||
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
|
||||
|
@ -25,18 +25,15 @@ import numpy as np
|
||||
from jax import core
|
||||
from jax import dtypes
|
||||
from jax import tree_util
|
||||
from jax._src import source_info_util
|
||||
from . import lax
|
||||
from jax.core import ShapedArray, AxisName, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.util import partial, unzip2, prod, canonicalize_axis, safe_map, moveaxis
|
||||
from jax.lib import xla_client as xc
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.config import config
|
||||
from jax._src.numpy import lax_numpy
|
||||
|
||||
xops = xc.ops
|
||||
@ -1044,15 +1041,9 @@ def all_gather(x, axis_name, *, axis_index_groups=None):
|
||||
[ 4. 5. 6. 7.]]
|
||||
"""
|
||||
axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups)
|
||||
# The all_gather primitive doesn't work when omni-staging is disabled.
|
||||
if not config.omnistaging_enabled:
|
||||
return _all_gather_via_psum(x, all_gather_dimension=0, axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups, axis_size=axis_size)
|
||||
|
||||
def bind(x):
|
||||
return all_gather_p.bind(x, all_gather_dimension=0, axis_name=axis_name,
|
||||
axis_index_groups=axis_index_groups, axis_size=axis_size)
|
||||
|
||||
bind = partial(all_gather_p.bind, all_gather_dimension=0,
|
||||
axis_name=axis_name, axis_index_groups=axis_index_groups,
|
||||
axis_size=axis_size)
|
||||
return tree_util.tree_map(bind, x)
|
||||
|
||||
def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
|
||||
@ -1337,42 +1328,3 @@ xla.parallel_translations[pgather_p] = _pgather_parallel_translation
|
||||
batching.primitive_batchers[pgather_p] = _pgather_batcher
|
||||
batching.collective_rules[pgather_p] = _pgather_collective_batcher
|
||||
core.axis_substitution_rules[pgather_p] = partial(_subst_all_names_in_param, 'axes')
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global axis_index
|
||||
|
||||
psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore
|
||||
psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore
|
||||
pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore
|
||||
|
||||
def _axis_index_bind(*, axis_name):
|
||||
dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
|
||||
frame = dynamic_axis_env[axis_name]
|
||||
sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
|
||||
nreps = dynamic_axis_env.nreps
|
||||
trace = frame.pmap_trace
|
||||
|
||||
out_aval = _axis_index_abstract_eval(
|
||||
nreps=nreps, sizes=sizes, axis_name=axis_name)
|
||||
out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
||||
eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
|
||||
dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
|
||||
source_info_util.current())
|
||||
out_tracer.recipe = eqn
|
||||
|
||||
return out_tracer
|
||||
|
||||
def _axis_index_translation_rule(c, nreps, sizes, axis_name):
|
||||
div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
|
||||
mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
|
||||
def _axis_index_abstract_eval(*, nreps, sizes, axis_name):
|
||||
return ShapedArray((), np.int32, named_shape={axis_name: sizes[-1]})
|
||||
|
||||
axis_index_p.def_custom_bind(_axis_index_bind)
|
||||
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
|
||||
xla.translations[axis_index_p] = _axis_index_translation_rule
|
||||
|
@ -2306,12 +2306,7 @@ def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
|
||||
|
||||
normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims)
|
||||
normalizer = normalizer - ddof
|
||||
if config.omnistaging_enabled:
|
||||
normalizer_mask = lax.le(normalizer, 0)
|
||||
else:
|
||||
zero = lax.full_like(normalizer, 0, shape=())
|
||||
normalizer_mask = lax.le(normalizer, zero)
|
||||
|
||||
normalizer_mask = lax.le(normalizer, 0)
|
||||
result = nansum(centered, axis, keepdims=keepdims)
|
||||
result = where(normalizer_mask, nan, result)
|
||||
divisor = where(normalizer_mask, 1, normalizer)
|
||||
@ -4353,8 +4348,6 @@ def _take_along_axis(arr, indices, axis):
|
||||
j += 1
|
||||
elif idx_shape[i] != 1:
|
||||
iota = lax.iota(_dtype(indices), out_shape[i])
|
||||
if not config.omnistaging_enabled:
|
||||
iota = lax.tie_in(arr, iota)
|
||||
iota = lax.broadcast_in_dim(iota, gather_index_shape, (j,))
|
||||
gather_indices.append(iota)
|
||||
slice_sizes.append(1)
|
||||
|
@ -57,18 +57,6 @@ def unzip3(xyzs):
|
||||
zs.append(z)
|
||||
return tuple(xs), tuple(ys), tuple(zs)
|
||||
|
||||
def unzip4(wxyzs):
|
||||
ws = []
|
||||
xs = []
|
||||
ys = []
|
||||
zs = []
|
||||
for w, x, y, z in wxyzs:
|
||||
ws.append(w)
|
||||
xs.append(x)
|
||||
ys.append(y)
|
||||
zs.append(z)
|
||||
return tuple(ws), tuple(xs), tuple(ys), tuple(zs)
|
||||
|
||||
def subvals(lst, replace):
|
||||
lst = list(lst)
|
||||
for i, v in replace:
|
||||
|
43
jax/api.py
43
jax/api.py
@ -206,7 +206,7 @@ def jit(
|
||||
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
||||
-0.85743 -0.78232 0.76827 0.59566 ]
|
||||
"""
|
||||
if FLAGS.experimental_cpp_jit and config.omnistaging_enabled:
|
||||
if FLAGS.experimental_cpp_jit:
|
||||
return _cpp_jit(fun, static_argnums, device, backend, donate_argnums)
|
||||
else:
|
||||
return _python_jit(fun, static_argnums, device, backend, donate_argnums)
|
||||
@ -644,16 +644,10 @@ def xla_computation(fun: Callable,
|
||||
"xla_computation in_parts", in_tree.children()[0], in_parts))
|
||||
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
|
||||
avals = map(shaped_abstractify, args_flat)
|
||||
if config.omnistaging_enabled:
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
else:
|
||||
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
jaxtree_fun, pvals, instantiate=True, stage_out=True) # type: ignore
|
||||
out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals]
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
|
||||
if out_parts is None:
|
||||
@ -1549,10 +1543,6 @@ def pmap(
|
||||
local_axis_size = _mapped_axis_size(in_tree, args, in_axes_flat, "pmap", kws=True)
|
||||
for arg in args: _check_arg(arg)
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
if not config.omnistaging_enabled and out_axes != 0:
|
||||
raise ValueError("out_axes supported only with omnistaging enabled")
|
||||
if not config.omnistaging_enabled and any(in_axis not in {None, 0} for in_axis in in_axes_flat):
|
||||
raise ValueError("in_axes other than 0 and None only supported with omnistaging enabled")
|
||||
if any(out_axis is None for out_axis in tree_flatten(out_axes)):
|
||||
raise NotImplementedError("None out_axes in pmap are not supported yet")
|
||||
# NOTE: We don't put out_tree() in the closure, because it's (1) non-hashable,
|
||||
@ -2048,19 +2038,10 @@ def make_jaxpr(fun: Callable,
|
||||
jax_args, in_tree = tree_flatten((args, kwargs))
|
||||
jaxtree_fun, out_tree = flatten_fun(wrapped, in_tree)
|
||||
in_avals = map(shaped_abstractify, jax_args)
|
||||
if config.omnistaging_enabled:
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
|
||||
else:
|
||||
if axis_env:
|
||||
raise NotImplementedError(
|
||||
"axis_env argument to make_jaxpr only supported with omnistaging.")
|
||||
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore
|
||||
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
|
||||
with ExitStack() as stack:
|
||||
for axis_name, size in axis_env or []:
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, in_avals)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
if return_shape:
|
||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
||||
@ -2488,11 +2469,7 @@ class CustomTransformsFunction(object):
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_pvals = [pe.PartialVal.unknown(raise_to_shaped(core.get_aval(x)))
|
||||
for x in args_flat]
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
|
||||
else:
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
|
||||
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
|
||||
in_tree=in_tree, out_tree=out_tree(),
|
||||
num_consts=len(consts))
|
||||
|
@ -59,11 +59,10 @@ class Config:
|
||||
self.FLAGS = NameSpace(self.read, self.update)
|
||||
self.use_absl = False
|
||||
self._contextmanager_flags = set()
|
||||
self._update_hooks = {}
|
||||
|
||||
# TODO(mattjj): delete these when only omnistaging is available
|
||||
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
|
||||
self._omnistaging_disablers = []
|
||||
self._update_hooks = {}
|
||||
|
||||
def update(self, name, val):
|
||||
if self.use_absl:
|
||||
@ -160,13 +159,8 @@ class Config:
|
||||
"see https://github.com/google/jax/blob/master/design_notes/omnistaging.md.\n"
|
||||
"To remove this warning, unset the JAX_OMNISTAGING environment variable.")
|
||||
|
||||
def register_omnistaging_disabler(self, disabler):
|
||||
if self.omnistaging_enabled:
|
||||
self._omnistaging_disablers.append(disabler)
|
||||
|
||||
def enable_omnistaging(self):
|
||||
if not self.omnistaging_enabled:
|
||||
raise Exception("can't re-enable omnistaging after it's been disabled")
|
||||
return # TODO(mattjj): remove all callers
|
||||
|
||||
def disable_omnistaging(self):
|
||||
warnings.warn(
|
||||
@ -341,6 +335,7 @@ FLAGS = flags.FLAGS
|
||||
already_configured_with_absl = False
|
||||
|
||||
|
||||
# TODO(mattjj): remove all uses of this flag
|
||||
flags.DEFINE_bool(
|
||||
'jax_omnistaging',
|
||||
bool_env('JAX_OMNISTAGING', True),
|
||||
|
160
jax/core.py
160
jax/core.py
@ -1819,166 +1819,6 @@ def pp_kv_pairs(kv_pairs):
|
||||
else:
|
||||
return pp('')
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
|
||||
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env, \
|
||||
eval_context, extra_jit_context
|
||||
|
||||
class TraceStack:
|
||||
upward: List[MainTrace]
|
||||
downward: List[MainTrace]
|
||||
|
||||
def __init__(self):
|
||||
self.upward = []
|
||||
self.downward = []
|
||||
|
||||
def next_level(self, bottom: bool) -> int:
|
||||
if bottom:
|
||||
return - (len(self.downward) + 1)
|
||||
else:
|
||||
return len(self.upward)
|
||||
|
||||
def push(self, main_trace: MainTrace, bottom: bool) -> None:
|
||||
if bottom:
|
||||
self.downward.append(main_trace)
|
||||
else:
|
||||
self.upward.append(main_trace)
|
||||
|
||||
def pop(self, bottom: bool) -> None:
|
||||
if bottom:
|
||||
self.downward.pop()
|
||||
else:
|
||||
self.upward.pop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'Trace stack\n{} ---\n{}'.format(
|
||||
map(' {}\n'.format, self.upward[::-1]),
|
||||
map(' {}\n'.format, self.downward))
|
||||
|
||||
def copy(self):
|
||||
new = TraceStack()
|
||||
new.upward = self.upward[:]
|
||||
new.downward = self.downward[:]
|
||||
return new
|
||||
|
||||
class TraceState:
|
||||
trace_stack: TraceStack
|
||||
substack: List[Sublevel]
|
||||
initial_style: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.trace_stack = TraceStack() # type: ignore
|
||||
self.substack = [Sublevel(0)]
|
||||
self.initial_style = False
|
||||
|
||||
def copy(self):
|
||||
new = TraceState()
|
||||
new.trace_stack = self.trace_stack.copy()
|
||||
new.substack = self.substack[:]
|
||||
new.initial_style = self.initial_style
|
||||
return new
|
||||
|
||||
def extra_jit_context(trace_stack):
|
||||
return None
|
||||
|
||||
thread_local_state = ThreadLocalState()
|
||||
|
||||
def reset_trace_state() -> bool:
|
||||
"Reset the global trace state and return True if it was already clean."
|
||||
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
|
||||
thread_local_state.trace_state.trace_stack.downward or
|
||||
thread_local_state.trace_state.trace_stack.upward):
|
||||
thread_local_state.trace_state.__init__() # type: ignore
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: Type[Trace], bottom=False, **payload) -> Generator[MainTrace, None, None]:
|
||||
level = thread_local_state.trace_state.trace_stack.next_level(bottom)
|
||||
main = MainTrace(level, trace_type, **payload)
|
||||
thread_local_state.trace_state.trace_stack.push(main, bottom)
|
||||
|
||||
try:
|
||||
yield main
|
||||
finally:
|
||||
thread_local_state.trace_state.trace_stack.pop(bottom)
|
||||
|
||||
if config.jax_check_tracer_leaks:
|
||||
t = ref(main)
|
||||
del main
|
||||
if t() is not None:
|
||||
print(thread_local_state.trace_state.trace_stack)
|
||||
raise Exception('Leaked trace {}'.format(t()))
|
||||
|
||||
def find_top_trace(xs) -> Optional[Trace]:
|
||||
top_trace = max((x._trace for x in xs if isinstance(x, Tracer)),
|
||||
key=attrgetter('level'), default=None)
|
||||
return top_trace and top_trace.main.with_cur_sublevel()
|
||||
|
||||
@contextmanager
|
||||
def eval_context():
|
||||
yield # dummy implementation for forward compatibility
|
||||
|
||||
def bind(self, *args, **kwargs):
|
||||
assert not config.jax_enable_checks or all(isinstance(arg, Tracer)
|
||||
or valid_jaxtype(arg) for arg in args), args
|
||||
top_trace = find_top_trace(args)
|
||||
if top_trace is None:
|
||||
return self.impl(*args, **kwargs)
|
||||
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
|
||||
if self.multiple_results:
|
||||
return map(full_lower, out_tracer)
|
||||
else:
|
||||
return full_lower(out_tracer)
|
||||
Primitive.bind = bind # type: ignore
|
||||
|
||||
def call_bind(primitive: Union['CallPrimitive', 'MapPrimitive'],
|
||||
fun: lu.WrappedFun, *args, **params):
|
||||
out_axes_transforms = _IgnoreElemList()
|
||||
if primitive.map_primitive:
|
||||
out_axes_thunk = params['out_axes_thunk']
|
||||
# The new thunk depends deterministically on the old thunk and the wrapped function.
|
||||
# Any caching already has to include the wrapped function as part of the key, so we
|
||||
# only use the previous thunk for equality checks.
|
||||
@as_hashable_function(closure=out_axes_thunk)
|
||||
def new_out_axes_thunk():
|
||||
out_axes = out_axes_thunk()
|
||||
for t in out_axes_transforms:
|
||||
out_axes = t(out_axes)
|
||||
return out_axes
|
||||
params = dict(params, out_axes_thunk=new_out_axes_thunk)
|
||||
params_tuple = tuple(params.items())
|
||||
top_trace = find_top_trace(args)
|
||||
level = (thread_local_state.trace_state.trace_stack.next_level(True)
|
||||
if top_trace is None else top_trace.level)
|
||||
params_tuple = tuple(params.items())
|
||||
fun, env_trace_todo = process_env_traces(
|
||||
fun, primitive, level, params_tuple, out_axes_transforms)
|
||||
if top_trace is None:
|
||||
with new_sublevel():
|
||||
outs = primitive.impl(fun, *args, **params)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = primitive.process(top_trace, fun, tracers, params)
|
||||
return apply_todos(env_trace_todo(), map(full_lower, outs))
|
||||
|
||||
@contextmanager
|
||||
def extend_axis_env(axis_name, size: int, tag: Any):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def initial_style_staging():
|
||||
trace_state = thread_local_state.trace_state
|
||||
prev, trace_state.initial_style = trace_state.initial_style, True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
trace_state.initial_style = prev
|
||||
|
||||
# Casting float0 array to a float-valued zero array.
|
||||
def zeros_like_float0(array, dtype=None):
|
||||
if not dtype:
|
||||
|
@ -213,16 +213,8 @@ class custom_jvp:
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
|
||||
if config.omnistaging_enabled:
|
||||
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
|
||||
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
|
||||
else:
|
||||
if _initial_style_staging():
|
||||
out_flat = custom_jvp_call_jaxpr(flat_fun, flat_jvp, *args_flat) # type: ignore
|
||||
out_tree = out_tree1()
|
||||
else:
|
||||
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
|
||||
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
|
||||
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
|
||||
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
def _add_args(f, extra_args):
|
||||
@ -489,21 +481,10 @@ class custom_vjp:
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
|
||||
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
|
||||
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
|
||||
if config.omnistaging_enabled:
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
|
||||
out_trees=out_trees)
|
||||
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
|
||||
out_tree = aux if fst else aux[0]
|
||||
else:
|
||||
if _initial_style_staging():
|
||||
out_flat = custom_vjp_call_jaxpr(flat_fun, flat_fwd, flat_bwd, # type: ignore
|
||||
*args_flat, out_trees=out_trees)
|
||||
out_tree = out_tree()
|
||||
else:
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
||||
*args_flat, out_trees=out_trees)
|
||||
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
|
||||
out_tree = aux if fst else aux[0]
|
||||
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat,
|
||||
out_trees=out_trees)
|
||||
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
|
||||
out_tree = aux if fst else aux[0]
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
@partial(partial, tree_map)
|
||||
@ -662,8 +643,6 @@ def _custom_vjp_call_jaxpr_vmap(
|
||||
fwd_jaxpr_thunk=batched_fwd_jaxpr_thunk, bwd=batched_bwd,
|
||||
out_trees=out_trees, num_consts=num_consts)
|
||||
out_dims = out_dims2[0] if out_dims2 else out_dims1
|
||||
if not config.omnistaging_enabled:
|
||||
out_dims = out_dims[:len(batched_outs)]
|
||||
return batched_outs, out_dims
|
||||
batching.initial_style_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
|
||||
|
||||
@ -674,76 +653,6 @@ batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
xla.translations[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_jaxpr, custom_vjp_call_jaxpr, custom_jvp_call_jaxpr
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
|
||||
bottom=True, stage_out=False) # type: ignore
|
||||
assert not any(isinstance(c, core.Tracer) for c in consts)
|
||||
return jaxpr, consts
|
||||
|
||||
def jvp_bind(self, fun, jvp, *args):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
fun, env_trace_todo1 = core.process_env_traces(
|
||||
fun, self, top_trace and top_trace.level, (), None)
|
||||
jvp, env_trace_todo2 = core.process_env_traces(
|
||||
jvp, self, top_trace and top_trace.level, (), None)
|
||||
if top_trace is None:
|
||||
with core.new_sublevel():
|
||||
outs = self.impl(fun, jvp, *args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers)
|
||||
_, env_trace_todo = lu.merge_linear_aux(env_trace_todo1, env_trace_todo2)
|
||||
if env_trace_todo:
|
||||
raise core.UnexpectedTracerError
|
||||
return map(core.full_lower, outs)
|
||||
CustomJVPCallPrimitive.bind = jvp_bind # type: ignore
|
||||
|
||||
def jvp_post_process(self, trace, out_tracers, params):
|
||||
raise core.UnexpectedTracerError
|
||||
CustomJVPCallPrimitive.post_process = jvp_post_process # type: ignore
|
||||
|
||||
def vjp_bind(self, fun, fwd, bwd, *args, out_trees):
|
||||
args = map(core.full_lower, args)
|
||||
top_trace = core.find_top_trace(args)
|
||||
if top_trace is None:
|
||||
outs = fun.call_wrapped(*args)
|
||||
else:
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers,
|
||||
out_trees=out_trees)
|
||||
return map(core.full_lower, outs)
|
||||
CustomVJPCallPrimitive.bind = vjp_bind # type: ignore
|
||||
|
||||
def vjp_post_process(self, trace, out_tracers, params):
|
||||
raise core.UnexpectedTracerError
|
||||
CustomVJPCallPrimitive.post_process = vjp_post_process # type: ignore
|
||||
|
||||
def custom_jvp_call_jaxpr(fun: Callable, jvp: Callable, *args):
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
jvp_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(jvp, in_avals * 2))
|
||||
return custom_jvp_call_jaxpr_p.bind(
|
||||
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
|
||||
jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
|
||||
|
||||
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
|
||||
in_avals = [raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
fun_jaxpr, consts = _initial_style_jaxpr(fun, in_avals) # consts can be tracers!
|
||||
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
|
||||
fwd_jaxpr_thunk = pe._memoize(lambda: _initial_style_jaxpr(fwd, in_avals))
|
||||
return custom_vjp_call_jaxpr_p.bind(
|
||||
*consts, *args, fun_jaxpr=closed_fun_jaxpr,
|
||||
fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees,
|
||||
num_consts=len(consts))
|
||||
|
||||
|
||||
def custom_gradient(fun):
|
||||
"""Convenience function for defining custom VJP rules (aka custom gradients).
|
||||
|
||||
@ -813,11 +722,7 @@ def custom_gradient(fun):
|
||||
ans_flat, out_tree = tree_flatten((ans,))
|
||||
rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
|
||||
ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
|
||||
else:
|
||||
ans_pvals = [pe.PartialVal.unknown(a) for a in ans_avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(rule, ans_pvals, instantiate=True)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
|
||||
return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
|
||||
|
||||
def bwd(res, cts):
|
||||
@ -921,15 +826,8 @@ def closure_convert(fun, *example_args):
|
||||
|
||||
@cache()
|
||||
def _closure_convert_for_avals(fun, in_tree, in_avals):
|
||||
if config.omnistaging_enabled:
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
else:
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
|
||||
wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore
|
||||
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
|
||||
out_tree = out_tree()
|
||||
|
||||
# We only want to closure convert for constants with respect to which we're
|
||||
|
@ -27,7 +27,6 @@ import numpy as np
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax import ad_util, core, lax, xla
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import unzip2, wrap_name
|
||||
import jax.numpy as jnp
|
||||
import jax.linear_util as lu
|
||||
@ -274,10 +273,6 @@ def _def_passthrough(prim, argnums=(0,)):
|
||||
_def_passthrough(lax.select_p, (0, 1, 2))
|
||||
_def_passthrough(lax.broadcast_in_dim_p)
|
||||
_def_passthrough(xla.device_put_p)
|
||||
try:
|
||||
_def_passthrough(lax_internal.tie_in_p, (0, 1))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class _DoubleDouble:
|
||||
|
@ -152,16 +152,9 @@ def convert(fun: Callable, *,
|
||||
|
||||
def converted_fun(*args: TfVal) -> TfVal:
|
||||
# TODO: is there a better way to check if we are inside a transformation?
|
||||
if config.omnistaging_enabled:
|
||||
if not core.trace_state_clean():
|
||||
raise ValueError("convert must be used outside all JAX transformations."
|
||||
+ f"Trace state: {core.thread_local_state.trace_state}")
|
||||
else:
|
||||
if (core.thread_local_state.trace_state.trace_stack.downward or
|
||||
core.thread_local_state.trace_state.trace_stack.upward or
|
||||
core.thread_local_state.trace_state.substack != [core.Sublevel(0)]):
|
||||
raise ValueError("convert must be used outside all JAX transformations."
|
||||
+ f"Trace state: {core.thread_local_state.trace_state}")
|
||||
if not core.trace_state_clean():
|
||||
raise ValueError("convert must be used outside all JAX transformations."
|
||||
+ f"Trace state: {core.thread_local_state.trace_state}")
|
||||
|
||||
def check_arg(a):
|
||||
if not _is_tfval(a):
|
||||
@ -267,8 +260,7 @@ def _interpret_fun(fun: lu.WrappedFun,
|
||||
in_vals: Sequence[TfVal],
|
||||
in_avals: Sequence[core.AbstractValue]
|
||||
) -> Sequence[Tuple[TfVal, core.AbstractValue]]:
|
||||
new_main = core.new_base_main if config.omnistaging_enabled else core.new_main
|
||||
with new_main(TensorFlowTrace) as main: # type: ignore
|
||||
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
|
||||
fun = _interpret_subtrace(fun, main, in_avals)
|
||||
out_vals: Sequence[Tuple[TfVal, core.AbstractValue]] = fun.call_wrapped(*in_vals)
|
||||
del main
|
||||
@ -813,10 +805,6 @@ tf_not_yet_impl = [
|
||||
"call_tf",
|
||||
]
|
||||
|
||||
try:
|
||||
tf_impl[lax.tie_in_p] = lambda x, y: y
|
||||
except AttributeError:
|
||||
pass
|
||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||
tf_impl[ad_util.zeros_like_p] = tf.zeros_like
|
||||
|
||||
|
@ -379,7 +379,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.TransformConvertAndCompare(f, arg, None)
|
||||
self.TransformConvertAndCompare(f, arg, "grad")
|
||||
|
||||
@jtu.skip_on_flag('jax_omnistaging', False)
|
||||
def test_convert_nullary_func(self):
|
||||
# Even nullary functions are converted to TF (as opposed to constant-folded
|
||||
# in JAX prior to conversion).
|
||||
@ -389,7 +388,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def()
|
||||
self.assertIn('op: "Sin"', str(f_tf_graph))
|
||||
|
||||
@jtu.skip_on_flag('jax_omnistaging', False)
|
||||
def test_convert_of_nested_independent_jit(self):
|
||||
def func(x):
|
||||
def inner1(y):
|
||||
@ -449,7 +447,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
ValueError, "convert must be used outside all JAX transformations"):
|
||||
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)
|
||||
|
||||
@jtu.skip_on_flag('jax_omnistaging', False)
|
||||
def test_name_scope(self):
|
||||
log = []
|
||||
|
||||
|
@ -682,28 +682,6 @@ def _make_add_any_harness(name, *, shapes=((2,), (2,)), dtype=np.float32):
|
||||
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.boolean):
|
||||
_make_add_any_harness("dtypes", dtype=dtype)
|
||||
|
||||
for rhs_dtype in jtu.dtypes.all:
|
||||
lhs_dtype = np.float32
|
||||
lhs_shape = (2, 3)
|
||||
rhs_shape = (4, 5)
|
||||
define(
|
||||
lax.tie_in_p,
|
||||
f"lhs={jtu.format_shape_dtype_string(lhs_shape, lhs_dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)}",
|
||||
lax.tie_in_p.bind,
|
||||
[RandArg(lhs_shape, lhs_dtype),
|
||||
RandArg(rhs_shape, rhs_dtype)],
|
||||
jax_unimplemented=[
|
||||
Limitation(
|
||||
"requires omnistaging to be disabled",
|
||||
enabled=config.omnistaging_enabled)
|
||||
],
|
||||
dtype=rhs_dtype,
|
||||
lhs_shape=lhs_shape,
|
||||
lhs_dtype=lhs_dtype,
|
||||
rhs_shape=rhs_shape,
|
||||
rhs_dtype=rhs_dtype,
|
||||
primitive=lax.tie_in_p)
|
||||
|
||||
for dtype in jtu.dtypes.all:
|
||||
shape: Tuple[int, ...] = (20, 20)
|
||||
define(
|
||||
|
@ -129,8 +129,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
all_primitives = tuple(sorted(all_primitives, key=str))
|
||||
for p in all_primitives:
|
||||
# TODO: remove tie_in once omnistaging is on by default
|
||||
if p.name == "axis_index" or p.name == "tie_in":
|
||||
if p.name == "axis_index":
|
||||
continue
|
||||
if p.name in tf_not_yet_impl:
|
||||
self.assertNotIn(
|
||||
|
@ -589,6 +589,3 @@ def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr,
|
||||
del jvp_jaxpr_thunk
|
||||
return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in)
|
||||
jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule
|
||||
|
||||
|
||||
deflinear(lax.tie_in_p)
|
||||
|
@ -118,7 +118,6 @@ from jax import tree_util
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.util import safe_map
|
||||
from jax.config import config
|
||||
|
||||
|
||||
class Scope(object):
|
||||
@ -291,25 +290,15 @@ class Scope(object):
|
||||
def start_subtrace(self):
|
||||
"""Starts a nested trace, returns the Trace object."""
|
||||
# TODO: This follows the __enter__ part of core.new_main.
|
||||
if config.omnistaging_enabled:
|
||||
level = core.thread_local_state.trace_state.trace_stack.next_level()
|
||||
main = core.MainTrace(level, pe.JaxprTrace)
|
||||
core.thread_local_state.trace_state.trace_stack.push(main)
|
||||
self._count_subtraces += 1
|
||||
return pe.JaxprTrace(main, core.cur_sublevel())
|
||||
else:
|
||||
level = core.thread_local_state.trace_state.trace_stack.next_level(False)
|
||||
main = core.MainTrace(level, pe.JaxprTrace)
|
||||
core.thread_local_state.trace_state.trace_stack.push(main, False)
|
||||
self._count_subtraces += 1
|
||||
return pe.JaxprTrace(main, core.cur_sublevel())
|
||||
level = core.thread_local_state.trace_state.trace_stack.next_level()
|
||||
main = core.MainTrace(level, pe.JaxprTrace)
|
||||
core.thread_local_state.trace_state.trace_stack.push(main)
|
||||
self._count_subtraces += 1
|
||||
return pe.JaxprTrace(main, core.cur_sublevel())
|
||||
|
||||
def end_subtrace(self):
|
||||
# TODO: This follows the __exit__ part of core.new_main
|
||||
if config.omnistaging_enabled:
|
||||
core.thread_local_state.trace_state.trace_stack.pop()
|
||||
else:
|
||||
core.thread_local_state.trace_state.trace_stack.pop(False)
|
||||
core.thread_local_state.trace_state.trace_stack.pop()
|
||||
self._count_subtraces -= 1
|
||||
|
||||
|
||||
|
@ -550,13 +550,8 @@ def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_
|
||||
# (in this case via partial_eval) before we call into backward_pass again.
|
||||
typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
|
||||
unknowns = map(is_undefined_primal, primals_in)
|
||||
if config.omnistaging_enabled:
|
||||
primal_jaxpr, tangent_jaxpr, out_unknowns = \
|
||||
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore
|
||||
else:
|
||||
primal_jaxpr, tangent_jaxpr, out_unknowns = \
|
||||
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
|
||||
trace_type=None) # type: ignore
|
||||
primal_jaxpr, tangent_jaxpr, out_unknowns = \
|
||||
pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore
|
||||
|
||||
def do_transpose(primals_in, cotangents_in):
|
||||
# NOTE: This is passing in undefined primals in place of tangent arguments, but it
|
||||
@ -706,12 +701,7 @@ def defvjp_all(prim, custom_vjp):
|
||||
primals_out = [primals_out]
|
||||
out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
|
||||
ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
|
||||
else:
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
|
||||
instantiate=True)
|
||||
jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
|
||||
tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
|
||||
num_res=len(res), out_avals=out_avals)
|
||||
return primals_out + tangents_out
|
||||
@ -766,17 +756,3 @@ class CustomVJPException(Exception):
|
||||
"closed-over value into the custom_vjp function as an argument, and "
|
||||
"adapting the custom_vjp fwd and bwd rules.")
|
||||
super().__init__(msg)
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global jvp_jaxpr
|
||||
|
||||
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
|
||||
assert len(jaxpr.in_avals) == len(nonzeros)
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
|
||||
tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
|
||||
avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
|
||||
pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
||||
jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_nonzeros()
|
||||
|
@ -48,7 +48,7 @@ def batchfun(axis_name, axis_size, in_dims, *in_vals):
|
||||
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
||||
and not isinstance(core.get_aval(x), core.AbstractUnit) # non-omnistaging
|
||||
and not isinstance(core.get_aval(x), core.AbstractUnit)
|
||||
else ax for x, ax in zip(in_vals, in_dims)]
|
||||
with core.new_main(BatchTrace, axis_name=axis_name) as main:
|
||||
with core.extend_axis_env(axis_name, axis_size, main):
|
||||
@ -477,20 +477,4 @@ def _merge_bdims(x, y):
|
||||
return x # arbitrary
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global batch_jaxpr
|
||||
|
||||
def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
f, out_batched = batch_subtrace_instantiate(f, instantiate, axis_size)
|
||||
f = batchfun(f, axis_name, axis_size, [0 if b else None for b in in_batched])
|
||||
avals_in = [core.unmapped_aval(axis_size, 0, aval) if b else aval
|
||||
for aval, b in zip(jaxpr.in_avals, in_batched)]
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
||||
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts_out), out_batched()
|
||||
|
||||
|
||||
collective_rules: Dict[core.Primitive, Callable] = {}
|
||||
|
@ -26,7 +26,6 @@ from ..api_util import flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from .._src.util import safe_map, safe_zip, split_list
|
||||
from .. import custom_derivatives
|
||||
from ..config import config
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
@ -214,14 +213,9 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
|
||||
complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)
|
||||
|
||||
in_avals = map(abstract, primals_in + primals_out + primals_out)
|
||||
if config.omnistaging_enabled:
|
||||
# TODO: Actually we do know some of the inputs, because they might be literals!
|
||||
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
|
||||
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
|
||||
else:
|
||||
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore
|
||||
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals),
|
||||
instantiate=True, stage_out=False) # type: ignore
|
||||
# TODO: Actually we do know some of the inputs, because they might be literals!
|
||||
ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
|
||||
complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
|
||||
assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then
|
||||
ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])
|
||||
|
||||
@ -231,12 +225,8 @@ def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotang
|
||||
unknowns = (map(ad.is_undefined_primal, primals_in) +
|
||||
map(ad.is_undefined_primal, primals_out) +
|
||||
[False] * len(cts_in))
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
|
||||
ivjp_jaxpr, unknowns, instantiate=False) # type:ignore
|
||||
else:
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
|
||||
ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) # type: ignore
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore
|
||||
ivjp_jaxpr, unknowns, instantiate=False) # type:ignore
|
||||
unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
|
||||
# Make sure we're able to compute all cotangents. We don't really care if we
|
||||
# can reconstruct or primals or not, although failure to do so might result in
|
||||
@ -312,15 +302,3 @@ def get_primitive_inverse(p):
|
||||
def definverse(primitive, inverse_rule):
|
||||
primitive_inverses[primitive] = inverse_rule
|
||||
return inverse_rule
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _initial_style_jaxpr, custom_jvp_call
|
||||
|
||||
def _initial_style_jaxpr(fun, in_avals):
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
|
||||
bottom=True, stage_out=False) # type: ignore
|
||||
assert not any(isinstance(c, core.Tracer) for c in consts)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
@ -17,7 +17,7 @@ from collections import namedtuple
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple,
|
||||
List, Union, cast, Type, no_type_check)
|
||||
List, Union, cast)
|
||||
from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
@ -163,11 +163,6 @@ class JaxprTrace(Trace):
|
||||
|
||||
# We use process_call to handle both call and map primitives.
|
||||
def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
|
||||
if not config.omnistaging_enabled:
|
||||
if (self.main.trace_type is StagingJaxprTrace # type: ignore
|
||||
and primitive in staged_out_calls): # type: ignore
|
||||
tracers = map(self.instantiate_const_abstracted, tracers)
|
||||
|
||||
if primitive in call_partial_eval_rules:
|
||||
return call_partial_eval_rules[primitive](self, primitive, f, tracers, params)
|
||||
|
||||
@ -406,15 +401,8 @@ call_param_updaters: Dict[core.Primitive, Callable] = {}
|
||||
|
||||
|
||||
def abstract_eval_fun(fun, *avals, **params):
|
||||
if config.omnistaging_enabled:
|
||||
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
||||
else:
|
||||
pvals_in = [PartialVal.unknown(a) for a in avals]
|
||||
_, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in,
|
||||
instantiate=True, stage_out=True) # type: ignore
|
||||
avals_out, _ = unzip2(pvals_out)
|
||||
for aval_out in avals_out:
|
||||
assert isinstance(aval_out, AbstractValue) # instantiate=True
|
||||
_, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
||||
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
|
||||
return avals_out
|
||||
|
||||
|
||||
@ -775,13 +763,8 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
|
||||
# Using the instantiated tracers, run call_bind like JaxprTrace.process_call.
|
||||
in_pvals = [t.pval for t in instantiated_tracers]
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
|
||||
else:
|
||||
with core.initial_style_staging(): # type: ignore
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
|
||||
jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval(
|
||||
f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False)
|
||||
|
||||
# Convert consts to inputs, since they may contain Tracer instances.
|
||||
jaxpr = convert_constvars_jaxpr(jaxpr)
|
||||
@ -792,12 +775,8 @@ def _remat_partial_eval(trace, _, f, tracers, params):
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
|
||||
in_unknowns = ([False] * len(consts) +
|
||||
[not t.is_known() for t in it.chain(env_tracers, tracers)])
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
||||
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
|
||||
else:
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
||||
closed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore
|
||||
jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr(
|
||||
closed_jaxpr, in_unknowns, instantiate=False) # type: ignore
|
||||
out_knowns = [not b for b in out_unknowns]
|
||||
out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns)
|
||||
|
||||
@ -1191,7 +1170,6 @@ def _memoize(thunk):
|
||||
|
||||
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
||||
assert config.omnistaging_enabled
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
@ -1221,7 +1199,6 @@ def extend_jaxpr_stack(main, frame):
|
||||
main.jaxpr_stack = main.jaxpr_stack[:-1]
|
||||
|
||||
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
|
||||
assert config.omnistaging_enabled
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.source_info = fun_sourceinfo(fun.f) # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
@ -1234,7 +1211,6 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
|
||||
# use trace_to_jaxpr directly because of an interaction with the curent
|
||||
# custom_derivatives.py, which we work around by adding the EvalTrace.
|
||||
# TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py
|
||||
assert config.omnistaging_enabled
|
||||
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
|
||||
return trace_to_jaxpr(fun, in_pvals)
|
||||
|
||||
@ -1247,124 +1223,3 @@ def fun_sourceinfo(fun):
|
||||
return f"{fun.__name__} at {filename}:{lineno}"
|
||||
except AttributeError:
|
||||
return "<unknown>"
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
@no_type_check
|
||||
def omnistaging_disabler() -> None:
|
||||
global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace
|
||||
|
||||
def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal],
|
||||
instantiate: Union[bool, Sequence[bool]] = False,
|
||||
stage_out=False, bottom=False,
|
||||
trace_type: Optional[Type[Trace]] = None,
|
||||
) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]:
|
||||
"""Traces a function into a Jaxpr, given PartialVals for inputs.
|
||||
|
||||
Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the
|
||||
computation that depends on unknown inputs. The `out_pvals` are the PartialVal
|
||||
for the outputs. The intermediate values that depend only on known inputs and
|
||||
are needed to compute the output of `jaxpr` are in `consts` and are passed in
|
||||
as the constvars of the `jaxpr`. The handling of the known outputs depends on
|
||||
`instantiate`.
|
||||
|
||||
For example, given `fun` defined as follows::
|
||||
|
||||
def fun(ki, ui): # ki will be a known input in this example
|
||||
ka = ki + 2
|
||||
kb = ka + 3
|
||||
return (kb, ui + ka)
|
||||
|
||||
with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only
|
||||
computation that depends on unknown inputs is `ui + ka` and will be the only
|
||||
computation in the body of the `jaxpr`. This computation depends on the known
|
||||
intermediate value `ka`, which will be computed statically. Currently, such
|
||||
constants are either embedded in the Jaxpr if they are scalars, or passed as a
|
||||
constvar to `jaxpr`, and then the value of the actual constant will be in
|
||||
`consts`:
|
||||
|
||||
When `instantiate=False` we get::
|
||||
|
||||
jaxpr =
|
||||
{ lambda ka ; ki ui.
|
||||
let c = add ui ka
|
||||
in (*, c) } # known outputs are `*`
|
||||
out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)]
|
||||
consts = [3] # the constant for `ka`
|
||||
|
||||
When `instantiate=True` we get::
|
||||
|
||||
jaxpr =
|
||||
{ lambda ka kb ; ki ui.
|
||||
let c = add ui ka
|
||||
in (kb, c) } # known output are explicit
|
||||
out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)]
|
||||
consts = [3, 6] # values for `ka` and `kb` constvars
|
||||
"""
|
||||
trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace)
|
||||
with core.new_main(trace_type, bottom=bottom) as main:
|
||||
fun = trace_to_subjaxpr(fun, main, instantiate)
|
||||
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
|
||||
assert not env
|
||||
del main
|
||||
|
||||
return jaxpr, out_pvals, consts
|
||||
|
||||
def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
|
||||
instantiate: Union[bool, Sequence[bool]],
|
||||
trace_type: Optional[Type[core.Trace]]
|
||||
) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
|
||||
cell = []
|
||||
def fun(*vals):
|
||||
pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
|
||||
for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)]
|
||||
jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate,
|
||||
trace_type=trace_type)
|
||||
out_pvs_2, out_consts_2 = unzip2(out_pvals_2)
|
||||
cell.append((out_pvs_2, jaxpr_2, len(consts_2)))
|
||||
return out_consts_2 + consts_2
|
||||
|
||||
# The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores
|
||||
# the avals, and it will never actually reach any primitives, because the `fun` above will
|
||||
# execute the jaxpr with the right avals (it reconstructs `pvals` inside).
|
||||
pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval)
|
||||
for aval, uk in zip(jaxpr.in_avals, unknowns)]
|
||||
jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True)
|
||||
(out_pvs_2, jaxpr_2, num_res), = cell
|
||||
assert len(jaxpr_2.constvars) == num_res
|
||||
|
||||
# jaxpr :: a -> b
|
||||
# jaxpr_1 :: a1 -> [b1, res]
|
||||
# jaxpr_2 :: res | a2 -> b2
|
||||
# jaxpr_2 :: [a2, res] -> b2
|
||||
jaxpr_2 = convert_constvars_jaxpr(jaxpr_2)
|
||||
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
|
||||
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
|
||||
if not unknown:
|
||||
var.aval = abstract_unit
|
||||
|
||||
uk_out = [pv is not None for pv in out_pvs_2]
|
||||
|
||||
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
||||
# See comment at top of `JaxprTrace`. This method should be reachable
|
||||
# only when we stage out, and in that case we drop the custom differentiation
|
||||
# rules, because we do not need them.
|
||||
if not config.omnistaging_enabled:
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
JaxprTrace.process_custom_jvp_call = process_custom_jvp_call
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
# See comment in the above process_custom_jvp_call method.
|
||||
if not config.omnistaging_enabled:
|
||||
assert self.main.trace_type is StagingJaxprTrace
|
||||
return fun.call_wrapped(*tracers)
|
||||
JaxprTrace.process_custom_vjp_call = process_custom_vjp_call
|
||||
|
||||
staged_out_calls = set()
|
||||
|
||||
class StagingJaxprTrace(JaxprTrace): pass
|
||||
|
@ -35,8 +35,7 @@ import itertools as it
|
||||
import operator as op
|
||||
import threading
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
Type, Union, Iterable, no_type_check, NamedTuple,
|
||||
TYPE_CHECKING)
|
||||
Type, Union, Iterable, NamedTuple, TYPE_CHECKING)
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
@ -46,7 +45,7 @@ from .. import core
|
||||
from .. import linear_util as lu
|
||||
from ..abstract_arrays import array_types
|
||||
from ..core import ConcreteArray, ShapedArray
|
||||
from .._src.util import (partial, unzip2, unzip3, prod, safe_map, safe_zip,
|
||||
from .._src.util import (partial, unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, curry)
|
||||
from ..lib import xla_bridge as xb
|
||||
@ -672,48 +671,25 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
else:
|
||||
local_devices = None # type: ignore
|
||||
|
||||
if config.omnistaging_enabled:
|
||||
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in safe_zip(in_axes, avals))
|
||||
if any(s is not None for s in global_arg_shapes):
|
||||
# TODO(skye): we could take this branch unconditionally if we handled
|
||||
# grad of global_arg_shapes correctly.
|
||||
global_sharded_avals = [
|
||||
aval.update(shape=shape) if shape is not None else aval
|
||||
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
|
||||
else:
|
||||
global_sharded_avals = sharded_avals # type: ignore
|
||||
logging.vlog(2, "sharded_avals: %s", sharded_avals)
|
||||
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
|
||||
|
||||
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in safe_zip(in_axes, avals))
|
||||
if any(s is not None for s in global_arg_shapes):
|
||||
# TODO(skye): we could take this branch unconditionally if we handled
|
||||
# grad of global_arg_shapes correctly.
|
||||
global_sharded_avals = [
|
||||
aval.update(shape=shape) if shape is not None else aval
|
||||
for shape, aval in safe_zip(global_arg_shapes, sharded_avals)]
|
||||
else:
|
||||
@lu.wrap_init
|
||||
def dynamic_fun(dummy, *args):
|
||||
with extend_dynamic_axis_env(axis_name, dummy._trace, global_axis_size): # type: ignore
|
||||
return fun.call_wrapped(*args)
|
||||
|
||||
sharded_avals = tuple(shard_aval(axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in safe_zip(in_axes, avals))
|
||||
pvals = [pe.PartialVal.unknown(aval) for aval in sharded_avals]
|
||||
# We add a dummy first invar, to carry the trace details to `dynamic_fun`
|
||||
pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
||||
dynamic_fun, [pval] + pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
|
||||
jaxpr.invars = jaxpr.invars[1:] # ignore dummy
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_pvs, out_consts = unzip2(out_pvals)
|
||||
global_sharded_avals = sharded_avals # type: ignore
|
||||
logging.vlog(2, "sharded_avals: %s", sharded_avals)
|
||||
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
|
||||
|
||||
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
|
||||
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
out_axes = out_axes_thunk()
|
||||
if config.omnistaging_enabled:
|
||||
assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
|
||||
else:
|
||||
assert len(out_pvals) == len(out_axes), (len(out_pvals), len(out_axes))
|
||||
assert all(out_axis == 0 for out_axis in out_axes)
|
||||
assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
|
||||
|
||||
# TODO(skye,mattjj): allow more collectives on multi-host as we test them, but
|
||||
# for now raise an error
|
||||
@ -724,21 +700,6 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
if is_multi_host_pmap:
|
||||
check_multihost_collective_allowlist(jaxpr)
|
||||
|
||||
if not config.omnistaging_enabled:
|
||||
if all(pv is None for pv in out_pvs):
|
||||
# When the output doesn't depend on the input we don't need to compile an
|
||||
# XLA computation at all; we handle this as a special case so we can stage
|
||||
# out multi-replica XLA computations regardless of the hardware available.
|
||||
# The 'None' values here are just dummies we know will be ignored.
|
||||
handlers = [
|
||||
_pval_to_result_handler( # type: ignore
|
||||
axis_size, None, None, None, pval, local_devices, backend_name) # type: ignore
|
||||
for pval in out_pvals # type: ignore
|
||||
]
|
||||
results = [handler(None) for handler in handlers]
|
||||
return lambda *_: results
|
||||
|
||||
|
||||
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
|
||||
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
|
||||
num_local_replicas = axis_size * jaxpr_replicas
|
||||
@ -893,31 +854,25 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
if spec is not None else None
|
||||
for aval, spec in safe_zip(avals, input_sharding_specs)]
|
||||
handle_args = partial(shard_args, compiled.local_devices(), input_indices)
|
||||
if config.omnistaging_enabled:
|
||||
nouts = len(out_sharded_avals)
|
||||
if out_parts is None:
|
||||
out_parts = (None,) * nouts
|
||||
if local_out_parts is None:
|
||||
local_out_parts = (None,) * nouts
|
||||
nouts = len(out_sharded_avals)
|
||||
if out_parts is None:
|
||||
out_parts = (None,) * nouts
|
||||
if local_out_parts is None:
|
||||
local_out_parts = (None,) * nouts
|
||||
|
||||
local_out_avals = [get_local_aval(aval, parts, lparts)
|
||||
for aval, parts, lparts
|
||||
in safe_zip(out_sharded_avals, out_parts, local_out_parts)]
|
||||
local_unmapped_avals = [core.unmapped_aval(axis_size, out_axis, aval)
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(local_out_avals, out_axes)]
|
||||
local_out_avals = [get_local_aval(aval, parts, lparts)
|
||||
for aval, parts, lparts
|
||||
in safe_zip(out_sharded_avals, out_parts, local_out_parts)]
|
||||
local_unmapped_avals = [core.unmapped_aval(axis_size, out_axis, aval)
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(local_out_avals, out_axes)]
|
||||
|
||||
out_specs = [_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
|
||||
parts, aval, out_axis)
|
||||
if aval is not core.abstract_unit else None
|
||||
for parts, aval, out_axis in safe_zip(local_out_parts, local_out_avals, out_axes)]
|
||||
handle_outs = avals_to_results_handler(
|
||||
num_local_replicas, local_num_partitions, out_specs, local_unmapped_avals)
|
||||
else:
|
||||
handle_outs = _pvals_to_results_handler(axis_size, num_local_replicas, # type: ignore
|
||||
local_num_partitions,
|
||||
local_out_parts, out_pvals,
|
||||
compiled.local_devices(), backend)
|
||||
out_specs = [_pmap_sharding_spec(num_local_replicas, axis_size, local_num_partitions,
|
||||
parts, aval, out_axis)
|
||||
if aval is not core.abstract_unit else None
|
||||
for parts, aval, out_axis in safe_zip(local_out_parts, local_out_avals, out_axes)]
|
||||
handle_outs = avals_to_results_handler(
|
||||
num_local_replicas, local_num_partitions, out_specs, local_unmapped_avals)
|
||||
|
||||
if hasattr(backend, "wrap_execute_replicated"):
|
||||
return backend.wrap_execute_replicated(compiled, compiled.local_devices(),
|
||||
@ -1390,7 +1345,6 @@ def mesh_callable(fun: lu.WrappedFun,
|
||||
spmd_lowering: bool,
|
||||
*local_in_untiled_avals,
|
||||
tile_by_mesh_axes: bool):
|
||||
assert config.omnistaging_enabled
|
||||
local_mesh = mesh.local_mesh
|
||||
global_axis_sizes = mesh.shape
|
||||
local_axis_sizes = local_mesh.shape
|
||||
@ -1585,102 +1539,6 @@ def maybe_extend_axis_env(*args, **kwargs):
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
@no_type_check
|
||||
def omnistaging_disabler() -> None:
|
||||
global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
|
||||
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
|
||||
apply_parallel_primitive, parallel_pure_rules, \
|
||||
_pvals_to_results_handler, _pval_to_result_handler, replicate, \
|
||||
axis_index, maybe_extend_axis_env
|
||||
|
||||
@contextmanager
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
|
||||
def _pvals_to_results_handler(
|
||||
size, nrep, npart,
|
||||
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]],
|
||||
out_pvals, devices, backend):
|
||||
if out_parts is None:
|
||||
out_parts = (None,) * len(out_pvals)
|
||||
handlers = [
|
||||
_pval_to_result_handler(size, nrep, npart, parts, pval, devices, backend)
|
||||
for pval, parts in safe_zip(out_pvals, out_parts) # type: ignore
|
||||
]
|
||||
|
||||
def handler(out_bufs):
|
||||
return [h(bufs) for h, bufs in safe_zip(handlers, out_bufs)]
|
||||
|
||||
return handler
|
||||
|
||||
def _pval_to_result_handler(axis_size, nrep, npart, parts, pval, devices, backend):
|
||||
if devices:
|
||||
assert all(d.host_id == xb.host_id(backend) for d in devices)
|
||||
aval, const = pval
|
||||
if aval is None:
|
||||
if nrep is None:
|
||||
nrep = axis_size
|
||||
# If 'const' is a ShardedDeviceArray, it must have come from a pmap nested
|
||||
# inside the one we're currently evaluating, and we should replicate
|
||||
# 'const' across the total number of devices needed. We don't necessarily
|
||||
# know the nested pmap's axis_size (e.g. the jaxpr for
|
||||
# pmap(pmap(lambda x: 3)) is trivial, with no pmaps), but we can use the
|
||||
# axis size of the output 'const'.
|
||||
# TODO: we might be doing unnecessary device transfers in the inner pmap.
|
||||
if isinstance(const, ShardedDeviceArray):
|
||||
nrep *= len(const)
|
||||
|
||||
bcast_const = (core.unit if const is core.unit
|
||||
else replicate(const, axis_size, nrep, devices, backend)) # type: ignore
|
||||
return lambda _: bcast_const # type: ignore
|
||||
else:
|
||||
if aval is not core.abstract_unit:
|
||||
unsharded_aval = aval.update(shape=(axis_size,) + aval.shape)
|
||||
sharding_spec = _pmap_sharding_spec(nrep, axis_size, npart, parts, aval, 0)
|
||||
indices = spec_to_indices(unsharded_aval.shape, sharding_spec)
|
||||
else:
|
||||
sharding_spec = indices = None
|
||||
unsharded_aval = aval
|
||||
return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
|
||||
|
||||
@contextmanager
|
||||
def extend_dynamic_axis_env(axis_name, pmap_trace, hard_size):
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
dynamic_axis_env.append(DynamicAxisEnvFrame(axis_name, pmap_trace, hard_size))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
dynamic_axis_env.pop()
|
||||
|
||||
def unmapped_device_count(backend=None):
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
mapped = prod(frame.hard_size for frame in dynamic_axis_env)
|
||||
unmapped, ragged = divmod(xb.device_count(backend), mapped)
|
||||
assert not ragged and unmapped > 0
|
||||
return unmapped
|
||||
|
||||
def apply_parallel_primitive(prim, *args, **params):
|
||||
# This is the op-by-op version of applying a collective primitive, like a psum
|
||||
# that doesn't have a data dependence on the argument of a pmap function. In
|
||||
# particular, this code gets hit when we write `axis_size = psum(1, 'i')`. We
|
||||
# look up information in the dynamic axis env.
|
||||
dynamic_axis_env = _thread_local_state.dynamic_axis_env
|
||||
axis_name = params.pop('axes')
|
||||
axis_index_groups = params.pop('axis_index_groups')
|
||||
if axis_index_groups is not None:
|
||||
shape = (len(axis_index_groups[0]),)
|
||||
else:
|
||||
logical_size = lambda frame: frame.hard_size
|
||||
if isinstance(axis_name, (list, tuple)):
|
||||
shape = tuple(logical_size(dynamic_axis_env[name]) for name in axis_name)
|
||||
else:
|
||||
shape = (logical_size(dynamic_axis_env[axis_name]),)
|
||||
return parallel_pure_rules[prim](*args, shape=shape, **params)
|
||||
|
||||
pe.staged_out_calls.add(xla_pmap_p) # type: ignore
|
||||
|
||||
parallel_pure_rules = {} # type: ignore
|
||||
|
||||
class DynamicAxisEnvFrame(object):
|
||||
__slots__ = ["name", "pmap_trace", "hard_size"]
|
||||
|
@ -86,20 +86,7 @@ def _sharded_callable(
|
||||
logging.vlog(2, "in_parts: %s", in_parts)
|
||||
logging.vlog(2, "local_in_parts: %s", local_in_parts)
|
||||
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_abstract_args)
|
||||
else:
|
||||
in_pvals = [pe.PartialVal.unknown(aval) for aval in global_abstract_args]
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore
|
||||
instantiate=False, bottom=True) # type: ignore
|
||||
|
||||
# TODO(skye): add tests for equationless jaxpr cases
|
||||
if not jaxpr.eqns and all(outvar.aval is core.abstract_unit
|
||||
for outvar in jaxpr.outvars):
|
||||
return lambda *_: [
|
||||
const if pv is None else core.unit for pv, const in out_pvals
|
||||
]
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(fun, global_abstract_args)
|
||||
|
||||
if xb.get_backend().platform not in ["tpu", "gpu"]:
|
||||
# TODO(skye): fall back to regular jit?
|
||||
@ -182,7 +169,6 @@ def _sharded_callable(
|
||||
|
||||
handle_args = partial(pxla.shard_args, compiled.local_devices(),
|
||||
input_indices)
|
||||
assert config.omnistaging_enabled
|
||||
handle_outs = _avals_to_results_handler(nrep, local_nparts, # type: ignore
|
||||
local_out_parts, local_out_avals)
|
||||
return partial(_execute_spatially_partitioned, compiled, handle_args,
|
||||
@ -467,29 +453,3 @@ def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
|
||||
A new version of ``x`` with the specified sharding applied.
|
||||
"""
|
||||
return sharding_constraint_p.bind(x, partitions=partitions)
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _pvals_to_results_handler, _pval_to_result_handler
|
||||
|
||||
def _pvals_to_results_handler(nrep, npart, partitions, out_pvals):
|
||||
handlers = [_pval_to_result_handler(npart, parts, out_pval)
|
||||
for parts, out_pval in safe_zip(partitions, out_pvals)] # type: ignore
|
||||
|
||||
def handler(out_bufs):
|
||||
return [h(bufs) for h, bufs in zip(handlers, out_bufs)]
|
||||
|
||||
return handler
|
||||
|
||||
def _pval_to_result_handler(npart, parts, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
raise NotImplementedError # TODO(skye): handle constant outputs
|
||||
else:
|
||||
if pv is not core.abstract_unit:
|
||||
spec = pxla.partitioned_sharding_spec(npart, parts, pv)
|
||||
indices = pxla.spec_to_indices(pv.shape, spec)
|
||||
else:
|
||||
spec = indices = None
|
||||
return pxla.aval_to_result_handler(spec, indices, pv)
|
||||
|
@ -649,25 +649,16 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
"got device={} and backend={}".format(device, backend))
|
||||
|
||||
abstract_args, arg_devices = unzip2(arg_specs)
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
else:
|
||||
pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
|
||||
jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore
|
||||
fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
|
||||
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
|
||||
jaxpr = apply_outfeed_rewriter(jaxpr)
|
||||
|
||||
nreps = jaxpr_replicas(jaxpr)
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = device.platform if device else backend
|
||||
if config.omnistaging_enabled:
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
else:
|
||||
out_avals = [pval.get_aval() for pval in pvals]
|
||||
result_handlers = map(partial(_pval_to_result_handler, device), pvals) # type: ignore
|
||||
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
||||
|
||||
# Computations that only produce constants and/or only rearrange their inputs,
|
||||
# which are often produced from partial evaluation, don't need compilation,
|
||||
@ -953,16 +944,9 @@ def lower_fun(fun, multiple_results, parallel=False, with_avals=False):
|
||||
wrapped_fun = lu.wrap_init(fun, params)
|
||||
if not multiple_results:
|
||||
wrapped_fun = _tuple_output(wrapped_fun)
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '',
|
||||
*xla_args)
|
||||
else:
|
||||
pvals = [pe.PartialVal.unknown(a) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True,
|
||||
stage_out=True) # type: ignore
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '',
|
||||
*xla_args)
|
||||
if multiple_results or any(v.aval._num_buffers > 1 for v in jaxpr.outvars):
|
||||
return xops.Tuple(c, outs)
|
||||
else:
|
||||
@ -980,17 +964,9 @@ def _array_aval_from_xla_shape(xla_shape):
|
||||
|
||||
def lower_fun_initial_style(fun):
|
||||
def f(c, axis_env, name_stack, avals, backend, *xla_args, **params):
|
||||
if config.omnistaging_enabled:
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
|
||||
name_stack, *xla_args)
|
||||
else:
|
||||
pvals = [pe.PartialVal.unknown(a) for a in avals]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(
|
||||
lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore
|
||||
xla_consts = _xla_consts(c, consts)
|
||||
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack,
|
||||
*xla_args)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
|
||||
outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
|
||||
name_stack, *xla_args)
|
||||
return xops.Tuple(c, outs)
|
||||
return f
|
||||
|
||||
@ -1438,18 +1414,3 @@ def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend,
|
||||
c, axis_env, in_nodes, name_stack, name="core_call",
|
||||
backend=backend, call_jaxpr=call_jaxpr)
|
||||
call_translations[core.call_p] = _call_translation_rule
|
||||
|
||||
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global _pval_to_result_handler
|
||||
|
||||
def _pval_to_result_handler(device, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
const = _device_put_impl(const, device) if device else const
|
||||
return lambda _: const
|
||||
else:
|
||||
return aval_to_result_handler(device, pv)
|
||||
|
||||
pe.staged_out_calls.add(xla_call_p) # type: ignore
|
||||
|
@ -283,7 +283,6 @@ from jax._src.lax.lax import (
|
||||
tanh,
|
||||
tanh_p,
|
||||
tie_in,
|
||||
tie_in_p,
|
||||
top_k,
|
||||
top_k_p,
|
||||
transpose,
|
||||
|
@ -485,9 +485,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
def test_omnistaging(self):
|
||||
# See https://github.com/google/jax/issues/5206
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
key_list = [None]
|
||||
|
||||
def init():
|
||||
@ -1551,10 +1548,7 @@ class APITest(jtu.JaxTestCase):
|
||||
def f():
|
||||
return jnp.zeros((3, 4))
|
||||
|
||||
if config.omnistaging_enabled:
|
||||
xla_comp = api.xla_computation(f)()
|
||||
else:
|
||||
xla_comp = api.xla_computation(f, instantiate_const_outputs=True)()
|
||||
xla_comp = api.xla_computation(f)()
|
||||
out_shape, = xla_comp.program_shape().result_shape().tuple_shapes()
|
||||
self.assertEqual(out_shape.dimensions(), (3, 4))
|
||||
|
||||
@ -1608,8 +1602,6 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)
|
||||
|
||||
def test_xla_computation_psum_constant(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test requires omnistaging")
|
||||
f = lambda: jax.lax.psum(1, "i")
|
||||
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
|
||||
|
||||
@ -1856,11 +1848,10 @@ class APITest(jtu.JaxTestCase):
|
||||
api.pmap(f, 'i')(x, x)
|
||||
|
||||
# With in_axes and out_axes
|
||||
if config.omnistaging_enabled:
|
||||
for x_in, y_in, x_out, y_out in it.product(*((0, 1, 2) for _ in range(4))):
|
||||
with jtu.assert_num_jit_and_pmap_compilations(1):
|
||||
for _ in range(2):
|
||||
api.pmap(f, 'i', in_axes=(x_in, y_in), out_axes=(x_out, y_out))(x, x)
|
||||
for x_in, y_in, x_out, y_out in it.product(*((0, 1, 2) for _ in range(4))):
|
||||
with jtu.assert_num_jit_and_pmap_compilations(1):
|
||||
for _ in range(2):
|
||||
api.pmap(f, 'i', in_axes=(x_in, y_in), out_axes=(x_out, y_out))(x, x)
|
||||
|
||||
# Forward-mode AD on the outside
|
||||
with jtu.assert_num_jit_and_pmap_compilations(1):
|
||||
@ -1962,7 +1953,7 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(outer_jaxpr.eqns[0].primitive.name, 'xla_call')
|
||||
subjaxpr_1 = outer_jaxpr.eqns[0].params["call_jaxpr"]
|
||||
self.assertEqual(str(subjaxpr_1), str(inner_jaxpr))
|
||||
self.assertLen(inner_jaxpr.eqns, 2 if config.omnistaging_enabled else 3)
|
||||
self.assertLen(inner_jaxpr.eqns, 2)
|
||||
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
|
||||
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
|
||||
|
||||
@ -2043,9 +2034,6 @@ class APITest(jtu.JaxTestCase):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
def test_escaped_tracer_omnistaging(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
count = 1
|
||||
|
||||
@jit
|
||||
@ -2067,9 +2055,6 @@ class APITest(jtu.JaxTestCase):
|
||||
g()
|
||||
|
||||
def test_escaped_tracer_omnistaging_top_trace(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
count = 1
|
||||
|
||||
def f(_, __):
|
||||
@ -2117,16 +2102,6 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines)
|
||||
self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines)
|
||||
|
||||
def test_omnistaging_flag(self):
|
||||
if FLAGS.jax_omnistaging:
|
||||
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
|
||||
self.assertLen(jaxpr.jaxpr.eqns, 1)
|
||||
else:
|
||||
# omnistaging can be enabled programmatically without setting the flag,
|
||||
# but that shouldn't happen in tests
|
||||
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
|
||||
self.assertLen(jaxpr.jaxpr.eqns, 0)
|
||||
|
||||
def test_eval_context(self):
|
||||
@jit
|
||||
def f():
|
||||
@ -2136,9 +2111,6 @@ class APITest(jtu.JaxTestCase):
|
||||
f() # doesn't crash
|
||||
|
||||
def test_concrete_error_because_arg(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
if x > y:
|
||||
@ -2151,9 +2123,6 @@ class APITest(jtu.JaxTestCase):
|
||||
f(1, 2)
|
||||
|
||||
def test_concrete_error_because_const(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
@jax.jit
|
||||
def f():
|
||||
assert jnp.add(1, 1) > 0
|
||||
@ -2165,18 +2134,12 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_xla_computation_zeros_doesnt_device_put(self):
|
||||
raise unittest.SkipTest("broken test") # TODO(mattjj): fix
|
||||
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
with jtu.count_device_put() as count:
|
||||
api.xla_computation(lambda: jnp.zeros(3))()
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
def test_join_concrete_arrays_with_omnistaging(self):
|
||||
# https://github.com/google/jax/issues/4622
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
x = jnp.array([1., 2., 3.])
|
||||
y = jnp.array([1., 2., 4.])
|
||||
|
||||
@ -2223,9 +2186,6 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(x, jax.interpreters.xla.Token)
|
||||
|
||||
def test_leak_checker_catches_a_jit_leak(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
@ -2238,9 +2198,6 @@ class APITest(jtu.JaxTestCase):
|
||||
f(3)
|
||||
|
||||
def test_leak_checker_catches_a_pmap_leak(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
@ -2253,9 +2210,6 @@ class APITest(jtu.JaxTestCase):
|
||||
f(np.ones(1))
|
||||
|
||||
def test_leak_checker_catches_a_grad_leak(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
@ -2267,9 +2221,6 @@ class APITest(jtu.JaxTestCase):
|
||||
api.grad(f)(3.)
|
||||
|
||||
def test_leak_checker_avoids_false_positives(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
@jit
|
||||
def f(x):
|
||||
@ -2285,9 +2236,6 @@ class APITest(jtu.JaxTestCase):
|
||||
api.vmap(f)(np.ones((1, 1))) # doesn't crash
|
||||
|
||||
def test_leak_checker_catches_a_scan_leak(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
lst = []
|
||||
|
||||
@ -2297,17 +2245,11 @@ class APITest(jtu.JaxTestCase):
|
||||
lax.scan(to_scan, 1., np.arange(3.))
|
||||
|
||||
def test_leak_checker_avoids_false_positives_scan(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, x: (jnp.sin(c), None)
|
||||
lax.scan(to_scan, 1., np.arange(3.)) # doesn't crash
|
||||
|
||||
def test_leak_checker_avoids_false_positives_scan_jvp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, x: (c, None)
|
||||
|
||||
@ -2316,9 +2258,6 @@ class APITest(jtu.JaxTestCase):
|
||||
api.jvp(f, (3.,), (1.,)) # doesn't crash
|
||||
|
||||
def test_leak_checker_avoids_false_positives_scan_vmap(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, _: (1., None)
|
||||
|
||||
@ -2328,9 +2267,6 @@ class APITest(jtu.JaxTestCase):
|
||||
f(np.arange(5.)) # doesn't crash
|
||||
|
||||
def test_leak_checker_avoids_false_positives_scan_vmap_2(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
to_scan = lambda c, _: (c, None)
|
||||
|
||||
@ -2340,9 +2276,6 @@ class APITest(jtu.JaxTestCase):
|
||||
f(np.arange(5.)) # doesn't crash
|
||||
|
||||
def test_leak_checker_catches_a_sublevel_leak(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
with jax.checking_leaks():
|
||||
@jit
|
||||
def f(x):
|
||||
@ -2804,25 +2737,8 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
jax.grad(scan_bug)(1.0) # doesn't crash
|
||||
|
||||
def test_remat_jit_static_argnum(self):
|
||||
# https://github.com/google/jax/issues/2833
|
||||
if config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works without omnistaging") # see next test
|
||||
|
||||
def f(a_bool, y):
|
||||
if a_bool:
|
||||
return y + 1
|
||||
else:
|
||||
return y
|
||||
|
||||
api.jit(api.remat(f, concrete=True), static_argnums=0)(True, 1) # no crash
|
||||
|
||||
|
||||
def test_remat_jit_static_argnum_omnistaging(self):
|
||||
# https://github.com/google/jax/issues/2833
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging") # see previous test
|
||||
|
||||
def named_call(f):
|
||||
def named_f(*args):
|
||||
f_ = lu.wrap_init(lambda: (f(*args),))
|
||||
@ -2894,9 +2810,6 @@ class RematTest(jtu.JaxTestCase):
|
||||
|
||||
def test_escaped_tracer_remat(self):
|
||||
# b/169779185
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f():
|
||||
seq = [jnp.zeros([])]
|
||||
def g():
|
||||
@ -2926,18 +2839,11 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
def fun(x):
|
||||
return (x, 1., np.zeros(1))
|
||||
|
||||
if config.omnistaging_enabled:
|
||||
expected = """
|
||||
{ lambda a ; b.
|
||||
let
|
||||
in (b, 1.0, a) }
|
||||
"""
|
||||
else:
|
||||
expected = """
|
||||
{ lambda b ; a.
|
||||
let
|
||||
in (a, 1.0, b) }
|
||||
"""
|
||||
expected = """
|
||||
{ lambda a ; b.
|
||||
let
|
||||
in (b, 1.0, a) }
|
||||
"""
|
||||
|
||||
jaxpr = api.make_jaxpr(fun)(0.)
|
||||
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
||||
@ -2949,43 +2855,21 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
lambda xt: xt + x,
|
||||
x + 2.,
|
||||
lambda xf: xf - x)
|
||||
if config.omnistaging_enabled:
|
||||
expected = """
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = add a 1.0
|
||||
d = add a 2.0
|
||||
e = convert_element_type[ new_dtype=int32
|
||||
weak_type=False ] b
|
||||
f = cond[ branches=( { lambda ; e_ a b c.
|
||||
let d = sub c a
|
||||
in (d,) }
|
||||
{ lambda ; a f_ b c.
|
||||
let d = add b a
|
||||
in (d,) } )
|
||||
linear=(False, False, False, False) ] e a a c d
|
||||
in (f,) }
|
||||
"""
|
||||
else:
|
||||
expected = """
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = convert_element_type[ new_dtype=int32
|
||||
weak_type=False ] b
|
||||
d = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] a
|
||||
e = convert_element_type[ new_dtype=float32
|
||||
weak_type=False ] a
|
||||
f = add a 1.0
|
||||
g = add a 2.0
|
||||
h = cond[ branches=( { lambda ; e_ c a b.
|
||||
let d = sub b c
|
||||
in (d,) }
|
||||
{ lambda ; c f_ a b.
|
||||
let d = add a c
|
||||
in (d,) } )
|
||||
linear=(False, False, False, False) ] c d e f g
|
||||
in (h,) }
|
||||
expected = """
|
||||
{ lambda ; a.
|
||||
let b = ge a 0.0
|
||||
c = add a 1.0
|
||||
d = add a 2.0
|
||||
e = convert_element_type[ new_dtype=int32
|
||||
weak_type=False ] b
|
||||
f = cond[ branches=( { lambda ; e_ a b c.
|
||||
let d = sub c a
|
||||
in (d,) }
|
||||
{ lambda ; a f_ b c.
|
||||
let d = add b a
|
||||
in (d,) } )
|
||||
linear=(False, False, False, False) ] e a a c d
|
||||
in (f,) }
|
||||
"""
|
||||
jaxpr = api.make_jaxpr(f)(3.)
|
||||
self.assertMultiLineStrippedEqual(expected, str(jaxpr))
|
||||
@ -3005,18 +2889,12 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertEqual(shape_tree, expected)
|
||||
|
||||
def test_make_jaxpr_axis_env(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2)
|
||||
self.assertIn('psum', str(jaxpr))
|
||||
|
||||
def test_make_jaxpr_named(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
|
||||
@ -3276,9 +3154,6 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_closed_over_tracers_error_message(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
@api.custom_jvp
|
||||
def g(y):
|
||||
@ -3326,9 +3201,6 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def test_nondiff_arg_hiding_jvp_tracer(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def f(x):
|
||||
@partial(api.custom_jvp, nondiff_argnums=(0,))
|
||||
def g(h, x):
|
||||
@ -3600,9 +3472,6 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_nondiff_argnums_vmap_tracer(self):
|
||||
# https://github.com/google/jax/issues/3964
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
@partial(jax.custom_jvp, nondiff_argnums=(0, 2))
|
||||
def sample(shape, param, seed):
|
||||
return jax.random.uniform(key=seed, shape=shape, minval=param)
|
||||
@ -3622,9 +3491,6 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
(1.,), (1.,))
|
||||
|
||||
def test_fun_with_nested_calls_2(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
def call(f, *args):
|
||||
f = api.custom_jvp(f)
|
||||
f.defjvp(lambda primals, tangents: (f(*primals), sum(tangents)))
|
||||
@ -3647,9 +3513,6 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
api.vmap(fun_with_nested_calls_2)(jnp.arange(3.))
|
||||
|
||||
def test_closure_with_vmap(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test only works with omnistaging")
|
||||
|
||||
# https://github.com/google/jax/issues/3822
|
||||
alpha = np.float32(2.)
|
||||
|
||||
@ -4788,9 +4651,6 @@ class InvertibleADTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.ignore_warning(message="Values that an @invertible function closes")
|
||||
def test_invertible_basic(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("Test requires omnistaging")
|
||||
|
||||
def f(x):
|
||||
return (jnp.exp(x) * 4) * x
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
from unittest import skipIf
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -976,8 +975,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
(lax.pmin, jnp.min)]
|
||||
for vmap_names in [('i',), ('i', 'j'), ('j', 'i')]
|
||||
for collective_names in it.permutations(vmap_names))
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testCommAssocCollective(self, collective, bulk_op, vmap_names, collective_names):
|
||||
x = jnp.arange(3 * 4 * 5, dtype=jnp.float32).reshape((3, 4, 5))
|
||||
|
||||
@ -993,8 +990,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
if collective is lax.psum:
|
||||
jtu.check_grads(f, (x,), 2, eps=1)
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testPPermute(self):
|
||||
nelem = 10
|
||||
ntests = 10
|
||||
@ -1013,8 +1008,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
||||
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAllToAll(self, vmap_axis, split_axis, concat_axis):
|
||||
shape = (4, 4, 4, 4)
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
@ -1029,8 +1022,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
||||
for split_axis, concat_axis, vmap_axis in it.product(range(2), range(2), range(3)))
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAllToAllSplitAxis(self, vmap_axis, split_axis, concat_axis):
|
||||
shape = (4, 4, 4)
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
@ -1084,16 +1075,12 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
jax.vmap(lambda *xs: xs, in_axes=(0, None), out_axes=(0, -2))(
|
||||
np.arange(5), 7)
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAxisIndex(self):
|
||||
x = np.arange(10)
|
||||
self.assertAllClose(
|
||||
vmap(lambda x: x - lax.axis_index('i'), axis_name='i')(x),
|
||||
x - np.arange(x.shape[0]))
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testCollectivePdot(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -1110,8 +1097,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
z = vmap(f, axis_name='i', in_axes=(0, 0), out_axes=None)(x, y)
|
||||
self.assertAllClose(z, jnp.dot(x.T, y))
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testCollectivePdotBatching(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -1122,8 +1107,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
zs = vmap(vmap(f, axis_name='i', in_axes=(1, 0), out_axes=None))(xs, ys)
|
||||
self.assertAllClose(zs, jnp.einsum('nij,njk->nik', xs, ys))
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testPdotJvp(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -1139,8 +1122,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(z, jnp.dot(x, y))
|
||||
self.assertAllClose(z_dot, jnp.dot(x_dot, y) + jnp.dot(x, y_dot))
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testPdotVjp(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -1165,8 +1146,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
y = vmap(f)(a=jnp.array([1]), b=jnp.array([2])) # doesn't work
|
||||
self.assertAllClose(x, y)
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAllGatherToUnmapped(self):
|
||||
def f(x):
|
||||
return lax.all_gather(x, axis_name='i')
|
||||
@ -1175,8 +1154,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
# Original mapped axis becomes first axis of unmapped return value.
|
||||
self.assertAllClose(vmap(f, axis_name='i', in_axes=1, out_axes=None)(x), x.T)
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testBatchedAllGather(self):
|
||||
def f(x):
|
||||
return lax.all_gather(x, axis_name='i')
|
||||
@ -1188,8 +1165,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
res = vmap(vmap(f, axis_name='j'), axis_name='i', out_axes=None)(x)
|
||||
self.assertAllClose(res, x.T)
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAllGatherVjp(self):
|
||||
def f(x):
|
||||
return lax.all_gather(x, axis_name='i')
|
||||
@ -1201,8 +1176,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
x_bar, = vmap(lambda x, y_bar: vjp(f, x)[1](y_bar), axis_name='i')(x, y_bar)
|
||||
self.assertAllClose(x_bar, np.sum(y_bar, axis=0))
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAllGatherOfConst(self):
|
||||
def f(x):
|
||||
a = lax.all_gather(jnp.ones_like(x), axis_name='i')
|
||||
@ -1226,8 +1199,6 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
for shape in [(7,), (5, 8)]
|
||||
for axis in range(len(shape))
|
||||
)
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
|
@ -112,8 +112,6 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.ignore_warning(message=".*is an experimental.*")
|
||||
def testXmap(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
|
||||
f = jax.experimental.maps.xmap(
|
||||
lambda x: 0. / x,
|
||||
|
@ -32,7 +32,6 @@ zip, unsafe_zip = safe_zip, zip
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
class DJaxTests(jtu.JaxTestCase):
|
||||
|
||||
def test_identity_typechecks(self):
|
||||
@ -79,7 +78,6 @@ class DJaxTests(jtu.JaxTestCase):
|
||||
djax.typecheck_jaxpr(jaxpr)
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
@skipIf(jax.config.x64_enabled, "only 32bit for now")
|
||||
class DJaxXLATests(jtu.JaxTestCase):
|
||||
|
||||
@ -123,7 +121,6 @@ class DJaxXLATests(jtu.JaxTestCase):
|
||||
self.assertAllClose(np.array(ans), expected, check_dtypes=False)
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
@skipIf(jax.config.x64_enabled, "only 32bit for now")
|
||||
class DJaxADTests(jtu.JaxTestCase):
|
||||
|
||||
@ -160,7 +157,6 @@ class DJaxADTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(np.array(z_dot), expected_z_dot, check_dtypes=False)
|
||||
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled, "requires omnistaging")
|
||||
@skipIf(jax.config.x64_enabled, "only 32bit for now")
|
||||
class DJaxBatchingTests(jtu.JaxTestCase):
|
||||
|
||||
|
@ -20,7 +20,7 @@ import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional, Sequence
|
||||
from unittest import SkipTest, skipIf
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -262,8 +262,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
9.00 )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_with_result_no_arg(self):
|
||||
def tap_func(arg, transforms):
|
||||
testing_stream.write(f"called tap_func with {arg}")
|
||||
@ -278,8 +276,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_result_unused(self):
|
||||
def tap_func(arg, transforms):
|
||||
testing_stream.write(f"called tap_func with {arg}")
|
||||
@ -415,8 +411,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
11""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_jit_result_unused(self):
|
||||
"""We can id_print even if we don't use the result."""
|
||||
|
||||
@ -593,8 +587,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_cond(self, with_jit=False):
|
||||
"""A conditional"""
|
||||
|
||||
@ -630,8 +622,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
dict(testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_while_cond(self, with_jit=False):
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
||||
@ -713,8 +703,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
testcase_name=f"_with_jit_{with_jit}",
|
||||
with_jit=with_jit)
|
||||
for with_jit in [True, False]))
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_scan_cond(self, with_jit=True):
|
||||
def func(x):
|
||||
x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
|
||||
@ -864,8 +852,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
4""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_jvp(self):
|
||||
jvp_fun1 = lambda x, xt: api.jvp(fun1, (x,), (xt,))
|
||||
res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
|
||||
@ -882,9 +868,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
testing_stream.reset()
|
||||
|
||||
def test_tap_grad_primal_unused(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("Test requires omnistaging")
|
||||
|
||||
# The output of id_print is not needed for backwards pass
|
||||
def func(x):
|
||||
return 2. * hcb.id_print(x * 3., what="x * 3",
|
||||
@ -948,9 +931,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
testing_stream.reset()
|
||||
|
||||
def test_tap_grad_grad(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("Test requires omnistaging")
|
||||
|
||||
def func(x):
|
||||
y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
|
||||
return x * (y * 3.)
|
||||
@ -976,8 +956,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
2.00""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_grad_pytree(self):
|
||||
def func(x):
|
||||
x4, x5 = hcb.id_print((x * 2., x * 3.), what="pair",
|
||||
@ -1000,8 +978,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
0.00 )""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_jvp_float0(self):
|
||||
def f(x, yint):
|
||||
x, yint = hcb.id_tap(lambda arg, _: arg, (x, yint))
|
||||
@ -1010,8 +986,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
res = api.jvp(f, (2., 3), (0.2, np.zeros((), dtypes.float0)))
|
||||
self.assertAllClose((6., 0.6), res)
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_grad_float0(self):
|
||||
def func(x, yint):
|
||||
x, yint = hcb.id_print((x, yint), what="pair", output_stream=testing_stream)
|
||||
@ -1147,8 +1121,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
[2 2 2 3 4]""", testing_stream.output)
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_transforms(self):
|
||||
def power(x, n):
|
||||
x, n = hcb.id_print((x, n), output_stream=testing_stream)
|
||||
@ -1315,8 +1287,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
|
||||
testing_stream.reset()
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_jvp_pmap_vmap(self):
|
||||
# A matrix M[ijk] = i * 100 + j * 10 * k
|
||||
nr_devices = len(devices())
|
||||
@ -1537,8 +1507,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
what: ct_b
|
||||
1.""", testing_stream.output)
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_tap_mask(self):
|
||||
|
||||
@partial(api.mask, in_shapes=['n'], out_shape='')
|
||||
@ -1714,9 +1682,6 @@ class HostCallbackIdTapTest(jtu.JaxTestCase):
|
||||
self.assertMultiLineStrippedEqual(expected, testing_stream.output)
|
||||
|
||||
def test_tap_named_call(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("Test requires omnistaging")
|
||||
|
||||
def tap_scalar(init, do_print=False):
|
||||
@partial(api.named_call, name="step")
|
||||
def step(acc, step_nr):
|
||||
@ -1775,8 +1740,6 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
res_inside = fun(2, use_outside=False)
|
||||
self.assertAllClose(res_inside, fun(2, use_outside=True))
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"test works only with omnistaging enabled")
|
||||
def test_call_no_result(self):
|
||||
def f_outside(arg):
|
||||
self.call_log_testing_stream(lambda x: None, arg,
|
||||
|
@ -76,7 +76,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
y, token = lax.infeed(
|
||||
token, shape=jax.ShapedArray((3, 4), jnp.float32))
|
||||
token = lax.outfeed(token, y + np.float32(1))
|
||||
return x - 1 if config.omnistaging_enabled else lax.tie_in(token, x - 1)
|
||||
return x - 1
|
||||
|
||||
x = np.float32(7.5)
|
||||
y = np.random.randn(3, 4).astype(np.float32)
|
||||
@ -101,7 +101,7 @@ class InfeedTest(jtu.JaxTestCase):
|
||||
def f(n):
|
||||
token = lax.create_token(n)
|
||||
token = lax.fori_loop(0, n, doubler, token)
|
||||
return n if config.omnistaging_enabled else lax.tie_in(token, n)
|
||||
return n
|
||||
|
||||
device = jax.local_devices()[0]
|
||||
n = 10
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
import itertools
|
||||
import operator
|
||||
import re
|
||||
from unittest import SkipTest, skipIf
|
||||
from unittest import SkipTest
|
||||
import textwrap
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -335,7 +335,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = np.array([4, 3, 4, 3])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test only works with omnistaging")
|
||||
def testWhileLoopAxisIndexBatched(self):
|
||||
def fun(x):
|
||||
return lax.while_loop(lambda x: x < lax.axis_index('i'), lambda x: x + 2, x)
|
||||
@ -996,7 +995,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test only works with omnistaging")
|
||||
def testCondGradVmapNan(self):
|
||||
eps = 1e-3
|
||||
|
||||
@ -2564,8 +2562,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
*_, ext_res = vjp_fun.args[0].args[0]
|
||||
self.assertIsInstance(ext_res, xla.DeviceArray)
|
||||
|
||||
@skipIf(not config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def test_scan_vmap_collectives(self):
|
||||
def scan_f(state, x):
|
||||
s = lax.psum(state, 'i') * x
|
||||
|
@ -2291,8 +2291,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
lambda x, y: dict(x=x['x'] + y['x']), [0])
|
||||
|
||||
def test_select_jvp_complexity(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("test requires omnistaging")
|
||||
jaxpr = jax.make_jaxpr(lambda x: jax.jvp(lambda x: lax.select(True, x, x),
|
||||
(x,), (1.,)))(1.)
|
||||
self.assertLen(jaxpr.jaxpr.eqns, 2)
|
||||
@ -2510,8 +2508,6 @@ class LaxNamedShapeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
def test_abstract_eval_collective(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("test requires omnistaging")
|
||||
with core.extend_axis_env('i', 10, None):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
|
@ -66,11 +66,10 @@ class MetadataTest(jtu.JaxTestCase):
|
||||
self.assertRegex(hlo, 'op_type="cos"')
|
||||
self.assertRegex(hlo, 'op_type="mul"')
|
||||
# TODO(mattjj,jekbradbury): update these tests post-omnistaging
|
||||
if not config.omnistaging_enabled:
|
||||
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
|
||||
self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
|
||||
self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\('
|
||||
'jvp\\(foo\\)\\)\\)/mul"')
|
||||
# self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
|
||||
# self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
|
||||
# self.assertRegex(hlo, 'op_name=".*jit\\(transpose\\('
|
||||
# 'jvp\\(foo\\)\\)\\)/mul"')
|
||||
|
||||
def test_cond_metadata(self):
|
||||
def true_fun(x):
|
||||
|
@ -120,10 +120,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
x2_uncommitted = jnp.array([2, 3])
|
||||
z1, z2, z3 = jax.jit(lambda x, y: (y, 1, x))(x_uncommitted, x2_uncommitted)
|
||||
self.assert_uncommitted_to_device(z1, devices[0])
|
||||
if config.omnistaging_enabled:
|
||||
self.assert_uncommitted_to_device(z2, devices[0])
|
||||
else:
|
||||
self.assertIs(z2, 1)
|
||||
self.assert_uncommitted_to_device(z2, devices[0])
|
||||
self.assert_uncommitted_to_device(z3, devices[0])
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ import gc
|
||||
import os
|
||||
from random import shuffle
|
||||
from typing import Optional, cast
|
||||
from unittest import SkipTest, skipIf
|
||||
from unittest import SkipTest
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
@ -51,17 +51,11 @@ prev_xla_flags = None
|
||||
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
||||
|
||||
def all_bdims(*shapes, pmap):
|
||||
if pmap and not config.omnistaging_enabled:
|
||||
bdims = ((None, 0) for shape in shapes)
|
||||
else:
|
||||
bdims = (it.chain([cast(Optional[int], None)],
|
||||
range(len(shape) + 1))
|
||||
for shape in shapes)
|
||||
bdims = (it.chain([cast(Optional[int], None)], range(len(shape) + 1))
|
||||
for shape in shapes)
|
||||
return (t for t in it.product(*bdims) if not all(e is None for e in t))
|
||||
|
||||
def out_bdims(shape, pmap):
|
||||
if pmap and not config.omnistaging_enabled:
|
||||
return (0,)
|
||||
return (d[0] for d in all_bdims(shape, pmap=pmap) if d[0] is not None)
|
||||
|
||||
|
||||
@ -157,9 +151,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testTrees(self):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
|
||||
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
|
||||
def protate(x, axis_name):
|
||||
n = lax.psum(1, axis_name)
|
||||
@ -217,9 +208,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
for split_axis, concat_axis in it.product(range(2), range(2)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAll(self, split_axis, concat_axis):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
|
||||
pmap_in_axis = 0
|
||||
shape = (xla_bridge.device_count(),) * 3
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
@ -240,9 +228,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
for split_axis, concat_axis in it.product(range(2), range(2)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllSplitAxis(self, split_axis, concat_axis):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
|
||||
if xla_bridge.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
pmap_in_axis = 0
|
||||
@ -620,8 +605,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOfGather(self):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
@partial(pmap, axis_name='i')
|
||||
def f(x):
|
||||
return lax.all_gather(x, axis_name='i')
|
||||
@ -893,29 +876,18 @@ class PmapTest(jtu.JaxTestCase):
|
||||
device_count = xla_bridge.device_count()
|
||||
f = pmap(lambda x: 3)
|
||||
x = jnp.arange(device_count + 1)
|
||||
if config.omnistaging_enabled:
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
(r"compiling computation that requires \d+ logical devices, "
|
||||
r"but only \d+ XLA devices are available .*"),
|
||||
lambda: f(x))
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
(r"compiling computation that requires \d+ logical devices, "
|
||||
r"but only \d+ XLA devices are available .*"),
|
||||
lambda: f(x))
|
||||
|
||||
# TODO(mattjj): test error message with explicit devices
|
||||
# f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
|
||||
# x = jnp.arange(2)
|
||||
# self.assertRaisesRegex(
|
||||
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
|
||||
# r"local devices are available.", lambda: f(x))
|
||||
else:
|
||||
self.assertRaisesRegex(
|
||||
ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
|
||||
r"local devices are available.", lambda: f(x))
|
||||
|
||||
f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
|
||||
x = jnp.arange(2)
|
||||
self.assertRaisesRegex(
|
||||
ValueError, "Cannot replicate across 2 replicas because only 1 "
|
||||
"local devices are available.", lambda: f(x))
|
||||
# TODO(mattjj): test error message with explicit devices
|
||||
# f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
|
||||
# x = jnp.arange(2)
|
||||
# self.assertRaisesRegex(
|
||||
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
|
||||
# r"local devices are available.", lambda: f(x))
|
||||
|
||||
def testNestedPmapConstant(self):
|
||||
if xla_bridge.device_count() == 1:
|
||||
@ -967,35 +939,22 @@ class PmapTest(jtu.JaxTestCase):
|
||||
f = pmap(pmap(lambda x: 3))
|
||||
shape = (2, xla_bridge.device_count() // 2 + 1, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
if config.omnistaging_enabled:
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
(r"compiling computation that requires \d+ logical devices, "
|
||||
r"but only \d+ XLA devices are available .*"),
|
||||
lambda: f(x))
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
(r"compiling computation that requires \d+ logical devices, "
|
||||
r"but only \d+ XLA devices are available .*"),
|
||||
lambda: f(x))
|
||||
|
||||
# TODO(mattjj): check error message with explicit devices
|
||||
# if xla_bridge.device_count() > 1:
|
||||
# f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
|
||||
# shape = (2, xla_bridge.device_count() // 2, 3)
|
||||
# x = jnp.arange(prod(shape)).reshape(shape)
|
||||
# self.assertRaisesRegex(
|
||||
# ValueError,
|
||||
# (r"compiling computation that requires \d+ replicas, "
|
||||
# r"but only \d+ XLA devices are available"),
|
||||
# lambda: f(x))
|
||||
else:
|
||||
self.assertRaisesRegex(
|
||||
ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
|
||||
r"local devices are available.", lambda: f(x))
|
||||
|
||||
if xla_bridge.device_count() > 1:
|
||||
f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
|
||||
shape = (2, xla_bridge.device_count() // 2, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
self.assertRaisesRegex(
|
||||
ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
|
||||
r"local devices are available.", lambda: f(x))
|
||||
# TODO(mattjj): check error message with explicit devices
|
||||
# if xla_bridge.device_count() > 1:
|
||||
# f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
|
||||
# shape = (2, xla_bridge.device_count() // 2, 3)
|
||||
# x = jnp.arange(prod(shape)).reshape(shape)
|
||||
# self.assertRaisesRegex(
|
||||
# ValueError,
|
||||
# (r"compiling computation that requires \d+ replicas, "
|
||||
# r"but only \d+ XLA devices are available"),
|
||||
# lambda: f(x))
|
||||
|
||||
def testCollectiveConstant(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
@ -1049,8 +1008,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f('i')(x), expected_j.T, check_dtypes=False)
|
||||
|
||||
def testAxisIndexNd(self):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("axis_index doesn't work without omnistaging")
|
||||
device_count = xla_bridge.device_count()
|
||||
if device_count < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
@ -1149,8 +1106,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testPswapaxes(self):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
device_count = xla_bridge.device_count()
|
||||
shape = (device_count, 3, device_count, 5)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
@ -1161,8 +1116,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOfPswapaxes(self):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
device_count = xla_bridge.device_count()
|
||||
shape = (device_count, 1, device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -1179,8 +1132,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllReplicaGroups(self):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
# If num_devices = 4, these would be the inputs/outputs:
|
||||
# input = [[0, 1], [2, 3], [4, 5], [6, 7]]
|
||||
# axis_index_groups = [[0, 1], [2, 3]]
|
||||
@ -1210,8 +1161,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOfAllToAllReplicaGroups(self):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("all_to_all doesn't work without omnistaging")
|
||||
device_count = xla_bridge.device_count()
|
||||
if device_count % 2 != 0:
|
||||
raise SkipTest('test requires an even number of devices')
|
||||
@ -1254,7 +1203,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmul(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
||||
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
||||
@ -1264,7 +1212,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmulJit(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
||||
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
||||
@ -1274,7 +1221,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsumConstant(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
def f(_):
|
||||
return lax.psum(1, 'i')
|
||||
@ -1284,7 +1230,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsum(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
def f(x):
|
||||
return x / lax.psum(x, 'i')
|
||||
@ -1294,7 +1239,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapAxisIndex(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
def f(x):
|
||||
return x * lax.axis_index('i')
|
||||
@ -1304,7 +1248,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapOfJit(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
n = 4 * xla_bridge.device_count()
|
||||
def f(x):
|
||||
return 3 * x
|
||||
@ -1342,7 +1285,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapDevicePersistence(self):
|
||||
if not config.omnistaging_enabled: raise SkipTest("requires omnistaging")
|
||||
device_count = xla_bridge.device_count()
|
||||
shape = (2 * 2 * device_count, 2, 3)
|
||||
|
||||
@ -1651,8 +1593,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
mapped_fn(indices) # doesn't crash
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testPdotBasic(self):
|
||||
num_devices = jax.device_count()
|
||||
|
||||
@ -1679,8 +1619,6 @@ class PmapTest(jtu.JaxTestCase):
|
||||
for axis in range(len(shape))
|
||||
)
|
||||
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
|
||||
if not config.omnistaging_enabled:
|
||||
self.skipTest("test requires omnistaging")
|
||||
if xla_bridge.device_count() < shape[axis]:
|
||||
raise SkipTest(f"test requires at least {shape[axis]} devices")
|
||||
if (jtu.device_under_test() == 'cpu' and
|
||||
@ -1757,8 +1695,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
{"testcase_name": "_collective={}".format(collective.__name__).replace(" ", ""),
|
||||
"collective": collective}
|
||||
for collective in [lax.psum, lax.pmean, lax.pmax, lax.pmin])
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testCollectivesWithVmap(self, collective):
|
||||
def f(map1, map2):
|
||||
@partial(map1, axis_name='i')
|
||||
@ -1775,8 +1711,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y)
|
||||
self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y)
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testPPermuteWithVmap(self):
|
||||
perm = [(0, 1), (1, 0)]
|
||||
|
||||
@ -1796,8 +1730,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
||||
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllInVmap(self, split_axis, concat_axis, vmap_axis):
|
||||
def f(x):
|
||||
@ -1867,8 +1799,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
||||
"split_axis": split_axis, "concat_axis": concat_axis}
|
||||
for split_axis, concat_axis in it.product(range(3), range(3)))
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllVsVmap(self, split_axis, concat_axis):
|
||||
def f(x):
|
||||
@ -1884,8 +1814,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
"axes": axes, "split_axis": split_axis, "concat_axis": concat_axis}
|
||||
for axes, split_axis, concat_axis
|
||||
in it.product([('i', 'j'), ('j', 'i')], range(3), range(3)))
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
|
||||
raise SkipTest("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
|
||||
@ -1900,8 +1828,6 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x),
|
||||
vmap(vmap(f, axis_name='j'), axis_name='i')(x))
|
||||
|
||||
@skipIf(not jax.config.omnistaging_enabled,
|
||||
"vmap collectives only supported when omnistaging is enabled")
|
||||
def testAllGatherWithVmap(self):
|
||||
def f(map2):
|
||||
@partial(jax.pmap, axis_name='i')
|
||||
@ -2062,7 +1988,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
expected = np.sin(x + 3.)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
|
||||
def testPmapInAxesBasic(self):
|
||||
@partial(pmap, in_axes=(1, 2))
|
||||
def f(x, y):
|
||||
@ -2075,7 +2000,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f(x, y),
|
||||
jnp.sin(x.transpose((1, 0, 2)) + y.transpose((2, 0, 1))))
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
|
||||
def testPmapInAxesGrad(self):
|
||||
def f(x, y, z):
|
||||
return jnp.sin(x + y + z)
|
||||
@ -2096,7 +2020,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(jax.grad(lambda args: fp(*args).sum())((x, y, z)),
|
||||
jax.grad(lambda args: fv(*args).sum())((x, y, z)))
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
|
||||
def testPmapOutAxesBasic(self):
|
||||
@partial(pmap, in_axes=(1, None), out_axes=(2, None))
|
||||
def f(x, y):
|
||||
@ -2109,7 +2032,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f(x, y),
|
||||
(jnp.sin(x.transpose((1, 0, 2)) + y).transpose((1, 2, 0)), y * 2))
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{in_axes}_{out_axes}",
|
||||
"in_axes": in_axes, "out_axes": out_axes}
|
||||
@ -2128,7 +2050,6 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(pmap(f, in_axes=in_axes, out_axes=out_axes), args,
|
||||
order=2, atol=2e-2, rtol=2e-2, eps=1e-3)
|
||||
|
||||
@skipIf(not config.omnistaging_enabled, "test requires omnistaging")
|
||||
def testPmapPostProcess(self):
|
||||
def mk_case(map_fun):
|
||||
def f(x, y):
|
||||
|
@ -752,8 +752,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
|
||||
|
||||
def testNoOpByOpUnderHash(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("test requires omnistaging")
|
||||
def fail(*args, **kwargs): assert False
|
||||
apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
|
||||
try:
|
||||
@ -916,8 +914,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
random.choice(key, 5, 2, replace=True)
|
||||
|
||||
def test_eval_shape_big_random_array(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("after deleting lazy constants, requires omnistaging")
|
||||
def f(x):
|
||||
return random.normal(random.PRNGKey(x), (int(1e12),))
|
||||
with jax.enable_checks(False): # check_jaxpr will materialize array
|
||||
@ -977,8 +973,6 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
api.jit(random.PRNGKey)(seed)
|
||||
|
||||
def test_random_split_doesnt_device_put_during_tracing(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("test requires omnistaging")
|
||||
key = random.PRNGKey(1).block_until_ready()
|
||||
with jtu.count_device_put() as count:
|
||||
api.jit(random.split)(key)
|
||||
|
@ -47,8 +47,6 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
{"testcase_name": "_jit={}".format(jit), "jit": jit}
|
||||
for jit in ["python", "cpp", None]))
|
||||
def test_make_array(self, jit):
|
||||
if jit == "cpp" and not config.omnistaging_enabled:
|
||||
self.skipTest("cpp_jit requires omnistaging")
|
||||
func = _maybe_jit(jit, lambda: jnp.arange(10.0))
|
||||
dtype_start = func().dtype
|
||||
with enable_x64():
|
||||
@ -64,9 +62,6 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
"enable_or_disable": f
|
||||
} for jit in ["python", "cpp", None] for f in [enable_x64, disable_x64]))
|
||||
def test_correctly_capture_default(self, jit, enable_or_disable):
|
||||
if jit == "cpp" and not config.omnistaging_enabled:
|
||||
self.skipTest("cpp_jit requires omnistaging")
|
||||
|
||||
# The fact we defined a jitted function with a block with a different value
|
||||
# of `config.enable_x64` has no impact on the output.
|
||||
with enable_or_disable():
|
||||
@ -87,8 +82,6 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
def test_near_singular_inverse(self, jit):
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.skipTest("64-bit inverse not available on TPU")
|
||||
if jit == "cpp" and not config.omnistaging_enabled:
|
||||
self.skipTest("cpp_jit requires omnistaging")
|
||||
@partial(_maybe_jit, jit, static_argnums=1)
|
||||
def near_singular_inverse(key, N, eps):
|
||||
X = random.uniform(key, (N, N))
|
||||
@ -111,8 +104,6 @@ class X64ContextTests(jtu.JaxTestCase):
|
||||
{"testcase_name": "_jit={}".format(jit), "jit": jit}
|
||||
for jit in ["python", "cpp", None]))
|
||||
def test_while_loop(self, jit):
|
||||
if jit == "cpp" and not config.omnistaging_enabled:
|
||||
self.skipTest("cpp_jit requires omnistaging")
|
||||
@partial(_maybe_jit, jit)
|
||||
def count_to(N):
|
||||
return lax.while_loop(lambda x: x < N, lambda x: x + 1.0, 0.0)
|
||||
|
@ -235,10 +235,7 @@ def schedules(sizes: Dict[str, int]
|
||||
|
||||
|
||||
class XMapTestCase(jtu.BufferDonationTestCase):
|
||||
def setUp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
super().setUp()
|
||||
pass
|
||||
|
||||
|
||||
# A mixin that enables SPMD lowering tests
|
||||
@ -689,9 +686,6 @@ class NamedNNTest(XMapTestCase):
|
||||
|
||||
|
||||
class NewPrimitiveTest(XMapTestCase):
|
||||
def setUp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
|
||||
def testGatherPositional(self):
|
||||
x = jnp.arange(27).reshape((9, 3))
|
||||
@ -1031,11 +1025,6 @@ class PDotTests(XMapTestCase):
|
||||
|
||||
class XMapErrorTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise SkipTest("xmap requires omnistaging")
|
||||
super().setUp()
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('x', 2)])
|
||||
def testRepeatedAxisResource(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user