Merge branch 'google:main' into patch-1

This commit is contained in:
Samuel Agyakwa 2023-10-12 12:29:49 -07:00
commit d69e810e90
81 changed files with 927 additions and 620 deletions

View File

@ -72,6 +72,20 @@ build:cuda --@xla//xla/python:enable_gpu=true
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
build:cuda --define=xla_python_enable_gpu=true
# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
# packages.
# This has pros and cons:
# * pro: we'll ignore other CUDA installations, which has frequently confused
# users in the past. By setting RPATH, we'll always use the NVIDIA pip
# packages if they are installed.
# * con: the user cannot override the CUDA installation location
# via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
# The list of CUDA pip packages that JAX depends on are present in setup.py.
build:cuda --linkopt=-Wl,--disable-new-dtags
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --@xla//xla/python:enable_gpu=true

View File

@ -29,6 +29,7 @@ jobs:
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
env:
DATE_12_WEEKS_AGO: $(python3 -c "import datetime; print(datetime.date.today() - datetime.timedelta(weeks=12))")
ENABLE_PJRT_COMPATIBILITY: false
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
timeout-minutes: 60
defaults:
@ -59,6 +60,8 @@ jobs:
pip install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
env:
ENABLE_PJRT_COMPATIBILITY: true
pip install .
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

View File

@ -25,6 +25,13 @@ Remember to align the itemized text with the first line of an item within a list
# jaxlib 0.4.19
* Changes
* jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in `LD_LIBRARY_PATH`. If this
causes problems and the intent is to use a system-installed CUDA, the fix is
to remove the pip installed CUDA library packages.
# jax 0.4.18 (Oct 6, 2023)
# jaxlib 0.4.18 (Oct 6, 2023)

View File

@ -116,8 +116,7 @@ def _nan_check_posthook(fun, args, kwargs, output):
fun._cache_miss(*args, **kwargs)[0] # probably won't return
def _update_debug_special_global(_):
if (config.config._read("jax_debug_nans") or
config.config._read("jax_debug_infs")):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
jax_jit.global_state().post_hook = _nan_check_posthook
else:
jax_jit.global_state().post_hook = None

View File

@ -234,6 +234,7 @@ def pure_callback(
may behave in unexpected ways, particularly under transformation.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
*args: arguments to be passed to the callback function
sharding: optional sharding that specifies the device from which the callback should
be invoked.
@ -480,6 +481,7 @@ def io_callback(
more efficient execution.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
*args: arguments to be passed to the callback function
sharding: optional sharding that specifies the device from which the callback should
be invoked.

View File

@ -305,7 +305,7 @@ def compile_or_get_cached(
try:
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
config.config.jax_use_original_compilation_cache_key_generation,
config.use_original_compilation_cache_key_generation.value,
)
except xc._xla.XlaRuntimeError as ex:
logger.error("compile_or_get_cached: unable to generate cache key, "
@ -324,7 +324,7 @@ def compile_or_get_cached(
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit
# rate after it is enabled.
if config.config.jax_use_original_compilation_cache_key_generation:
if config.use_original_compilation_cache_key_generation.value:
# TODO(b/293308239) Remove metrics for the original cache after the new
# compilation cache key implementation is fully rolled out.
monitoring.record_event('/jax/compilation_cache/cache_hits_original')

View File

@ -598,6 +598,16 @@ class NameSpace:
config = Config()
_read = config._read
update = config.update
define_bool_state = config.define_bool_state
define_enum_state = config.define_enum_state
define_int_state = config.define_int_state
define_float_state = config.define_float_state
define_string_state = config.define_string_state
parse_flags_with_absl = config.parse_flags_with_absl
flags = config
FLAGS = flags.FLAGS

View File

@ -63,8 +63,8 @@ class State:
if local_device_ids:
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
config.config.update("jax_cuda_visible_devices", visible_devices)
config.config.update("jax_rocm_visible_devices", visible_devices)
config.update("jax_cuda_visible_devices", visible_devices)
config.update("jax_rocm_visible_devices", visible_devices)
self.process_id = process_id

View File

@ -1184,7 +1184,10 @@ def lower_jaxpr_to_fun(
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
# A lowering context just for function body entry/exit code.
entry_lowering_ctx = LoweringRuleContext(
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values)
module_context=ctx, primitive=None,
avals_in=[], avals_out=None,
tokens_in=TokenSet.create([]), tokens_out=None,
axis_size_env=None, dim_var_values=dim_var_values)
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = [
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
@ -1414,7 +1417,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
rule_ctx = LoweringRuleContext(
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
module_context=eqn_ctx, primitive=eqn.primitive,
avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values)
if config.dynamic_shapes.value:
@ -1973,12 +1977,14 @@ def cache_lowering(f):
The lowering will be emitted out-of-line in a separate function, together with
a call to that function. If the same primitive is called with the same shapes
and parameters, a new call to the original function will be added, without
emitting a new function.
emitting a new function. We allow for different lowering for the same
primitive for different platforms in the same module.
"""
@functools.wraps(f)
def cached_lowering(ctx, *args, **params):
assert ctx.primitive is not None
key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out),
key = (f, ctx.primitive,
tuple(ctx.avals_in), tuple(ctx.avals_out),
tuple(params.items()))
try:
func = ctx.module_context.cached_primitive_lowerings.get(key)
@ -2503,7 +2509,8 @@ def custom_call(
def reduce_window(
ctx: LoweringRuleContext, *,
ctx: LoweringRuleContext,
*,
# Base name to be used for the reducer function
reducer_name: str,
# Compute the reducer body given the reducer.

View File

@ -26,6 +26,7 @@ import numpy as np
from jax import lax
from jax import numpy as jnp
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import effects
@ -40,7 +41,6 @@ from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
_ensure_index_tuple, donation_vector,
shaped_abstractify, check_callable)
from jax._src.array import ArrayImpl
from jax._src.config import config
from jax._src.errors import JAXTypeError
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -670,8 +670,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
name=name),
source_info_util.new_source_info(), resource_env, {})
jaxpr = plan.subst_axes_with_resources(jaxpr)
use_spmd_lowering = config.experimental_xmap_spmd_lowering
ensure_fixed_sharding = config.experimental_xmap_ensure_fixed_sharding
use_spmd_lowering = _SPMD_LOWERING.value
ensure_fixed_sharding = _ENSURE_FIXED_SHARDING.value
if use_spmd_lowering and ensure_fixed_sharding:
jaxpr = _fix_inferred_spmd_sharding(jaxpr, resource_env)
@ -686,7 +686,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
mesh = resource_env.physical_mesh
tiling_method: pxla.TilingMethod
if config.experimental_xmap_spmd_lowering_manual:
if _SPMD_LOWERING_MANUAL.value:
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
tiling_method = pxla.TileManual(manual_mesh_axes)
else:
@ -1284,7 +1284,7 @@ batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap
def _xmap_lowering_rule(ctx, *args, **kwargs):
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
if config.experimental_xmap_spmd_lowering_manual:
if _SPMD_LOWERING_MANUAL.value:
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
else:
return _xmap_lowering_rule_spmd(ctx, *args, **kwargs)
@ -1404,7 +1404,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
for dim, dim_extra_axis in enumerate(extra):
if dim_extra_axis is None: continue
assert dim_extra_axis not in axes
assert not config.jax_enable_checks or all(v != dim for v in axes.values())
assert not config.enable_checks.value or all(v != dim for v in axes.values())
axes[dim_extra_axis] = dim
add_spmd_axes(mesh_in_axes, spmd_in_axes)
add_spmd_axes(mesh_out_axes, spmd_out_axes)
@ -1839,38 +1839,34 @@ def _clear_compilation_cache(_):
def _ensure_spmd_and(f):
def update(v):
if v and not config.experimental_xmap_spmd_lowering:
if v and not _SPMD_LOWERING.value:
raise RuntimeError("This flag requires enabling the experimental_xmap_spmd_lowering flag")
return f(v)
return update
try:
config.define_bool_state(
name="experimental_xmap_spmd_lowering",
default=False,
help=("When set, multi-device xmap computations will be compiled through "
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
"Not supported on CPU!"),
update_global_hook=_clear_compilation_cache,
update_thread_local_hook=_thread_local_flag_unsupported)
config.define_bool_state(
name="experimental_xmap_spmd_lowering_manual",
default=False,
help=("When set, multi-device xmap computations will be compiled using "
"the MANUAL partitioning feature of the XLA SPMD partitioner instead of "
"sharding constraints on vectorized code. "
"Requires experimental_xmap_spmd_lowering!"),
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
update_thread_local_hook=_thread_local_flag_unsupported)
config.define_bool_state(
name="experimental_xmap_ensure_fixed_sharding",
default=False,
help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will "
"try to limit the flexibility of the automated SPMD partitioner heuristics "
"by emitting additional sharding annotations for program intermediates."),
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
update_thread_local_hook=_thread_local_flag_unsupported)
except Exception:
raise ImportError("jax.experimental.maps has to be imported before JAX flags "
"are parsed")
_SPMD_LOWERING = config.define_bool_state(
name="experimental_xmap_spmd_lowering",
default=False,
help=("When set, multi-device xmap computations will be compiled through "
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
"Not supported on CPU!"),
update_global_hook=_clear_compilation_cache,
update_thread_local_hook=_thread_local_flag_unsupported)
_SPMD_LOWERING_MANUAL = config.define_bool_state(
name="experimental_xmap_spmd_lowering_manual",
default=False,
help=("When set, multi-device xmap computations will be compiled using "
"the MANUAL partitioning feature of the XLA SPMD partitioner instead of "
"sharding constraints on vectorized code. "
"Requires experimental_xmap_spmd_lowering!"),
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
update_thread_local_hook=_thread_local_flag_unsupported)
_ENSURE_FIXED_SHARDING = config.define_bool_state(
name="experimental_xmap_ensure_fixed_sharding",
default=False,
help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will "
"try to limit the flexibility of the automated SPMD partitioner heuristics "
"by emitting additional sharding annotations for program intermediates."),
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
update_thread_local_hook=_thread_local_flag_unsupported)

View File

@ -21,9 +21,9 @@ import warnings
import numpy as np
from jax import config
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import util
@ -91,11 +91,12 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
if not dtypes.safe_to_cast(y, x):
# TODO(jakevdp): change this to an error after the deprecation period.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)} "
f"with jax_numpy_dtype_promotion={config.jax_numpy_dtype_promotion!r}. "
"In future JAX releases this will result in an error.",
FutureWarning)
warnings.warn(
"scatter inputs have incompatible types: cannot safely cast value "
f"from dtype={lax.dtype(y)} to dtype={lax.dtype(x)} with "
f"jax_numpy_dtype_promotion={config.numpy_dtype_promotion.value!r}. "
"In future JAX releases this will result in an error.",
FutureWarning)
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,

View File

@ -789,44 +789,37 @@ def _dot_general_lowering_rule(
return vector.ShapeCastOp(out_type, red).result
if lhs_dims == (1,):
lhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>")
transpose_lhs = False
elif lhs_dims == (0,):
lhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (k, i)>")
transpose_lhs = True
else:
raise NotImplementedError
if rhs_dims == (0,):
rhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>")
transpose_rhs = False
elif rhs_dims == (1,):
rhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (j, k)>")
out_tile = arith.ConstantOp(
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
)
op = vector.ContractionOp(
out_type,
x,
y,
out_tile,
indexing_maps=ir.ArrayAttr.get([
lhs_dim_attr,
rhs_dim_attr,
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
]),
iterator_types=ir.ArrayAttr.get([
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<parallel>"),
ir.Attribute.parse("#vector.iterator_type<reduction>"),
]),
)
transpose_rhs = True
else:
raise NotImplementedError
if precision is not None:
if precision[0] != precision[1]:
raise NotImplementedError("Per-operand dot precision unsupported")
precision = precision[0]
if precision is None or precision == lax.Precision.DEFAULT:
pass # That's the default in Mosaic.
precision_attr = None # That's the default in Mosaic.
elif precision == lax.Precision.HIGHEST:
op.attributes["precision"] = ir.Attribute.parse(
precision_attr = ir.Attribute.parse(
"#tpu.contract_precision<fp32>"
)
else:
raise NotImplementedError(f"Unsupported dot precision: {precision}")
out_tile = arith.ConstantOp(
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
)
op = tpu.MatmulOp(
out_type, x, y, out_tile,
transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs,
precision=precision_attr
)
return op.result

View File

@ -28,6 +28,7 @@ from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import state
@ -1183,6 +1184,9 @@ def _closed_call_lowering_rule(
triton_lowering_rules[jax_core.closed_call_p] = _closed_call_lowering_rule
triton_lowering_rules[custom_derivatives.custom_jvp_call_p] = (
_closed_call_lowering_rule
)
def _remat_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr, **_):

View File

@ -42,16 +42,13 @@ from jax._src.interpreters import mlir
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
from jax._src import api
from jax._src import pjit as pjit_lib
from jax._src import config as jax_config
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes as _dtypes
from jax._src import monitoring
from jax._src import stages
from jax._src.interpreters import pxla
from jax._src.config import (bool_env, config,
raise_persistent_cache_errors,
persistent_cache_min_compile_time_secs)
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
from jax._src.util import unzip2
from jax._src.public_test_util import ( # noqa: F401
@ -64,18 +61,18 @@ from jax._src import xla_bridge
# jax.test_util. Functionality appearing here is for internal use only, and
# may be changed or removed at any time and without any deprecation cycle.
_TEST_DUT = jax_config.DEFINE_string(
_TEST_DUT = config.DEFINE_string(
'jax_test_dut', '',
help=
'Describes the device under test in case special consideration is required.'
)
_NUM_GENERATED_CASES = jax_config.DEFINE_integer(
_NUM_GENERATED_CASES = config.DEFINE_integer(
'jax_num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
help='Number of generated cases to test')
_MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
_MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer(
'max_cases_sampling_retries',
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
'Number of times a failed test sample should be retried. '
@ -83,25 +80,25 @@ _MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
'sampling process is terminated.'
)
_SKIP_SLOW_TESTS = jax_config.DEFINE_bool(
_SKIP_SLOW_TESTS = config.DEFINE_bool(
'jax_skip_slow_tests',
bool_env('JAX_SKIP_SLOW_TESTS', False),
config.bool_env('JAX_SKIP_SLOW_TESTS', False),
help='Skip tests marked as slow (> 5 sec).'
)
_TEST_TARGETS = jax_config.DEFINE_string(
_TEST_TARGETS = config.DEFINE_string(
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.'
)
_EXCLUDE_TEST_TARGETS = jax_config.DEFINE_string(
_EXCLUDE_TEST_TARGETS = config.DEFINE_string(
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.'
)
TEST_WITH_PERSISTENT_COMPILATION_CACHE = jax_config.DEFINE_bool(
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.DEFINE_bool(
'jax_test_with_persistent_compilation_cache',
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
help='If enabled, the persistent compilation cache will be enabled for all '
'test cases. This can be used to increase compilation cache coverage.')
@ -299,7 +296,7 @@ def supported_dtypes():
np.uint8, np.uint16, np.uint32, np.uint64,
_dtypes.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128}
if not config.x64_enabled:
if not config.enable_x64.value:
types -= {np.uint64, np.int64, np.float64, np.complex128}
return types
@ -532,7 +529,7 @@ def rand_fullrange(rng, standardize_nans=False):
# leads to overflows in this case; sample from signed ints instead.
if dtype == np.uint64:
vals = vals.astype(np.int64)
elif dtype == np.uint32 and not config.x64_enabled:
elif dtype == np.uint32 and not config.enable_x64.value:
vals = vals.astype(np.int32)
vals = vals.reshape(shape)
# Non-standard NaNs cause errors in numpy equality assertions.
@ -915,8 +912,8 @@ class JaxTestCase(parameterized.TestCase):
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
cls._compilation_cache_exit_stack = ExitStack()
stack = cls._compilation_cache_exit_stack
stack.enter_context(raise_persistent_cache_errors(True))
stack.enter_context(persistent_cache_min_compile_time_secs(0))
stack.enter_context(config.raise_persistent_cache_errors(True))
stack.enter_context(config.persistent_cache_min_compile_time_secs(0))
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
compilation_cache.initialize_cache(tmp_dir)
@ -963,7 +960,7 @@ class JaxTestCase(parameterized.TestCase):
self.assertDtypesMatch(x, y)
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
if not config.x64_enabled and canonicalize_dtypes:
if not config.enable_x64.value and canonicalize_dtypes:
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True),
_dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True))
else:
@ -1133,26 +1130,6 @@ def with_and_without_mesh(f):
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))
old_spmd_lowering_flag = None
def set_spmd_lowering_flag(val: bool):
global old_spmd_lowering_flag
old_spmd_lowering_flag = config.experimental_xmap_spmd_lowering
config.update('experimental_xmap_spmd_lowering', val)
def restore_spmd_lowering_flag():
if old_spmd_lowering_flag is None: return
config.update('experimental_xmap_spmd_lowering', old_spmd_lowering_flag)
old_spmd_manual_lowering_flag = None
def set_spmd_manual_lowering_flag(val: bool):
global old_spmd_manual_lowering_flag
old_spmd_manual_lowering_flag = config.experimental_xmap_spmd_lowering_manual
config.update('experimental_xmap_spmd_lowering_manual', val)
def restore_spmd_manual_lowering_flag():
if old_spmd_manual_lowering_flag is None: return
config.update('experimental_xmap_spmd_lowering_manual', old_spmd_manual_lowering_flag)
def create_global_mesh(mesh_shape, axis_names):
size = math.prod(mesh_shape)
if len(jax.devices()) < size:

View File

@ -38,7 +38,7 @@ from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager
import numpy as np
mosaic_use_cpp_passes = config.config.define_bool_state(
_MOSAIC_USE_CPP_PASSES = config.define_bool_state(
name="mosaic_use_cpp_passes",
default=False,
help=(
@ -54,13 +54,13 @@ tpu = tpu_mosaic.tpu
apply_vector_layout = tpu_mosaic.apply_vector_layout
infer_memref_layout = tpu_mosaic.infer_memref_layout
mosaic_allow_hlo = config.config.define_bool_state(
_MOSAIC_ALLOW_HLO = config.define_bool_state(
name="jax_mosaic_allow_hlo",
default=False,
help="Allow hlo dialects in Mosaic",
)
mosaic_dump_mlir = config.config.define_bool_state(
_MOSAIC_DUMP_MLIR = config.define_bool_state(
name="jax_mosaic_dump_mlir",
default=False,
help="Print mlir module after each pass",
@ -243,7 +243,7 @@ def _lower_tpu_kernel(
)
dump_mlir(module, "initial module")
if mosaic_allow_hlo.value:
if _MOSAIC_ALLOW_HLO.value:
# Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [
"hlo-legalize-to-arithmetic",
@ -255,7 +255,7 @@ def _lower_tpu_kernel(
)
dump_mlir(module, "after hlo conversion module")
if mosaic_use_cpp_passes.value:
if _MOSAIC_USE_CPP_PASSES.value:
pipeline = [
(
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
@ -278,7 +278,7 @@ def _lower_tpu_kernel(
module.operation.verify()
dump_mlir(module, "after infer vector layout pass")
if mosaic_use_cpp_passes.value:
if _MOSAIC_USE_CPP_PASSES.value:
pipeline = [
(
"func.func(tpu-apply-vector-layout{sublane-count=8"
@ -418,6 +418,6 @@ def _lowered_as_tpu_kernel(
def dump_mlir(module: ir.Module, msg: str):
"""A helper function to print mlir module with a message."""
if mosaic_dump_mlir.value:
if _MOSAIC_DUMP_MLIR.value:
print(f"[jax_mosaic_dump_mlir] {msg}")
print(module)

View File

@ -15,8 +15,6 @@
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570
from __future__ import annotations
from jax._src.core import (
AbstractToken as AbstractToken,
AbstractValue as AbstractValue,
@ -27,7 +25,6 @@ from jax._src.core import (
ConcreteArray as ConcreteArray,
ConcretizationTypeError as ConcretizationTypeError,
DShapedArray as DShapedArray,
DimSize as DimSize,
DropVar as DropVar,
Effect as Effect,
Effects as Effects,
@ -50,7 +47,6 @@ from jax._src.core import (
OutputType as OutputType,
ParamDict as ParamDict,
Primitive as Primitive,
Shape as Shape,
ShapedArray as ShapedArray,
Sublevel as Sublevel,
TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
@ -60,15 +56,11 @@ from jax._src.core import (
TraceStack as TraceStack,
TraceState as TraceState,
Tracer as Tracer,
TracerArrayConversionError as TracerArrayConversionError,
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
UnshapedArray as UnshapedArray,
Value as Value,
Var as Var,
abstract_token as abstract_token,
apply_todos as apply_todos,
as_hashable_function as as_hashable_function,
as_named_shape as as_named_shape,
aval_mapping_handlers as aval_mapping_handlers,
axis_frame as axis_frame,
@ -82,7 +74,6 @@ from jax._src.core import (
check_type as check_type,
check_valid_jaxtype as check_valid_jaxtype,
closed_call_p as closed_call_p,
collections as collections,
concrete_aval as concrete_aval,
concrete_or_error as concrete_or_error,
concretization_function_error as concretization_function_error,
@ -92,7 +83,6 @@ from jax._src.core import (
definitely_equal as definitely_equal, # TODO(necula): remove this API
dimension_as_value as dimension_as_value, # TODO(necula): remove this API
do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr,
dtypes as dtypes,
ensure_compile_time_eval as ensure_compile_time_eval,
escaped_tracer_error as escaped_tracer_error,
eval_context as eval_context,
@ -114,13 +104,10 @@ from jax._src.core import (
lattice_join as lattice_join,
leaked_tracer_error as leaked_tracer_error,
literalable_types as literalable_types,
lu as lu,
map as map,
map_bind as map_bind,
map_bind_with_continuation as map_bind_with_continuation,
mapped_aval as mapped_aval,
maybe_find_leaked_tracers as maybe_find_leaked_tracers,
namedtuple as namedtuple,
new_base_main as new_base_main,
new_jaxpr_eqn as new_jaxpr_eqn,
new_main as new_main,
@ -128,8 +115,6 @@ from jax._src.core import (
no_axis_name as no_axis_name,
no_effects as no_effects,
outfeed_primitives as outfeed_primitives,
partial as partial,
pp as pp,
pp_aval as pp_aval,
pp_eqn as pp_eqn,
pp_eqn_rules as pp_eqn_rules,
@ -150,11 +135,7 @@ from jax._src.core import (
raise_as_much_as_possible as raise_as_much_as_possible,
raise_to_shaped as raise_to_shaped,
raise_to_shaped_mappings as raise_to_shaped_mappings,
ref as ref,
reset_trace_state as reset_trace_state,
safe_map as safe_map,
safe_zip as safe_zip,
source_info_util as source_info_util,
stash_axis_env as stash_axis_env,
str_eqn_compact as str_eqn_compact,
subjaxprs as subjaxprs,
@ -165,12 +146,8 @@ from jax._src.core import (
substitute_vars_in_output_ty as substitute_vars_in_output_ty,
thread_local_state as thread_local_state,
token as token,
total_ordering as total_ordering,
trace_state_clean as trace_state_clean,
traceback_util as traceback_util,
traverse_jaxpr_params as traverse_jaxpr_params,
tuple_delete as tuple_delete,
tuple_insert as tuple_insert,
typecheck as typecheck,
typecompat as typecompat,
typematch as typematch,
@ -178,7 +155,130 @@ from jax._src.core import (
used_axis_names as used_axis_names,
used_axis_names_jaxpr as used_axis_names_jaxpr,
valid_jaxtype as valid_jaxtype,
zip as zip,
)
symbolic_equal_dim = definitely_equal # TODO(necula): remove this API
from jax._src import core as _src_core
_deprecations = {
# Added Oct 11, 2023:
"DimSize": (
"jax.core.DimSize is deprecated. Use DimSize = int | Any.",
_src_core.DimSize,
),
"Shape": (
"jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
_src_core.Shape,
),
"TracerArrayConversionError": (
"jax.core.TracerArrayConversionError is deprecated. Use jax.errors.TracerArrayConversionError",
_src_core.TracerArrayConversionError,
),
"TracerIntegerConversionError": (
"jax.core.TracerIntegerConversionError is deprecated. Use jax.errors.TracerIntegerConversionError",
_src_core.TracerIntegerConversionError,
),
"UnexpectedTracerError": (
"jax.core.UnexpectedTracerError is deprecated. Use jax.errors.UnexpectedTracerError",
_src_core.UnexpectedTracerError,
),
"as_hashable_function": (
"jax.core.as_hashable_function is deprecated. Use jax.util.as_hashable_function directly.",
_src_core.as_hashable_function,
),
"collections": (
"jax.core.collections is deprecated. Use the collections module directly.",
_src_core.collections,
),
"dtypes": (
"jax.core.dtypes is deprecated. Use jax.dtypes directly.",
_src_core.dtypes,
),
"lu": (
"jax.core.lu is deprecated. Use lu = jax.extend.linear_util",
_src_core.lu,
),
"map": (
"jax.core.map is deprecated. Use the built-in map function.",
_src_core.map,
),
"namedtuple": (
"jax.core.namedtuple is deprecated. Use collections.namedtuple directly.",
_src_core.namedtuple,
),
"partial": (
"jax.core.partial is deprecated. Use functools.partial directly.",
_src_core.partial,
),
"pp": (
"jax.core.pp is deprecated. jax._src.pretty_printer is a non-public API.",
_src_core.pp,
),
"ref": (
"jax.core.ref is deprecated. Use weakref.ref directly.",
_src_core.ref,
),
"safe_map": (
"jax.core.safe_map is deprecated. Use jax.util.safe_map directly.",
_src_core.safe_map,
),
"safe_zip": (
"jax.core.safe_zip is deprecated. Use jax.util.safe_zip directly.",
_src_core.safe_zip,
),
"source_info_util": (
"jax.core.source_info_util is deprecated. jax._src.source_info_util is a non-public API.",
_src_core.source_info_util,
),
"total_ordering": (
"jax.core.total_ordering is deprecated. Use functools.total_ordering directly.",
_src_core.total_ordering,
),
"traceback_util": (
"jax.core.traceback_util is deprecated. jax._src.traceback_util is a non-public API.",
_src_core.traceback_util,
),
"tuple_delete": (
"jax.core.tuple_delete is deprecated. Use tuple_delete = lambda t, i: (*t[:i], *t[i+1:])",
_src_core.tuple_delete,
),
"tuple_insert": (
"jax.core.tuple_insert is deprecated. Use tuple_insert = lambda t, v, i: (*t[:i], v, *t[i:])",
_src_core.tuple_insert,
),
"zip": (
"jax.core.zip is deprecated. Use the built-in zip function.",
_src_core.zip,
),
}
import typing
if typing.TYPE_CHECKING:
DimSize = _src_core.DimSize
Shape = _src_core.Shape
TracerArrayConversionError = _src_core.TracerArrayConversionError
TracerIntegerConversionError = _src_core.TracerIntegerConversionError
UnexpectedTracerError = _src_core.UnexpectedTracerError
as_hashable_function = _src_core.as_hashable_function
collections = _src_core.collections
dtypes = _src_core.dtypes
lu = _src_core.lu
map = _src_core.map
namedtuple = _src_core.namedtuple
partial = _src_core.partial
pp = _src_core.pp
ref = _src_core.ref
safe_map = _src_core.safe_map
safe_zip = _src_core.safe_zip
source_info_util = _src_core.source_info_util
total_ordering = _src_core.total_ordering
traceback_util = _src_core.traceback_util
tuple_delete = _src_core.tuple_delete
tuple_insert = _src_core.tuple_insert
zip = _src_core.zip
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _src_core

View File

@ -28,9 +28,9 @@ from absl import logging
import numpy as np
import jax
from jax import config
from jax import sharding
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src.interpreters import mlir
@ -116,6 +116,8 @@ class DisabledSafetyCheck:
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 8
Sharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
@dataclasses.dataclass(frozen=True)
class Exported:
"""A JAX function lowered to StableHLO.
@ -133,9 +135,8 @@ class Exported:
expressions in the shapes, with dimension variables among those in
`in_avals.
in_shardings: the flattened input shardings. Only for the inputs that are
specified in `module_kept_var_idx`. If `None` then it is equivalent
to unspecified shardings.
out_shardings: the flattened output shardings, as long as `in_avals`.
specified in `module_kept_var_idx`.
out_shardings: the flattened output shardings, as long as `out_avals`.
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
'cuda', 'rocm'. See below for the calling convention for when
there are multiple lowering platforms.
@ -226,8 +227,8 @@ class Exported:
out_tree: tree_util.PyTreeDef
out_avals: tuple[core.AbstractValue, ...]
in_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
out_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
in_shardings: tuple[Sharding, ...]
out_shardings: tuple[Sharding, ...]
lowering_platform: str # For backwards compatibility
lowering_platforms: tuple[str, ...]
disabled_checks: Sequence[DisabledSafetyCheck]
@ -386,7 +387,7 @@ def export(fun_jax: Callable,
exported = jax_export.export(f_jax)(*args, **kwargs)
"""
fun_name = getattr(fun_jax, "__name__", "unknown")
version = config.jax_serialization_version
version = config.jax_serialization_version.value
if (version < minimum_supported_serialization_version or
version > maximum_supported_serialization_version):
raise ValueError(
@ -826,14 +827,64 @@ def _check_module(mod: ir.Module, *,
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
raise ValueError(msg)
def expand_in_shardings(in_shardings: tuple[Sharding, ...],
module_kept_var_idx: Sequence[int],
nr_inputs: int) -> tuple[Sharding, ...]:
"""Expands in_shardings with unspecified shardings for inputs not kept.
Assumes in_shardings corresponds to module_kept_var_idx.
"""
assert len(in_shardings) == len(module_kept_var_idx)
assert nr_inputs >= len(module_kept_var_idx)
all_in_shardings: list[Sharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
all_in_shardings[idx] = in_s
return tuple(all_in_shardings)
# TODO(yashkatariya, necula): remove this function once we relax the checks
# in the jit front-end.
def canonical_shardings(
in_shardings: Sequence[Sharding],
out_shardings: Sequence[Sharding]
) -> tuple[Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]],
Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]]]:
"""Prepares canonical in_ and out_shardings for a jit invocation.
The pjit front-end is picky about what in- and out-shardings it accepts,
e.g., if all are unspecified then the whole sharding should be the
sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are
replaced with the replicated sharding.
"""
# Prepare a replicated sharding, search in both the input and output shardings
specified_shardings = [
s for s in itertools.chain(in_shardings, out_shardings)
if not sharding_impls.is_unspecified(s)]
if specified_shardings:
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
else:
replicated_s = None
def canonicalize(
ss: Sequence[Sharding]) -> Union[pxla.UnspecifiedValue,
Sequence[sharding.XLACompatibleSharding]]:
if all(sharding_impls.is_unspecified(s) for s in ss):
return sharding_impls.UNSPECIFIED
return tuple(
s if not sharding_impls.is_unspecified(s) else replicated_s
for s in ss)
return (canonicalize(in_shardings), canonicalize(out_shardings))
def _get_vjp_fun(primal_fun: Callable, *,
in_tree: tree_util.PyTreeDef,
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue],
module_kept_var_idx: tuple[int, ...],
in_shardings,
out_shardings,
in_shardings: tuple[Sharding, ...],
out_shardings: tuple[Sharding, ...],
apply_jit: bool
) -> tuple[Callable, Sequence[core.AbstractValue]]:
# Since jax.vjp does not handle kwargs, it is easier to do all the work
@ -855,36 +906,11 @@ def _get_vjp_fun(primal_fun: Callable, *,
itertools.chain(in_avals,
map(lambda a: a.at_least_vspace(), out_avals)))
# Expand in_shardings to all in_avals even not kept ones.
all_in_shardings = [sharding_impls.UNSPECIFIED] * len(in_avals)
for idx, in_s in zip(sorted(module_kept_var_idx),
in_shardings): # type: ignore
all_in_shardings[idx] = in_s # type: ignore
all_shardings = all_in_shardings + list(out_shardings) # type: ignore
# Cannot mix unspecified and specified shardings. Make the unspecified
# ones replicated.
specified_shardings = [
s for s in all_shardings if not sharding_impls.is_unspecified(s)]
vjp_in_shardings: Any # The primal inputs followed by output cotangents
vjp_out_shardings: Any # The primal output cotangents
if 0 == len(specified_shardings):
vjp_in_shardings = sharding_impls.UNSPECIFIED
vjp_out_shardings = sharding_impls.UNSPECIFIED
else:
if len(specified_shardings) < len(all_shardings):
# There are some specified, but not all; pjit front-end does not liwk
in_s = specified_shardings[0] # pjit will enforce that all have same devices
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
all_shardings = [
s if not sharding_impls.is_unspecified(s) else replicated_s
for s in all_shardings]
vjp_in_shardings = tuple(all_shardings)
vjp_out_shardings = tuple(all_shardings[:len(in_avals)])
if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings):
vjp_out_shardings = sharding_impls.UNSPECIFIED
all_in_shardings = expand_in_shardings(in_shardings,
module_kept_var_idx, len(in_avals))
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
tuple(itertools.chain(all_in_shardings, out_shardings)),
all_in_shardings)
if apply_jit:
return pjit.pjit(fun_vjp_jax,
@ -1037,6 +1063,13 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True
# Apply in_shardings
all_in_shardings = expand_in_shardings(exported.in_shardings,
exported.module_kept_var_idx,
len(args))
args = tuple(
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, all_in_shardings))
submodule = ir.Module.parse(exported.mlir_module())
symtab = ir.SymbolTable(submodule.operation)
# The called function may have been exported with polymorphic shapes and called
@ -1072,12 +1105,44 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
kept_args)
# The ctx.avals_out already contain the abstract values refined by
# _call_exported_abstract_eval.
return tuple(
results = tuple(
convert_shape(out, out_aval, refined_out_aval)
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
# Apply out_shardings
results = tuple(
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
)
return results
for _p in ("cpu", "tpu", "cuda", "rocm"):
mlir.register_lowering(call_exported_p,
functools.partial(_call_exported_lowering, platform=_p),
platform=_p)
def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
exported: Exported,
x: ir.Value,
x_aval: core.AbstractValue,
x_sharding: Sharding) -> ir.Value:
if sharding_impls.is_unspecified(x_sharding):
return x
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
ctx_device_assignment = axis_context.device_assignment
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
ctx_device_assignment = list(axis_context.mesh.devices.flat)
else:
raise NotImplementedError(type(axis_context))
assert isinstance(x_sharding, sharding_impls.XLACompatibleSharding)
sharding_device_assignment = x_sharding._device_assignment
if len(ctx_device_assignment) != len(sharding_device_assignment):
raise NotImplementedError(
f"Exported module {exported.fun_name} was lowered for "
f"{len(sharding_device_assignment)} devices and is called in a context with "
f"{len(ctx_device_assignment)} devices"
)
return mlir.wrap_with_sharding_op(
ctx, x, x_aval,
x_sharding._to_xla_hlo_sharding(x_aval.ndim).to_proto())

View File

@ -48,9 +48,9 @@ import numpy as np
import opt_einsum
import jax
from jax import config
from jax.interpreters import xla
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import effects
@ -651,7 +651,7 @@ class _DimExpr():
raise InconclusiveDimensionOperation("")
remainder = 0
if config.jax_enable_checks:
if config.enable_checks.value:
assert self == divisor * quotient + remainder
return quotient, remainder
except InconclusiveDimensionOperation:

View File

@ -17,14 +17,14 @@ import os
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax.experimental.jax2tf.examples import saved_model_main
from jax.experimental.jax2tf.tests import tf_test_util
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
FLAGS = flags.FLAGS

View File

@ -26,7 +26,6 @@ from absl.testing import absltest, parameterized
import numpy as np
import jax
from jax import config
from jax import lax
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu
@ -60,6 +59,7 @@ import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
config.parse_flags_with_absl()
@ -159,7 +159,7 @@ class CompatTest(bctu.CompatTestBase):
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
@ -179,7 +179,7 @@ class CompatTest(bctu.CompatTestBase):
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_eig_lapack_geev(self, dtype_name="f32"):
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
@ -271,7 +271,7 @@ class CompatTest(bctu.CompatTestBase):
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
# For lax.linalg.eigh
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
@ -327,7 +327,7 @@ class CompatTest(bctu.CompatTestBase):
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
# For lax.linalg.qr
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
@ -395,7 +395,7 @@ class CompatTest(bctu.CompatTestBase):
for dtype_name in ("f32", "f64", "c64", "c128"))
def test_cpu_lu_lapack_getrf(self, dtype_name:str):
# For lax.linalg.lu on CPU.
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
c64=np.complex64, c128=np.complex128)[dtype_name]
@ -480,7 +480,7 @@ class CompatTest(bctu.CompatTestBase):
for dtype_name in ("f32", "f64", "c64", "c128")])
@jax.default_matmul_precision("float32")
def test_cpu_schur_lapack_gees(self, dtype_name="f32"):
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
@ -509,7 +509,7 @@ class CompatTest(bctu.CompatTestBase):
for dtype_name in ("f32", "f64", "c64", "c128"))
@jax.default_matmul_precision("float32")
def test_cpu_svd_lapack_gesdd(self, dtype_name="f32"):
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
@ -534,7 +534,7 @@ class CompatTest(bctu.CompatTestBase):
for dtype_name in ("f32", "f64", "c64", "c128")])
@jax.default_matmul_precision("float32")
def test_cpu_triangular_solve_blas_trsm(self, dtype_name="f32"):
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
self.skipTest("Test disabled for x32 mode")
dtype = dict(f32=np.float32, f64=np.float64,
@ -669,16 +669,11 @@ class CompatTest(bctu.CompatTestBase):
# replace this strict check with something else.
data = self.load_testdata(stablehlo_dynamic_rng_bit_generator.data_2023_06_17)
prev_default_prng_impl = jax.config.jax_default_prng_impl
try:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
with config.default_prng_impl("unsafe_rbg"):
self.run_one_test(
func, data, polymorphic_shapes=(None, "b0, b1"),
# Recent serializations also include shape_assertion, tested with dynamic_top_k
expect_current_custom_calls=["stablehlo.dynamic_rng_bit_generator", "shape_assertion"])
finally:
jax.config.update("jax_default_prng_impl", prev_default_prng_impl)
def test_stablehlo_dynamic_top_k(self):
# stablehlo.dynamic_top_k is used temporarily for a top_k with dynamism

View File

@ -20,9 +20,9 @@ from typing import Any, Callable, Optional, Union
import jax
from jax import lax
from jax import numpy as jnp
from jax import config
from jax._src import test_util as jtu
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax.experimental.jax2tf.tests import primitive_harness
import numpy as np
@ -114,7 +114,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
"""Checks if this limitation is enabled for dtype and device and mode."""
native_serialization_mask = (
Jax2TfLimitation.FOR_NATIVE
if config.jax2tf_default_native_serialization
if config.jax2tf_default_native_serialization.value
else Jax2TfLimitation.FOR_NON_NATIVE)
return ((mode is None or mode in self.modes) and
(self.native_serialization & native_serialization_mask) and

View File

@ -19,7 +19,6 @@ import contextlib
import math
import os
import re
from typing import Callable, Optional
import unittest
from absl import logging
@ -31,20 +30,17 @@ from jax import dtypes
from jax import lax
from jax import numpy as jnp
from jax import sharding
from jax._src import config
from jax._src import core
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src.lib import xla_client
from jax._src.lib.mlir.dialects import hlo
import jax._src.xla_bridge
from jax import config
from jax._src import xla_bridge as xb
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.sharding import PartitionSpec as P
import numpy as np
@ -233,7 +229,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
with_function=[True, False],
)
def test_converts_64bit(self, dtype=np.int64, with_function=False):
if not config.jax_enable_x64:
if not config.enable_x64.value:
self.skipTest("requires x64 mode")
big_const = np.full((5,), 2 ** 33, dtype=dtype)
self.ConvertAndCompare(jnp.sin, big_const)
@ -249,7 +245,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_64bit_behavior_enable_x64_readme(self):
# Tests some of the examples from the README
if not config.jax_enable_x64:
if not config.enable_x64.value:
self.skipTest("requires x64 mode")
# JAX and TF have different default float types if JAX_ENABLE_X64=1
@ -266,7 +262,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_64bit_behavior_not_enable_x64_readme(self):
# Tests some of the examples from the README
if config.jax_enable_x64:
if config.enable_x64.value:
self.skipTest("requires not x64 mode")
# JAX and TF have same default float types if JAX_ENABLE_X64=0
@ -817,7 +813,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
arg = np.array(3.)
f_tf = jax2tf.convert(jax.grad(remat_f))
f_tf_hlo = self.TfToHlo(f_tf, arg)
if jax.config.jax_remat_opt_barrier:
if config.remat_opt_barrier.value:
self.assertRegex(f_tf_hlo, r"opt-barrier")
else:
self.assertRegex(f_tf_hlo,
@ -849,7 +845,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
f_tf = jax2tf.convert(f_jax)
# for native serialization the HLO we get from TF is constant-folded, so this
# test fails.
if not config.jax2tf_default_native_serialization:
if not config.jax2tf_default_native_serialization.value:
self.assertIn("sine(", self.TfToHlo(f_tf))
def test_convert_of_nested_independent_jit(self):
@ -954,7 +950,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
out = jax2tf.convert(caller_jax, with_gradient=False)(2.)
return out
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
else:
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
@ -984,7 +980,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_shared_constants(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
def f(x):
@ -996,7 +992,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_shared_constants_under_cond(self):
# Check that the constants are shared properly in converted functions
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const_size = 512
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
@ -1012,7 +1008,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_shared_constants_under_scan(self):
# See https://github.com/google/jax/issues/7992.
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const_size = 512
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
@ -1031,7 +1027,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
def test_shared_constants_under_jit(self):
# We do not share constants under jit.
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
const = np.random.uniform(size=(16, 16)).astype(np.float32) # A shared constant
@jax.jit
@ -1047,7 +1043,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
# randint has the property that the TF lowering of the randbits_p
# primitive generates constants that did not exist in the Jaxpr. As such
# it has created new errors related to the sharing of the constants.
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
key = jax.random.PRNGKey(42)
@ -1296,7 +1292,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
jax_comp = jax.xla_computation(f_while)(x)
backend = jax._src.xla_bridge.get_backend()
backend = xb.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()
print(f"JAX OPT HLO = {jax_opt_hlo}")
@ -1342,7 +1338,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.fail(f"{op.name} does not start with {scope_name}.")
def test_name_scope_polymorphic(self):
if config.jax2tf_default_native_serialization and not config.jax_dynamic_shapes:
if (config.jax2tf_default_native_serialization.value and
not config.dynamic_shapes.value):
self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
def func_jax(x, y):

View File

@ -28,9 +28,10 @@ import unittest
from absl.testing import absltest
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src import maps # Needed for config flags.
from jax import config
import numpy as np
@ -139,7 +140,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
raise unittest.SkipTest("Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation")
# The CPU/GPU have more supported types than TPU.
self.assertEqual("cpu", jtu.device_under_test(), "The documentation can be generated only on CPU")
self.assertTrue(config.x64_enabled, "The documentation must be generated with JAX_ENABLE_X64=1")
self.assertTrue(config.enable_x64.value, "The documentation must be generated with JAX_ENABLE_X64=1")
with open(os.path.join(os.path.dirname(__file__),
'../g3doc/jax_primitives_coverage.md.template')) as f:

View File

@ -29,14 +29,6 @@ from jax.experimental.jax2tf.tests import primitive_harness
def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
return re.compile("(" + "|".join(parts) + ")")
# TODO(necula): Failures to be investigated (on multiple platforms)
_known_failures = make_disjunction_regexp(
"cumsum_",
"cumprod_",
)
# TODO(necula): Failures to be investigated (on GPU).
_known_failures_gpu = make_disjunction_regexp(
# Failures due to failure to export custom call targets for GPU, these
@ -85,13 +77,8 @@ class PrimitiveTest(jtu.JaxTestCase):
#one_containing="",
)
def test_prim(self, harness: primitive_harness.Harness):
if (
_known_failures.search(harness.fullname)
or (
jtu.device_under_test() == "gpu"
and _known_failures_gpu.search(harness.fullname)
)
):
if (jtu.device_under_test() == "gpu"
and _known_failures_gpu.search(harness.fullname)):
self.skipTest("failure to be investigated")
func_jax = harness.dyn_fun
@ -108,13 +95,7 @@ class PrimitiveTest(jtu.JaxTestCase):
for d in self.__class__.devices
if d.platform not in unimplemented_platforms
]
logging.info(
"Using devices %s",
[
(str(d), d.platform, d.device_kind, d.client.platform)
for d in devices
],
)
logging.info("Using devices %s", [str(d) for d in devices])
# lowering_platforms uses "cuda" instead of "gpu"
lowering_platforms: list[str] = [
p if p != "gpu" else "cuda"

View File

@ -67,8 +67,7 @@ from jax._src import random as jax_random
# then the test file has to import jtu first (to define the flags) which is not
# desired if the test file is outside of this project (we don't want a
# dependency on jtu outside of jax repo).
jax.config.parse_flags_with_absl()
FLAGS = config.FLAGS
config.parse_flags_with_absl()
Rng = Any # A random number generator
DType = Any

View File

@ -72,7 +72,7 @@ from jax._src.interpreters import xla
import numpy as np
import tensorflow as tf # type: ignore[import]
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
# Import after parsing flags
from jax.experimental.jax2tf.tests import tf_test_util

View File

@ -38,6 +38,7 @@ from jax import lax
import jax.numpy as jnp
from jax import random
from jax import tree_util
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import util
@ -51,9 +52,6 @@ from jax.experimental.jax2tf.tests import tf_test_util
import tensorflow as tf # type: ignore[import]
from jax import config
from jax._src.config import numpy_dtype_promotion
config.parse_flags_with_absl()
# Import after parsing flags
@ -723,7 +721,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
expected_output_signature=(
# for native serialization we cannot refine the inferred shape of the
# output if the input is more specific than polymorphic_shapes.
tf.TensorSpec([2, 3]) if not config.jax2tf_default_native_serialization
tf.TensorSpec([2, 3]) if not config.jax2tf_default_native_serialization.value
else tf.TensorSpec([2, None])))
check_shape_poly(self,
@ -1300,7 +1298,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# for native serialization we cannot refine the inferred shape of the
# output if the input is more specific than polymorphic_shapes.
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0]))
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1]))
else:
@ -1424,9 +1422,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
def test_prng(self):
# The PRNG implementation uses opaque types, test shape polymorphism
try:
prev_custom_prng = config.jax_enable_custom_prng
config.update("jax_enable_custom_prng", True)
with config.enable_custom_prng(True):
def f_jax(x): # x: f32[b1, b2]
key = random.PRNGKey(123) # key: key<fry>[]
@ -1452,8 +1448,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
check_shape_poly(self, f_jax,
arg_descriptors=[RandArg((3, 4), _f32)],
polymorphic_shapes=["b1, b2"])
finally:
config.update("jax_enable_custom_prng", prev_custom_prng)
def test_saved_model(self):
f_jax = jnp.sin
@ -1788,7 +1782,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# should be a symbolic dimension
self.assertTrue(isinstance(res, int) or hasattr(res, "dimension_as_value"))
if config.jax_enable_x64:
if config.enable_x64.value:
# Outside jax2tf, x.shape[0] is a Python (64-bit) integer and for most
# operations here JAX is not involved at all because the other operand
# is a Python or NumPy constant. So the result will be 64-bits. But under
@ -1833,7 +1827,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
self.assertTrue(d1.aval.weak_type), d1
return d0 + np.array(5., dtype=np.float32) + d1 + x[0]
with numpy_dtype_promotion("strict"):
with config.numpy_dtype_promotion("strict"):
# strict type promotion is sensitive to weak_types
check_shape_poly(self,
f_jax,
@ -2236,7 +2230,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
polymorphic_shapes=[poly],
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization)
check_result=config.jax2tf_default_native_serialization.value)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for poly in ["b, ...", "b, w, w"]
for left in ([True, False] if dtype == np.float32 else [True])
@ -2523,7 +2517,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
lambda x, full_matrices: lax.linalg.qr(x, full_matrices=full_matrices),
arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)],
polymorphic_shapes=[poly],
tol=(None if config.jax2tf_default_native_serialization else 1e-5))
tol=(None if config.jax2tf_default_native_serialization.value else 1e-5))
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
# m and n must be static for now
for shape, poly, full_matrices in [
@ -2793,7 +2787,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
polymorphic_shapes=[poly],
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization)
check_result=config.jax2tf_default_native_serialization.value)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for compute_schur_vectors in [True, False]
for (shape, poly) in [
@ -2914,7 +2908,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
polymorphic_shapes=[a_poly, b_poly],
# In non-native serialization, we cannot check exact match,
# we ought to check the invariants of the result.
check_result=config.jax2tf_default_native_serialization)
check_result=config.jax2tf_default_native_serialization.value)
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
for (left_side, a_shape, b_shape, a_poly, b_poly) in [
(True, (3, 4, 4), (3, 4, 5), "b, ...", "b, ..."),
@ -3065,14 +3059,14 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
)
def test_harness(self, harness: PolyHarness):
if harness.expect_error == expect_error_associative_scan and (
not config.jax2tf_default_native_serialization
not config.jax2tf_default_native_serialization.value
or jtu.test_device_matches(["tpu"])
):
harness.expect_error = (None, None)
# Exclude some harnesses that are known to fail for native serialization
# FOR NATIVE SERIALIZATION
if config.jax2tf_default_native_serialization:
if config.jax2tf_default_native_serialization.value:
if not harness.enable_xla:
raise unittest.SkipTest("disabled for native_serialization and enable_xla=False")
@ -3123,7 +3117,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
# FOR GRAPH SERIALIZATION
if not config.jax2tf_default_native_serialization:
if not config.jax2tf_default_native_serialization.value:
if ("random_gamma_threefry_non_partitionable" in harness.fullname and
jtu.test_device_matches(["cpu"])):
harness.tol = 1e-6

View File

@ -30,8 +30,11 @@ import unittest
from absl.testing import absltest
import jax
from jax._src import compiler
from jax._src import config
from jax._src import maps
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax import lax
from jax.experimental import jax2tf
from jax.experimental import pjit
@ -40,19 +43,18 @@ from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
from jax._src import compiler
from jax._src import xla_bridge
import numpy as np
import tensorflow as tf # type: ignore[import]
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
# Must come after initializing the flags
from jax.experimental.jax2tf.tests import tf_test_util
prev_xla_flags = None
prev_spmd_lowering_flag = None
topology = None
@ -74,7 +76,9 @@ def setUpModule():
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
jtu.set_spmd_lowering_flag(True)
global prev_spmd_lowering_flag
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)
def tearDownModule():
@ -83,7 +87,7 @@ def tearDownModule():
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
jtu.restore_spmd_lowering_flag()
config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag)
class ShardingTest(tf_test_util.JaxToTfTestCase):

View File

@ -29,7 +29,6 @@ import jax
from jax import lax
from jax import tree_util
from jax import vmap
from jax import config
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import (
nfold_vmap, _count_stored_elements,
@ -41,6 +40,7 @@ from jax._src.interpreters import mlir
import jax.numpy as jnp
from jax.util import safe_zip, unzip2, split_list
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src.interpreters import ad
@ -784,7 +784,7 @@ def _bcoo_dot_general_fallback(data, indices, spinfo):
def _bcoo_dot_general_gpu_impl(lhs_data, lhs_indices, rhs, *,
dimension_numbers, preferred_element_type,
lhs_spinfo):
if not config.jax_bcoo_cusparse_lowering:
if not config.bcoo_cusparse_lowering.value:
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers,
preferred_element_type=preferred_element_type,

View File

@ -25,7 +25,6 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax import config
from jax import lax
from jax import tree_util
from jax.experimental.sparse._base import JAXSparse
@ -37,6 +36,7 @@ from jax.experimental.sparse.util import (
from jax.util import split_list, safe_zip
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums
@ -623,7 +623,7 @@ def _bcsr_dot_general_gpu_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers,
preferred_element_type, lhs_spinfo: SparseInfo):
if not config.jax_bcoo_cusparse_lowering:
if not config.bcoo_cusparse_lowering.value:
return _bcsr_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
dimension_numbers=dimension_numbers,

View File

@ -144,7 +144,7 @@ setup(
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add an version constraint
# here.
"nvidia-nvjitlink=cu12>=12.2",
"nvidia-nvjitlink-cu12>=12.2",
],
# Target that does not depend on the CUDA pip wheels, for those who want

View File

@ -40,14 +40,13 @@ import weakref
from absl import logging
from absl.testing import absltest, parameterized
import jax
from jax import config
from jax import custom_derivatives as custom_derivatives_public
from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit
from jax import lax
from jax import tree_util
from jax._src import api, api_util, dtypes, lib
from jax._src import array
from jax._src import config as config_internal
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
@ -76,7 +75,6 @@ from jax.sharding import PartitionSpec as P
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def _check_instance(self, x):
@ -233,7 +231,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
assert len(side) == 2 # but should still cache
f(one, two, z=np.zeros(3)) # doesn't crash
if config.x64_enabled:
if config.enable_x64.value:
# In the above call, three is of a new type (int64), thus it should
# trigger a new compilation.
assert len(side) == 3
@ -879,7 +877,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# TODO(frostig): remove `wrap` once we always enable_custom_prng
def wrap(arr):
arr = np.array(arr, dtype=np.uint32)
if config.jax_enable_custom_prng:
if config.enable_custom_prng.value:
return prng.random_wrap(arr, impl=jax.random.default_prng_impl())
else:
return arr
@ -1057,7 +1055,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
compiled = lowered.compile()
self.assertAllClose(compiled(1.), 2.)
self.assertEqual(lowered.in_avals, compiled.in_avals)
expected_dtype = np.float64 if config.x64_enabled else np.float32
expected_dtype = np.float64 if config.enable_x64.value else np.float32
for obj in [lowered, compiled]:
self.assertEqual(
obj.in_avals,
@ -2913,7 +2911,7 @@ class APITest(jtu.JaxTestCase):
def test_dtype_warning(self):
# cf. issue #1230
if config.x64_enabled:
if config.enable_x64.value:
raise unittest.SkipTest("test only applies when x64 is disabled")
def check_warning(warn, nowarn):
@ -4029,25 +4027,15 @@ class APITest(jtu.JaxTestCase):
def test_dot_precision_flag(self):
x = jnp.zeros((2, 2))
prev_val = config._read("jax_default_matmul_precision")
try:
config.FLAGS.jax_default_matmul_precision = "tensorfloat32"
with config.default_matmul_precision("tensorfloat32"):
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally:
config.FLAGS.jax_default_matmul_precision = prev_val
self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
prev_val = config._read("jax_default_matmul_precision")
try:
config.update('jax_default_matmul_precision','tensorfloat32')
with config.default_matmul_precision("tensorfloat32"):
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally:
config.update('jax_default_matmul_precision', prev_val)
self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
def test_dot_precision_forces_retrace(self):
num_traces = 0
@ -4067,7 +4055,7 @@ class APITest(jtu.JaxTestCase):
for f in [f_jit, f_cond]:
# Use _read() to read the flag value rather than threadlocal value.
precision = config._read('jax_default_matmul_precision')
precision = config._read("jax_default_matmul_precision")
try:
num_traces = 0
x = jnp.zeros((2, 2))
@ -4078,7 +4066,7 @@ class APITest(jtu.JaxTestCase):
with jax.default_matmul_precision("tensorfloat32"):
f(x)
self.assertEqual(num_traces, 2)
FLAGS.jax_default_matmul_precision = "float32"
config.update("jax_default_matmul_precision", "float32")
f(x)
self.assertGreaterEqual(num_traces, 2)
nt = num_traces
@ -4087,7 +4075,7 @@ class APITest(jtu.JaxTestCase):
f(x)
self.assertEqual(num_traces, nt + 1)
finally:
FLAGS.jax_default_matmul_precision = precision
config.update("jax_default_matmul_precision", precision)
def test_backward_pass_ref_dropping(self):
refs = []
@ -4239,9 +4227,9 @@ class APITest(jtu.JaxTestCase):
for f in [f_jit, f_cond]:
# Use _read() to read the flag value rather than threadlocal value.
allow_promotion = config._read('jax_numpy_rank_promotion')
allow_promotion = config._read("jax_numpy_rank_promotion")
try:
FLAGS.jax_numpy_rank_promotion = "allow"
config.update("jax_numpy_rank_promotion", "allow")
num_traces = 0
@jax.jit
def f(x):
@ -4256,7 +4244,7 @@ class APITest(jtu.JaxTestCase):
with jax.numpy_rank_promotion("warn"):
f(x)
self.assertEqual(num_traces, 2)
FLAGS.jax_numpy_rank_promotion = "raise"
config.update("jax_numpy_rank_promotion", "raise")
f(x)
self.assertGreaterEqual(num_traces, 2)
nt = num_traces
@ -4265,7 +4253,7 @@ class APITest(jtu.JaxTestCase):
f(x)
self.assertEqual(num_traces, nt + 1)
finally:
FLAGS.jax_numpy_rank_promotion = allow_promotion
config.update("jax_numpy_rank_promotion", allow_promotion)
def test_grad_negative_argnums(self):
def f(x, y):
@ -4409,7 +4397,7 @@ class APITest(jtu.JaxTestCase):
inp = jnp.arange(8)
with config_internal.log_compiles(True):
with config.log_compiles(True):
with self.assertLogs(level='WARNING') as cm:
add(inp)
jax.clear_caches()
@ -6217,7 +6205,7 @@ class JaxprTest(jtu.JaxTestCase):
def test_elide_trivial_convert_element_types(self):
# since we apply convert_element_type to a numpy.ndarray, the primitive is
# still bound and thus would appear in the jaxpr if we didn't clean it up
if config.x64_enabled:
if config.enable_x64.value:
x = np.arange(3, dtype='float64')
else:
x = np.arange(3, dtype='float32')
@ -7464,7 +7452,7 @@ class CustomJVPTest(jtu.JaxTestCase):
def test_custom_jvp_implicit_broadcasting(self):
# https://github.com/google/jax/issues/6357
if config.x64_enabled:
if config.enable_x64.value:
raise unittest.SkipTest("test only applies when x64 is disabled")
@jax.custom_jvp

View File

@ -17,9 +17,9 @@ import unittest
from absl.testing import absltest
import jax
from jax import config
import jax.dlpack
import jax.numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
@ -115,7 +115,8 @@ class DLPackTest(jtu.JaxTestCase):
)
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testTensorFlowToJax(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]:
if (not config.enable_x64.value and
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
raise self.skipTest("x64 types are disabled by jax_enable_x64")
if (jtu.test_device_matches(["gpu"]) and
not tf.config.list_physical_devices("GPU")):
@ -138,8 +139,8 @@ class DLPackTest(jtu.JaxTestCase):
)
@unittest.skipIf(not tf, "Test requires TensorFlow")
def testJaxToTensorFlow(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
jnp.float64]:
if (not config.enable_x64.value and
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
self.skipTest("x64 types are disabled by jax_enable_x64")
if (jtu.test_device_matches(["gpu"]) and
not tf.config.list_physical_devices("GPU")):
@ -159,7 +160,7 @@ class DLPackTest(jtu.JaxTestCase):
# See https://github.com/google/jax/issues/11895
x = jax.dlpack.from_dlpack(
tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64)))
dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
self.assertEqual(x.dtype, dtype_expected)
@jtu.sample_product(

View File

@ -35,7 +35,6 @@ from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class CacheKeyTest(jtu.JaxTestCase):

View File

@ -21,15 +21,15 @@ import numpy as np
import jax
from jax import lax
import jax._src.test_util as jtu
from jax._src.lib import xla_extension
from jax import config
from jax.experimental import checkify
from jax.experimental import pjit
from jax.sharding import NamedSharding
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
from jax._src.lib import xla_extension
import jax.numpy as jnp
config.parse_flags_with_absl()
@ -1305,12 +1305,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
class LowerableChecksTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.prev = config.jax_experimental_unsafe_xla_runtime_errors
config.update("jax_experimental_unsafe_xla_runtime_errors", True)
def tearDown(self):
config.update("jax_experimental_unsafe_xla_runtime_errors", self.prev)
super().tearDown()
self.enter_context(config.xla_runtime_errors(True))
@jtu.run_on_devices("cpu", "gpu")
def test_jit(self):

View File

@ -23,18 +23,15 @@ import warnings
from absl.testing import absltest
import jax
from jax import config
from jax import jit
from jax import lax
from jax import pmap
from jax._src import compilation_cache as cc
from jax._src import compiler
from jax._src import config
from jax._src import monitoring
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.config import persistent_cache_min_compile_time_secs
from jax._src.config import raise_persistent_cache_errors
from jax._src.config import use_original_compilation_cache_key_generation
from jax._src.lib import xla_client
from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
@ -43,7 +40,6 @@ import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
FAKE_COMPILE_TIME = 10
_counts = Counter() # Map event name to count
@ -228,9 +224,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.initialize_cache(tmpdir)
f = jit(lambda x: x * x)
with raise_persistent_cache_errors(False), mock.patch.object(
cc._cache.__class__, "put"
) as mock_put, warnings.catch_warnings(record=True) as w:
with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._cache.__class__, "put") as mock_put,
warnings.catch_warnings(record=True) as w,
):
mock_put.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
self.assertLen(w, 1)
@ -247,9 +245,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.initialize_cache(tmpdir)
f = jit(lambda x: x * x)
with raise_persistent_cache_errors(False), mock.patch.object(
cc._cache.__class__, "get"
) as mock_get, warnings.catch_warnings(record=True) as w:
with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._cache.__class__, "get") as mock_get,
warnings.catch_warnings(record=True) as w,
):
mock_get.side_effect = RuntimeError("test error")
self.assertEqual(f(2), 4)
if len(w) > 1:
@ -264,8 +264,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
)
def test_min_compile_time(self):
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
2
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(2),
):
cc.initialize_cache(tmpdir)
@ -282,8 +283,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertEqual(files_in_cache, 1)
def test_cache_saving_metric(self):
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
2):
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(2),
):
cc.initialize_cache(tmpdir)
durations = Counter() # Map metric name to time duration.
@ -346,8 +349,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_cache_misses_metric(self):
previous_counts = Counter(_counts)
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
2):
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(2),
):
cc.initialize_cache(tmpdir)
# Mock time to create a long compilation time and make cache misses.
@ -362,8 +367,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_cache_hits_original_metric(self):
previous_counts = Counter(_counts)
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
2), use_original_compilation_cache_key_generation(True):
with (
tempfile.TemporaryDirectory() as tmpdir,
config.persistent_cache_min_compile_time_secs(2),
config.use_original_compilation_cache_key_generation(True),
):
cc.initialize_cache(tmpdir)
# Mock time to create a long compilation time, cache saved.

View File

@ -24,17 +24,14 @@ from absl.testing import parameterized
import numpy as np
import jax
from jax._src import dtypes
from jax import numpy as jnp
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
bool_dtypes = [np.dtype('bool')]
np_signed_dtypes = [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
@ -102,7 +99,7 @@ class DtypesTest(jtu.JaxTestCase):
True: _EXPECTED_CANONICALIZE_X64,
False: _EXPECTED_CANONICALIZE_X32,
}
for in_dtype, expected_dtype in expected[config.x64_enabled].items():
for in_dtype, expected_dtype in expected[config.enable_x64.value].items():
self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype)
@parameterized.named_parameters(
@ -371,7 +368,7 @@ class DtypesTest(jtu.JaxTestCase):
dtypes.dtype(None)
def testDefaultDtypes(self):
precision = config.jax_default_dtype_bits
precision = config.default_dtype_bits.value
assert precision in ['32', '64']
self.assertEqual(dtypes.bool_, np.bool_)
self.assertEqual(dtypes.int_, np.int32 if precision == '32' else np.int64)
@ -450,7 +447,7 @@ class TestPromotionTables(jtu.JaxTestCase):
# Note: * here refers to weakly-typed values
typecodes = \
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*']
if config.x64_enabled:
if config.enable_x64.value:
expected = [
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*'],
['u1','u1','u2','u4','u8','i2','i2','i4','i8','bf','f2','f4','f8','c4','c8','u1','f*','c*'],

View File

@ -32,7 +32,6 @@ from jax._src import core
from jax._src import test_util as jtu
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")

View File

@ -22,16 +22,14 @@ from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, lax
from jax._src import config as jax_config
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import traceback_util
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def get_exception(etype, f):
@ -43,7 +41,7 @@ def get_exception(etype, f):
def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
filter_mode="remove_frames"):
with jax_config.traceback_filtering(filter_mode):
with config.traceback_filtering(filter_mode):
test.assertRaises(etype, f)
e = get_exception(etype, f)
c = e.__cause__

View File

@ -11,21 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import functools
import logging
import math
import re
from typing import Optional, Sequence
from typing import Sequence
import unittest
from absl.testing import absltest
import jax
from jax import numpy as jnp
from jax import tree_util
from jax.config import config
from jax.experimental.export import export
from jax.experimental import pjit
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
@ -38,6 +43,17 @@ import numpy as np
config.parse_flags_with_absl()
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu.set_host_platform_device_count(2)
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
prev_xla_flags()
# A primitive for testing multi-platform lowering. Takes one argument and
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
@ -72,28 +88,39 @@ for platform in ["cpu", "tpu", "rocm"]:
def _testing_multi_platform_func(x):
return _testing_multi_platform_p.bind(x)
def _testing_multi_platform_fun_expected(x):
return x + _testing_multi_platform_to_add[xb.canonicalize_platform(jtu.device_under_test())]
def _testing_multi_platform_fun_expected(x,
platform: str | None = None):
return x + _testing_multi_platform_to_add[
xb.canonicalize_platform(platform or jtu.device_under_test())
]
class JaxExportTest(jtu.JaxTestCase):
def override_serialization_version(self, version_override: int):
version = config.jax_serialization_version
version = config.jax_serialization_version.value
if version != version_override:
self.addCleanup(functools.partial(config.update,
"jax_serialization_version",
version_override))
config.update("jax_serialization_version", version_override)
self.enter_context(config.jax_serialization_version(version_override))
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version)
config.jax_serialization_version.value)
@classmethod
def setUpClass(cls):
# Find the available platforms
cls.platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
jax.devices(backend)
except RuntimeError:
continue
cls.platforms.append(backend)
super(JaxExportTest, cls).setUpClass()
def setUp(self):
super().setUp()
# Run tests with the maximum supported version by default
self.override_serialization_version(
export.maximum_supported_serialization_version)
export.maximum_supported_serialization_version)
def test_basic_export_only(self):
def my_fun(x):
@ -351,7 +378,7 @@ class JaxExportTest(jtu.JaxTestCase):
export.poly_spec((3, 4), np.float32, "w, h"))
# Peek at the module
module_str = exp.mlir_module()
self.assertEqual(config.jax_serialization_version >= 7,
self.assertEqual(config.jax_serialization_version.value >= 7,
"shape_assertion" in module_str)
self.assertIn("jax.uses_shape_polymorphism = true",
module_str)
@ -586,7 +613,7 @@ class JaxExportTest(jtu.JaxTestCase):
)),
])
def test_shape_constraints_errors(self, *,
shape, poly_spec: str, expect_error: Optional[str] = None):
shape, poly_spec: str, expect_error: str | None = None):
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
return 0.
@ -600,49 +627,170 @@ class JaxExportTest(jtu.JaxTestCase):
export.poly_spec(x.shape, x.dtype, poly_spec))
export.call_exported(exp)(x)
def test_with_sharding(self):
nr_devices = 2
if len(jax.devices()) < nr_devices:
self.skipTest("Need at least 2 devices")
export_devices = jax.devices()[0:nr_devices]
export_mesh = Mesh(export_devices, axis_names=("x",))
a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4))
@functools.partial(
jax.jit,
in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),),),
out_shardings=jax.sharding.NamedSharding(export_mesh, P(None, "x")))
def f_jax(b): # b: f32[16 // DEVICES, 4]
return b * 2.
res_native = f_jax(a)
exp = export.export(f_jax)(a)
run_devices = export_devices[::-1] # We can use other devices
run_mesh = Mesh(run_devices, "y")
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P()))
expected_re = re.compile(
# The top-level input it replicated
r"func.func .* @main\(%arg0: tensor<16x4xf32> {mhlo.sharding = \"{replicated}\"}\).*"
# We apply the in_shardings for f_jax
r".*custom_call @Sharding\(%arg0\) {mhlo.sharding = \"{devices=\[2,1\]<=\[2\]}\"}.*"
r"%1 = .*call @call_exported_f_jax.*"
# We apply the out_shardings for f_jax
r".*custom_call @Sharding\(%1\) {mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
re.DOTALL)
hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text()
self.assertRegex(hlo, expected_re)
res_exported = export.call_exported(exp)(a_device)
self.assertAllClose(res_native, res_exported)
# Test error reporting
with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
_ = export.call_exported(exp)(a)
with self.assertRaisesRegex(
NotImplementedError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
_ = jax.jit(
export.call_exported(exp),
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
)(a)
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
in_shardings=in_shardings, out_shardings=out_shardings)
for in_shardings in ("missing", None, "P")
for out_shardings in ("missing", None, "P")
])
def test_grad_with_sharding(self, in_shardings="P", out_shardings=None):
if len(jax.devices()) < 2:
self.skipTest("Test requires at least 2 devices")
x_shape = (10, 20)
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
def f_jax(x): # x: f32[10,20] -> f32[20,10]
return jnp.sin(x.T)
pjit_kwargs = {}
if in_shardings != "missing":
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
if out_shardings != "missing":
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
with Mesh(jax.devices()[:2], "x"):
exp = export.export(f_jax)(x)
exp_vjp = exp.vjp()
vjp_module_str = str(exp_vjp.mlir_module())
if in_shardings == "P":
primal_in_sharding = "{devices=[1,2]<=[2]}"
else:
primal_in_sharding = "{replicated}"
if out_shardings == "P":
primal_out_sharding = "{devices=[2,1]<=[2]}"
else:
primal_out_sharding = "{replicated}"
main = re.search(
r"func.func public @main\(%arg0: tensor<10x20xf32> {mhlo.sharding = \"([^\"]+)\""
r".*%arg1: tensor<20x10xf32> {mhlo.sharding = \"([^\"]+)\""
# result
r".* -> \(tensor<10x20xf32>.*mhlo.sharding = \"([^\"]+)\"",
vjp_module_str)
self.assertEqual(
main.groups(),
(primal_in_sharding, primal_out_sharding, primal_in_sharding))
# Custom calls for the primal input shape
primal_in_calls = re.findall(
r"custom_call @Sharding.* {mhlo.sharding = \"(.+)\"} : .*tensor<10x20xf32>",
vjp_module_str)
self.assertTrue(
all(s == primal_in_sharding for s in primal_in_calls),
primal_in_calls
)
# Custom calls for the primal output shape
primal_out_calls = re.findall(
r"custom_call @Sharding.* {mhlo.sharding = \"(.+)\"} : .*tensor<20x10xf32>",
vjp_module_str)
self.assertTrue(
all(s == primal_out_sharding for s in primal_out_calls),
primal_in_calls
)
def test_multi_platform(self):
if jtu.test_device_matches(["gpu"]):
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
x = np.arange(5, dtype=np.float32)
x = np.arange(8, dtype=np.float32)
exp = export.export(_testing_multi_platform_func,
lowering_platforms=('cpu', 'tpu'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
lowering_platforms=("cpu", "tpu", "cuda"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
module_str = str(exp.mlir_module())
expected_main_re = (
r"@main\("
r"%arg0: tensor<i..> {jax.platform_index = true}.*, "
r"%arg1: tensor<5xf32>.* ->")
r"%arg1: tensor<8xf32>.* ->")
self.assertRegex(module_str, expected_main_re)
self.assertIn("jax.uses_shape_polymorphism = true",
module_str)
res = export.call_exported(exp)(x)
self.assertAllClose(res, _testing_multi_platform_fun_expected(x))
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call_exported(exp)(x_device)
self.assertAllClose(
res_exp,
_testing_multi_platform_fun_expected(x, platform=platform))
def test_multi_platform_nested(self):
if jtu.test_device_matches(["tpu"]):
# The outer export is not applicable to TPU
raise unittest.SkipTest("Not intended for running on TPU")
x = np.arange(5, dtype=np.float32)
exp = export.export(_testing_multi_platform_func,
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
lowering_platforms=("cpu", "tpu", "cuda"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
# Now serialize the call to the exported using a different sequence of
# lowering platforms, but included in the lowering platforms for the
# nested exported.
exp2 = export.export(export.call_exported(exp),
lowering_platforms=('cpu', 'cuda'))(x)
res2 = export.call_exported(exp2)(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
lowering_platforms=("cpu", "cuda"))(x)
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
if platform == "tpu": continue
x_device = jax.device_put(x, jax.devices(platform)[0])
res_exp = export.call_exported(exp2)(x_device)
self.assertAllClose(
res_exp,
_testing_multi_platform_fun_expected(x, platform=platform))
def test_multi_platform_nested_inside_single_platform_export(self):
x = np.arange(5, dtype=np.float32)
exp = export.export(_testing_multi_platform_func,
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
lowering_platforms=("cpu", "tpu", "cuda"))(x)
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
# Now serialize the call for the current platform.
exp2 = export.export(export.call_exported(exp))(x)
@ -657,7 +805,7 @@ class JaxExportTest(jtu.JaxTestCase):
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
exp = export.export(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)),
lowering_platforms=('cpu', 'tpu'))(
lowering_platforms=("cpu", "tpu"))(
export.poly_spec((5, 6), np.float32, "b1, b2")
)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
@ -668,6 +816,31 @@ class JaxExportTest(jtu.JaxTestCase):
res2 = export.call_exported(exp2)(x)
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
def test_multi_platform_and_sharding(self):
export_devices = jax.devices()[0:2]
export_mesh = Mesh(export_devices, axis_names=("x",))
a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4))
@functools.partial(
jax.jit,
in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),),),
out_shardings=jax.sharding.NamedSharding(export_mesh, P(None, "x")))
def f_jax(b): # b: f32[16 // DEVICES, 4]
return b * 2.
res_native = f_jax(a)
exp = export.export(f_jax,
lowering_platforms=("cpu", "tpu", "cuda"))(a)
# Call with argument placed on different plaforms
for platform in self.__class__.platforms:
run_devices = jax.devices(platform)[0:len(export_devices)]
if len(run_devices) != len(export_devices):
continue
run_mesh = Mesh(run_devices, ("x",))
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None))
res_exp = export.call_exported(exp)(a_device)
self.assertArraysAllClose(res_native, res_exp)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -23,11 +23,11 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax import numpy as jnp
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.numpy.util import promote_dtypes_complex
from jax import config
config.parse_flags_with_absl()
FFT_NORMS = [None, "ortho", "forward", "backward"]
@ -129,7 +129,7 @@ class FftTest(jtu.JaxTestCase):
@parameterized.parameters((np.float32,), (np.float64,))
def testLaxIrfftDoesNotMutateInputs(self, dtype):
if dtype == np.float64 and not config.x64_enabled:
if dtype == np.float64 and not config.enable_x64.value:
raise self.skipTest("float64 requires jax_enable_x64=true")
x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]],
dtype=dtypes.to_complex_dtype(dtype))
@ -162,7 +162,7 @@ class FftTest(jtu.JaxTestCase):
tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker)
# Test gradient for differentiable types.
if (config.x64_enabled and
if (config.enable_x64.value and
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
# TODO(skye): can we be more precise?
tol = 0.15

View File

@ -28,17 +28,17 @@ from absl.testing import absltest
import jax
from jax import ad_checkpoint
from jax._src import core
from jax import config
from jax import dtypes
from jax.experimental import host_callback as hcb
from jax.sharding import PartitionSpec as P
from jax.experimental import pjit
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
from jax.experimental import host_callback as hcb
from jax.experimental import pjit
from jax.sharding import PartitionSpec as P
from jax._src import core
from jax._src import xla_bridge
from jax._src import test_util as jtu
from jax._src.lib import xla_client
xops = xla_client.ops
@ -46,7 +46,6 @@ xops = xla_client.ops
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
class _TestingOutputStream:
@ -333,7 +332,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
( 6.00 9.00 )""")
def test_tap_eval_exception(self):
if not FLAGS.jax_host_callback_outfeed:
if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error
def tap_err(*args, **kwargs):
@ -818,7 +817,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
self.assertEqual(100, count)
def test_tap_jit_tap_exception(self):
if not FLAGS.jax_host_callback_outfeed:
if not hcb._HOST_CALLBACK_OUTFEED.value:
raise SkipTest("TODO: implement error handling for customcall")
# Simulate a tap error
def tap_err(*args, **kwargs):
@ -1541,7 +1540,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
@jtu.sample_product(device_index=[0, 1])
def test_tap_pjit(self, device_index=0):
if (device_index != 0 and
not FLAGS.jax_host_callback_outfeed and
not hcb._HOST_CALLBACK_OUTFEED.value and
jtu.test_device_matches(["cpu"])):
# See comment in host_callback.py.
raise SkipTest("device_index works only with outfeed on CPU")

View File

@ -39,7 +39,6 @@ except ImportError:
tf = None
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):

View File

@ -17,14 +17,14 @@ import inspect
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import core
from jax._src import api_util
from jax._src.interpreters import pxla
from jax import dtypes
from jax._src import lib as jaxlib
from jax import numpy as jnp
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import lib as jaxlib
from jax._src import test_util as jtu
from jax import config
from jax._src.interpreters import pxla
import numpy as np
config.parse_flags_with_absl()
@ -136,7 +136,7 @@ class JaxJitTest(jtu.JaxTestCase):
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
# Complex
if not (config.x64_enabled and jtu.test_device_matches(["tpu"])):
if not (config.enable_x64.value and jtu.test_device_matches(["tpu"])):
# No TPU support for complex128.
res = np.asarray(_cpp_device_put(1 + 1j, device))
self.assertEqual(res, 1 + 1j)
@ -145,7 +145,7 @@ class JaxJitTest(jtu.JaxTestCase):
def test_arg_signature_of_value(self):
"""Tests the C++ code-path."""
jax_enable_x64 = config.x64_enabled
jax_enable_x64 = config.enable_x64.value
# 1. Numpy scalar types
for dtype in jtu.supported_dtypes():

View File

@ -19,20 +19,20 @@ import warnings
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax._src import core
from jax import lax
from jax._src import effects
from jax._src import linear_util as lu
from jax import config
from jax.experimental import maps
from jax.experimental import pjit
from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import config
from jax._src import core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
import numpy as np
config.parse_flags_with_absl()
@ -297,8 +297,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.old_x64 = config.jax_enable_x64
config.update('jax_enable_x64', False)
self.enter_context(config.enable_x64(False))
self._old_lowering = mlir._lowerings[effect_p]
def _effect_lowering(ctx, *, effect):
if effects.ordered_effects.contains(effect):
@ -315,7 +314,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def tearDown(self):
super().tearDown()
dispatch.runtime_tokens.clear()
config.update('jax_enable_x64', self.old_x64)
mlir.register_lowering(effect_p, self._old_lowering)
def test_can_lower_lowerable_effect(self):

View File

@ -20,10 +20,10 @@ from absl.testing import absltest
import jax
from jax import jit, make_jaxpr, numpy as jnp
from jax._src import config
from jax._src import jaxpr_util
from jax._src.lib import xla_client
from jax._src import test_util as jtu
from jax import config
from jax._src.lib import xla_client
config.parse_flags_with_absl()
@ -67,15 +67,15 @@ class JaxprStatsTest(jtu.JaxTestCase):
hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)
t = '64' if config.x64_enabled else '32'
t = '64' if config.enable_x64.value else '32'
shapes = [
f'add :: float{t}[]',
f'sin :: float{t}[]',
f'cos :: float{t}[]',
f'reduce_sum :: float{t}[]',
f'concatenate :: float{t}[2]',
f'pjit :: float{t}[] *',
]
shapes.append(f'pjit :: float{t}[] *')
for k in shapes:
self.assertEqual(hist[k], 1)

View File

@ -33,7 +33,6 @@ from jax.test_util import check_grads
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
compatible_shapes = [[(3,)],

View File

@ -30,13 +30,13 @@ from jax import lax
from jax import numpy as jnp
from jax import ops
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import util
from jax._src.util import NumpyComplexWarning
from jax._src.lax import lax as lax_internal
from jax._src.util import NumpyComplexWarning
from jax import config
config.parse_flags_with_absl()
# We disable the whitespace continuation check in this file because otherwise it
@ -60,7 +60,7 @@ class IndexSpec(typing.NamedTuple):
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
# TODO(mattjj,dougalm): add higher-order check
default_tol = 1e-6 if config.x64_enabled else 1e-2
default_tol = 1e-6 if config.enable_x64.value else 1e-2
atol = atol or default_tol
rtol = rtol or default_tol
eps = eps or default_tol

View File

@ -31,12 +31,11 @@ import jax.ops
from jax import lax
from jax import numpy as jnp
from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
@ -612,7 +611,7 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
np.issubdtype(shift_dtype, np.signedinteger)
has_32 = any(np.iinfo(d).bits == 32 for d in dtypes)
promoting_to_64 = has_32 and signed_mix
if promoting_to_64 and not config.x64_enabled:
if promoting_to_64 and not config.enable_x64.value:
self.skipTest("np.right_shift/left_shift promoting to int64"
"differs from jnp in 32 bit mode.")

View File

@ -31,8 +31,7 @@ from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.util import NumpyComplexWarning
jax.config.parse_flags_with_absl()
FLAGS = config.FLAGS
config.parse_flags_with_absl()
numpy_version = jtu.numpy_version()

View File

@ -43,17 +43,16 @@ from jax import numpy as jnp
from jax import tree_util
from jax.test_util import check_grads
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import safe_zip, NumpyComplexWarning
from jax._src import array
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
numpy_version = jtu.numpy_version()
@ -2333,7 +2332,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testFrexp(self, shape, dtype, rng_factory):
# integer types are converted to float64 in numpy's implementation
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
and not config.x64_enabled):
and not config.enable_x64.value):
self.skipTest("Only run float64 testcase when float64 is enabled.")
rng = rng_factory(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -2405,7 +2404,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v)
self.assertEqual(out_int32.dtype, np.int32)
if config.x64_enabled:
if config.enable_x64.value:
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
self.assertEqual(out_int64.dtype, np.int64)
else:
@ -3160,7 +3159,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
jnp.array(val)
# explicit uint64 should work
if config.x64_enabled:
if config.enable_x64.value:
self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64'))
def testArrayFromList(self):
@ -3534,7 +3533,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
shape=[(0,), (32,), (2, 16)],
a_dtype=all_dtypes,
dtype=(*all_dtypes, None) if config.x64_enabled else all_dtypes,
dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes,
)
def testView(self, shape, a_dtype, dtype):
if jtu.test_device_matches(["tpu"]):
@ -4645,7 +4644,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
endpoint, base, dtype):
if (dtype in int_dtypes and
jtu.test_device_matches(["gpu", "tpu"]) and
not config.x64_enabled):
not config.enable_x64.value):
raise unittest.SkipTest("GPUx32 truncated exponentiation"
" doesn't exactly match other platforms.")
rng = jtu.rand_default(self.rng())
@ -4724,29 +4723,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
check_dtypes=False, atol=tol, rtol=tol)
def testDisableNumpyRankPromotionBroadcasting(self):
try:
prev_flag = config._read('jax_numpy_rank_promotion')
FLAGS.jax_numpy_rank_promotion = "allow"
with jax.numpy_rank_promotion('allow'):
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = config._read('jax_numpy_rank_promotion')
FLAGS.jax_numpy_rank_promotion = "raise"
with jax.numpy_rank_promotion('raise'):
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to raise for scalars
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
try:
prev_flag = config._read('jax_numpy_rank_promotion')
FLAGS.jax_numpy_rank_promotion = "warn"
with jax.numpy_rank_promotion('warn'):
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
jnp.ones(2) + 3 # don't want to warn for scalars
finally:
FLAGS.jax_numpy_rank_promotion = prev_flag
@unittest.skip("Test fails on CI, perhaps due to JIT caching")
def testDisableNumpyRankPromotionBroadcastingDecorator(self):
@ -5077,7 +5064,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol)
def testDefaultDtypes(self):
precision = config.jax_default_dtype_bits
precision = config.default_dtype_bits.value
assert precision in ['32', '64']
self.assertEqual(jnp.bool_, np.bool_)
self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64)

View File

@ -26,7 +26,6 @@ from jax._src.numpy.ufunc_api import get_if_single_primitive
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
def scalar_add(x, y):

View File

@ -31,7 +31,7 @@ from jax._src import config
from jax._src import dtypes
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
float_types = jtu.dtypes.floating

View File

@ -28,7 +28,6 @@ from jax.scipy import special as lsp_special
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]

View File

@ -23,7 +23,6 @@ from absl.testing import absltest
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
linear_sizes = [16, 97, 128]

View File

@ -37,7 +37,6 @@ from jax.scipy import cluster as lsp_cluster
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
compatible_shapes = [[(), ()],

View File

@ -34,21 +34,21 @@ from jax.test_util import check_grads
from jax import tree_util
import jax.util
from jax.interpreters import xla
from jax._src.interpreters import mlir
from jax.interpreters import batching
from jax.interpreters import xla
from jax._src import array
from jax._src import config
from jax._src import dtypes
from jax._src.interpreters import pxla
from jax._src import test_util as jtu
from jax._src import lax_reference
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.internal_test_util import lax_test_util
from jax._src.util import NumpyComplexWarning
from jax import config
config.parse_flags_with_absl()
@ -96,7 +96,7 @@ class LaxTest(jtu.JaxTestCase):
dtype=rec.dtypes)
for rec in lax_test_util.lax_ops()))
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
if (not config.x64_enabled and op_name == "nextafter"
if (not config.enable_x64.value and op_name == "nextafter"
and dtype == np.float64):
raise SkipTest("64-bit mode disabled")
rng = rng_factory(self.rng())
@ -293,7 +293,7 @@ class LaxTest(jtu.JaxTestCase):
)
@jax.default_matmul_precision("float32")
def testConvPreferredElement(self, lhs_shape, rhs_shape, dtype, preferred_element_type):
if (not config.x64_enabled and
if (not config.enable_x64.value and
(dtype == np.float64 or preferred_element_type == np.float64
or dtype == np.int64 or preferred_element_type == np.int64
or dtype == np.complex128 or preferred_element_type == np.complex128)):
@ -1033,7 +1033,7 @@ class LaxTest(jtu.JaxTestCase):
)
def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
preferred_element_type):
if (not config.x64_enabled and
if (not config.enable_x64.value and
(dtype == np.float64 or preferred_element_type == np.float64
or dtype == np.int64 or preferred_element_type == np.int64)):
raise SkipTest("64-bit mode disabled")
@ -1682,7 +1682,7 @@ class LaxTest(jtu.JaxTestCase):
],
)
def testReduce(self, op, reference_op, init_val, shape, dtype, dims, primitive):
if not config.x64_enabled and dtype in (np.float64, np.int64, np.uint64):
if not config.enable_x64.value and dtype in (np.float64, np.int64, np.uint64):
raise SkipTest("x64 mode is disabled.")
def reference_fun(operand):
if hasattr(reference_op, "reduce"):
@ -2622,7 +2622,7 @@ class LaxTest(jtu.JaxTestCase):
def testRngBitGenerator(self):
# This test covers the original behavior of lax.rng_bit_generator, which
# required x64=True, and only checks shapes and jit invariance.
if not config.x64_enabled:
if not config.enable_x64.value:
raise SkipTest("RngBitGenerator requires 64bit key")
key = np.array((1, 2)).astype(np.uint64)

View File

@ -29,8 +29,6 @@ from jax._src import util
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

View File

@ -36,7 +36,6 @@ from jax._src.util import safe_map, safe_zip
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -29,13 +29,12 @@ from jax import jit, grad, jvp, vmap
from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.numpy.util import promote_dtypes_inexact
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
@ -1619,7 +1618,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
if ((rdtype in [np.float64, np.complex128]
or cdtype in [np.float64, np.complex128])
and not config.x64_enabled):
and not config.enable_x64.value):
self.skipTest("Only run float64 testcase when float64 is enabled.")
int_types_excl_i8 = set(int_types) - {np.int8}
@ -1640,7 +1639,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
@jtu.skip_on_devices("rocm")
def testToeplitzSymmetricConstruction(self, shape, dtype):
if (dtype in [np.float64, np.complex128]
and not config.x64_enabled):
and not config.enable_x64.value):
self.skipTest("Only run float64 testcase when float64 is enabled.")
int_types_excl_i8 = set(int_types) - {np.int8}

View File

@ -28,7 +28,7 @@ import jax._src.test_util as jtu
# parsing to work correctly with bazel (otherwise we could avoid importing
# absltest/absl logging altogether).
from absl.testing import absltest
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
@contextlib.contextmanager

View File

@ -27,7 +27,7 @@ from jax import numpy as jnp
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
npr.seed(0)

View File

@ -29,10 +29,11 @@ import jax
from jax import config
from jax._src import core
from jax._src import distributed
import jax.numpy as jnp
from jax._src import maps
from jax._src import test_util as jtu
from jax._src import util
from jax.experimental import pjit
import jax.numpy as jnp
try:
import portpicker
@ -258,7 +259,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.xmap_spmd_lowering_enabled = jax.config.experimental_xmap_spmd_lowering
self.xmap_spmd_lowering_enabled = maps._SPMD_LOWERING.value
jax.config.update("experimental_xmap_spmd_lowering", True)
def tearDown(self):

View File

@ -34,7 +34,7 @@ from jax import random
import jax
import jax.numpy as jnp
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
class NNFunctionsTest(jtu.JaxTestCase):

View File

@ -25,12 +25,12 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax import random
from jax._src import config
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import state
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.pallas.pallas_call import _trace_to_jaxpr
from jax.config import config
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
from jax.experimental import pallas as pl
@ -139,6 +139,7 @@ class PallasTest(parameterized.TestCase):
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
class PallasCallTest(PallasTest):
def test_add_one(self):
@ -213,8 +214,9 @@ class PallasCallTest(PallasTest):
out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32),
grid=1)
def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref):
mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])[:, None, None]
o_ref[...] = jnp.where(mask, x_ref[in_idx_ref[()]], 0)
mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])
o_ref[...] = jnp.where(jax.lax.broadcast_in_dim(mask, (4, 2, 2), (0,)),
x_ref[in_idx_ref[()]], 0)
x = jnp.arange(7 * 2 * 2.).reshape(7, 2, 2)
for ii in range(7):
@ -225,7 +227,6 @@ class PallasCallTest(PallasTest):
np.testing.assert_allclose(out[oi], x[ii])
np.testing.assert_allclose(out[oi + 1:], jnp.zeros_like(out[oi + 1:]))
@parameterized.parameters(*[
((), (2,), ()),
((1,), (2,), (0,)),
@ -712,9 +713,32 @@ class PallasCallTest(PallasTest):
np.testing.assert_allclose(lock, 0)
np.testing.assert_allclose(count, num_threads)
def test_custom_jvp_call(self):
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
def softmax(x, axis=-1):
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
@softmax.defjvp
def softmax_jvp(axis, primals, tangents):
(x,), (x_dot,) = primals, tangents
y = softmax(x, axis)
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
m, n = 16, 32
x = random.normal(random.PRNGKey(0), (m, n))
@functools.partial(self.pallas_call, out_shape=x, grid=1)
def softmax_kernel(x_ref, y_ref):
y_ref[:] = softmax(x_ref[:])
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
class PallasCallInterpreterTest(PallasCallTest):
INTERPRET = True
class PallasControlFlowTest(PallasTest):
def setUp(self):
@ -727,9 +751,7 @@ class PallasControlFlowTest(PallasTest):
# fori_loop handles i64 index variables, i.e. error: 'scf.for' op along
# control flow edge from Region #0 to Region #0: source type #0
# 'tensor<4xf64>' should match input type #0 'tensor<4xf32>'
orig_val = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", True)
try:
with config.enable_x64(True):
@functools.partial(self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float64),
grid=1,
@ -744,8 +766,6 @@ class PallasControlFlowTest(PallasTest):
np.testing.assert_allclose(np.arange(1, 5.) * 3,
f(jnp.arange(1, 5., dtype=jnp.float64)))
finally:
jax.config.update("jax_enable_x64", orig_val)
def test_cond_simple(self):
arg = jnp.float32(0.)

View File

@ -31,6 +31,7 @@ import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import config
from jax._src import maps
from jax._src import test_util as jtu
from jax import dtypes
from jax import stages
@ -60,9 +61,10 @@ from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2, safe_zip
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
prev_xla_flags = None
prev_spmd_lowering_flag = None
def setUpModule():
@ -75,7 +77,9 @@ def setUpModule():
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
jtu.set_spmd_lowering_flag(True)
global prev_spmd_lowering_flag
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)
def tearDownModule():
if prev_xla_flags is None:
@ -83,8 +87,7 @@ def tearDownModule():
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
jtu.restore_spmd_lowering_flag()
config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag)
def create_array(global_shape, global_mesh, mesh_axes, global_data=None,

View File

@ -32,30 +32,28 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
from jax import lax
from jax._src.lax import parallel
from jax._src import api as src_api
from jax import random
from jax._src import core
from jax._src.lib import xla_extension_version
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import config as jax_config
from jax import lax
from jax import random
from jax import tree_util
from jax.ad_checkpoint import checkpoint as new_checkpoint
import jax.numpy as jnp
from jax._src import api as src_api
from jax._src import array
from jax._src import core
from jax._src import config
from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_extension
from jax._src.util import safe_map, safe_zip
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src import array
from jax._src.sharding_impls import PmapSharding
from jax.ad_checkpoint import checkpoint as new_checkpoint
from jax._src.lax import parallel
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import safe_map, safe_zip
from jax import config
config.parse_flags_with_absl()
prev_xla_flags = None
@ -125,7 +123,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
if devices is None:
devices = jax.devices()
pmap_sharding = PmapSharding(np.array(devices), sharding_spec)
pmap_sharding = jax.sharding.PmapSharding(np.array(devices), sharding_spec)
return array.make_array_from_callback(
input_shape, pmap_sharding, lambda idx: input_data[idx]), input_data
@ -188,7 +186,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# the default order of pmap for single-host jobs.
device_order = jax.devices()
pmap_sharding = pmap(lambda x: x)(np.arange(jax.device_count())).sharding
if jax.config.jax_pmap_shmap_merge:
if config.pmap_shmap_merge.value:
self.assertListEqual(device_order, pmap_sharding._device_assignment)
else:
self.assertListEqual(device_order, pmap_sharding.devices.tolist())
@ -1283,7 +1281,7 @@ class PythonPmapTest(jtu.JaxTestCase):
expected = np.repeat(3, device_count)
self.assertAllClose(ans, expected, check_dtypes=False)
if not config.jax_disable_jit:
if not config.disable_jit.value:
f = self.pmap(lambda x: (x, 3))
x = np.arange(device_count)
with jtu.assert_num_jit_and_pmap_compilations(1):
@ -1307,7 +1305,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# Test that 'ans' was properly replicated across devices.
ans_devices = ans.sharding._device_assignment
# TODO(mattjj,sharadmv): fix physical layout with eager pmap, remove 'if'
if not config.jax_disable_jit:
if not config.disable_jit.value:
self.assertEqual(ans_devices, tuple(devices))
def testPmapConstantError(self):
@ -1371,7 +1369,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertTrue(ans.sharding == expected_sharded.sharding)
def testNestedPmapConstantError(self):
if config.jax_disable_jit:
if config.disable_jit.value:
raise SkipTest("error test doesn't apply with disable_jit")
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2 + 1, 3)
@ -1837,7 +1835,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testJitOfPmapWarningMessage(self):
device_count = jax.device_count()
if device_count == 1 or config.jax_disable_jit:
if device_count == 1 or config.disable_jit.value:
raise SkipTest("test requires at least two devices")
def foo(x): return x
@ -1853,7 +1851,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testJitOfPmapOutputSharding(self):
device_count = jax.device_count()
if device_count == 1 or config.jax_disable_jit:
if device_count == 1 or config.disable_jit.value:
raise SkipTest("test requires at least two devices")
@jax.jit
@ -1872,7 +1870,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testJitOfPmapLowerHasReplicaAttributes(self):
device_count = jax.device_count()
if device_count == 1 or config.jax_disable_jit:
if device_count == 1 or config.disable_jit.value:
raise SkipTest("test requires at least two devices")
@jax.jit
@ -2037,7 +2035,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def test_grad_of_pmap_compilation_caching(self, axis_size):
if len(jax.local_devices()) < axis_size:
raise SkipTest("too few devices for test")
if config.jax_disable_jit:
if config.disable_jit.value:
raise SkipTest("caching doesn't apply with jit disabled")
@jax.pmap
@ -2059,7 +2057,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
def testSizeOverflow(self):
if config.jax_disable_jit:
if config.disable_jit.value:
# TODO(sharadmv, mattjj): investigate and fix this issue
raise SkipTest("OOMs in eager mode")
x = jnp.arange(1)
@ -2171,7 +2169,7 @@ class CppPmapTest(PythonPmapTest):
@property
def pmap(self):
if jax.config.jax_pmap_shmap_merge:
if config.pmap_shmap_merge.value:
return src_api.pmap
return src_api._cpp_pmap
@ -2206,7 +2204,7 @@ class CppPmapTest(PythonPmapTest):
pmaped_f(inputs)
self.assertEqual(pmaped_f._cache_size, 1)
jax_config.update_thread_local_jit_state()
config.update_thread_local_jit_state()
pmaped_f(inputs)
self.assertEqual(pmaped_f._cache_size, 1)
@ -2511,7 +2509,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
f(jnp.ones(jax.device_count() + 1))
def testBadAxisSizeErrorNested(self):
if config.jax_disable_jit:
if config.disable_jit.value:
raise SkipTest("error doesn't apply when jit is disabled")
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
axis_name='j'),
@ -2526,7 +2524,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testNestedPmaps(self):
if jax.device_count() % 2 != 0:
raise SkipTest
if config.jax_disable_jit:
if config.disable_jit.value:
raise SkipTest("disable_jit requires num devices to equal axis size")
# Devices specified in outer pmap are OK
@ -2545,7 +2543,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testNestedPmapsBools(self):
if jax.device_count() % 2 != 0:
raise SkipTest
if config.jax_disable_jit:
if config.disable_jit.value:
raise SkipTest("disable_jit requires num devices to equal axis size")
# Devices specified in outer pmap are OK
@ -3150,7 +3148,7 @@ class ArrayPmapTest(jtu.JaxTestCase):
self.assertArraysEqual(w, jnp.cos(jnp.sin(x) ** 2))
def test_same_out_sharding_id(self):
if config.jax_disable_jit:
if config.disable_jit.value:
self.skipTest('Skip this under eager pmap mode.')
shape = (jax.device_count(), 2)
arr, inp_data = create_input_array_for_pmap(shape)
@ -3215,8 +3213,8 @@ class EagerPmapMixin:
def setUp(self):
super().setUp()
self.eager_pmap_enabled = config.jax_eager_pmap
self.jit_disabled = config.jax_disable_jit
self.eager_pmap_enabled = config.eager_pmap.value
self.jit_disabled = config.disable_jit.value
config.update('jax_disable_jit', True)
config.update('jax_eager_pmap', True)

View File

@ -26,13 +26,13 @@ from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import maps
from jax._src import test_util as jtu
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax.experimental import io_callback
from jax.experimental import maps
from jax.experimental import pjit
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
@ -40,7 +40,7 @@ import jax.numpy as jnp
from jax.sharding import Mesh
import numpy as np
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
def _format_multiline(text):
@ -632,8 +632,10 @@ class PureCallbackTest(jtu.JaxTestCase):
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
jtu.set_spmd_lowering_flag(True)
jtu.set_spmd_manual_lowering_flag(True)
spmd_lowering = maps._SPMD_LOWERING.value
spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)
try:
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
@ -659,8 +661,11 @@ class PureCallbackTest(jtu.JaxTestCase):
out, np.sin(np.arange(jax.local_device_count()))
)
finally:
jtu.restore_spmd_manual_lowering_flag()
jtu.restore_spmd_lowering_flag()
config.update('experimental_xmap_spmd_lowering', spmd_lowering)
config.update(
'experimental_xmap_spmd_lowering_manual',
spmd_manual_lowering,
)
def test_cant_take_grad_of_pure_callback(self):

View File

@ -17,13 +17,13 @@ import unittest
from absl.testing import absltest
import jax
from jax import config
import jax.dlpack
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
import jax.numpy as jnp
from jax._src import test_util as jtu
config.parse_flags_with_absl()
@ -69,8 +69,11 @@ class DLPackTest(jtu.JaxTestCase):
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
def testJaxToTorch(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
jnp.complex128]:
if not config.enable_x64.value and dtype in [
jnp.int64,
jnp.float64,
jnp.complex128,
]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())
np = rng(shape, dtype)
@ -89,7 +92,7 @@ class DLPackTest(jtu.JaxTestCase):
if xla_extension_version < 186:
self.skipTest("Need xla_extension_version >= 186")
if not config.x64_enabled and dtype in [
if not config.enable_x64.value and dtype in [
jnp.int64,
jnp.float64,
jnp.complex128,
@ -113,13 +116,16 @@ class DLPackTest(jtu.JaxTestCase):
# See https://github.com/google/jax/issues/11895
x = jax.dlpack.from_dlpack(
torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64)))
dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
self.assertEqual(x.dtype, dtype_expected)
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
def testTorchToJax(self, shape, dtype):
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
jnp.complex128]:
if not config.enable_x64.value and dtype in [
jnp.int64,
jnp.float64,
jnp.complex128,
]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())
@ -142,8 +148,11 @@ class DLPackTest(jtu.JaxTestCase):
if xla_extension_version < 191:
self.skipTest("Need xla_extension_version >= 191")
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
jnp.complex128]:
if not config.enable_x64.value and dtype in [
jnp.int64,
jnp.float64,
jnp.complex128,
]:
self.skipTest("x64 types are disabled by jax_enable_x64")
rng = jtu.rand_default(self.rng())

View File

@ -16,18 +16,18 @@
import functools
import jax
from jax import config
import jax.numpy as jnp
import numpy as np
import scipy.linalg as osp_linalg
from jax._src.lax import qdwh
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lax import qdwh
from absl.testing import absltest
config.parse_flags_with_absl()
_JAX_ENABLE_X64_QDWH = config.x64_enabled
_JAX_ENABLE_X64_QDWH = config.enable_x64.value
# Input matrix data type for QdwhTest.
_QDWH_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64_QDWH else np.float32

View File

@ -37,7 +37,7 @@ from jax import vmap
from jax._src import prng as prng_internal
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
float_dtypes = jtu.dtypes.all_floating
complex_dtypes = jtu.dtypes.complex

View File

@ -40,7 +40,7 @@ from jax.interpreters import xla
from jax._src import random as jax_random
from jax._src import prng as prng_internal
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
PRNG_IMPLS = list(prng_internal.prngs.items())

View File

@ -43,7 +43,7 @@ import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -1382,7 +1382,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
@ -1391,7 +1391,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
@ -1401,7 +1401,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
@parameterized.named_parameters(
(name + f'_check_rep={check_rep}', *params, check_rep)
for (name, *params) in sample(config.FLAGS.jax_num_generated_cases, sample_shmap)
for (name, *params) in sample(jtu._NUM_GENERATED_CASES.value, sample_shmap)
for check_rep in [True, False]
)
@jax.default_matmul_precision("float32")
@ -1414,7 +1414,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
@jax.default_matmul_precision("float32")
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
mesh = self.make_mesh(mesh)
@ -1433,7 +1433,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases,
sample(jtu._NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
mesh = self.make_mesh(mesh)
@ -1456,7 +1456,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases,
sample(jtu._NUM_GENERATED_CASES.value,
partial(sample_shmap_batched, 5)))
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
mesh = self.make_mesh(mesh)

View File

@ -41,7 +41,6 @@ from jax.util import split_list
import numpy as np
config.parse_flags_with_absl()
FLAGS = config.FLAGS
COMPATIBLE_SHAPE_PAIRS = [
[(), ()],

View File

@ -46,7 +46,6 @@ import numpy as np
import scipy.sparse
config.parse_flags_with_absl()
FLAGS = config.FLAGS
all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex

View File

@ -48,7 +48,7 @@ from jax._src.state.primitives import (get_p, swap_p, addupdate_p,
from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
AccumEffect, AbstractRef)
jax.config.parse_flags_with_absl()
config.parse_flags_with_absl()
class StatePrimitivesTest(jtu.JaxTestCase):
@ -831,7 +831,7 @@ if CAN_USE_HYPOTHESIS:
@hp.given(get_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
max_examples=jtu._NUM_GENERATED_CASES.value)
def test_get_vmap(self, get_vmap_param: GetVmapParams):
indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
@ -870,7 +870,7 @@ if CAN_USE_HYPOTHESIS:
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
max_examples=jtu._NUM_GENERATED_CASES.value)
def test_set_vmap(self, set_vmap_param: SetVmapParams):
if jtu.test_device_matches(["gpu"]):
self.skipTest("Scatter is nondeterministic on GPU")
@ -915,7 +915,7 @@ if CAN_USE_HYPOTHESIS:
@hp.given(set_vmap_params())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
max_examples=jtu._NUM_GENERATED_CASES.value)
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
@ -1538,7 +1538,7 @@ if CAN_USE_HYPOTHESIS:
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
max_examples=jtu._NUM_GENERATED_CASES.value)
def test_jvp(self, data):
spec = data.draw(func_spec())
@ -1563,7 +1563,7 @@ if CAN_USE_HYPOTHESIS:
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
max_examples=jtu._NUM_GENERATED_CASES.value)
def test_linearize(self, data):
spec = data.draw(func_spec())
@ -1589,7 +1589,7 @@ if CAN_USE_HYPOTHESIS:
@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
max_examples=jtu._NUM_GENERATED_CASES.value)
def test_vjp(self, data):
spec = data.draw(func_spec())

View File

@ -16,18 +16,18 @@
import functools
import jax
from jax import config
import jax.numpy as jnp
import numpy as np
import scipy.linalg as osp_linalg
from jax._src.lax import svd
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lax import svd
from absl.testing import absltest
config.parse_flags_with_absl()
_JAX_ENABLE_X64 = config.x64_enabled
_JAX_ENABLE_X64 = config.enable_x64.value
# Input matrix data type for SvdTest.
_SVD_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64 else np.float32

View File

@ -23,7 +23,6 @@ from jax._src import util
from jax import config
from jax._src.util import weakref_lru_cache
config.parse_flags_with_absl()
FLAGS = config.FLAGS
try:
from jax._src.lib import utils as jaxlib_utils

View File

@ -20,17 +20,16 @@ import warnings
from absl import logging
from absl.testing import absltest
from jax._src import compiler
from jax._src import config as jax_config
from jax._src import config
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.config import config
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_client as xc
config.parse_flags_with_absl()
FLAGS = config.FLAGS
mock = absltest.mock
@ -65,7 +64,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
# --jax_xla_profile_version takes precedence.
jax_flag_profile = 1
another_profile = 2
with jax_config.jax_xla_profile_version(jax_flag_profile):
with config.jax_xla_profile_version(jax_flag_profile):
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda: another_profile):
self.assertEqual(
@ -283,10 +282,10 @@ class GetBackendTest(jtu.JaxTestCase):
fail_quietly=False, experimental=experimental)
def setUp(self):
super().setUp()
self._orig_factories = xb._backend_factories
xb._backend_factories = {}
self._orig_jax_platforms = config._read("jax_platforms")
config.FLAGS.jax_platforms = ""
self.enter_context(config.jax_platforms(""))
self._save_backend_state()
self._reset_backend_state()
@ -294,8 +293,8 @@ class GetBackendTest(jtu.JaxTestCase):
self._register_factory("cpu", 0)
def tearDown(self):
super().tearDown()
xb._backend_factories = self._orig_factories
config.FLAGS.jax_platforms = self._orig_jax_platforms
self._restore_backend_state()
def _save_backend_state(self):
@ -380,10 +379,7 @@ class GetBackendTest(jtu.JaxTestCase):
self._register_factory("platform_A", 20, assert_used_at_most_once=True)
self._register_factory("platform_B", 10, assert_used_at_most_once=True)
orig_jax_platforms = config._read("jax_platforms")
try:
config.FLAGS.jax_platforms = "cpu,platform_A"
with config.jax_platforms("cpu,platform_A"):
backend = xb.get_backend()
self.assertEqual(backend.platform, "cpu")
# Only specified backends initialized.
@ -395,10 +391,6 @@ class GetBackendTest(jtu.JaxTestCase):
with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"):
backend = xb.get_backend("platform_B")
finally:
config.FLAGS.jax_platforms = orig_jax_platforms
def test_experimental_warning(self):
self._register_factory("platform_A", 20, experimental=True)

View File

@ -33,24 +33,24 @@ import jax.scipy as jscipy
from jax._src import test_util as jtu
from jax import vmap
from jax import lax
from jax._src import core
from jax._src.core import NamedShape
from jax.experimental import maps
from jax._src import array
from jax._src.sharding_impls import NamedSharding
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.experimental.maps import xmap, serial_loop, SerialLoop
from jax.ad_checkpoint import checkpoint
from jax.errors import JAXTypeError
from jax._src.nn import initializers as nn_initializers
from jax.experimental.maps import xmap, serial_loop, SerialLoop
from jax.experimental.pjit import pjit
from jax.interpreters import batching
from jax.sharding import PartitionSpec as P
from jax._src import array
from jax._src import core
from jax._src import maps
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.util import unzip2
from jax._src.core import NamedShape
from jax._src.lax import parallel as lax_parallel
from jax._src.lax.parallel import pgather
from jax.interpreters import batching
from jax.ad_checkpoint import checkpoint
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.nn import initializers as nn_initializers
from jax._src.sharding_impls import NamedSharding
from jax._src.util import unzip2
from jax import config
config.parse_flags_with_absl()
@ -246,10 +246,11 @@ class XMapTestCase(jtu.BufferDonationTestCase):
class SPMDTestMixin:
def setUp(self):
super().setUp()
jtu.set_spmd_lowering_flag(True)
self.spmd_lowering = maps._SPMD_LOWERING.value
config.update('experimental_xmap_spmd_lowering', True)
def tearDown(self):
jtu.restore_spmd_lowering_flag()
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
class ManualSPMDTestMixin:
@ -257,12 +258,14 @@ class ManualSPMDTestMixin:
if not hasattr(xla_client.OpSharding.Type, "MANUAL"):
raise SkipTest
super().setUp()
jtu.set_spmd_lowering_flag(True)
jtu.set_spmd_manual_lowering_flag(True)
self.spmd_lowering = maps._SPMD_LOWERING.value
self.spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)
def tearDown(self):
jtu.restore_spmd_manual_lowering_flag()
jtu.restore_spmd_lowering_flag()
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
@jtu.pytest_mark_if_available('multiaccelerator')
@ -433,7 +436,7 @@ class XMapTest(XMapTestCase):
m_size = math.prod([2] + [2] * (len(mesh) - 2))
self.assertListEqual(y_op_sharding.tile_assignment_dimensions(),
[2, 1, 1, m_size])
if config.experimental_xmap_spmd_lowering:
if maps._SPMD_LOWERING.value:
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
# Make sure that there are non-partial sharding specs in the HLO
if xla_extension_version >= 180:
@ -746,7 +749,7 @@ class XMapTest(XMapTestCase):
axis_resources={'i': 'x'})
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
hlo = f.lower(x).as_text(dialect='stablehlo')
if config.experimental_xmap_spmd_lowering:
if maps._SPMD_LOWERING.value:
self.assertIn("mhlo.num_partitions = 2", hlo)
self.assertIn("mhlo.num_replicas = 1", hlo)
else:
@ -1201,7 +1204,7 @@ class NewPrimitiveTest(XMapTestCase):
@jtu.with_and_without_mesh
def testGather(self, mesh, axis_resources):
if axis_resources and not config.experimental_xmap_spmd_lowering:
if axis_resources and not maps._SPMD_LOWERING.value:
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
y = jnp.arange(35).reshape((5, 7)) % 3

View File

@ -20,8 +20,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "50baaece6db8e61ada5db1f49f19a9d7dc1fa297"
XLA_SHA256 = "c09aefa72d931909ddac059c6313008e209b9a363050dddbc2d3d348fda881a7"
XLA_COMMIT = "500c965b04709f15008c26b46df2f8406279d730"
XLA_SHA256 = "c0f98502d28e30f2905b56e8893b635e2081ad4ddb57009117fa0dc83bdf6d8e"
def repo():
tf_http_archive(