mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Increase minimum jaxlib version to 0.1.74.
This commit is contained in:
parent
52421da0f6
commit
3fd3c46f20
@ -109,8 +109,7 @@ flags.DEFINE_bool(
|
||||
"Set this to `False` only if it crashes otherwise and report "
|
||||
"the error to the jax-team.")
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_pmap",
|
||||
bool_env("JAX_CPP_PMAP", jax._src.lib._xla_extension_version >= 39),
|
||||
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
|
||||
"A flag enabling the C++ jax.pmap fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
@ -2014,8 +2013,7 @@ def _cpp_pmap(
|
||||
|
||||
return out, fastpath_data
|
||||
|
||||
# TODO(slebedev): Remove the ignore once jaxlib>=0.1.71.
|
||||
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss, # type: ignore[call-arg]
|
||||
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss,
|
||||
static_broadcasted_tuple, pxla._shard_arg)
|
||||
|
||||
f_pmapped = wraps(fun)(cpp_mapped_f)
|
||||
|
@ -55,7 +55,6 @@ import jax._src.lib
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
xb = xla_bridge
|
||||
xc = xla_client
|
||||
@ -2801,21 +2800,7 @@ ad.defjvp2(rsqrt_p,
|
||||
lambda g, ans, x:
|
||||
mul(g, mul(_const(x, -0.5), div(ans, x))))
|
||||
|
||||
# TODO(phawkins): remove the fallback translation rule after the minimum jaxlib
|
||||
# is 0.1.70 or newer.
|
||||
if jax._src.lib._xla_extension_version >= 28:
|
||||
_cbrt_translation_rule = None
|
||||
else:
|
||||
def _cbrt_translation_rule(ctx, avals_in, avals_out, x):
|
||||
x_aval, = avals_in
|
||||
return [xops.Mul(
|
||||
xops.Sign(x),
|
||||
xops.Pow(xops.Abs(x),
|
||||
xla.pyval_to_ir_constant(ctx.builder,
|
||||
np.array(1/3, dtype=x_aval.dtype))))]
|
||||
|
||||
cbrt_p = standard_unop(_float, 'cbrt',
|
||||
translation_rule=_cbrt_translation_rule)
|
||||
cbrt_p = standard_unop(_float, 'cbrt')
|
||||
ad.defjvp2(cbrt_p,
|
||||
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
||||
|
||||
@ -5900,15 +5885,6 @@ def _compute_argminmax(value_comparator, get_identity,
|
||||
axes)
|
||||
return res[1]
|
||||
|
||||
def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
|
||||
axis, = axes
|
||||
idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis))
|
||||
maxval = np.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype)
|
||||
maxval = broadcast(tie_in(a, maxval), a.shape)
|
||||
maxvals = expand_dims(op(a, (axis,)), (axis,))
|
||||
mask_idxs = select(eq(a, maxvals) | ne(a, a), idxs, maxval)
|
||||
return _reduce_min(mask_idxs, (axis,))
|
||||
|
||||
_argmin_translation_rule = xla.lower_fun(
|
||||
partial(_compute_argminmax, lt, _get_min_identity),
|
||||
multiple_results=False, new_style=True)
|
||||
@ -5922,28 +5898,12 @@ argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
||||
weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(argmin_p)
|
||||
ad.defjvp_zero(argmin_p)
|
||||
if jax._src.lib._xla_extension_version < 41:
|
||||
xla.register_translation(
|
||||
argmin_p,
|
||||
xla.lower_fun(
|
||||
partial(_argminmax_gpu_translation_rule, _reduce_min),
|
||||
multiple_results=False,
|
||||
new_style=True),
|
||||
platform='gpu')
|
||||
|
||||
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
||||
'argmax', _argmax_translation_rule,
|
||||
weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(argmax_p)
|
||||
ad.defjvp_zero(argmax_p)
|
||||
if jax._src.lib._xla_extension_version < 41:
|
||||
xla.register_translation(
|
||||
argmax_p,
|
||||
xla.lower_fun(
|
||||
partial(_argminmax_gpu_translation_rule, _reduce_max),
|
||||
multiple_results=False,
|
||||
new_style=True),
|
||||
platform='gpu')
|
||||
|
||||
|
||||
def _reduce_logical_shape_rule(operand, *, axes):
|
||||
@ -6927,7 +6887,6 @@ def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
|
||||
def _rng_bit_generator_translation_rule(
|
||||
ctx, avals_in, avals_out, key, *, shape, dtype, algorithm):
|
||||
c = ctx.builder
|
||||
backend_is_gpu = ctx.platform == "gpu"
|
||||
key_shape, key_dtype = c.get_shape(key).dimensions(), c.get_shape(key).numpy_dtype()
|
||||
# While the RngBitGenerator HLO accepts a u64[2] key on all backends, we
|
||||
# typically represent the key argument to this primitive as a u32[4] so as to
|
||||
@ -6938,48 +6897,15 @@ def _rng_bit_generator_translation_rule(
|
||||
(key_shape == (2,) and key_dtype == np.dtype('uint64'))), (key_shape, key_dtype)
|
||||
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
|
||||
if key_dtype == np.dtype('uint32'):
|
||||
# TODO(mattjj): the BitcastConvertType segfaults on GPU
|
||||
# TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer
|
||||
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
|
||||
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
|
||||
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
|
||||
else:
|
||||
key = _convert_4xU32_to_2xU64_without_bitcast(c, key)
|
||||
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
|
||||
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
|
||||
out_key, out_vals = xla.xla_destructure(
|
||||
c, xops.RngBitGenerator(algorithm, key, xla_shape))
|
||||
if key_dtype == np.dtype('uint32'):
|
||||
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
|
||||
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
|
||||
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
|
||||
else:
|
||||
out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key)
|
||||
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
|
||||
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
|
||||
return [out_key, out_vals]
|
||||
|
||||
def _convert_4xU32_to_2xU64_without_bitcast(c, key):
|
||||
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
|
||||
new_key = xops.Constant(c, np.zeros(2, dtype=np.dtype('uint64')))
|
||||
_32 = xops.Constant(c, np.array(32, np.uint64))
|
||||
for i in [0, 2]:
|
||||
hi = xops.ConvertElementType(xops.Slice(key, [i] , [i+1], [1]), u64_etype)
|
||||
lo = xops.ConvertElementType(xops.Slice(key, [i+1], [i+2], [1]), u64_etype)
|
||||
elt = xops.Xor(xops.ShiftLeft(hi, _32), lo)
|
||||
new_key = xops.DynamicUpdateSlice(new_key, elt,
|
||||
[xla.pyval_to_ir_constant(c, i // 2)])
|
||||
return new_key
|
||||
|
||||
def _convert_2xU64_to_4xU32_without_bitcast(c, key):
|
||||
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
|
||||
new_key = xops.Constant(c, np.zeros(4, dtype=np.dtype('uint32')))
|
||||
_32 = xops.Constant(c, np.array(32, np.uint64))
|
||||
for i in [0, 1]:
|
||||
elt = xops.Slice(key, [i], [i+1], [1])
|
||||
hi = xops.ConvertElementType(xops.ShiftRightLogical(elt, _32), u32_etype)
|
||||
lo = xops.ConvertElementType(elt, u32_etype)
|
||||
new_key = xops.DynamicUpdateSlice(new_key, hi,
|
||||
[xla.pyval_to_ir_constant(c, 2 * i)])
|
||||
new_key = xops.DynamicUpdateSlice(new_key, lo,
|
||||
[xla.pyval_to_ir_constant(c, 2 * i + 1)])
|
||||
return new_key
|
||||
|
||||
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
||||
return [key.named_shape, key.named_shape]
|
||||
|
@ -41,7 +41,6 @@ from jax._src.lib import cusparse
|
||||
from jax._src.lib import rocsolver
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
@ -1538,11 +1537,6 @@ def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
c = ctx.builder
|
||||
|
||||
if jaxlib_version < (0, 1, 72):
|
||||
raise NotImplementedError(
|
||||
"The Schur primitive is only implemented for jaxlib versions >= 0.1.72"
|
||||
)
|
||||
|
||||
_cpu_gees = lapack.gees
|
||||
|
||||
if sort_eig_vals:
|
||||
|
@ -120,10 +120,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
hash_obj.update(compile_options_obj.device_assignment.serialize())
|
||||
|
||||
def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
if jax._src.lib.version >= (0, 1, 72):
|
||||
expected_options = 31
|
||||
else:
|
||||
expected_options = 30
|
||||
expected_options = 31
|
||||
assert len(dir(executable_obj)) == expected_options, (
|
||||
f"Unexpected number of executable_build_options fields: "
|
||||
f"{len(dir(executable_obj))}. This likely means that an extra "
|
||||
@ -136,8 +133,7 @@ def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
if executable_obj.device_assignment is not None:
|
||||
hash_obj.update(executable_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
|
||||
if jax._src.lib.version >= (0, 1, 72):
|
||||
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)
|
||||
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)
|
||||
|
||||
def _hash_debug_options(hash_obj, debug_obj):
|
||||
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math)
|
||||
|
@ -1674,20 +1674,13 @@ translations[lax.rng_uniform_p] = _rng_uniform_lowering
|
||||
# xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
|
||||
# if key_dtype == np.dtype('uint32'):
|
||||
# # TODO(mattjj): the BitcastConvertType segfaults on GPU
|
||||
# # TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer
|
||||
# if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
|
||||
# u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
|
||||
# key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
|
||||
# else:
|
||||
# key = _convert_4xU32_to_2xU64_without_bitcast(c, key)
|
||||
# u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
|
||||
# key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
|
||||
# out_key, out_vals = xla.xla_destructure(
|
||||
# c, xops.RngBitGenerator(algorithm, key, xla_shape))
|
||||
# if key_dtype == np.dtype('uint32'):
|
||||
# if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
|
||||
# u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
|
||||
# out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
|
||||
# else:
|
||||
# out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key)
|
||||
# u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
|
||||
# out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
|
||||
# return [out_key, out_vals]
|
||||
|
||||
|
||||
|
@ -34,14 +34,13 @@ from functools import partial
|
||||
import itertools as it
|
||||
import operator as op
|
||||
import threading
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional,
|
||||
from typing import (Any, Callable, Dict, List, Optional,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable)
|
||||
import sys
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from .._src.config import config
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
@ -55,7 +54,6 @@ from ..errors import JAXTypeError
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import _xla_extension_version
|
||||
from ..tree_util import tree_flatten, tree_map
|
||||
from . import batching
|
||||
from . import partial_eval as pe
|
||||
@ -97,9 +95,7 @@ MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
# mypy will consider this constant to be True at type check time.
|
||||
MYPY = False
|
||||
|
||||
# TODO(jblespiau): Remove the version check when jaxlib 0.1.70 is the minimal
|
||||
# version.
|
||||
if MYPY or (not TYPE_CHECKING and _xla_extension_version < 30):
|
||||
if MYPY:
|
||||
class ShardingSpec:
|
||||
"""Describes the sharding of an ndarray.
|
||||
|
||||
@ -503,8 +499,8 @@ def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
# TODO(jblespiau): Remove when jaxlib 0.1.72 is the minimal version.
|
||||
_USE_CPP_SDA = _xla_extension_version >= 38
|
||||
# TODO(jblespiau): Consider removing this option.
|
||||
_USE_CPP_SDA = True
|
||||
|
||||
|
||||
def make_sharded_device_array(
|
||||
@ -539,9 +535,8 @@ def make_sharded_device_array(
|
||||
if (_USE_CPP_SDA and
|
||||
(not device_buffers or
|
||||
isinstance(device_buffers[0], xb.xla_client.Buffer))):
|
||||
# TODO(slebedev): Remove the ignore once jaxlib>=0.1.71.
|
||||
return pmap_lib.ShardedDeviceArray.make(
|
||||
aval, sharding_spec, device_buffers, # type: ignore[arg-type, call-arg]
|
||||
aval, sharding_spec, device_buffers,
|
||||
indices, aval.weak_type)
|
||||
|
||||
return _ShardedDeviceArray(aval, sharding_spec, device_buffers, indices)
|
||||
@ -1783,12 +1778,8 @@ class MeshExecutable:
|
||||
use_spmd_partitioning=spmd_lowering,
|
||||
)
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
if jax._src.lib.version >= (0, 1, 72):
|
||||
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
|
||||
_allow_propagation_to_outputs
|
||||
elif _allow_propagation_to_outputs:
|
||||
raise RuntimeError("Propagation of SPMD sharding specs to outputs is only supported "
|
||||
"in jaxlib 0.1.72+. Please update your JAX version.")
|
||||
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
|
||||
_allow_propagation_to_outputs
|
||||
|
||||
local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names)
|
||||
local_input_specs = [local_sharding_spec(aval, aval_in_axes)
|
||||
|
@ -13,4 +13,4 @@
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.2.26"
|
||||
_minimum_jaxlib_version = "0.1.69"
|
||||
_minimum_jaxlib_version = "0.1.74"
|
||||
|
@ -82,16 +82,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
# Tensorflow.
|
||||
return api._cpp_jit
|
||||
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version < 40,
|
||||
"Test requires jaxlib 0.1.73")
|
||||
def test_jit_repr(self):
|
||||
def my_function():
|
||||
return
|
||||
jitted = jit(my_function)
|
||||
self.assertEqual(repr(jitted), f"<CompiledFunction of {repr(my_function)}>")
|
||||
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version < 40,
|
||||
"Test requires jaxlib 0.1.73")
|
||||
def test_jit_repr_errors(self):
|
||||
class Callable:
|
||||
def __call__(self): pass
|
||||
@ -692,8 +688,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
np.testing.assert_allclose(f_pruned(*args), 3)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version <= 36,
|
||||
"Test requires jaxlib 0.1.71")
|
||||
def testBuffersAreFreedPromptly(self):
|
||||
# Regression test for a bug where garbage collection was delayed too long
|
||||
# for NumPy buffers that are aliased zero-copy by the runtime.
|
||||
|
@ -97,9 +97,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
f(1)
|
||||
|
||||
def testPmap(self):
|
||||
pmap_funcs = [api._python_pmap]
|
||||
if jax._src.lib._xla_extension_version >= 36:
|
||||
pmap_funcs.append(api._cpp_pmap)
|
||||
pmap_funcs = [api._python_pmap, api._cpp_pmap]
|
||||
|
||||
for pmap in pmap_funcs:
|
||||
f = pmap(lambda x: 0. / x)
|
||||
|
@ -19,7 +19,6 @@ from functools import partial
|
||||
import itertools
|
||||
import typing
|
||||
from typing import Any, Optional, Tuple
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -1241,8 +1240,6 @@ class IndexedUpdateTest(jtu.JaxTestCase):
|
||||
with self.assertNoWarnings():
|
||||
x.at[normalize(idx)].set(0)
|
||||
|
||||
@unittest.skipIf(jax._src.lib.version < (0, 1, 72),
|
||||
"Bug fixed in jaxlib 0.1.72")
|
||||
def testIndexedUpdateAliasingBug(self):
|
||||
# https://github.com/google/jax/issues/7461
|
||||
fn = lambda x: x.at[1:].set(1 + x[:-1])
|
||||
|
@ -1453,8 +1453,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testSchur(self, shape, dtype):
|
||||
if jax._src.lib.version < (0, 1, 72):
|
||||
self.skipTest("Schur LAPACK wrapper only implemented for jaxlib versions >= 0.1.72")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
@ -1470,8 +1468,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
|
||||
for dtype in float_types + complex_types))
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testSchurBatching(self, shape, dtype):
|
||||
if jax._src.lib.version < (0, 1, 72):
|
||||
self.skipTest("Schur LAPACK wrapper only implemented for jaxlib versions >= 0.1.72")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
batch_size = 10
|
||||
shape = (batch_size, ) + shape
|
||||
|
@ -35,8 +35,6 @@ config.parse_flags_with_absl()
|
||||
class CloudpickleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version < 31,
|
||||
"Requires jaxlib 0.1.71")
|
||||
def testPickleOfJittedFunctions(self):
|
||||
|
||||
@jax.jit
|
||||
@ -56,8 +54,6 @@ class CloudpickleTest(jtu.JaxTestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version < 39,
|
||||
"Requires jaxlib 0.1.72")
|
||||
def testPickleOfPmappedFunctions(self):
|
||||
|
||||
@jax.pmap
|
||||
|
@ -266,7 +266,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertTrue(hasattr(y, "sharding_spec"))
|
||||
|
||||
@check_1d_2d_mesh(set_mesh=True)
|
||||
@unittest.skipIf(jax._src.lib.version < (0, 1, 72), "Needs jaxlib 0.1.72+")
|
||||
def testAutodiff(self, mesh, resources):
|
||||
if len(mesh) != 2: return
|
||||
assert resources == ('x', 'y')
|
||||
|
@ -1821,10 +1821,7 @@ class CppPmapTest(PythonPmapTest):
|
||||
|
||||
@property
|
||||
def pmap(self):
|
||||
if jax._src.lib._xla_extension_version >= 38:
|
||||
return src_api._cpp_pmap
|
||||
else:
|
||||
return src_api._python_pmap
|
||||
return src_api._cpp_pmap
|
||||
|
||||
|
||||
class VmapOfPmapTest(jtu.JaxTestCase):
|
||||
|
@ -20,13 +20,12 @@ import re
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.tree_util import _process_pytree
|
||||
from jax import flatten_util
|
||||
import jax.numpy as jnp
|
||||
import jax._src.lib
|
||||
|
||||
|
||||
def _dummy_func(*args, **kwargs):
|
||||
return
|
||||
@ -212,9 +211,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
|
||||
def testTreedefTupleFromChildren(self):
|
||||
# https://github.com/google/jax/issues/7377
|
||||
# TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
|
||||
if jax._src.lib._xla_extension_version < 29:
|
||||
self.skipTest("fixed in future jaxlib")
|
||||
tree = ((1, 2, (3, 4)), (5,))
|
||||
leaves, treedef1 = tree_util.tree_flatten(tree)
|
||||
treedef2 = tree_util.treedef_tuple(treedef1.children())
|
||||
@ -331,8 +327,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertRegex(str(treedef), correct_string)
|
||||
|
||||
def testTreeDefWithEmptyDictStringRepresentation(self):
|
||||
if jax._src.lib._xla_extension_version < 35:
|
||||
self.skipTest("fixed in future jaxlib")
|
||||
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user