mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove code to support jaxlib <= 0.4.33.
This commit is contained in:
parent
c0240764bc
commit
d3f63a66b8
@ -62,7 +62,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -3055,14 +3054,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data, False # Do not remove cache entry
|
||||
|
||||
if xla_extension_version >= 286:
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [],
|
||||
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
|
||||
else:
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, cc_shard_arg)
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [],
|
||||
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
|
||||
|
||||
def cc_shard_arg(x, sharding, layout):
|
||||
return shard_args([sharding], [layout], [x])[0]
|
||||
|
@ -47,7 +47,6 @@ from jax._src.lax.lax import (
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -709,8 +708,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
|
||||
out_aval = ctx.avals_out[0]
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
ctx_args = (ctx,)
|
||||
w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand,
|
||||
w, vl, vr, info = lapack.geev_hlo(ctx, operand_aval.dtype, operand,
|
||||
input_shape_vals=op_shape_vals,
|
||||
jobvl=compute_left_eigenvectors,
|
||||
jobvr=compute_right_eigenvectors)
|
||||
@ -2033,8 +2031,7 @@ def _svd_cpu_gpu_lowering(
|
||||
compute_uv=compute_uv)
|
||||
else:
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
ctx_args = (ctx,)
|
||||
s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand,
|
||||
s, u, vt, info = gesvd_impl(ctx, operand_aval.dtype, operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
a_shape_vals=a_shape_vals)
|
||||
@ -2477,9 +2474,7 @@ batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
|
||||
def _hessenberg_cpu_hlo(ctx, a):
|
||||
a_aval, = ctx.avals_in
|
||||
batch_dims = a_aval.shape[:-2]
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 34) else ()
|
||||
a, taus, info = lapack.gehrd_hlo(*ctx_args, a_aval.dtype, a)
|
||||
a, taus, info = lapack.gehrd_hlo(ctx, a_aval.dtype, a)
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
|
107
jax/_src/pjit.py
107
jax/_src/pjit.py
@ -62,7 +62,6 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import sharding
|
||||
from jax._src.mesh import AbstractMesh
|
||||
from jax._src.sharding_impls import (
|
||||
@ -322,28 +321,11 @@ _cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
|
||||
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
|
||||
|
||||
|
||||
if xla_extension_version < 286:
|
||||
def _get_cpp_global_cache(pjit_has_explicit_sharding):
|
||||
if pjit_has_explicit_sharding:
|
||||
return xc._xla.PjitFunctionCache()
|
||||
else:
|
||||
return _cpp_pjit_cache_fun_only
|
||||
|
||||
def _pjit_explicit_sharding_and_layout(
|
||||
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
|
||||
device, backend) -> bool:
|
||||
return (device is not None or
|
||||
backend is not None or
|
||||
any(not is_unspecified(i) for i in in_shardings_flat) or
|
||||
any(not is_unspecified(o) for o in out_shardings_flat) or
|
||||
any(i is not None for i in in_layouts_flat) or
|
||||
any(o is not None for o in out_layouts_flat))
|
||||
else:
|
||||
def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
|
||||
if contains_explicit_attributes:
|
||||
return _cpp_pjit_cache_explicit_attributes
|
||||
else:
|
||||
return _cpp_pjit_cache_fun_only
|
||||
def _get_cpp_global_cache(contains_explicit_attributes: bool):
|
||||
if contains_explicit_attributes:
|
||||
return _cpp_pjit_cache_explicit_attributes
|
||||
else:
|
||||
return _cpp_pjit_cache_fun_only
|
||||
|
||||
|
||||
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
@ -364,35 +346,24 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
|
||||
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
||||
|
||||
if xla_extension_version >= 286:
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
donate_argnums=jit_info.donate_argnums,
|
||||
donate_argnames=jit_info.donate_argnames,
|
||||
device=jit_info.device, backend=jit_info.backend,
|
||||
in_shardings_treedef=jit_info.in_shardings_treedef,
|
||||
in_shardings_leaves=jit_info.in_shardings_leaves,
|
||||
out_shardings_treedef=jit_info.out_shardings_treedef,
|
||||
out_shardings_leaves=jit_info.out_shardings_leaves,
|
||||
in_layouts_treedef=jit_info.in_layouts_treedef,
|
||||
in_layouts_leaves=jit_info.in_layouts_leaves,
|
||||
out_layouts_treedef=jit_info.out_layouts_treedef,
|
||||
out_layouts_leaves=jit_info.out_layouts_leaves,
|
||||
use_resource_env=jit_info.use_resource_env)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
|
||||
pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
|
||||
else:
|
||||
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
||||
jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
|
||||
jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
|
||||
jit_info.device, jit_info.backend)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
jit_info.static_argnames, jit_info.donate_argnums,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
donate_argnums=jit_info.donate_argnums,
|
||||
donate_argnames=jit_info.donate_argnames,
|
||||
device=jit_info.device, backend=jit_info.backend,
|
||||
in_shardings_treedef=jit_info.in_shardings_treedef,
|
||||
in_shardings_leaves=jit_info.in_shardings_leaves,
|
||||
out_shardings_treedef=jit_info.out_shardings_treedef,
|
||||
out_shardings_leaves=jit_info.out_shardings_leaves,
|
||||
in_layouts_treedef=jit_info.in_layouts_treedef,
|
||||
in_layouts_leaves=jit_info.in_layouts_leaves,
|
||||
out_layouts_treedef=jit_info.out_layouts_treedef,
|
||||
out_layouts_leaves=jit_info.out_layouts_leaves,
|
||||
use_resource_env=jit_info.use_resource_env)
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
|
||||
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
|
||||
pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
cpp_pjitted_f._fun = fun
|
||||
@ -1752,26 +1723,18 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
resource_env, donated_invars, name, keep_unused, inline)
|
||||
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
if xla_extension_version >= 286:
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
donate_argnums=donated_argnums, donate_argnames=None,
|
||||
device=None, backend=None,
|
||||
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
|
||||
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
|
||||
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
|
||||
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
|
||||
use_resource_env=resource_env is not None)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], cache_key,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
|
||||
else:
|
||||
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
|
||||
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
cache_key = pxla.JitGlobalCppCacheKeys(
|
||||
donate_argnums=donated_argnums, donate_argnames=None,
|
||||
device=None, backend=None,
|
||||
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
|
||||
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
|
||||
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
|
||||
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
|
||||
use_resource_env=resource_env is not None)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], cache_key,
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
|
@ -536,7 +536,6 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
@ -1079,117 +1078,6 @@ def _outside_call_impl(*args, **params):
|
||||
outside_call_p.def_impl(_outside_call_impl)
|
||||
|
||||
|
||||
def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
builder.set_sharding(sharding_proto)
|
||||
try:
|
||||
return op_fn(*args, **kwargs)
|
||||
finally:
|
||||
builder.clear_sharding()
|
||||
|
||||
def _outside_call_translation_rule(ctx,
|
||||
avals_in,
|
||||
avals_out,
|
||||
*args_op: XlaOp,
|
||||
has_token,
|
||||
identity,
|
||||
device_index,
|
||||
flat_results_aval=(),
|
||||
**params):
|
||||
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
|
||||
assert has_token
|
||||
use_outfeed = _use_outfeed(ctx.platform)
|
||||
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
|
||||
current_token = args_op[-2]
|
||||
current_itoken = args_op[-1]
|
||||
comp = ctx.builder
|
||||
assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), (
|
||||
"The last two arguments must be tokens")
|
||||
|
||||
args_to_outfeed = args_op[:-2]
|
||||
# Some platforms refuse to infeed empty arrays. We generate constants
|
||||
# instead.
|
||||
non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)),
|
||||
flat_results_aval))
|
||||
need_callback_results_on_device = (not identity and
|
||||
len(non_empty_flat_results_aval) > 0)
|
||||
send_infeed = use_outfeed and need_callback_results_on_device
|
||||
generated_infeed = False # Keep track if we emitted an infeed op
|
||||
|
||||
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
|
||||
callback_id = _register_callback(
|
||||
functools.partial(
|
||||
_outside_call_run_callback,
|
||||
send_infeed=send_infeed,
|
||||
identity=identity,
|
||||
flat_results_aval=flat_results_aval,
|
||||
**params))
|
||||
next_token = _callback_handler_data.receiver.add_outfeed(
|
||||
comp, current_token, callback_id, args_to_outfeed, device_index)
|
||||
if identity:
|
||||
results = list(args_to_outfeed)
|
||||
next_itoken = current_itoken
|
||||
else:
|
||||
empty_results = [
|
||||
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
|
||||
for aval in flat_results_aval
|
||||
if _aval_is_empty(aval)
|
||||
]
|
||||
if non_empty_flat_results_aval:
|
||||
assert need_callback_results_on_device
|
||||
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
|
||||
# We shard the infeed as AssignedDevice(device_index). This must match the
|
||||
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
|
||||
# this kind of sharding, we use a custom translation for infeed.
|
||||
array_sharding_proto = xla_client.OpSharding()
|
||||
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
|
||||
array_sharding_proto.tile_assignment_dimensions = [1]
|
||||
array_sharding_proto.tile_assignment_devices = [device_index]
|
||||
|
||||
token_sharding_proto = xla_client.OpSharding()
|
||||
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
|
||||
infeed_sharding_proto = xla.tuple_sharding_proto(
|
||||
[array_sharding_proto] * len(non_empty_flat_results_aval) +
|
||||
[token_sharding_proto])
|
||||
|
||||
shape = [
|
||||
shape.with_major_to_minor_layout_if_absent()
|
||||
for x in non_empty_flat_results_aval
|
||||
for shape in xla.aval_to_xla_shapes(x)
|
||||
]
|
||||
|
||||
build_infeed = functools.partial(xops.InfeedWithToken,
|
||||
after_outfeed_itoken,
|
||||
xla_client.Shape.tuple_shape(shape))
|
||||
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
|
||||
build_infeed)
|
||||
outs = xops.GetTupleElement(outs_and_token, 0)
|
||||
next_itoken = xops.GetTupleElement(outs_and_token, 1)
|
||||
non_empty_results = [
|
||||
xops.GetTupleElement(outs, i)
|
||||
for i in range(len(non_empty_flat_results_aval))
|
||||
]
|
||||
generated_infeed = True
|
||||
results = [
|
||||
empty_results.pop(0)
|
||||
if _aval_is_empty(result_aval) else non_empty_results.pop(0)
|
||||
for result_aval in flat_results_aval
|
||||
]
|
||||
else:
|
||||
results = empty_results
|
||||
next_itoken = current_itoken
|
||||
|
||||
assert generated_infeed == send_infeed, (
|
||||
f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")
|
||||
assert identity or len(results) == len(flat_results_aval), (
|
||||
f"got {len(results)} but expected {len(flat_results_aval)}. "
|
||||
f"identity = {identity}")
|
||||
return results + [next_token, next_itoken]
|
||||
|
||||
if xla_extension_version < 287:
|
||||
xla.register_translation(outside_call_p, _outside_call_translation_rule)
|
||||
|
||||
|
||||
def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext,
|
||||
*args_op,
|
||||
identity,
|
||||
@ -1318,25 +1206,14 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
|
||||
platform = ctx.module_context.platforms[0]
|
||||
use_outfeed = _use_outfeed(platform)
|
||||
if use_outfeed:
|
||||
if xla_extension_version < 287:
|
||||
return mlir.xla_fallback_lowering(outside_call_p)(
|
||||
ctx,
|
||||
*args,
|
||||
has_token=has_token,
|
||||
identity=identity,
|
||||
device_index=device_index,
|
||||
flat_results_aval=flat_results_aval,
|
||||
**params,
|
||||
)
|
||||
else:
|
||||
return _outside_call_outfeed_lowering(
|
||||
ctx, *args,
|
||||
has_token=has_token,
|
||||
identity=identity,
|
||||
flat_results_aval=flat_results_aval,
|
||||
device_index=device_index,
|
||||
**params,
|
||||
)
|
||||
return _outside_call_outfeed_lowering(
|
||||
ctx, *args,
|
||||
has_token=has_token,
|
||||
identity=identity,
|
||||
flat_results_aval=flat_results_aval,
|
||||
device_index=device_index,
|
||||
**params,
|
||||
)
|
||||
else:
|
||||
# TODO(necula): It seems that on CPU, with custom call, the device_index
|
||||
# does not work, and the callback is always run on device_index=0
|
||||
@ -1405,10 +1282,7 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
|
||||
f"identity = {identity}")
|
||||
return list(results) + [next_token, next_itoken]
|
||||
|
||||
if xla_extension_version < 287:
|
||||
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
|
||||
else:
|
||||
mlir.register_lowering(outside_call_p, _outside_call_lowering)
|
||||
mlir.register_lowering(outside_call_p, _outside_call_lowering)
|
||||
|
||||
def _outside_call_run_callback(
|
||||
arrays, device, *,
|
||||
|
@ -67,7 +67,6 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -705,15 +704,12 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 34)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_hessenberg_lapack_gehrd.data_2024_08_31[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_hessenberg_lapack_gehrd.data_2024_08_31[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
|
||||
def test_approx_top_k(self):
|
||||
def func():
|
||||
|
@ -3975,8 +3975,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
size_im = 11
|
||||
atol = None
|
||||
|
||||
if (name in {"arccos", "arcsin", "arcsinh", "arccosh"}
|
||||
or name in {"arctan", "arctanh"} and jax._src.lib.version > (0, 4, 31)):
|
||||
if name in {"arccos", "arcsin", "arcsinh", "arccosh", "arctan", "arctanh"}:
|
||||
# TODO(pearu): eliminate this if-block when a fix to mpmath#787
|
||||
# becomes available
|
||||
extra_prec_multiplier = 20
|
||||
@ -4132,16 +4131,6 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
elif name == 'arccos':
|
||||
regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real')
|
||||
|
||||
elif name == 'arctan' and jax._src.lib.version <= (0, 4, 31):
|
||||
if dtype == np.complex64:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj',
|
||||
'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag')
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real')
|
||||
|
||||
elif name == 'arctanh' and jax._src.lib.version <= (0, 4, 31):
|
||||
regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
|
||||
|
||||
elif name in {'cos', 'sin'}:
|
||||
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
|
||||
|
||||
|
@ -25,7 +25,6 @@ from jax._src import config
|
||||
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -46,9 +45,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
|
||||
def test_auto_layout(self):
|
||||
# Remove this condition when xla_extension_version >= 285
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
|
||||
self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape1 = (128, 128)
|
||||
shape2 = (128, 128)
|
||||
@ -114,9 +110,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T)
|
||||
|
||||
def test_default_layout(self):
|
||||
# Remove this condition when xla_extension_version >= 285
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
|
||||
self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 4, 2)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -156,9 +149,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
out_shardings=DLL.AUTO).lower(sds).compile()
|
||||
|
||||
def test_in_layouts_out_layouts(self):
|
||||
# Remove this condition when xla_extension_version >= 285
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
|
||||
self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (8, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -183,9 +173,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
def test_sharding_and_layouts(self):
|
||||
# Remove this condition when xla_extension_version >= 285
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
|
||||
self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
shape = (4, 8)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
@ -477,9 +464,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
jax.device_put(inp, l)
|
||||
|
||||
def test_concrete_layout_in_shardings(self):
|
||||
# Remove this condition when xla_extension_version >= 285
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285:
|
||||
self.skipTest("Requires xla_extension_version >= 285 for GPU backend.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
shape = (16, 128)
|
||||
|
@ -35,7 +35,6 @@ from jax._src.sharding_impls import (NamedSharding, PositionalSharding,
|
||||
TransferToMemoryKind, PartitionSpec as P)
|
||||
from jax.experimental.compute_on import compute_on
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -416,8 +415,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
out, np_inp * np_inp, s_dev, "device")
|
||||
|
||||
def test_parameter_streaming(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
_, s_host, np_inp, inp_host = _create_inputs(
|
||||
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
||||
s_dev = s_host.with_memory_kind('device')
|
||||
@ -461,8 +458,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
out, np_inp, s_host, 'pinned_host')
|
||||
|
||||
def test_parameter_streaming_with_scalar_and_constant(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
mesh = jtu.create_mesh((2, 2), ("x", "y"))
|
||||
scalar_inp = 1
|
||||
s_host = NamedSharding(mesh, P(), memory_kind="pinned_host")
|
||||
@ -512,8 +507,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_parameter_and_output_streaming_with_scalar(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
|
||||
@ -581,8 +574,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_hbm.sharding, out_s)
|
||||
|
||||
def test_output_streaming(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
mesh = jtu.create_mesh((1, 1), ("x", "y"))
|
||||
np_inp = np.arange(16.0).reshape(8, 2)
|
||||
s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device")
|
||||
@ -599,8 +590,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_host.sharding, s_host)
|
||||
|
||||
def test_weight_offload_with_dp_on_output(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
_, s_dev, np_inp, inp_dev = _create_inputs(
|
||||
(8, 2), P("x", "y"), mem_kind="device")
|
||||
s_host = s_dev.with_memory_kind('pinned_host')
|
||||
@ -616,8 +605,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
out_host, np_inp * 2, s_host, 'pinned_host')
|
||||
|
||||
def test_output_streaming_inside_scan(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2:
|
||||
self.skipTest("This test requires an xla_version >= 2.")
|
||||
mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z"))
|
||||
@ -650,8 +637,6 @@ class DevicePutTest(jtu.JaxTestCase):
|
||||
self.assertEqual(t.shape, t_copy.shape)
|
||||
|
||||
def test_close_over_host_constant_and_stream(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
|
||||
_, s_host, np_inp, inp_host = _create_inputs(
|
||||
(8, 2), P("x", "y"), mem_kind="pinned_host")
|
||||
@ -1562,8 +1547,6 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
||||
|
||||
def test_remat_scan_layout_change_offloadable(self):
|
||||
if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289:
|
||||
self.skipTest("Requires xla_extension_version >= 289")
|
||||
mesh = jtu.create_mesh((2,), ("x",))
|
||||
shape = (256, 128)
|
||||
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
@ -59,7 +59,6 @@ from jax._src.lib.mlir import dialects
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -661,10 +660,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
jax.grad(f)(x) # Warm up the cache.
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
jax.grad(f)(x)
|
||||
if xla_extension_version >= 286:
|
||||
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
|
||||
else:
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testEvalJaxpr(self):
|
||||
@ -4590,8 +4586,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
' match the mesh shape of the target sharding.*'):
|
||||
with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y')))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 286,
|
||||
"Requires xla_extension_version >= 286")
|
||||
def test_global_jit_cpp_cache_hit_out_shardings(self):
|
||||
mesh = jtu.create_mesh((2,), 'x')
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
@ -24,7 +24,6 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import flatten_util
|
||||
from jax import tree_util
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.tree_util import flatten_one_level, prefix_errors
|
||||
import jax.numpy as jnp
|
||||
@ -485,10 +484,8 @@ class TreeTest(jtu.JaxTestCase):
|
||||
[([1], (2,), {"a": [1]})],
|
||||
re.escape("Custom node type mismatch"),
|
||||
),
|
||||
*(
|
||||
[]
|
||||
if xla_extension_version < 288
|
||||
else [(None, [2], re.escape("Expected None, got [2]."))]
|
||||
(
|
||||
(None, [2], re.escape("Expected None, got [2]."))
|
||||
),
|
||||
)
|
||||
def testFlattenUpToErrors(self, tree, xs, error):
|
||||
|
Loading…
x
Reference in New Issue
Block a user