mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge branch 'google:main' into patch-1
This commit is contained in:
commit
d69e810e90
14
.bazelrc
14
.bazelrc
@ -72,6 +72,20 @@ build:cuda --@xla//xla/python:enable_gpu=true
|
||||
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
|
||||
build:cuda --define=xla_python_enable_gpu=true
|
||||
|
||||
# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
|
||||
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
|
||||
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
|
||||
# packages.
|
||||
# This has pros and cons:
|
||||
# * pro: we'll ignore other CUDA installations, which has frequently confused
|
||||
# users in the past. By setting RPATH, we'll always use the NVIDIA pip
|
||||
# packages if they are installed.
|
||||
# * con: the user cannot override the CUDA installation location
|
||||
# via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is
|
||||
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
|
||||
# The list of CUDA pip packages that JAX depends on are present in setup.py.
|
||||
build:cuda --linkopt=-Wl,--disable-new-dtags
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --@xla//xla/python:enable_gpu=true
|
||||
|
3
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
3
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -29,6 +29,7 @@ jobs:
|
||||
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
|
||||
env:
|
||||
DATE_12_WEEKS_AGO: $(python3 -c "import datetime; print(datetime.date.today() - datetime.timedelta(weeks=12))")
|
||||
ENABLE_PJRT_COMPATIBILITY: false
|
||||
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
|
||||
timeout-minutes: 60
|
||||
defaults:
|
||||
@ -59,6 +60,8 @@ jobs:
|
||||
pip install requests
|
||||
|
||||
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
|
||||
env:
|
||||
ENABLE_PJRT_COMPATIBILITY: true
|
||||
pip install .
|
||||
pip install --pre jaxlib \
|
||||
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
@ -25,6 +25,13 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
# jaxlib 0.4.19
|
||||
|
||||
* Changes
|
||||
* jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
|
||||
(nvidia-... packages) over any other CUDA installation if they are
|
||||
installed, including installations named in `LD_LIBRARY_PATH`. If this
|
||||
causes problems and the intent is to use a system-installed CUDA, the fix is
|
||||
to remove the pip installed CUDA library packages.
|
||||
|
||||
# jax 0.4.18 (Oct 6, 2023)
|
||||
|
||||
# jaxlib 0.4.18 (Oct 6, 2023)
|
||||
|
@ -116,8 +116,7 @@ def _nan_check_posthook(fun, args, kwargs, output):
|
||||
fun._cache_miss(*args, **kwargs)[0] # probably won't return
|
||||
|
||||
def _update_debug_special_global(_):
|
||||
if (config.config._read("jax_debug_nans") or
|
||||
config.config._read("jax_debug_infs")):
|
||||
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
|
||||
jax_jit.global_state().post_hook = _nan_check_posthook
|
||||
else:
|
||||
jax_jit.global_state().post_hook = None
|
||||
|
@ -234,6 +234,7 @@ def pure_callback(
|
||||
may behave in unexpected ways, particularly under transformation.
|
||||
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
||||
whose structure matches the expected output of the callback function at runtime.
|
||||
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
|
||||
*args: arguments to be passed to the callback function
|
||||
sharding: optional sharding that specifies the device from which the callback should
|
||||
be invoked.
|
||||
@ -480,6 +481,7 @@ def io_callback(
|
||||
more efficient execution.
|
||||
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
|
||||
whose structure matches the expected output of the callback function at runtime.
|
||||
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
|
||||
*args: arguments to be passed to the callback function
|
||||
sharding: optional sharding that specifies the device from which the callback should
|
||||
be invoked.
|
||||
|
@ -305,7 +305,7 @@ def compile_or_get_cached(
|
||||
try:
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, devices, compile_options, backend,
|
||||
config.config.jax_use_original_compilation_cache_key_generation,
|
||||
config.use_original_compilation_cache_key_generation.value,
|
||||
)
|
||||
except xc._xla.XlaRuntimeError as ex:
|
||||
logger.error("compile_or_get_cached: unable to generate cache key, "
|
||||
@ -324,7 +324,7 @@ def compile_or_get_cached(
|
||||
|
||||
# TODO(b/293308239) Instrument metrics for new cache savings and cache hit
|
||||
# rate after it is enabled.
|
||||
if config.config.jax_use_original_compilation_cache_key_generation:
|
||||
if config.use_original_compilation_cache_key_generation.value:
|
||||
# TODO(b/293308239) Remove metrics for the original cache after the new
|
||||
# compilation cache key implementation is fully rolled out.
|
||||
monitoring.record_event('/jax/compilation_cache/cache_hits_original')
|
||||
|
@ -598,6 +598,16 @@ class NameSpace:
|
||||
|
||||
|
||||
config = Config()
|
||||
|
||||
_read = config._read
|
||||
update = config.update
|
||||
define_bool_state = config.define_bool_state
|
||||
define_enum_state = config.define_enum_state
|
||||
define_int_state = config.define_int_state
|
||||
define_float_state = config.define_float_state
|
||||
define_string_state = config.define_string_state
|
||||
parse_flags_with_absl = config.parse_flags_with_absl
|
||||
|
||||
flags = config
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -63,8 +63,8 @@ class State:
|
||||
if local_device_ids:
|
||||
visible_devices = ','.join(str(x) for x in local_device_ids) # type: ignore[union-attr]
|
||||
logger.info('JAX distributed initialized with visible devices: %s', visible_devices)
|
||||
config.config.update("jax_cuda_visible_devices", visible_devices)
|
||||
config.config.update("jax_rocm_visible_devices", visible_devices)
|
||||
config.update("jax_cuda_visible_devices", visible_devices)
|
||||
config.update("jax_rocm_visible_devices", visible_devices)
|
||||
|
||||
self.process_id = process_id
|
||||
|
||||
|
@ -1184,7 +1184,10 @@ def lower_jaxpr_to_fun(
|
||||
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
|
||||
# A lowering context just for function body entry/exit code.
|
||||
entry_lowering_ctx = LoweringRuleContext(
|
||||
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values)
|
||||
module_context=ctx, primitive=None,
|
||||
avals_in=[], avals_out=None,
|
||||
tokens_in=TokenSet.create([]), tokens_out=None,
|
||||
axis_size_env=None, dim_var_values=dim_var_values)
|
||||
if not use_sharding_annotations and ir_arg_shardings is not None:
|
||||
flat_args = [
|
||||
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
|
||||
@ -1414,7 +1417,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
tokens_in = tokens.subset(effects)
|
||||
avals_in = map(aval, eqn.invars)
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
|
||||
module_context=eqn_ctx, primitive=eqn.primitive,
|
||||
avals_in=avals_in,
|
||||
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
||||
tokens_out=None, dim_var_values=dim_var_values)
|
||||
if config.dynamic_shapes.value:
|
||||
@ -1973,12 +1977,14 @@ def cache_lowering(f):
|
||||
The lowering will be emitted out-of-line in a separate function, together with
|
||||
a call to that function. If the same primitive is called with the same shapes
|
||||
and parameters, a new call to the original function will be added, without
|
||||
emitting a new function.
|
||||
emitting a new function. We allow for different lowering for the same
|
||||
primitive for different platforms in the same module.
|
||||
"""
|
||||
@functools.wraps(f)
|
||||
def cached_lowering(ctx, *args, **params):
|
||||
assert ctx.primitive is not None
|
||||
key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out),
|
||||
key = (f, ctx.primitive,
|
||||
tuple(ctx.avals_in), tuple(ctx.avals_out),
|
||||
tuple(params.items()))
|
||||
try:
|
||||
func = ctx.module_context.cached_primitive_lowerings.get(key)
|
||||
@ -2503,7 +2509,8 @@ def custom_call(
|
||||
|
||||
|
||||
def reduce_window(
|
||||
ctx: LoweringRuleContext, *,
|
||||
ctx: LoweringRuleContext,
|
||||
*,
|
||||
# Base name to be used for the reducer function
|
||||
reducer_name: str,
|
||||
# Compute the reducer body given the reducer.
|
||||
|
@ -26,6 +26,7 @@ import numpy as np
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
@ -40,7 +41,6 @@ from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
|
||||
_ensure_index_tuple, donation_vector,
|
||||
shaped_abstractify, check_callable)
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax._src.config import config
|
||||
from jax._src.errors import JAXTypeError
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -670,8 +670,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
name=name),
|
||||
source_info_util.new_source_info(), resource_env, {})
|
||||
jaxpr = plan.subst_axes_with_resources(jaxpr)
|
||||
use_spmd_lowering = config.experimental_xmap_spmd_lowering
|
||||
ensure_fixed_sharding = config.experimental_xmap_ensure_fixed_sharding
|
||||
use_spmd_lowering = _SPMD_LOWERING.value
|
||||
ensure_fixed_sharding = _ENSURE_FIXED_SHARDING.value
|
||||
if use_spmd_lowering and ensure_fixed_sharding:
|
||||
jaxpr = _fix_inferred_spmd_sharding(jaxpr, resource_env)
|
||||
|
||||
@ -686,7 +686,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
||||
mesh = resource_env.physical_mesh
|
||||
tiling_method: pxla.TilingMethod
|
||||
if config.experimental_xmap_spmd_lowering_manual:
|
||||
if _SPMD_LOWERING_MANUAL.value:
|
||||
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
|
||||
tiling_method = pxla.TileManual(manual_mesh_axes)
|
||||
else:
|
||||
@ -1284,7 +1284,7 @@ batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap
|
||||
|
||||
def _xmap_lowering_rule(ctx, *args, **kwargs):
|
||||
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
|
||||
if config.experimental_xmap_spmd_lowering_manual:
|
||||
if _SPMD_LOWERING_MANUAL.value:
|
||||
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
|
||||
else:
|
||||
return _xmap_lowering_rule_spmd(ctx, *args, **kwargs)
|
||||
@ -1404,7 +1404,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
for dim, dim_extra_axis in enumerate(extra):
|
||||
if dim_extra_axis is None: continue
|
||||
assert dim_extra_axis not in axes
|
||||
assert not config.jax_enable_checks or all(v != dim for v in axes.values())
|
||||
assert not config.enable_checks.value or all(v != dim for v in axes.values())
|
||||
axes[dim_extra_axis] = dim
|
||||
add_spmd_axes(mesh_in_axes, spmd_in_axes)
|
||||
add_spmd_axes(mesh_out_axes, spmd_out_axes)
|
||||
@ -1839,38 +1839,34 @@ def _clear_compilation_cache(_):
|
||||
|
||||
def _ensure_spmd_and(f):
|
||||
def update(v):
|
||||
if v and not config.experimental_xmap_spmd_lowering:
|
||||
if v and not _SPMD_LOWERING.value:
|
||||
raise RuntimeError("This flag requires enabling the experimental_xmap_spmd_lowering flag")
|
||||
return f(v)
|
||||
return update
|
||||
|
||||
|
||||
try:
|
||||
config.define_bool_state(
|
||||
name="experimental_xmap_spmd_lowering",
|
||||
default=False,
|
||||
help=("When set, multi-device xmap computations will be compiled through "
|
||||
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
|
||||
"Not supported on CPU!"),
|
||||
update_global_hook=_clear_compilation_cache,
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
config.define_bool_state(
|
||||
name="experimental_xmap_spmd_lowering_manual",
|
||||
default=False,
|
||||
help=("When set, multi-device xmap computations will be compiled using "
|
||||
"the MANUAL partitioning feature of the XLA SPMD partitioner instead of "
|
||||
"sharding constraints on vectorized code. "
|
||||
"Requires experimental_xmap_spmd_lowering!"),
|
||||
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
config.define_bool_state(
|
||||
name="experimental_xmap_ensure_fixed_sharding",
|
||||
default=False,
|
||||
help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will "
|
||||
"try to limit the flexibility of the automated SPMD partitioner heuristics "
|
||||
"by emitting additional sharding annotations for program intermediates."),
|
||||
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
except Exception:
|
||||
raise ImportError("jax.experimental.maps has to be imported before JAX flags "
|
||||
"are parsed")
|
||||
_SPMD_LOWERING = config.define_bool_state(
|
||||
name="experimental_xmap_spmd_lowering",
|
||||
default=False,
|
||||
help=("When set, multi-device xmap computations will be compiled through "
|
||||
"the XLA SPMD partitioner instead of explicit cross-replica collectives. "
|
||||
"Not supported on CPU!"),
|
||||
update_global_hook=_clear_compilation_cache,
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
_SPMD_LOWERING_MANUAL = config.define_bool_state(
|
||||
name="experimental_xmap_spmd_lowering_manual",
|
||||
default=False,
|
||||
help=("When set, multi-device xmap computations will be compiled using "
|
||||
"the MANUAL partitioning feature of the XLA SPMD partitioner instead of "
|
||||
"sharding constraints on vectorized code. "
|
||||
"Requires experimental_xmap_spmd_lowering!"),
|
||||
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
_ENSURE_FIXED_SHARDING = config.define_bool_state(
|
||||
name="experimental_xmap_ensure_fixed_sharding",
|
||||
default=False,
|
||||
help=("When set and `experimental_xmap_spmd_lowering` is enabled, the lowering will "
|
||||
"try to limit the flexibility of the automated SPMD partitioner heuristics "
|
||||
"by emitting additional sharding annotations for program intermediates."),
|
||||
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
|
@ -21,9 +21,9 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
from jax import lax
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
@ -91,11 +91,12 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
|
||||
|
||||
if not dtypes.safe_to_cast(y, x):
|
||||
# TODO(jakevdp): change this to an error after the deprecation period.
|
||||
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
|
||||
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)} "
|
||||
f"with jax_numpy_dtype_promotion={config.jax_numpy_dtype_promotion!r}. "
|
||||
"In future JAX releases this will result in an error.",
|
||||
FutureWarning)
|
||||
warnings.warn(
|
||||
"scatter inputs have incompatible types: cannot safely cast value "
|
||||
f"from dtype={lax.dtype(y)} to dtype={lax.dtype(x)} with "
|
||||
f"jax_numpy_dtype_promotion={config.numpy_dtype_promotion.value!r}. "
|
||||
"In future JAX releases this will result in an error.",
|
||||
FutureWarning)
|
||||
|
||||
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
|
||||
indexer = jnp._index_to_gather(jnp.shape(x), idx,
|
||||
|
@ -789,44 +789,37 @@ def _dot_general_lowering_rule(
|
||||
return vector.ShapeCastOp(out_type, red).result
|
||||
|
||||
if lhs_dims == (1,):
|
||||
lhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (i, k)>")
|
||||
transpose_lhs = False
|
||||
elif lhs_dims == (0,):
|
||||
lhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (k, i)>")
|
||||
transpose_lhs = True
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if rhs_dims == (0,):
|
||||
rhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (k, j)>")
|
||||
transpose_rhs = False
|
||||
elif rhs_dims == (1,):
|
||||
rhs_dim_attr = ir.Attribute.parse("affine_map<(i, j, k) -> (j, k)>")
|
||||
out_tile = arith.ConstantOp(
|
||||
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
|
||||
)
|
||||
op = vector.ContractionOp(
|
||||
out_type,
|
||||
x,
|
||||
y,
|
||||
out_tile,
|
||||
indexing_maps=ir.ArrayAttr.get([
|
||||
lhs_dim_attr,
|
||||
rhs_dim_attr,
|
||||
ir.Attribute.parse("affine_map<(i, j, k) -> (i, j)>"),
|
||||
]),
|
||||
iterator_types=ir.ArrayAttr.get([
|
||||
ir.Attribute.parse("#vector.iterator_type<parallel>"),
|
||||
ir.Attribute.parse("#vector.iterator_type<parallel>"),
|
||||
ir.Attribute.parse("#vector.iterator_type<reduction>"),
|
||||
]),
|
||||
)
|
||||
transpose_rhs = True
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if precision is not None:
|
||||
if precision[0] != precision[1]:
|
||||
raise NotImplementedError("Per-operand dot precision unsupported")
|
||||
precision = precision[0]
|
||||
if precision is None or precision == lax.Precision.DEFAULT:
|
||||
pass # That's the default in Mosaic.
|
||||
precision_attr = None # That's the default in Mosaic.
|
||||
elif precision == lax.Precision.HIGHEST:
|
||||
op.attributes["precision"] = ir.Attribute.parse(
|
||||
precision_attr = ir.Attribute.parse(
|
||||
"#tpu.contract_precision<fp32>"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dot precision: {precision}")
|
||||
out_tile = arith.ConstantOp(
|
||||
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
|
||||
)
|
||||
op = tpu.MatmulOp(
|
||||
out_type, x, y, out_tile,
|
||||
transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs,
|
||||
precision=precision_attr
|
||||
)
|
||||
return op.result
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import state
|
||||
@ -1183,6 +1184,9 @@ def _closed_call_lowering_rule(
|
||||
|
||||
|
||||
triton_lowering_rules[jax_core.closed_call_p] = _closed_call_lowering_rule
|
||||
triton_lowering_rules[custom_derivatives.custom_jvp_call_p] = (
|
||||
_closed_call_lowering_rule
|
||||
)
|
||||
|
||||
|
||||
def _remat_lowering_rule(ctx: TritonLoweringRuleContext, *args, jaxpr, **_):
|
||||
|
@ -42,16 +42,13 @@ from jax._src.interpreters import mlir
|
||||
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
|
||||
from jax._src import api
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import monitoring
|
||||
from jax._src import stages
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.config import (bool_env, config,
|
||||
raise_persistent_cache_errors,
|
||||
persistent_cache_min_compile_time_secs)
|
||||
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
||||
from jax._src.util import unzip2
|
||||
from jax._src.public_test_util import ( # noqa: F401
|
||||
@ -64,18 +61,18 @@ from jax._src import xla_bridge
|
||||
# jax.test_util. Functionality appearing here is for internal use only, and
|
||||
# may be changed or removed at any time and without any deprecation cycle.
|
||||
|
||||
_TEST_DUT = jax_config.DEFINE_string(
|
||||
_TEST_DUT = config.DEFINE_string(
|
||||
'jax_test_dut', '',
|
||||
help=
|
||||
'Describes the device under test in case special consideration is required.'
|
||||
)
|
||||
|
||||
_NUM_GENERATED_CASES = jax_config.DEFINE_integer(
|
||||
_NUM_GENERATED_CASES = config.DEFINE_integer(
|
||||
'jax_num_generated_cases',
|
||||
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
|
||||
help='Number of generated cases to test')
|
||||
|
||||
_MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
|
||||
_MAX_CASES_SAMPLING_RETRIES = config.DEFINE_integer(
|
||||
'max_cases_sampling_retries',
|
||||
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
|
||||
'Number of times a failed test sample should be retried. '
|
||||
@ -83,25 +80,25 @@ _MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
|
||||
'sampling process is terminated.'
|
||||
)
|
||||
|
||||
_SKIP_SLOW_TESTS = jax_config.DEFINE_bool(
|
||||
_SKIP_SLOW_TESTS = config.DEFINE_bool(
|
||||
'jax_skip_slow_tests',
|
||||
bool_env('JAX_SKIP_SLOW_TESTS', False),
|
||||
config.bool_env('JAX_SKIP_SLOW_TESTS', False),
|
||||
help='Skip tests marked as slow (> 5 sec).'
|
||||
)
|
||||
|
||||
_TEST_TARGETS = jax_config.DEFINE_string(
|
||||
_TEST_TARGETS = config.DEFINE_string(
|
||||
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
||||
'Regular expression specifying which tests to run, called via re.search on '
|
||||
'the test name. If empty or unspecified, run all tests.'
|
||||
)
|
||||
_EXCLUDE_TEST_TARGETS = jax_config.DEFINE_string(
|
||||
_EXCLUDE_TEST_TARGETS = config.DEFINE_string(
|
||||
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
||||
'Regular expression specifying which tests NOT to run, called via re.search '
|
||||
'on the test name. If empty or unspecified, run all tests.'
|
||||
)
|
||||
TEST_WITH_PERSISTENT_COMPILATION_CACHE = jax_config.DEFINE_bool(
|
||||
TEST_WITH_PERSISTENT_COMPILATION_CACHE = config.DEFINE_bool(
|
||||
'jax_test_with_persistent_compilation_cache',
|
||||
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
||||
config.bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
||||
help='If enabled, the persistent compilation cache will be enabled for all '
|
||||
'test cases. This can be used to increase compilation cache coverage.')
|
||||
|
||||
@ -299,7 +296,7 @@ def supported_dtypes():
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
_dtypes.bfloat16, np.float16, np.float32, np.float64,
|
||||
np.complex64, np.complex128}
|
||||
if not config.x64_enabled:
|
||||
if not config.enable_x64.value:
|
||||
types -= {np.uint64, np.int64, np.float64, np.complex128}
|
||||
return types
|
||||
|
||||
@ -532,7 +529,7 @@ def rand_fullrange(rng, standardize_nans=False):
|
||||
# leads to overflows in this case; sample from signed ints instead.
|
||||
if dtype == np.uint64:
|
||||
vals = vals.astype(np.int64)
|
||||
elif dtype == np.uint32 and not config.x64_enabled:
|
||||
elif dtype == np.uint32 and not config.enable_x64.value:
|
||||
vals = vals.astype(np.int32)
|
||||
vals = vals.reshape(shape)
|
||||
# Non-standard NaNs cause errors in numpy equality assertions.
|
||||
@ -915,8 +912,8 @@ class JaxTestCase(parameterized.TestCase):
|
||||
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||
cls._compilation_cache_exit_stack = ExitStack()
|
||||
stack = cls._compilation_cache_exit_stack
|
||||
stack.enter_context(raise_persistent_cache_errors(True))
|
||||
stack.enter_context(persistent_cache_min_compile_time_secs(0))
|
||||
stack.enter_context(config.raise_persistent_cache_errors(True))
|
||||
stack.enter_context(config.persistent_cache_min_compile_time_secs(0))
|
||||
|
||||
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
|
||||
compilation_cache.initialize_cache(tmp_dir)
|
||||
@ -963,7 +960,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
self.assertDtypesMatch(x, y)
|
||||
|
||||
def assertDtypesMatch(self, x, y, *, canonicalize_dtypes=True):
|
||||
if not config.x64_enabled and canonicalize_dtypes:
|
||||
if not config.enable_x64.value and canonicalize_dtypes:
|
||||
self.assertEqual(_dtypes.canonicalize_dtype(_dtype(x), allow_extended_dtype=True),
|
||||
_dtypes.canonicalize_dtype(_dtype(y), allow_extended_dtype=True))
|
||||
else:
|
||||
@ -1133,26 +1130,6 @@ def with_and_without_mesh(f):
|
||||
('Mesh', (('x', 2),), (('i', 'x'),))
|
||||
))(with_mesh_from_kwargs(f))
|
||||
|
||||
old_spmd_lowering_flag = None
|
||||
def set_spmd_lowering_flag(val: bool):
|
||||
global old_spmd_lowering_flag
|
||||
old_spmd_lowering_flag = config.experimental_xmap_spmd_lowering
|
||||
config.update('experimental_xmap_spmd_lowering', val)
|
||||
|
||||
def restore_spmd_lowering_flag():
|
||||
if old_spmd_lowering_flag is None: return
|
||||
config.update('experimental_xmap_spmd_lowering', old_spmd_lowering_flag)
|
||||
|
||||
old_spmd_manual_lowering_flag = None
|
||||
def set_spmd_manual_lowering_flag(val: bool):
|
||||
global old_spmd_manual_lowering_flag
|
||||
old_spmd_manual_lowering_flag = config.experimental_xmap_spmd_lowering_manual
|
||||
config.update('experimental_xmap_spmd_lowering_manual', val)
|
||||
|
||||
def restore_spmd_manual_lowering_flag():
|
||||
if old_spmd_manual_lowering_flag is None: return
|
||||
config.update('experimental_xmap_spmd_lowering_manual', old_spmd_manual_lowering_flag)
|
||||
|
||||
def create_global_mesh(mesh_shape, axis_names):
|
||||
size = math.prod(mesh_shape)
|
||||
if len(jax.devices()) < size:
|
||||
|
@ -38,7 +38,7 @@ from jaxlib.mlir.dialects import stablehlo
|
||||
from jaxlib.mlir.passmanager import PassManager
|
||||
import numpy as np
|
||||
|
||||
mosaic_use_cpp_passes = config.config.define_bool_state(
|
||||
_MOSAIC_USE_CPP_PASSES = config.define_bool_state(
|
||||
name="mosaic_use_cpp_passes",
|
||||
default=False,
|
||||
help=(
|
||||
@ -54,13 +54,13 @@ tpu = tpu_mosaic.tpu
|
||||
apply_vector_layout = tpu_mosaic.apply_vector_layout
|
||||
infer_memref_layout = tpu_mosaic.infer_memref_layout
|
||||
|
||||
mosaic_allow_hlo = config.config.define_bool_state(
|
||||
_MOSAIC_ALLOW_HLO = config.define_bool_state(
|
||||
name="jax_mosaic_allow_hlo",
|
||||
default=False,
|
||||
help="Allow hlo dialects in Mosaic",
|
||||
)
|
||||
|
||||
mosaic_dump_mlir = config.config.define_bool_state(
|
||||
_MOSAIC_DUMP_MLIR = config.define_bool_state(
|
||||
name="jax_mosaic_dump_mlir",
|
||||
default=False,
|
||||
help="Print mlir module after each pass",
|
||||
@ -243,7 +243,7 @@ def _lower_tpu_kernel(
|
||||
)
|
||||
dump_mlir(module, "initial module")
|
||||
|
||||
if mosaic_allow_hlo.value:
|
||||
if _MOSAIC_ALLOW_HLO.value:
|
||||
# Run hlo dialect conversion: hlo -> linalg -> vector.
|
||||
pipeline = [
|
||||
"hlo-legalize-to-arithmetic",
|
||||
@ -255,7 +255,7 @@ def _lower_tpu_kernel(
|
||||
)
|
||||
dump_mlir(module, "after hlo conversion module")
|
||||
|
||||
if mosaic_use_cpp_passes.value:
|
||||
if _MOSAIC_USE_CPP_PASSES.value:
|
||||
pipeline = [
|
||||
(
|
||||
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
|
||||
@ -278,7 +278,7 @@ def _lower_tpu_kernel(
|
||||
module.operation.verify()
|
||||
dump_mlir(module, "after infer vector layout pass")
|
||||
|
||||
if mosaic_use_cpp_passes.value:
|
||||
if _MOSAIC_USE_CPP_PASSES.value:
|
||||
pipeline = [
|
||||
(
|
||||
"func.func(tpu-apply-vector-layout{sublane-count=8"
|
||||
@ -418,6 +418,6 @@ def _lowered_as_tpu_kernel(
|
||||
|
||||
def dump_mlir(module: ir.Module, msg: str):
|
||||
"""A helper function to print mlir module with a message."""
|
||||
if mosaic_dump_mlir.value:
|
||||
if _MOSAIC_DUMP_MLIR.value:
|
||||
print(f"[jax_mosaic_dump_mlir] {msg}")
|
||||
print(module)
|
||||
|
148
jax/core.py
148
jax/core.py
@ -15,8 +15,6 @@
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from jax._src.core import (
|
||||
AbstractToken as AbstractToken,
|
||||
AbstractValue as AbstractValue,
|
||||
@ -27,7 +25,6 @@ from jax._src.core import (
|
||||
ConcreteArray as ConcreteArray,
|
||||
ConcretizationTypeError as ConcretizationTypeError,
|
||||
DShapedArray as DShapedArray,
|
||||
DimSize as DimSize,
|
||||
DropVar as DropVar,
|
||||
Effect as Effect,
|
||||
Effects as Effects,
|
||||
@ -50,7 +47,6 @@ from jax._src.core import (
|
||||
OutputType as OutputType,
|
||||
ParamDict as ParamDict,
|
||||
Primitive as Primitive,
|
||||
Shape as Shape,
|
||||
ShapedArray as ShapedArray,
|
||||
Sublevel as Sublevel,
|
||||
TRACER_LEAK_DEBUGGER_WARNING as TRACER_LEAK_DEBUGGER_WARNING,
|
||||
@ -60,15 +56,11 @@ from jax._src.core import (
|
||||
TraceStack as TraceStack,
|
||||
TraceState as TraceState,
|
||||
Tracer as Tracer,
|
||||
TracerArrayConversionError as TracerArrayConversionError,
|
||||
TracerIntegerConversionError as TracerIntegerConversionError,
|
||||
UnexpectedTracerError as UnexpectedTracerError,
|
||||
UnshapedArray as UnshapedArray,
|
||||
Value as Value,
|
||||
Var as Var,
|
||||
abstract_token as abstract_token,
|
||||
apply_todos as apply_todos,
|
||||
as_hashable_function as as_hashable_function,
|
||||
as_named_shape as as_named_shape,
|
||||
aval_mapping_handlers as aval_mapping_handlers,
|
||||
axis_frame as axis_frame,
|
||||
@ -82,7 +74,6 @@ from jax._src.core import (
|
||||
check_type as check_type,
|
||||
check_valid_jaxtype as check_valid_jaxtype,
|
||||
closed_call_p as closed_call_p,
|
||||
collections as collections,
|
||||
concrete_aval as concrete_aval,
|
||||
concrete_or_error as concrete_or_error,
|
||||
concretization_function_error as concretization_function_error,
|
||||
@ -92,7 +83,6 @@ from jax._src.core import (
|
||||
definitely_equal as definitely_equal, # TODO(necula): remove this API
|
||||
dimension_as_value as dimension_as_value, # TODO(necula): remove this API
|
||||
do_subst_axis_names_jaxpr as do_subst_axis_names_jaxpr,
|
||||
dtypes as dtypes,
|
||||
ensure_compile_time_eval as ensure_compile_time_eval,
|
||||
escaped_tracer_error as escaped_tracer_error,
|
||||
eval_context as eval_context,
|
||||
@ -114,13 +104,10 @@ from jax._src.core import (
|
||||
lattice_join as lattice_join,
|
||||
leaked_tracer_error as leaked_tracer_error,
|
||||
literalable_types as literalable_types,
|
||||
lu as lu,
|
||||
map as map,
|
||||
map_bind as map_bind,
|
||||
map_bind_with_continuation as map_bind_with_continuation,
|
||||
mapped_aval as mapped_aval,
|
||||
maybe_find_leaked_tracers as maybe_find_leaked_tracers,
|
||||
namedtuple as namedtuple,
|
||||
new_base_main as new_base_main,
|
||||
new_jaxpr_eqn as new_jaxpr_eqn,
|
||||
new_main as new_main,
|
||||
@ -128,8 +115,6 @@ from jax._src.core import (
|
||||
no_axis_name as no_axis_name,
|
||||
no_effects as no_effects,
|
||||
outfeed_primitives as outfeed_primitives,
|
||||
partial as partial,
|
||||
pp as pp,
|
||||
pp_aval as pp_aval,
|
||||
pp_eqn as pp_eqn,
|
||||
pp_eqn_rules as pp_eqn_rules,
|
||||
@ -150,11 +135,7 @@ from jax._src.core import (
|
||||
raise_as_much_as_possible as raise_as_much_as_possible,
|
||||
raise_to_shaped as raise_to_shaped,
|
||||
raise_to_shaped_mappings as raise_to_shaped_mappings,
|
||||
ref as ref,
|
||||
reset_trace_state as reset_trace_state,
|
||||
safe_map as safe_map,
|
||||
safe_zip as safe_zip,
|
||||
source_info_util as source_info_util,
|
||||
stash_axis_env as stash_axis_env,
|
||||
str_eqn_compact as str_eqn_compact,
|
||||
subjaxprs as subjaxprs,
|
||||
@ -165,12 +146,8 @@ from jax._src.core import (
|
||||
substitute_vars_in_output_ty as substitute_vars_in_output_ty,
|
||||
thread_local_state as thread_local_state,
|
||||
token as token,
|
||||
total_ordering as total_ordering,
|
||||
trace_state_clean as trace_state_clean,
|
||||
traceback_util as traceback_util,
|
||||
traverse_jaxpr_params as traverse_jaxpr_params,
|
||||
tuple_delete as tuple_delete,
|
||||
tuple_insert as tuple_insert,
|
||||
typecheck as typecheck,
|
||||
typecompat as typecompat,
|
||||
typematch as typematch,
|
||||
@ -178,7 +155,130 @@ from jax._src.core import (
|
||||
used_axis_names as used_axis_names,
|
||||
used_axis_names_jaxpr as used_axis_names_jaxpr,
|
||||
valid_jaxtype as valid_jaxtype,
|
||||
zip as zip,
|
||||
)
|
||||
|
||||
symbolic_equal_dim = definitely_equal # TODO(necula): remove this API
|
||||
|
||||
from jax._src import core as _src_core
|
||||
_deprecations = {
|
||||
# Added Oct 11, 2023:
|
||||
"DimSize": (
|
||||
"jax.core.DimSize is deprecated. Use DimSize = int | Any.",
|
||||
_src_core.DimSize,
|
||||
),
|
||||
"Shape": (
|
||||
"jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
|
||||
_src_core.Shape,
|
||||
),
|
||||
"TracerArrayConversionError": (
|
||||
"jax.core.TracerArrayConversionError is deprecated. Use jax.errors.TracerArrayConversionError",
|
||||
_src_core.TracerArrayConversionError,
|
||||
),
|
||||
"TracerIntegerConversionError": (
|
||||
"jax.core.TracerIntegerConversionError is deprecated. Use jax.errors.TracerIntegerConversionError",
|
||||
_src_core.TracerIntegerConversionError,
|
||||
),
|
||||
"UnexpectedTracerError": (
|
||||
"jax.core.UnexpectedTracerError is deprecated. Use jax.errors.UnexpectedTracerError",
|
||||
_src_core.UnexpectedTracerError,
|
||||
),
|
||||
"as_hashable_function": (
|
||||
"jax.core.as_hashable_function is deprecated. Use jax.util.as_hashable_function directly.",
|
||||
_src_core.as_hashable_function,
|
||||
),
|
||||
"collections": (
|
||||
"jax.core.collections is deprecated. Use the collections module directly.",
|
||||
_src_core.collections,
|
||||
),
|
||||
"dtypes": (
|
||||
"jax.core.dtypes is deprecated. Use jax.dtypes directly.",
|
||||
_src_core.dtypes,
|
||||
),
|
||||
"lu": (
|
||||
"jax.core.lu is deprecated. Use lu = jax.extend.linear_util",
|
||||
_src_core.lu,
|
||||
),
|
||||
"map": (
|
||||
"jax.core.map is deprecated. Use the built-in map function.",
|
||||
_src_core.map,
|
||||
),
|
||||
"namedtuple": (
|
||||
"jax.core.namedtuple is deprecated. Use collections.namedtuple directly.",
|
||||
_src_core.namedtuple,
|
||||
),
|
||||
"partial": (
|
||||
"jax.core.partial is deprecated. Use functools.partial directly.",
|
||||
_src_core.partial,
|
||||
),
|
||||
"pp": (
|
||||
"jax.core.pp is deprecated. jax._src.pretty_printer is a non-public API.",
|
||||
_src_core.pp,
|
||||
),
|
||||
"ref": (
|
||||
"jax.core.ref is deprecated. Use weakref.ref directly.",
|
||||
_src_core.ref,
|
||||
),
|
||||
"safe_map": (
|
||||
"jax.core.safe_map is deprecated. Use jax.util.safe_map directly.",
|
||||
_src_core.safe_map,
|
||||
),
|
||||
"safe_zip": (
|
||||
"jax.core.safe_zip is deprecated. Use jax.util.safe_zip directly.",
|
||||
_src_core.safe_zip,
|
||||
),
|
||||
"source_info_util": (
|
||||
"jax.core.source_info_util is deprecated. jax._src.source_info_util is a non-public API.",
|
||||
_src_core.source_info_util,
|
||||
),
|
||||
"total_ordering": (
|
||||
"jax.core.total_ordering is deprecated. Use functools.total_ordering directly.",
|
||||
_src_core.total_ordering,
|
||||
),
|
||||
"traceback_util": (
|
||||
"jax.core.traceback_util is deprecated. jax._src.traceback_util is a non-public API.",
|
||||
_src_core.traceback_util,
|
||||
),
|
||||
"tuple_delete": (
|
||||
"jax.core.tuple_delete is deprecated. Use tuple_delete = lambda t, i: (*t[:i], *t[i+1:])",
|
||||
_src_core.tuple_delete,
|
||||
),
|
||||
"tuple_insert": (
|
||||
"jax.core.tuple_insert is deprecated. Use tuple_insert = lambda t, v, i: (*t[:i], v, *t[i:])",
|
||||
_src_core.tuple_insert,
|
||||
),
|
||||
"zip": (
|
||||
"jax.core.zip is deprecated. Use the built-in zip function.",
|
||||
_src_core.zip,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
DimSize = _src_core.DimSize
|
||||
Shape = _src_core.Shape
|
||||
TracerArrayConversionError = _src_core.TracerArrayConversionError
|
||||
TracerIntegerConversionError = _src_core.TracerIntegerConversionError
|
||||
UnexpectedTracerError = _src_core.UnexpectedTracerError
|
||||
as_hashable_function = _src_core.as_hashable_function
|
||||
collections = _src_core.collections
|
||||
dtypes = _src_core.dtypes
|
||||
lu = _src_core.lu
|
||||
map = _src_core.map
|
||||
namedtuple = _src_core.namedtuple
|
||||
partial = _src_core.partial
|
||||
pp = _src_core.pp
|
||||
ref = _src_core.ref
|
||||
safe_map = _src_core.safe_map
|
||||
safe_zip = _src_core.safe_zip
|
||||
source_info_util = _src_core.source_info_util
|
||||
total_ordering = _src_core.total_ordering
|
||||
traceback_util = _src_core.traceback_util
|
||||
tuple_delete = _src_core.tuple_delete
|
||||
tuple_insert = _src_core.tuple_insert
|
||||
zip = _src_core.zip
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
del _src_core
|
||||
|
@ -28,9 +28,9 @@ from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import sharding
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import mlir
|
||||
@ -116,6 +116,8 @@ class DisabledSafetyCheck:
|
||||
minimum_supported_serialization_version = 6
|
||||
maximum_supported_serialization_version = 8
|
||||
|
||||
Sharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Exported:
|
||||
"""A JAX function lowered to StableHLO.
|
||||
@ -133,9 +135,8 @@ class Exported:
|
||||
expressions in the shapes, with dimension variables among those in
|
||||
`in_avals.
|
||||
in_shardings: the flattened input shardings. Only for the inputs that are
|
||||
specified in `module_kept_var_idx`. If `None` then it is equivalent
|
||||
to unspecified shardings.
|
||||
out_shardings: the flattened output shardings, as long as `in_avals`.
|
||||
specified in `module_kept_var_idx`.
|
||||
out_shardings: the flattened output shardings, as long as `out_avals`.
|
||||
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. See below for the calling convention for when
|
||||
there are multiple lowering platforms.
|
||||
@ -226,8 +227,8 @@ class Exported:
|
||||
out_tree: tree_util.PyTreeDef
|
||||
out_avals: tuple[core.AbstractValue, ...]
|
||||
|
||||
in_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
|
||||
out_shardings: Optional[tuple[Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue], ...]]
|
||||
in_shardings: tuple[Sharding, ...]
|
||||
out_shardings: tuple[Sharding, ...]
|
||||
lowering_platform: str # For backwards compatibility
|
||||
lowering_platforms: tuple[str, ...]
|
||||
disabled_checks: Sequence[DisabledSafetyCheck]
|
||||
@ -386,7 +387,7 @@ def export(fun_jax: Callable,
|
||||
exported = jax_export.export(f_jax)(*args, **kwargs)
|
||||
"""
|
||||
fun_name = getattr(fun_jax, "__name__", "unknown")
|
||||
version = config.jax_serialization_version
|
||||
version = config.jax_serialization_version.value
|
||||
if (version < minimum_supported_serialization_version or
|
||||
version > maximum_supported_serialization_version):
|
||||
raise ValueError(
|
||||
@ -826,14 +827,64 @@ def _check_module(mod: ir.Module, *,
|
||||
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-lowering-supports-only-select-custom-calls")
|
||||
raise ValueError(msg)
|
||||
|
||||
def expand_in_shardings(in_shardings: tuple[Sharding, ...],
|
||||
module_kept_var_idx: Sequence[int],
|
||||
nr_inputs: int) -> tuple[Sharding, ...]:
|
||||
"""Expands in_shardings with unspecified shardings for inputs not kept.
|
||||
|
||||
Assumes in_shardings corresponds to module_kept_var_idx.
|
||||
"""
|
||||
assert len(in_shardings) == len(module_kept_var_idx)
|
||||
assert nr_inputs >= len(module_kept_var_idx)
|
||||
all_in_shardings: list[Sharding] = [sharding_impls.UNSPECIFIED] * nr_inputs
|
||||
for idx, in_s in zip(sorted(module_kept_var_idx), in_shardings):
|
||||
all_in_shardings[idx] = in_s
|
||||
return tuple(all_in_shardings)
|
||||
|
||||
# TODO(yashkatariya, necula): remove this function once we relax the checks
|
||||
# in the jit front-end.
|
||||
def canonical_shardings(
|
||||
in_shardings: Sequence[Sharding],
|
||||
out_shardings: Sequence[Sharding]
|
||||
) -> tuple[Union[pxla.UnspecifiedValue,
|
||||
Sequence[sharding.XLACompatibleSharding]],
|
||||
Union[pxla.UnspecifiedValue,
|
||||
Sequence[sharding.XLACompatibleSharding]]]:
|
||||
"""Prepares canonical in_ and out_shardings for a jit invocation.
|
||||
|
||||
The pjit front-end is picky about what in- and out-shardings it accepts,
|
||||
e.g., if all are unspecified then the whole sharding should be the
|
||||
sharding_impls.UNSPECIFIED object, otherwise the unspecified shardings are
|
||||
replaced with the replicated sharding.
|
||||
"""
|
||||
# Prepare a replicated sharding, search in both the input and output shardings
|
||||
specified_shardings = [
|
||||
s for s in itertools.chain(in_shardings, out_shardings)
|
||||
if not sharding_impls.is_unspecified(s)]
|
||||
if specified_shardings:
|
||||
in_s = specified_shardings[0] # pjit will enforce that all have same devices
|
||||
assert isinstance(in_s, sharding.XLACompatibleSharding)
|
||||
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
|
||||
else:
|
||||
replicated_s = None
|
||||
|
||||
def canonicalize(
|
||||
ss: Sequence[Sharding]) -> Union[pxla.UnspecifiedValue,
|
||||
Sequence[sharding.XLACompatibleSharding]]:
|
||||
if all(sharding_impls.is_unspecified(s) for s in ss):
|
||||
return sharding_impls.UNSPECIFIED
|
||||
return tuple(
|
||||
s if not sharding_impls.is_unspecified(s) else replicated_s
|
||||
for s in ss)
|
||||
return (canonicalize(in_shardings), canonicalize(out_shardings))
|
||||
|
||||
def _get_vjp_fun(primal_fun: Callable, *,
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
out_avals: Sequence[core.AbstractValue],
|
||||
module_kept_var_idx: tuple[int, ...],
|
||||
in_shardings,
|
||||
out_shardings,
|
||||
in_shardings: tuple[Sharding, ...],
|
||||
out_shardings: tuple[Sharding, ...],
|
||||
apply_jit: bool
|
||||
) -> tuple[Callable, Sequence[core.AbstractValue]]:
|
||||
# Since jax.vjp does not handle kwargs, it is easier to do all the work
|
||||
@ -855,36 +906,11 @@ def _get_vjp_fun(primal_fun: Callable, *,
|
||||
itertools.chain(in_avals,
|
||||
map(lambda a: a.at_least_vspace(), out_avals)))
|
||||
|
||||
# Expand in_shardings to all in_avals even not kept ones.
|
||||
all_in_shardings = [sharding_impls.UNSPECIFIED] * len(in_avals)
|
||||
for idx, in_s in zip(sorted(module_kept_var_idx),
|
||||
in_shardings): # type: ignore
|
||||
all_in_shardings[idx] = in_s # type: ignore
|
||||
all_shardings = all_in_shardings + list(out_shardings) # type: ignore
|
||||
# Cannot mix unspecified and specified shardings. Make the unspecified
|
||||
# ones replicated.
|
||||
specified_shardings = [
|
||||
s for s in all_shardings if not sharding_impls.is_unspecified(s)]
|
||||
|
||||
vjp_in_shardings: Any # The primal inputs followed by output cotangents
|
||||
vjp_out_shardings: Any # The primal output cotangents
|
||||
if 0 == len(specified_shardings):
|
||||
vjp_in_shardings = sharding_impls.UNSPECIFIED
|
||||
vjp_out_shardings = sharding_impls.UNSPECIFIED
|
||||
else:
|
||||
if len(specified_shardings) < len(all_shardings):
|
||||
# There are some specified, but not all; pjit front-end does not liwk
|
||||
in_s = specified_shardings[0] # pjit will enforce that all have same devices
|
||||
assert isinstance(in_s, sharding.XLACompatibleSharding)
|
||||
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
|
||||
all_shardings = [
|
||||
s if not sharding_impls.is_unspecified(s) else replicated_s
|
||||
for s in all_shardings]
|
||||
|
||||
vjp_in_shardings = tuple(all_shardings)
|
||||
vjp_out_shardings = tuple(all_shardings[:len(in_avals)])
|
||||
if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings):
|
||||
vjp_out_shardings = sharding_impls.UNSPECIFIED
|
||||
all_in_shardings = expand_in_shardings(in_shardings,
|
||||
module_kept_var_idx, len(in_avals))
|
||||
vjp_in_shardings, vjp_out_shardings = canonical_shardings(
|
||||
tuple(itertools.chain(all_in_shardings, out_shardings)),
|
||||
all_in_shardings)
|
||||
|
||||
if apply_jit:
|
||||
return pjit.pjit(fun_vjp_jax,
|
||||
@ -1037,6 +1063,13 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
if exported.uses_shape_polymorphism:
|
||||
ctx.module_context.shape_poly_state.uses_dim_vars = True
|
||||
|
||||
# Apply in_shardings
|
||||
all_in_shardings = expand_in_shardings(exported.in_shardings,
|
||||
exported.module_kept_var_idx,
|
||||
len(args))
|
||||
args = tuple(
|
||||
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
|
||||
for x, x_aval, x_sharding in zip(args, ctx.avals_in, all_in_shardings))
|
||||
submodule = ir.Module.parse(exported.mlir_module())
|
||||
symtab = ir.SymbolTable(submodule.operation)
|
||||
# The called function may have been exported with polymorphic shapes and called
|
||||
@ -1072,12 +1105,44 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
kept_args)
|
||||
# The ctx.avals_out already contain the abstract values refined by
|
||||
# _call_exported_abstract_eval.
|
||||
return tuple(
|
||||
results = tuple(
|
||||
convert_shape(out, out_aval, refined_out_aval)
|
||||
for out, out_aval, refined_out_aval in zip(call.results, exported.out_avals, ctx.avals_out))
|
||||
# Apply out_shardings
|
||||
results = tuple(
|
||||
wrap_with_sharding(ctx, exported, x, x_aval, x_sharding)
|
||||
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
for _p in ("cpu", "tpu", "cuda", "rocm"):
|
||||
mlir.register_lowering(call_exported_p,
|
||||
functools.partial(_call_exported_lowering, platform=_p),
|
||||
platform=_p)
|
||||
|
||||
def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
|
||||
exported: Exported,
|
||||
x: ir.Value,
|
||||
x_aval: core.AbstractValue,
|
||||
x_sharding: Sharding) -> ir.Value:
|
||||
if sharding_impls.is_unspecified(x_sharding):
|
||||
return x
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
ctx_device_assignment = axis_context.device_assignment
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
ctx_device_assignment = list(axis_context.mesh.devices.flat)
|
||||
else:
|
||||
raise NotImplementedError(type(axis_context))
|
||||
assert isinstance(x_sharding, sharding_impls.XLACompatibleSharding)
|
||||
sharding_device_assignment = x_sharding._device_assignment
|
||||
if len(ctx_device_assignment) != len(sharding_device_assignment):
|
||||
raise NotImplementedError(
|
||||
f"Exported module {exported.fun_name} was lowered for "
|
||||
f"{len(sharding_device_assignment)} devices and is called in a context with "
|
||||
f"{len(ctx_device_assignment)} devices"
|
||||
)
|
||||
return mlir.wrap_with_sharding_op(
|
||||
ctx, x, x_aval,
|
||||
x_sharding._to_xla_hlo_sharding(x_aval.ndim).to_proto())
|
||||
|
@ -48,9 +48,9 @@ import numpy as np
|
||||
import opt_einsum
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
@ -651,7 +651,7 @@ class _DimExpr():
|
||||
raise InconclusiveDimensionOperation("")
|
||||
remainder = 0
|
||||
|
||||
if config.jax_enable_checks:
|
||||
if config.enable_checks.value:
|
||||
assert self == divisor * quotient + remainder
|
||||
return quotient, remainder
|
||||
except InconclusiveDimensionOperation:
|
||||
|
@ -17,14 +17,14 @@ import os
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.experimental.jax2tf.examples import saved_model_main
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
|
@ -26,7 +26,6 @@ from absl.testing import absltest, parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu
|
||||
@ -60,6 +59,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -159,7 +159,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"):
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
@ -179,7 +179,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
def test_cpu_eig_lapack_geev(self, dtype_name="f32"):
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
@ -271,7 +271,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"):
|
||||
# For lax.linalg.eigh
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
@ -327,7 +327,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"):
|
||||
# For lax.linalg.qr
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
@ -395,7 +395,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
def test_cpu_lu_lapack_getrf(self, dtype_name:str):
|
||||
# For lax.linalg.lu on CPU.
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
c64=np.complex64, c128=np.complex128)[dtype_name]
|
||||
@ -480,7 +480,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
for dtype_name in ("f32", "f64", "c64", "c128")])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_cpu_schur_lapack_gees(self, dtype_name="f32"):
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
@ -509,7 +509,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_cpu_svd_lapack_gesdd(self, dtype_name="f32"):
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
@ -534,7 +534,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
for dtype_name in ("f32", "f64", "c64", "c128")])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_cpu_triangular_solve_blas_trsm(self, dtype_name="f32"):
|
||||
if not config.jax_enable_x64 and dtype_name in ["f64", "c128"]:
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
dtype = dict(f32=np.float32, f64=np.float64,
|
||||
@ -669,16 +669,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
# replace this strict check with something else.
|
||||
data = self.load_testdata(stablehlo_dynamic_rng_bit_generator.data_2023_06_17)
|
||||
|
||||
prev_default_prng_impl = jax.config.jax_default_prng_impl
|
||||
try:
|
||||
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
|
||||
|
||||
with config.default_prng_impl("unsafe_rbg"):
|
||||
self.run_one_test(
|
||||
func, data, polymorphic_shapes=(None, "b0, b1"),
|
||||
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
||||
expect_current_custom_calls=["stablehlo.dynamic_rng_bit_generator", "shape_assertion"])
|
||||
finally:
|
||||
jax.config.update("jax_default_prng_impl", prev_default_prng_impl)
|
||||
|
||||
def test_stablehlo_dynamic_top_k(self):
|
||||
# stablehlo.dynamic_top_k is used temporarily for a top_k with dynamism
|
||||
|
@ -20,9 +20,9 @@ from typing import Any, Callable, Optional, Union
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax.experimental.jax2tf.tests import primitive_harness
|
||||
import numpy as np
|
||||
|
||||
@ -114,7 +114,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
"""Checks if this limitation is enabled for dtype and device and mode."""
|
||||
native_serialization_mask = (
|
||||
Jax2TfLimitation.FOR_NATIVE
|
||||
if config.jax2tf_default_native_serialization
|
||||
if config.jax2tf_default_native_serialization.value
|
||||
else Jax2TfLimitation.FOR_NON_NATIVE)
|
||||
return ((mode is None or mode in self.modes) and
|
||||
(self.native_serialization & native_serialization_mask) and
|
||||
|
@ -19,7 +19,6 @@ import contextlib
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, Optional
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
@ -31,20 +30,17 @@ from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import sharding
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
import jax._src.xla_bridge
|
||||
from jax import config
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import mlir
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
import numpy as np
|
||||
@ -233,7 +229,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
with_function=[True, False],
|
||||
)
|
||||
def test_converts_64bit(self, dtype=np.int64, with_function=False):
|
||||
if not config.jax_enable_x64:
|
||||
if not config.enable_x64.value:
|
||||
self.skipTest("requires x64 mode")
|
||||
big_const = np.full((5,), 2 ** 33, dtype=dtype)
|
||||
self.ConvertAndCompare(jnp.sin, big_const)
|
||||
@ -249,7 +245,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_64bit_behavior_enable_x64_readme(self):
|
||||
# Tests some of the examples from the README
|
||||
if not config.jax_enable_x64:
|
||||
if not config.enable_x64.value:
|
||||
self.skipTest("requires x64 mode")
|
||||
|
||||
# JAX and TF have different default float types if JAX_ENABLE_X64=1
|
||||
@ -266,7 +262,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_64bit_behavior_not_enable_x64_readme(self):
|
||||
# Tests some of the examples from the README
|
||||
if config.jax_enable_x64:
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("requires not x64 mode")
|
||||
|
||||
# JAX and TF have same default float types if JAX_ENABLE_X64=0
|
||||
@ -817,7 +813,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
arg = np.array(3.)
|
||||
f_tf = jax2tf.convert(jax.grad(remat_f))
|
||||
f_tf_hlo = self.TfToHlo(f_tf, arg)
|
||||
if jax.config.jax_remat_opt_barrier:
|
||||
if config.remat_opt_barrier.value:
|
||||
self.assertRegex(f_tf_hlo, r"opt-barrier")
|
||||
else:
|
||||
self.assertRegex(f_tf_hlo,
|
||||
@ -849,7 +845,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf = jax2tf.convert(f_jax)
|
||||
# for native serialization the HLO we get from TF is constant-folded, so this
|
||||
# test fails.
|
||||
if not config.jax2tf_default_native_serialization:
|
||||
if not config.jax2tf_default_native_serialization.value:
|
||||
self.assertIn("sine(", self.TfToHlo(f_tf))
|
||||
|
||||
def test_convert_of_nested_independent_jit(self):
|
||||
@ -954,7 +950,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
out = jax2tf.convert(caller_jax, with_gradient=False)(2.)
|
||||
return out
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
|
||||
else:
|
||||
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
|
||||
@ -984,7 +980,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_shared_constants(self):
|
||||
# Check that the constants are shared properly in converted functions
|
||||
# See https://github.com/google/jax/issues/7992.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
|
||||
const = np.random.uniform(size=256).astype(np.float32) # A shared constant
|
||||
def f(x):
|
||||
@ -996,7 +992,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_shared_constants_under_cond(self):
|
||||
# Check that the constants are shared properly in converted functions
|
||||
# See https://github.com/google/jax/issues/7992.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
|
||||
const_size = 512
|
||||
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
|
||||
@ -1012,7 +1008,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_shared_constants_under_scan(self):
|
||||
# See https://github.com/google/jax/issues/7992.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
|
||||
const_size = 512
|
||||
const = np.random.uniform(size=const_size).astype(np.float32) # A shared constant
|
||||
@ -1031,7 +1027,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_shared_constants_under_jit(self):
|
||||
# We do not share constants under jit.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
|
||||
const = np.random.uniform(size=(16, 16)).astype(np.float32) # A shared constant
|
||||
@jax.jit
|
||||
@ -1047,7 +1043,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
# randint has the property that the TF lowering of the randbits_p
|
||||
# primitive generates constants that did not exist in the Jaxpr. As such
|
||||
# it has created new errors related to the sharing of the constants.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
raise unittest.SkipTest("shared constants tests not interesting for native serialization")
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
@ -1296,7 +1292,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
jax_comp = jax.xla_computation(f_while)(x)
|
||||
backend = jax._src.xla_bridge.get_backend()
|
||||
backend = xb.get_backend()
|
||||
modules = backend.compile(jax_comp).hlo_modules()
|
||||
jax_opt_hlo = modules[0].to_string()
|
||||
print(f"JAX OPT HLO = {jax_opt_hlo}")
|
||||
@ -1342,7 +1338,8 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.fail(f"{op.name} does not start with {scope_name}.")
|
||||
|
||||
def test_name_scope_polymorphic(self):
|
||||
if config.jax2tf_default_native_serialization and not config.jax_dynamic_shapes:
|
||||
if (config.jax2tf_default_native_serialization.value and
|
||||
not config.dynamic_shapes.value):
|
||||
self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
|
||||
|
||||
def func_jax(x, y):
|
||||
|
@ -28,9 +28,10 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import maps # Needed for config flags.
|
||||
from jax import config
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -139,7 +140,7 @@ class JaxPrimitiveTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation")
|
||||
# The CPU/GPU have more supported types than TPU.
|
||||
self.assertEqual("cpu", jtu.device_under_test(), "The documentation can be generated only on CPU")
|
||||
self.assertTrue(config.x64_enabled, "The documentation must be generated with JAX_ENABLE_X64=1")
|
||||
self.assertTrue(config.enable_x64.value, "The documentation must be generated with JAX_ENABLE_X64=1")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__),
|
||||
'../g3doc/jax_primitives_coverage.md.template')) as f:
|
||||
|
@ -29,14 +29,6 @@ from jax.experimental.jax2tf.tests import primitive_harness
|
||||
def make_disjunction_regexp(*parts: str) -> re.Pattern[str]:
|
||||
return re.compile("(" + "|".join(parts) + ")")
|
||||
|
||||
|
||||
# TODO(necula): Failures to be investigated (on multiple platforms)
|
||||
_known_failures = make_disjunction_regexp(
|
||||
"cumsum_",
|
||||
"cumprod_",
|
||||
)
|
||||
|
||||
|
||||
# TODO(necula): Failures to be investigated (on GPU).
|
||||
_known_failures_gpu = make_disjunction_regexp(
|
||||
# Failures due to failure to export custom call targets for GPU, these
|
||||
@ -85,13 +77,8 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
#one_containing="",
|
||||
)
|
||||
def test_prim(self, harness: primitive_harness.Harness):
|
||||
if (
|
||||
_known_failures.search(harness.fullname)
|
||||
or (
|
||||
jtu.device_under_test() == "gpu"
|
||||
and _known_failures_gpu.search(harness.fullname)
|
||||
)
|
||||
):
|
||||
if (jtu.device_under_test() == "gpu"
|
||||
and _known_failures_gpu.search(harness.fullname)):
|
||||
self.skipTest("failure to be investigated")
|
||||
|
||||
func_jax = harness.dyn_fun
|
||||
@ -108,13 +95,7 @@ class PrimitiveTest(jtu.JaxTestCase):
|
||||
for d in self.__class__.devices
|
||||
if d.platform not in unimplemented_platforms
|
||||
]
|
||||
logging.info(
|
||||
"Using devices %s",
|
||||
[
|
||||
(str(d), d.platform, d.device_kind, d.client.platform)
|
||||
for d in devices
|
||||
],
|
||||
)
|
||||
logging.info("Using devices %s", [str(d) for d in devices])
|
||||
# lowering_platforms uses "cuda" instead of "gpu"
|
||||
lowering_platforms: list[str] = [
|
||||
p if p != "gpu" else "cuda"
|
||||
|
@ -67,8 +67,7 @@ from jax._src import random as jax_random
|
||||
# then the test file has to import jtu first (to define the flags) which is not
|
||||
# desired if the test file is outside of this project (we don't want a
|
||||
# dependency on jtu outside of jax repo).
|
||||
jax.config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
Rng = Any # A random number generator
|
||||
DType = Any
|
||||
|
@ -72,7 +72,7 @@ from jax._src.interpreters import xla
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Import after parsing flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
@ -38,6 +38,7 @@ from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
from jax import tree_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
@ -51,9 +52,6 @@ from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
from jax import config
|
||||
from jax._src.config import numpy_dtype_promotion
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Import after parsing flags
|
||||
@ -723,7 +721,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
expected_output_signature=(
|
||||
# for native serialization we cannot refine the inferred shape of the
|
||||
# output if the input is more specific than polymorphic_shapes.
|
||||
tf.TensorSpec([2, 3]) if not config.jax2tf_default_native_serialization
|
||||
tf.TensorSpec([2, 3]) if not config.jax2tf_default_native_serialization.value
|
||||
else tf.TensorSpec([2, None])))
|
||||
|
||||
check_shape_poly(self,
|
||||
@ -1300,7 +1298,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
# for native serialization we cannot refine the inferred shape of the
|
||||
# output if the input is more specific than polymorphic_shapes.
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0]))
|
||||
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1]))
|
||||
else:
|
||||
@ -1424,9 +1422,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def test_prng(self):
|
||||
# The PRNG implementation uses opaque types, test shape polymorphism
|
||||
try:
|
||||
prev_custom_prng = config.jax_enable_custom_prng
|
||||
config.update("jax_enable_custom_prng", True)
|
||||
with config.enable_custom_prng(True):
|
||||
|
||||
def f_jax(x): # x: f32[b1, b2]
|
||||
key = random.PRNGKey(123) # key: key<fry>[]
|
||||
@ -1452,8 +1448,6 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
check_shape_poly(self, f_jax,
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b1, b2"])
|
||||
finally:
|
||||
config.update("jax_enable_custom_prng", prev_custom_prng)
|
||||
|
||||
def test_saved_model(self):
|
||||
f_jax = jnp.sin
|
||||
@ -1788,7 +1782,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# should be a symbolic dimension
|
||||
self.assertTrue(isinstance(res, int) or hasattr(res, "dimension_as_value"))
|
||||
|
||||
if config.jax_enable_x64:
|
||||
if config.enable_x64.value:
|
||||
# Outside jax2tf, x.shape[0] is a Python (64-bit) integer and for most
|
||||
# operations here JAX is not involved at all because the other operand
|
||||
# is a Python or NumPy constant. So the result will be 64-bits. But under
|
||||
@ -1833,7 +1827,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertTrue(d1.aval.weak_type), d1
|
||||
return d0 + np.array(5., dtype=np.float32) + d1 + x[0]
|
||||
|
||||
with numpy_dtype_promotion("strict"):
|
||||
with config.numpy_dtype_promotion("strict"):
|
||||
# strict type promotion is sensitive to weak_types
|
||||
check_shape_poly(self,
|
||||
f_jax,
|
||||
@ -2236,7 +2230,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=[poly],
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization)
|
||||
check_result=config.jax2tf_default_native_serialization.value)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for poly in ["b, ...", "b, w, w"]
|
||||
for left in ([True, False] if dtype == np.float32 else [True])
|
||||
@ -2523,7 +2517,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
lambda x, full_matrices: lax.linalg.qr(x, full_matrices=full_matrices),
|
||||
arg_descriptors=[RandArg(shape, dtype), StaticArg(full_matrices)],
|
||||
polymorphic_shapes=[poly],
|
||||
tol=(None if config.jax2tf_default_native_serialization else 1e-5))
|
||||
tol=(None if config.jax2tf_default_native_serialization.value else 1e-5))
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
# m and n must be static for now
|
||||
for shape, poly, full_matrices in [
|
||||
@ -2793,7 +2787,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=[poly],
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization)
|
||||
check_result=config.jax2tf_default_native_serialization.value)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for compute_schur_vectors in [True, False]
|
||||
for (shape, poly) in [
|
||||
@ -2914,7 +2908,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
polymorphic_shapes=[a_poly, b_poly],
|
||||
# In non-native serialization, we cannot check exact match,
|
||||
# we ought to check the invariants of the result.
|
||||
check_result=config.jax2tf_default_native_serialization)
|
||||
check_result=config.jax2tf_default_native_serialization.value)
|
||||
for dtype in [np.float32, np.float64, np.complex64, np.complex128]
|
||||
for (left_side, a_shape, b_shape, a_poly, b_poly) in [
|
||||
(True, (3, 4, 4), (3, 4, 5), "b, ...", "b, ..."),
|
||||
@ -3065,14 +3059,14 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
)
|
||||
def test_harness(self, harness: PolyHarness):
|
||||
if harness.expect_error == expect_error_associative_scan and (
|
||||
not config.jax2tf_default_native_serialization
|
||||
not config.jax2tf_default_native_serialization.value
|
||||
or jtu.test_device_matches(["tpu"])
|
||||
):
|
||||
harness.expect_error = (None, None)
|
||||
|
||||
# Exclude some harnesses that are known to fail for native serialization
|
||||
# FOR NATIVE SERIALIZATION
|
||||
if config.jax2tf_default_native_serialization:
|
||||
if config.jax2tf_default_native_serialization.value:
|
||||
if not harness.enable_xla:
|
||||
raise unittest.SkipTest("disabled for native_serialization and enable_xla=False")
|
||||
|
||||
@ -3123,7 +3117,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
"native serialization with shape polymorphism not implemented for window_reductions on CPU and GPU")
|
||||
|
||||
# FOR GRAPH SERIALIZATION
|
||||
if not config.jax2tf_default_native_serialization:
|
||||
if not config.jax2tf_default_native_serialization.value:
|
||||
if ("random_gamma_threefry_non_partitionable" in harness.fullname and
|
||||
jtu.test_device_matches(["cpu"])):
|
||||
harness.tol = 1e-6
|
||||
|
@ -30,8 +30,11 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import maps
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax import lax
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental import pjit
|
||||
@ -40,19 +43,18 @@ from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import jax.numpy as jnp
|
||||
from jax._src import compiler
|
||||
from jax._src import xla_bridge
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Must come after initializing the flags
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
|
||||
prev_xla_flags = None
|
||||
prev_spmd_lowering_flag = None
|
||||
|
||||
topology = None
|
||||
|
||||
@ -74,7 +76,9 @@ def setUpModule():
|
||||
" --xla_force_host_platform_device_count=8")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
global prev_spmd_lowering_flag
|
||||
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
@ -83,7 +87,7 @@ def tearDownModule():
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag)
|
||||
|
||||
|
||||
class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
|
@ -29,7 +29,6 @@ import jax
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
from jax import config
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import (
|
||||
nfold_vmap, _count_stored_elements,
|
||||
@ -41,6 +40,7 @@ from jax._src.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
from jax.util import safe_zip, unzip2, split_list
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import ad
|
||||
@ -784,7 +784,7 @@ def _bcoo_dot_general_fallback(data, indices, spinfo):
|
||||
def _bcoo_dot_general_gpu_impl(lhs_data, lhs_indices, rhs, *,
|
||||
dimension_numbers, preferred_element_type,
|
||||
lhs_spinfo):
|
||||
if not config.jax_bcoo_cusparse_lowering:
|
||||
if not config.bcoo_cusparse_lowering.value:
|
||||
return _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
preferred_element_type=preferred_element_type,
|
||||
|
@ -25,7 +25,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
@ -37,6 +36,7 @@ from jax.experimental.sparse.util import (
|
||||
from jax.util import split_list, safe_zip
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.lax.lax import DotDimensionNumbers, _dot_general_batch_dim_nums
|
||||
@ -623,7 +623,7 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, *, dimension_numbers,
|
||||
preferred_element_type, lhs_spinfo: SparseInfo):
|
||||
|
||||
if not config.jax_bcoo_cusparse_lowering:
|
||||
if not config.bcoo_cusparse_lowering.value:
|
||||
return _bcsr_dot_general_default_lowering(
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs,
|
||||
dimension_numbers=dimension_numbers,
|
||||
|
2
setup.py
2
setup.py
@ -144,7 +144,7 @@ setup(
|
||||
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
|
||||
# Until NVIDIA add version constraints, add an version constraint
|
||||
# here.
|
||||
"nvidia-nvjitlink=cu12>=12.2",
|
||||
"nvidia-nvjitlink-cu12>=12.2",
|
||||
],
|
||||
|
||||
# Target that does not depend on the CUDA pip wheels, for those who want
|
||||
|
@ -40,14 +40,13 @@ import weakref
|
||||
from absl import logging
|
||||
from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import custom_derivatives as custom_derivatives_public
|
||||
from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import api, api_util, dtypes, lib
|
||||
from jax._src import array
|
||||
from jax._src import config as config_internal
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
@ -76,7 +75,6 @@ from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def _check_instance(self, x):
|
||||
@ -233,7 +231,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
assert len(side) == 2 # but should still cache
|
||||
|
||||
f(one, two, z=np.zeros(3)) # doesn't crash
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
# In the above call, three is of a new type (int64), thus it should
|
||||
# trigger a new compilation.
|
||||
assert len(side) == 3
|
||||
@ -879,7 +877,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
# TODO(frostig): remove `wrap` once we always enable_custom_prng
|
||||
def wrap(arr):
|
||||
arr = np.array(arr, dtype=np.uint32)
|
||||
if config.jax_enable_custom_prng:
|
||||
if config.enable_custom_prng.value:
|
||||
return prng.random_wrap(arr, impl=jax.random.default_prng_impl())
|
||||
else:
|
||||
return arr
|
||||
@ -1057,7 +1055,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
compiled = lowered.compile()
|
||||
self.assertAllClose(compiled(1.), 2.)
|
||||
self.assertEqual(lowered.in_avals, compiled.in_avals)
|
||||
expected_dtype = np.float64 if config.x64_enabled else np.float32
|
||||
expected_dtype = np.float64 if config.enable_x64.value else np.float32
|
||||
for obj in [lowered, compiled]:
|
||||
self.assertEqual(
|
||||
obj.in_avals,
|
||||
@ -2913,7 +2911,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_dtype_warning(self):
|
||||
# cf. issue #1230
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
raise unittest.SkipTest("test only applies when x64 is disabled")
|
||||
|
||||
def check_warning(warn, nowarn):
|
||||
@ -4029,25 +4027,15 @@ class APITest(jtu.JaxTestCase):
|
||||
def test_dot_precision_flag(self):
|
||||
x = jnp.zeros((2, 2))
|
||||
|
||||
prev_val = config._read("jax_default_matmul_precision")
|
||||
try:
|
||||
config.FLAGS.jax_default_matmul_precision = "tensorfloat32"
|
||||
with config.default_matmul_precision("tensorfloat32"):
|
||||
jnp.dot(x, x) # doesn't crash
|
||||
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
||||
finally:
|
||||
config.FLAGS.jax_default_matmul_precision = prev_val
|
||||
self.assertIn('Precision.HIGH', str(jaxpr))
|
||||
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
|
||||
|
||||
prev_val = config._read("jax_default_matmul_precision")
|
||||
try:
|
||||
config.update('jax_default_matmul_precision','tensorfloat32')
|
||||
with config.default_matmul_precision("tensorfloat32"):
|
||||
jnp.dot(x, x) # doesn't crash
|
||||
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
||||
finally:
|
||||
config.update('jax_default_matmul_precision', prev_val)
|
||||
self.assertIn('Precision.HIGH', str(jaxpr))
|
||||
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
|
||||
|
||||
def test_dot_precision_forces_retrace(self):
|
||||
num_traces = 0
|
||||
@ -4067,7 +4055,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
for f in [f_jit, f_cond]:
|
||||
# Use _read() to read the flag value rather than threadlocal value.
|
||||
precision = config._read('jax_default_matmul_precision')
|
||||
precision = config._read("jax_default_matmul_precision")
|
||||
try:
|
||||
num_traces = 0
|
||||
x = jnp.zeros((2, 2))
|
||||
@ -4078,7 +4066,7 @@ class APITest(jtu.JaxTestCase):
|
||||
with jax.default_matmul_precision("tensorfloat32"):
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 2)
|
||||
FLAGS.jax_default_matmul_precision = "float32"
|
||||
config.update("jax_default_matmul_precision", "float32")
|
||||
f(x)
|
||||
self.assertGreaterEqual(num_traces, 2)
|
||||
nt = num_traces
|
||||
@ -4087,7 +4075,7 @@ class APITest(jtu.JaxTestCase):
|
||||
f(x)
|
||||
self.assertEqual(num_traces, nt + 1)
|
||||
finally:
|
||||
FLAGS.jax_default_matmul_precision = precision
|
||||
config.update("jax_default_matmul_precision", precision)
|
||||
|
||||
def test_backward_pass_ref_dropping(self):
|
||||
refs = []
|
||||
@ -4239,9 +4227,9 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
for f in [f_jit, f_cond]:
|
||||
# Use _read() to read the flag value rather than threadlocal value.
|
||||
allow_promotion = config._read('jax_numpy_rank_promotion')
|
||||
allow_promotion = config._read("jax_numpy_rank_promotion")
|
||||
try:
|
||||
FLAGS.jax_numpy_rank_promotion = "allow"
|
||||
config.update("jax_numpy_rank_promotion", "allow")
|
||||
num_traces = 0
|
||||
@jax.jit
|
||||
def f(x):
|
||||
@ -4256,7 +4244,7 @@ class APITest(jtu.JaxTestCase):
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
f(x)
|
||||
self.assertEqual(num_traces, 2)
|
||||
FLAGS.jax_numpy_rank_promotion = "raise"
|
||||
config.update("jax_numpy_rank_promotion", "raise")
|
||||
f(x)
|
||||
self.assertGreaterEqual(num_traces, 2)
|
||||
nt = num_traces
|
||||
@ -4265,7 +4253,7 @@ class APITest(jtu.JaxTestCase):
|
||||
f(x)
|
||||
self.assertEqual(num_traces, nt + 1)
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = allow_promotion
|
||||
config.update("jax_numpy_rank_promotion", allow_promotion)
|
||||
|
||||
def test_grad_negative_argnums(self):
|
||||
def f(x, y):
|
||||
@ -4409,7 +4397,7 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
inp = jnp.arange(8)
|
||||
|
||||
with config_internal.log_compiles(True):
|
||||
with config.log_compiles(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
add(inp)
|
||||
jax.clear_caches()
|
||||
@ -6217,7 +6205,7 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
def test_elide_trivial_convert_element_types(self):
|
||||
# since we apply convert_element_type to a numpy.ndarray, the primitive is
|
||||
# still bound and thus would appear in the jaxpr if we didn't clean it up
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
x = np.arange(3, dtype='float64')
|
||||
else:
|
||||
x = np.arange(3, dtype='float32')
|
||||
@ -7464,7 +7452,7 @@ class CustomJVPTest(jtu.JaxTestCase):
|
||||
|
||||
def test_custom_jvp_implicit_broadcasting(self):
|
||||
# https://github.com/google/jax/issues/6357
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
raise unittest.SkipTest("test only applies when x64 is disabled")
|
||||
|
||||
@jax.custom_jvp
|
||||
|
@ -17,9 +17,9 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
@ -115,7 +115,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
)
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testTensorFlowToJax(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]:
|
||||
if (not config.enable_x64.value and
|
||||
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
|
||||
raise self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
if (jtu.test_device_matches(["gpu"]) and
|
||||
not tf.config.list_physical_devices("GPU")):
|
||||
@ -138,8 +139,8 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
)
|
||||
@unittest.skipIf(not tf, "Test requires TensorFlow")
|
||||
def testJaxToTensorFlow(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64,
|
||||
jnp.float64]:
|
||||
if (not config.enable_x64.value and
|
||||
dtype in [jnp.int64, jnp.uint64, jnp.float64]):
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
if (jtu.test_device_matches(["gpu"]) and
|
||||
not tf.config.list_physical_devices("GPU")):
|
||||
@ -159,7 +160,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
# See https://github.com/google/jax/issues/11895
|
||||
x = jax.dlpack.from_dlpack(
|
||||
tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64)))
|
||||
dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32
|
||||
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
|
||||
self.assertEqual(x.dtype, dtype_expected)
|
||||
|
||||
@jtu.sample_product(
|
||||
|
@ -35,7 +35,6 @@ from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
class CacheKeyTest(jtu.JaxTestCase):
|
||||
|
@ -21,15 +21,15 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src.lib import xla_extension
|
||||
from jax import config
|
||||
from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
from jax.sharding import NamedSharding
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
|
||||
from jax._src.lib import xla_extension
|
||||
import jax.numpy as jnp
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -1305,12 +1305,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
class LowerableChecksTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.prev = config.jax_experimental_unsafe_xla_runtime_errors
|
||||
config.update("jax_experimental_unsafe_xla_runtime_errors", True)
|
||||
|
||||
def tearDown(self):
|
||||
config.update("jax_experimental_unsafe_xla_runtime_errors", self.prev)
|
||||
super().tearDown()
|
||||
self.enter_context(config.xla_runtime_errors(True))
|
||||
|
||||
@jtu.run_on_devices("cpu", "gpu")
|
||||
def test_jit(self):
|
||||
|
@ -23,18 +23,15 @@ import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import jit
|
||||
from jax import lax
|
||||
from jax import pmap
|
||||
from jax._src import compilation_cache as cc
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import monitoring
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import persistent_cache_min_compile_time_secs
|
||||
from jax._src.config import raise_persistent_cache_errors
|
||||
from jax._src.config import use_original_compilation_cache_key_generation
|
||||
from jax._src.lib import xla_client
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
@ -43,7 +40,6 @@ import numpy as np
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
FAKE_COMPILE_TIME = 10
|
||||
_counts = Counter() # Map event name to count
|
||||
@ -228,9 +224,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.initialize_cache(tmpdir)
|
||||
f = jit(lambda x: x * x)
|
||||
|
||||
with raise_persistent_cache_errors(False), mock.patch.object(
|
||||
cc._cache.__class__, "put"
|
||||
) as mock_put, warnings.catch_warnings(record=True) as w:
|
||||
with (
|
||||
config.raise_persistent_cache_errors(False),
|
||||
mock.patch.object(cc._cache.__class__, "put") as mock_put,
|
||||
warnings.catch_warnings(record=True) as w,
|
||||
):
|
||||
mock_put.side_effect = RuntimeError("test error")
|
||||
self.assertEqual(f(2), 4)
|
||||
self.assertLen(w, 1)
|
||||
@ -247,9 +245,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.initialize_cache(tmpdir)
|
||||
f = jit(lambda x: x * x)
|
||||
|
||||
with raise_persistent_cache_errors(False), mock.patch.object(
|
||||
cc._cache.__class__, "get"
|
||||
) as mock_get, warnings.catch_warnings(record=True) as w:
|
||||
with (
|
||||
config.raise_persistent_cache_errors(False),
|
||||
mock.patch.object(cc._cache.__class__, "get") as mock_get,
|
||||
warnings.catch_warnings(record=True) as w,
|
||||
):
|
||||
mock_get.side_effect = RuntimeError("test error")
|
||||
self.assertEqual(f(2), 4)
|
||||
if len(w) > 1:
|
||||
@ -264,8 +264,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_min_compile_time(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
|
||||
2
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmpdir,
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
@ -282,8 +283,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertEqual(files_in_cache, 1)
|
||||
|
||||
def test_cache_saving_metric(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
|
||||
2):
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmpdir,
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
durations = Counter() # Map metric name to time duration.
|
||||
@ -346,8 +349,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cache_misses_metric(self):
|
||||
previous_counts = Counter(_counts)
|
||||
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
|
||||
2):
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmpdir,
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
# Mock time to create a long compilation time and make cache misses.
|
||||
@ -362,8 +367,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cache_hits_original_metric(self):
|
||||
previous_counts = Counter(_counts)
|
||||
with tempfile.TemporaryDirectory() as tmpdir, persistent_cache_min_compile_time_secs(
|
||||
2), use_original_compilation_cache_key_generation(True):
|
||||
with (
|
||||
tempfile.TemporaryDirectory() as tmpdir,
|
||||
config.persistent_cache_min_compile_time_secs(2),
|
||||
config.use_original_compilation_cache_key_generation(True),
|
||||
):
|
||||
cc.initialize_cache(tmpdir)
|
||||
|
||||
# Mock time to create a long compilation time, cache saved.
|
||||
|
@ -24,17 +24,14 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import dtypes
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
||||
from jax._src.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
bool_dtypes = [np.dtype('bool')]
|
||||
|
||||
np_signed_dtypes = [np.dtype('int8'), np.dtype('int16'), np.dtype('int32'),
|
||||
@ -102,7 +99,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
True: _EXPECTED_CANONICALIZE_X64,
|
||||
False: _EXPECTED_CANONICALIZE_X32,
|
||||
}
|
||||
for in_dtype, expected_dtype in expected[config.x64_enabled].items():
|
||||
for in_dtype, expected_dtype in expected[config.enable_x64.value].items():
|
||||
self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -371,7 +368,7 @@ class DtypesTest(jtu.JaxTestCase):
|
||||
dtypes.dtype(None)
|
||||
|
||||
def testDefaultDtypes(self):
|
||||
precision = config.jax_default_dtype_bits
|
||||
precision = config.default_dtype_bits.value
|
||||
assert precision in ['32', '64']
|
||||
self.assertEqual(dtypes.bool_, np.bool_)
|
||||
self.assertEqual(dtypes.int_, np.int32 if precision == '32' else np.int64)
|
||||
@ -450,7 +447,7 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
# Note: * here refers to weakly-typed values
|
||||
typecodes = \
|
||||
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*']
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
expected = [
|
||||
['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*'],
|
||||
['u1','u1','u2','u4','u8','i2','i2','i4','i8','bf','f2','f4','f8','c4','c8','u1','f*','c*'],
|
||||
|
@ -32,7 +32,6 @@ from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
|
||||
|
@ -22,16 +22,14 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import grad, jit, vmap, lax
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import traceback_util
|
||||
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def get_exception(etype, f):
|
||||
@ -43,7 +41,7 @@ def get_exception(etype, f):
|
||||
|
||||
def check_filtered_stack_trace(test, etype, f, frame_patterns=(),
|
||||
filter_mode="remove_frames"):
|
||||
with jax_config.traceback_filtering(filter_mode):
|
||||
with config.traceback_filtering(filter_mode):
|
||||
test.assertRaises(etype, f)
|
||||
e = get_exception(etype, f)
|
||||
c = e.__cause__
|
||||
|
@ -11,21 +11,26 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import Optional, Sequence
|
||||
from typing import Sequence
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.config import config
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental import pjit
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -38,6 +43,17 @@ import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
# This will control the CPU devices. On TPU we always have 2 devices
|
||||
prev_xla_flags = jtu.set_host_platform_device_count(2)
|
||||
|
||||
# Reset to previous configuration in case other test modules will be run.
|
||||
def tearDownModule():
|
||||
prev_xla_flags()
|
||||
|
||||
|
||||
# A primitive for testing multi-platform lowering. Takes one argument and
|
||||
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
|
||||
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
|
||||
@ -72,28 +88,39 @@ for platform in ["cpu", "tpu", "rocm"]:
|
||||
def _testing_multi_platform_func(x):
|
||||
return _testing_multi_platform_p.bind(x)
|
||||
|
||||
def _testing_multi_platform_fun_expected(x):
|
||||
return x + _testing_multi_platform_to_add[xb.canonicalize_platform(jtu.device_under_test())]
|
||||
|
||||
def _testing_multi_platform_fun_expected(x,
|
||||
platform: str | None = None):
|
||||
return x + _testing_multi_platform_to_add[
|
||||
xb.canonicalize_platform(platform or jtu.device_under_test())
|
||||
]
|
||||
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
def override_serialization_version(self, version_override: int):
|
||||
version = config.jax_serialization_version
|
||||
version = config.jax_serialization_version.value
|
||||
if version != version_override:
|
||||
self.addCleanup(functools.partial(config.update,
|
||||
"jax_serialization_version",
|
||||
version_override))
|
||||
config.update("jax_serialization_version", version_override)
|
||||
self.enter_context(config.jax_serialization_version(version_override))
|
||||
logging.info(
|
||||
"Using JAX serialization version %s",
|
||||
config.jax_serialization_version)
|
||||
config.jax_serialization_version.value)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Find the available platforms
|
||||
cls.platforms = []
|
||||
for backend in ["cpu", "gpu", "tpu"]:
|
||||
try:
|
||||
jax.devices(backend)
|
||||
except RuntimeError:
|
||||
continue
|
||||
cls.platforms.append(backend)
|
||||
super(JaxExportTest, cls).setUpClass()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Run tests with the maximum supported version by default
|
||||
self.override_serialization_version(
|
||||
export.maximum_supported_serialization_version)
|
||||
export.maximum_supported_serialization_version)
|
||||
|
||||
def test_basic_export_only(self):
|
||||
def my_fun(x):
|
||||
@ -351,7 +378,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
export.poly_spec((3, 4), np.float32, "w, h"))
|
||||
# Peek at the module
|
||||
module_str = exp.mlir_module()
|
||||
self.assertEqual(config.jax_serialization_version >= 7,
|
||||
self.assertEqual(config.jax_serialization_version.value >= 7,
|
||||
"shape_assertion" in module_str)
|
||||
self.assertIn("jax.uses_shape_polymorphism = true",
|
||||
module_str)
|
||||
@ -586,7 +613,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
)),
|
||||
])
|
||||
def test_shape_constraints_errors(self, *,
|
||||
shape, poly_spec: str, expect_error: Optional[str] = None):
|
||||
shape, poly_spec: str, expect_error: str | None = None):
|
||||
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
|
||||
return 0.
|
||||
|
||||
@ -600,49 +627,170 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
export.poly_spec(x.shape, x.dtype, poly_spec))
|
||||
export.call_exported(exp)(x)
|
||||
|
||||
def test_with_sharding(self):
|
||||
nr_devices = 2
|
||||
if len(jax.devices()) < nr_devices:
|
||||
self.skipTest("Need at least 2 devices")
|
||||
export_devices = jax.devices()[0:nr_devices]
|
||||
export_mesh = Mesh(export_devices, axis_names=("x",))
|
||||
a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4))
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),),),
|
||||
out_shardings=jax.sharding.NamedSharding(export_mesh, P(None, "x")))
|
||||
def f_jax(b): # b: f32[16 // DEVICES, 4]
|
||||
return b * 2.
|
||||
|
||||
res_native = f_jax(a)
|
||||
exp = export.export(f_jax)(a)
|
||||
|
||||
run_devices = export_devices[::-1] # We can use other devices
|
||||
run_mesh = Mesh(run_devices, "y")
|
||||
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P()))
|
||||
|
||||
expected_re = re.compile(
|
||||
# The top-level input it replicated
|
||||
r"func.func .* @main\(%arg0: tensor<16x4xf32> {mhlo.sharding = \"{replicated}\"}\).*"
|
||||
# We apply the in_shardings for f_jax
|
||||
r".*custom_call @Sharding\(%arg0\) {mhlo.sharding = \"{devices=\[2,1\]<=\[2\]}\"}.*"
|
||||
r"%1 = .*call @call_exported_f_jax.*"
|
||||
# We apply the out_shardings for f_jax
|
||||
r".*custom_call @Sharding\(%1\) {mhlo.sharding = \"{devices=\[1,2\]<=\[2\]}\"}.*",
|
||||
re.DOTALL)
|
||||
hlo = jax.jit(export.call_exported(exp)).lower(a_device).as_text()
|
||||
self.assertRegex(hlo, expected_re)
|
||||
|
||||
res_exported = export.call_exported(exp)(a_device)
|
||||
self.assertAllClose(res_native, res_exported)
|
||||
|
||||
# Test error reporting
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
_ = export.call_exported(exp)(a)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
|
||||
_ = jax.jit(
|
||||
export.call_exported(exp),
|
||||
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
|
||||
)(a)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_in_shardings={in_shardings}_out_shardings={out_shardings}",
|
||||
in_shardings=in_shardings, out_shardings=out_shardings)
|
||||
for in_shardings in ("missing", None, "P")
|
||||
for out_shardings in ("missing", None, "P")
|
||||
])
|
||||
def test_grad_with_sharding(self, in_shardings="P", out_shardings=None):
|
||||
if len(jax.devices()) < 2:
|
||||
self.skipTest("Test requires at least 2 devices")
|
||||
x_shape = (10, 20)
|
||||
x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||
def f_jax(x): # x: f32[10,20] -> f32[20,10]
|
||||
return jnp.sin(x.T)
|
||||
|
||||
pjit_kwargs = {}
|
||||
if in_shardings != "missing":
|
||||
pjit_kwargs["in_shardings"] = (P(None, "x") if in_shardings == "P" else None)
|
||||
if out_shardings != "missing":
|
||||
pjit_kwargs["out_shardings"] = (P("x", None) if out_shardings == "P" else None)
|
||||
f_jax = pjit.pjit(f_jax, **pjit_kwargs)
|
||||
|
||||
with Mesh(jax.devices()[:2], "x"):
|
||||
exp = export.export(f_jax)(x)
|
||||
exp_vjp = exp.vjp()
|
||||
|
||||
vjp_module_str = str(exp_vjp.mlir_module())
|
||||
|
||||
if in_shardings == "P":
|
||||
primal_in_sharding = "{devices=[1,2]<=[2]}"
|
||||
else:
|
||||
primal_in_sharding = "{replicated}"
|
||||
if out_shardings == "P":
|
||||
primal_out_sharding = "{devices=[2,1]<=[2]}"
|
||||
else:
|
||||
primal_out_sharding = "{replicated}"
|
||||
|
||||
main = re.search(
|
||||
r"func.func public @main\(%arg0: tensor<10x20xf32> {mhlo.sharding = \"([^\"]+)\""
|
||||
r".*%arg1: tensor<20x10xf32> {mhlo.sharding = \"([^\"]+)\""
|
||||
# result
|
||||
r".* -> \(tensor<10x20xf32>.*mhlo.sharding = \"([^\"]+)\"",
|
||||
vjp_module_str)
|
||||
self.assertEqual(
|
||||
main.groups(),
|
||||
(primal_in_sharding, primal_out_sharding, primal_in_sharding))
|
||||
|
||||
# Custom calls for the primal input shape
|
||||
primal_in_calls = re.findall(
|
||||
r"custom_call @Sharding.* {mhlo.sharding = \"(.+)\"} : .*tensor<10x20xf32>",
|
||||
vjp_module_str)
|
||||
self.assertTrue(
|
||||
all(s == primal_in_sharding for s in primal_in_calls),
|
||||
primal_in_calls
|
||||
)
|
||||
|
||||
# Custom calls for the primal output shape
|
||||
primal_out_calls = re.findall(
|
||||
r"custom_call @Sharding.* {mhlo.sharding = \"(.+)\"} : .*tensor<20x10xf32>",
|
||||
vjp_module_str)
|
||||
self.assertTrue(
|
||||
all(s == primal_out_sharding for s in primal_out_calls),
|
||||
primal_in_calls
|
||||
)
|
||||
|
||||
def test_multi_platform(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
x = np.arange(8, dtype=np.float32)
|
||||
exp = export.export(_testing_multi_platform_func,
|
||||
lowering_platforms=('cpu', 'tpu'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
|
||||
lowering_platforms=("cpu", "tpu", "cuda"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
|
||||
module_str = str(exp.mlir_module())
|
||||
expected_main_re = (
|
||||
r"@main\("
|
||||
r"%arg0: tensor<i..> {jax.platform_index = true}.*, "
|
||||
r"%arg1: tensor<5xf32>.* ->")
|
||||
r"%arg1: tensor<8xf32>.* ->")
|
||||
self.assertRegex(module_str, expected_main_re)
|
||||
|
||||
self.assertIn("jax.uses_shape_polymorphism = true",
|
||||
module_str)
|
||||
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, _testing_multi_platform_fun_expected(x))
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = export.call_exported(exp)(x_device)
|
||||
self.assertAllClose(
|
||||
res_exp,
|
||||
_testing_multi_platform_fun_expected(x, platform=platform))
|
||||
|
||||
def test_multi_platform_nested(self):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
# The outer export is not applicable to TPU
|
||||
raise unittest.SkipTest("Not intended for running on TPU")
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
exp = export.export(_testing_multi_platform_func,
|
||||
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
|
||||
lowering_platforms=("cpu", "tpu", "cuda"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
|
||||
|
||||
# Now serialize the call to the exported using a different sequence of
|
||||
# lowering platforms, but included in the lowering platforms for the
|
||||
# nested exported.
|
||||
exp2 = export.export(export.call_exported(exp),
|
||||
lowering_platforms=('cpu', 'cuda'))(x)
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
||||
lowering_platforms=("cpu", "cuda"))(x)
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
if platform == "tpu": continue
|
||||
x_device = jax.device_put(x, jax.devices(platform)[0])
|
||||
res_exp = export.call_exported(exp2)(x_device)
|
||||
self.assertAllClose(
|
||||
res_exp,
|
||||
_testing_multi_platform_fun_expected(x, platform=platform))
|
||||
|
||||
def test_multi_platform_nested_inside_single_platform_export(self):
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
exp = export.export(_testing_multi_platform_func,
|
||||
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
|
||||
lowering_platforms=("cpu", "tpu", "cuda"))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
|
||||
|
||||
# Now serialize the call for the current platform.
|
||||
exp2 = export.export(export.call_exported(exp))(x)
|
||||
@ -657,7 +805,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
exp = export.export(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)),
|
||||
lowering_platforms=('cpu', 'tpu'))(
|
||||
lowering_platforms=("cpu", "tpu"))(
|
||||
export.poly_spec((5, 6), np.float32, "b1, b2")
|
||||
)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
@ -668,6 +816,31 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
||||
|
||||
def test_multi_platform_and_sharding(self):
|
||||
export_devices = jax.devices()[0:2]
|
||||
export_mesh = Mesh(export_devices, axis_names=("x",))
|
||||
a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4))
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
in_shardings=(jax.sharding.NamedSharding(export_mesh, P("x", None),),),
|
||||
out_shardings=jax.sharding.NamedSharding(export_mesh, P(None, "x")))
|
||||
def f_jax(b): # b: f32[16 // DEVICES, 4]
|
||||
return b * 2.
|
||||
|
||||
res_native = f_jax(a)
|
||||
exp = export.export(f_jax,
|
||||
lowering_platforms=("cpu", "tpu", "cuda"))(a)
|
||||
|
||||
# Call with argument placed on different plaforms
|
||||
for platform in self.__class__.platforms:
|
||||
run_devices = jax.devices(platform)[0:len(export_devices)]
|
||||
if len(run_devices) != len(export_devices):
|
||||
continue
|
||||
run_mesh = Mesh(run_devices, ("x",))
|
||||
a_device = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, None))
|
||||
res_exp = export.call_exported(exp)(a_device)
|
||||
self.assertArraysAllClose(res_native, res_exp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -23,11 +23,11 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.numpy.util import promote_dtypes_complex
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FFT_NORMS = [None, "ortho", "forward", "backward"]
|
||||
@ -129,7 +129,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters((np.float32,), (np.float64,))
|
||||
def testLaxIrfftDoesNotMutateInputs(self, dtype):
|
||||
if dtype == np.float64 and not config.x64_enabled:
|
||||
if dtype == np.float64 and not config.enable_x64.value:
|
||||
raise self.skipTest("float64 requires jax_enable_x64=true")
|
||||
x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]],
|
||||
dtype=dtypes.to_complex_dtype(dtype))
|
||||
@ -162,7 +162,7 @@ class FftTest(jtu.JaxTestCase):
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(jnp_fn, args_maker)
|
||||
# Test gradient for differentiable types.
|
||||
if (config.x64_enabled and
|
||||
if (config.enable_x64.value and
|
||||
dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
|
||||
# TODO(skye): can we be more precise?
|
||||
tol = 0.15
|
||||
|
@ -28,17 +28,17 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import ad_checkpoint
|
||||
from jax._src import core
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental import pjit
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.experimental import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import core
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
xops = xla_client.ops
|
||||
@ -46,7 +46,6 @@ xops = xla_client.ops
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
class _TestingOutputStream:
|
||||
@ -333,7 +332,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
( 6.00 9.00 )""")
|
||||
|
||||
def test_tap_eval_exception(self):
|
||||
if not FLAGS.jax_host_callback_outfeed:
|
||||
if not hcb._HOST_CALLBACK_OUTFEED.value:
|
||||
raise SkipTest("TODO: implement error handling for customcall")
|
||||
# Simulate a tap error
|
||||
def tap_err(*args, **kwargs):
|
||||
@ -818,7 +817,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(100, count)
|
||||
|
||||
def test_tap_jit_tap_exception(self):
|
||||
if not FLAGS.jax_host_callback_outfeed:
|
||||
if not hcb._HOST_CALLBACK_OUTFEED.value:
|
||||
raise SkipTest("TODO: implement error handling for customcall")
|
||||
# Simulate a tap error
|
||||
def tap_err(*args, **kwargs):
|
||||
@ -1541,7 +1540,7 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(device_index=[0, 1])
|
||||
def test_tap_pjit(self, device_index=0):
|
||||
if (device_index != 0 and
|
||||
not FLAGS.jax_host_callback_outfeed and
|
||||
not hcb._HOST_CALLBACK_OUTFEED.value and
|
||||
jtu.test_device_matches(["cpu"])):
|
||||
# See comment in host_callback.py.
|
||||
raise SkipTest("device_index works only with outfeed on CPU")
|
||||
|
@ -39,7 +39,6 @@ except ImportError:
|
||||
tf = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def call_tf_no_ad(tf_fun: Callable, arg, *, result_shape):
|
||||
|
@ -17,14 +17,14 @@ import inspect
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import api_util
|
||||
from jax._src.interpreters import pxla
|
||||
from jax import dtypes
|
||||
from jax._src import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import lib as jaxlib
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
from jax._src.interpreters import pxla
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -136,7 +136,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype)
|
||||
|
||||
# Complex
|
||||
if not (config.x64_enabled and jtu.test_device_matches(["tpu"])):
|
||||
if not (config.enable_x64.value and jtu.test_device_matches(["tpu"])):
|
||||
# No TPU support for complex128.
|
||||
res = np.asarray(_cpp_device_put(1 + 1j, device))
|
||||
self.assertEqual(res, 1 + 1j)
|
||||
@ -145,7 +145,7 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
|
||||
def test_arg_signature_of_value(self):
|
||||
"""Tests the C++ code-path."""
|
||||
jax_enable_x64 = config.x64_enabled
|
||||
jax_enable_x64 = config.enable_x64.value
|
||||
|
||||
# 1. Numpy scalar types
|
||||
for dtype in jtu.supported_dtypes():
|
||||
|
@ -19,20 +19,20 @@ import warnings
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -297,8 +297,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.old_x64 = config.jax_enable_x64
|
||||
config.update('jax_enable_x64', False)
|
||||
self.enter_context(config.enable_x64(False))
|
||||
self._old_lowering = mlir._lowerings[effect_p]
|
||||
def _effect_lowering(ctx, *, effect):
|
||||
if effects.ordered_effects.contains(effect):
|
||||
@ -315,7 +314,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dispatch.runtime_tokens.clear()
|
||||
config.update('jax_enable_x64', self.old_x64)
|
||||
mlir.register_lowering(effect_p, self._old_lowering)
|
||||
|
||||
def test_can_lower_lowerable_effect(self):
|
||||
|
@ -20,10 +20,10 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import jit, make_jaxpr, numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import jaxpr_util
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import test_util as jtu
|
||||
from jax import config
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -67,15 +67,15 @@ class JaxprStatsTest(jtu.JaxTestCase):
|
||||
|
||||
hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr)
|
||||
|
||||
t = '64' if config.x64_enabled else '32'
|
||||
t = '64' if config.enable_x64.value else '32'
|
||||
shapes = [
|
||||
f'add :: float{t}[]',
|
||||
f'sin :: float{t}[]',
|
||||
f'cos :: float{t}[]',
|
||||
f'reduce_sum :: float{t}[]',
|
||||
f'concatenate :: float{t}[2]',
|
||||
f'pjit :: float{t}[] *',
|
||||
]
|
||||
shapes.append(f'pjit :: float{t}[] *')
|
||||
for k in shapes:
|
||||
self.assertEqual(hist[k], 1)
|
||||
|
||||
|
@ -33,7 +33,6 @@ from jax.test_util import check_grads
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
compatible_shapes = [[(3,)],
|
||||
|
@ -30,13 +30,13 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import ops
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# We disable the whitespace continuation check in this file because otherwise it
|
||||
@ -60,7 +60,7 @@ class IndexSpec(typing.NamedTuple):
|
||||
|
||||
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
|
||||
# TODO(mattjj,dougalm): add higher-order check
|
||||
default_tol = 1e-6 if config.x64_enabled else 1e-2
|
||||
default_tol = 1e-6 if config.enable_x64.value else 1e-2
|
||||
atol = atol or default_tol
|
||||
rtol = rtol or default_tol
|
||||
eps = eps or default_tol
|
||||
|
@ -31,12 +31,11 @@ import jax.ops
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
|
||||
@ -612,7 +611,7 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
|
||||
np.issubdtype(shift_dtype, np.signedinteger)
|
||||
has_32 = any(np.iinfo(d).bits == 32 for d in dtypes)
|
||||
promoting_to_64 = has_32 and signed_mix
|
||||
if promoting_to_64 and not config.x64_enabled:
|
||||
if promoting_to_64 and not config.enable_x64.value:
|
||||
self.skipTest("np.right_shift/left_shift promoting to int64"
|
||||
"differs from jnp in 32 bit mode.")
|
||||
|
||||
|
@ -31,8 +31,7 @@ from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
|
@ -43,17 +43,16 @@ from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.test_util import check_grads
|
||||
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
||||
from jax._src.util import safe_zip, NumpyComplexWarning
|
||||
from jax._src import array
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
numpy_version = jtu.numpy_version()
|
||||
|
||||
@ -2333,7 +2332,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testFrexp(self, shape, dtype, rng_factory):
|
||||
# integer types are converted to float64 in numpy's implementation
|
||||
if (dtype not in [jnp.bfloat16, np.float16, np.float32]
|
||||
and not config.x64_enabled):
|
||||
and not config.enable_x64.value):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -2405,7 +2404,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v)
|
||||
self.assertEqual(out_int32.dtype, np.int32)
|
||||
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v)
|
||||
self.assertEqual(out_int64.dtype, np.int64)
|
||||
else:
|
||||
@ -3160,7 +3159,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
jnp.array(val)
|
||||
|
||||
# explicit uint64 should work
|
||||
if config.x64_enabled:
|
||||
if config.enable_x64.value:
|
||||
self.assertEqual(np.uint64(val), jnp.array(val, dtype='uint64'))
|
||||
|
||||
def testArrayFromList(self):
|
||||
@ -3534,7 +3533,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
|
||||
shape=[(0,), (32,), (2, 16)],
|
||||
a_dtype=all_dtypes,
|
||||
dtype=(*all_dtypes, None) if config.x64_enabled else all_dtypes,
|
||||
dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes,
|
||||
)
|
||||
def testView(self, shape, a_dtype, dtype):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
@ -4645,7 +4644,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
endpoint, base, dtype):
|
||||
if (dtype in int_dtypes and
|
||||
jtu.test_device_matches(["gpu", "tpu"]) and
|
||||
not config.x64_enabled):
|
||||
not config.enable_x64.value):
|
||||
raise unittest.SkipTest("GPUx32 truncated exponentiation"
|
||||
" doesn't exactly match other platforms.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -4724,29 +4723,17 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
def testDisableNumpyRankPromotionBroadcasting(self):
|
||||
try:
|
||||
prev_flag = config._read('jax_numpy_rank_promotion')
|
||||
FLAGS.jax_numpy_rank_promotion = "allow"
|
||||
with jax.numpy_rank_promotion('allow'):
|
||||
jnp.ones(2) + jnp.ones((1, 2)) # works just fine
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = prev_flag
|
||||
|
||||
try:
|
||||
prev_flag = config._read('jax_numpy_rank_promotion')
|
||||
FLAGS.jax_numpy_rank_promotion = "raise"
|
||||
with jax.numpy_rank_promotion('raise'):
|
||||
self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2)))
|
||||
jnp.ones(2) + 3 # don't want to raise for scalars
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = prev_flag
|
||||
|
||||
try:
|
||||
prev_flag = config._read('jax_numpy_rank_promotion')
|
||||
FLAGS.jax_numpy_rank_promotion = "warn"
|
||||
with jax.numpy_rank_promotion('warn'):
|
||||
self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on "
|
||||
r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2)))
|
||||
jnp.ones(2) + 3 # don't want to warn for scalars
|
||||
finally:
|
||||
FLAGS.jax_numpy_rank_promotion = prev_flag
|
||||
|
||||
@unittest.skip("Test fails on CI, perhaps due to JIT caching")
|
||||
def testDisableNumpyRankPromotionBroadcastingDecorator(self):
|
||||
@ -5077,7 +5064,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
def testDefaultDtypes(self):
|
||||
precision = config.jax_default_dtype_bits
|
||||
precision = config.default_dtype_bits.value
|
||||
assert precision in ['32', '64']
|
||||
self.assertEqual(jnp.bool_, np.bool_)
|
||||
self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64)
|
||||
|
@ -26,7 +26,6 @@ from jax._src.numpy.ufunc_api import get_if_single_primitive
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def scalar_add(x, y):
|
||||
|
@ -31,7 +31,7 @@ from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
float_types = jtu.dtypes.floating
|
||||
|
@ -28,7 +28,6 @@ from jax.scipy import special as lsp_special
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
|
@ -23,7 +23,6 @@ from absl.testing import absltest
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
linear_sizes = [16, 97, 128]
|
||||
|
@ -37,7 +37,6 @@ from jax.scipy import cluster as lsp_cluster
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
compatible_shapes = [[(), ()],
|
||||
|
@ -34,21 +34,21 @@ from jax.test_util import check_grads
|
||||
from jax import tree_util
|
||||
import jax.util
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import xla
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import lax_reference
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
dtype=rec.dtypes)
|
||||
for rec in lax_test_util.lax_ops()))
|
||||
def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol):
|
||||
if (not config.x64_enabled and op_name == "nextafter"
|
||||
if (not config.enable_x64.value and op_name == "nextafter"
|
||||
and dtype == np.float64):
|
||||
raise SkipTest("64-bit mode disabled")
|
||||
rng = rng_factory(self.rng())
|
||||
@ -293,7 +293,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testConvPreferredElement(self, lhs_shape, rhs_shape, dtype, preferred_element_type):
|
||||
if (not config.x64_enabled and
|
||||
if (not config.enable_x64.value and
|
||||
(dtype == np.float64 or preferred_element_type == np.float64
|
||||
or dtype == np.int64 or preferred_element_type == np.int64
|
||||
or dtype == np.complex128 or preferred_element_type == np.complex128)):
|
||||
@ -1033,7 +1033,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
)
|
||||
def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
|
||||
preferred_element_type):
|
||||
if (not config.x64_enabled and
|
||||
if (not config.enable_x64.value and
|
||||
(dtype == np.float64 or preferred_element_type == np.float64
|
||||
or dtype == np.int64 or preferred_element_type == np.int64)):
|
||||
raise SkipTest("64-bit mode disabled")
|
||||
@ -1682,7 +1682,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
],
|
||||
)
|
||||
def testReduce(self, op, reference_op, init_val, shape, dtype, dims, primitive):
|
||||
if not config.x64_enabled and dtype in (np.float64, np.int64, np.uint64):
|
||||
if not config.enable_x64.value and dtype in (np.float64, np.int64, np.uint64):
|
||||
raise SkipTest("x64 mode is disabled.")
|
||||
def reference_fun(operand):
|
||||
if hasattr(reference_op, "reduce"):
|
||||
@ -2622,7 +2622,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
def testRngBitGenerator(self):
|
||||
# This test covers the original behavior of lax.rng_bit_generator, which
|
||||
# required x64=True, and only checks shapes and jit invariance.
|
||||
if not config.x64_enabled:
|
||||
if not config.enable_x64.value:
|
||||
raise SkipTest("RngBitGenerator requires 64bit key")
|
||||
|
||||
key = np.array((1, 2)).astype(np.uint64)
|
||||
|
@ -29,8 +29,6 @@ from jax._src import util
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
|
@ -36,7 +36,6 @@ from jax._src.util import safe_map, safe_zip
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -29,13 +29,12 @@ from jax import jit, grad, jvp, vmap
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
scipy_version = tuple(map(int, scipy.version.version.split('.')[:3]))
|
||||
|
||||
@ -1619,7 +1618,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype):
|
||||
if ((rdtype in [np.float64, np.complex128]
|
||||
or cdtype in [np.float64, np.complex128])
|
||||
and not config.x64_enabled):
|
||||
and not config.enable_x64.value):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
|
||||
int_types_excl_i8 = set(int_types) - {np.int8}
|
||||
@ -1640,7 +1639,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.skip_on_devices("rocm")
|
||||
def testToeplitzSymmetricConstruction(self, shape, dtype):
|
||||
if (dtype in [np.float64, np.complex128]
|
||||
and not config.x64_enabled):
|
||||
and not config.enable_x64.value):
|
||||
self.skipTest("Only run float64 testcase when float64 is enabled.")
|
||||
|
||||
int_types_excl_i8 = set(int_types) - {np.int8}
|
||||
|
@ -28,7 +28,7 @@ import jax._src.test_util as jtu
|
||||
# parsing to work correctly with bazel (otherwise we could avoid importing
|
||||
# absltest/absl logging altogether).
|
||||
from absl.testing import absltest
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -27,7 +27,7 @@ from jax import numpy as jnp
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
npr.seed(0)
|
||||
|
||||
|
||||
|
@ -29,10 +29,11 @@ import jax
|
||||
from jax import config
|
||||
from jax._src import core
|
||||
from jax._src import distributed
|
||||
import jax.numpy as jnp
|
||||
from jax._src import maps
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax.experimental import pjit
|
||||
import jax.numpy as jnp
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
@ -258,7 +259,7 @@ class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.xmap_spmd_lowering_enabled = jax.config.experimental_xmap_spmd_lowering
|
||||
self.xmap_spmd_lowering_enabled = maps._SPMD_LOWERING.value
|
||||
jax.config.update("experimental_xmap_spmd_lowering", True)
|
||||
|
||||
def tearDown(self):
|
||||
|
@ -34,7 +34,7 @@ from jax import random
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class NNFunctionsTest(jtu.JaxTestCase):
|
||||
|
@ -25,12 +25,12 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax._src import config
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import state
|
||||
from jax._src.lax.control_flow.for_loop import for_loop
|
||||
from jax._src.pallas.pallas_call import _trace_to_jaxpr
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
@ -139,6 +139,7 @@ class PallasTest(parameterized.TestCase):
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
||||
|
||||
class PallasCallTest(PallasTest):
|
||||
|
||||
def test_add_one(self):
|
||||
@ -213,8 +214,9 @@ class PallasCallTest(PallasTest):
|
||||
out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32),
|
||||
grid=1)
|
||||
def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref):
|
||||
mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])[:, None, None]
|
||||
o_ref[...] = jnp.where(mask, x_ref[in_idx_ref[()]], 0)
|
||||
mask = (jnp.arange(o_ref.shape[0]) == out_idx_ref[()])
|
||||
o_ref[...] = jnp.where(jax.lax.broadcast_in_dim(mask, (4, 2, 2), (0,)),
|
||||
x_ref[in_idx_ref[()]], 0)
|
||||
|
||||
x = jnp.arange(7 * 2 * 2.).reshape(7, 2, 2)
|
||||
for ii in range(7):
|
||||
@ -225,7 +227,6 @@ class PallasCallTest(PallasTest):
|
||||
np.testing.assert_allclose(out[oi], x[ii])
|
||||
np.testing.assert_allclose(out[oi + 1:], jnp.zeros_like(out[oi + 1:]))
|
||||
|
||||
|
||||
@parameterized.parameters(*[
|
||||
((), (2,), ()),
|
||||
((1,), (2,), (0,)),
|
||||
@ -712,9 +713,32 @@ class PallasCallTest(PallasTest):
|
||||
np.testing.assert_allclose(lock, 0)
|
||||
np.testing.assert_allclose(count, num_threads)
|
||||
|
||||
def test_custom_jvp_call(self):
|
||||
@functools.partial(jax.custom_jvp, nondiff_argnums=(1,))
|
||||
def softmax(x, axis=-1):
|
||||
unnormalized = jnp.exp(x - jnp.max(x, axis, keepdims=True))
|
||||
return unnormalized / jnp.sum(unnormalized, axis, keepdims=True)
|
||||
|
||||
@softmax.defjvp
|
||||
def softmax_jvp(axis, primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
y = softmax(x, axis)
|
||||
return y, y * (x_dot - (y * x_dot).sum(axis, keepdims=True))
|
||||
|
||||
m, n = 16, 32
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
|
||||
@functools.partial(self.pallas_call, out_shape=x, grid=1)
|
||||
def softmax_kernel(x_ref, y_ref):
|
||||
y_ref[:] = softmax(x_ref[:])
|
||||
|
||||
np.testing.assert_allclose(softmax_kernel(x), jax.nn.softmax(x), atol=1e-7)
|
||||
|
||||
|
||||
class PallasCallInterpreterTest(PallasCallTest):
|
||||
INTERPRET = True
|
||||
|
||||
|
||||
class PallasControlFlowTest(PallasTest):
|
||||
|
||||
def setUp(self):
|
||||
@ -727,9 +751,7 @@ class PallasControlFlowTest(PallasTest):
|
||||
# fori_loop handles i64 index variables, i.e. error: 'scf.for' op along
|
||||
# control flow edge from Region #0 to Region #0: source type #0
|
||||
# 'tensor<4xf64>' should match input type #0 'tensor<4xf32>'
|
||||
orig_val = jax.config.jax_enable_x64
|
||||
jax.config.update("jax_enable_x64", True)
|
||||
try:
|
||||
with config.enable_x64(True):
|
||||
@functools.partial(self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((4,), jnp.float64),
|
||||
grid=1,
|
||||
@ -744,8 +766,6 @@ class PallasControlFlowTest(PallasTest):
|
||||
|
||||
np.testing.assert_allclose(np.arange(1, 5.) * 3,
|
||||
f(jnp.arange(1, 5., dtype=jnp.float64)))
|
||||
finally:
|
||||
jax.config.update("jax_enable_x64", orig_val)
|
||||
|
||||
def test_cond_simple(self):
|
||||
arg = jnp.float32(0.)
|
||||
|
@ -31,6 +31,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import maps
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax import stages
|
||||
@ -60,9 +61,10 @@ from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2, safe_zip
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
prev_spmd_lowering_flag = None
|
||||
|
||||
|
||||
def setUpModule():
|
||||
@ -75,7 +77,9 @@ def setUpModule():
|
||||
" --xla_force_host_platform_device_count=8")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
global prev_spmd_lowering_flag
|
||||
prev_spmd_lowering_flag = maps._SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
def tearDownModule():
|
||||
if prev_xla_flags is None:
|
||||
@ -83,8 +87,7 @@ def tearDownModule():
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
config.update('experimental_xmap_spmd_lowering', prev_spmd_lowering_flag)
|
||||
|
||||
|
||||
def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
|
||||
|
@ -32,30 +32,28 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax._src.lax import parallel
|
||||
from jax._src import api as src_api
|
||||
from jax import random
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax._src import config as jax_config
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax import tree_util
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint
|
||||
import jax.numpy as jnp
|
||||
from jax._src import api as src_api
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint
|
||||
from jax._src.lax import parallel
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
@ -125,7 +123,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
|
||||
if devices is None:
|
||||
devices = jax.devices()
|
||||
|
||||
pmap_sharding = PmapSharding(np.array(devices), sharding_spec)
|
||||
pmap_sharding = jax.sharding.PmapSharding(np.array(devices), sharding_spec)
|
||||
|
||||
return array.make_array_from_callback(
|
||||
input_shape, pmap_sharding, lambda idx: input_data[idx]), input_data
|
||||
@ -188,7 +186,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# the default order of pmap for single-host jobs.
|
||||
device_order = jax.devices()
|
||||
pmap_sharding = pmap(lambda x: x)(np.arange(jax.device_count())).sharding
|
||||
if jax.config.jax_pmap_shmap_merge:
|
||||
if config.pmap_shmap_merge.value:
|
||||
self.assertListEqual(device_order, pmap_sharding._device_assignment)
|
||||
else:
|
||||
self.assertListEqual(device_order, pmap_sharding.devices.tolist())
|
||||
@ -1283,7 +1281,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
expected = np.repeat(3, device_count)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
if not config.jax_disable_jit:
|
||||
if not config.disable_jit.value:
|
||||
f = self.pmap(lambda x: (x, 3))
|
||||
x = np.arange(device_count)
|
||||
with jtu.assert_num_jit_and_pmap_compilations(1):
|
||||
@ -1307,7 +1305,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# Test that 'ans' was properly replicated across devices.
|
||||
ans_devices = ans.sharding._device_assignment
|
||||
# TODO(mattjj,sharadmv): fix physical layout with eager pmap, remove 'if'
|
||||
if not config.jax_disable_jit:
|
||||
if not config.disable_jit.value:
|
||||
self.assertEqual(ans_devices, tuple(devices))
|
||||
|
||||
def testPmapConstantError(self):
|
||||
@ -1371,7 +1369,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertTrue(ans.sharding == expected_sharded.sharding)
|
||||
|
||||
def testNestedPmapConstantError(self):
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
raise SkipTest("error test doesn't apply with disable_jit")
|
||||
f = self.pmap(self.pmap(lambda x: 3))
|
||||
shape = (2, jax.device_count() // 2 + 1, 3)
|
||||
@ -1837,7 +1835,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testJitOfPmapWarningMessage(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
if device_count == 1 or config.jax_disable_jit:
|
||||
if device_count == 1 or config.disable_jit.value:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
|
||||
def foo(x): return x
|
||||
@ -1853,7 +1851,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testJitOfPmapOutputSharding(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
if device_count == 1 or config.jax_disable_jit:
|
||||
if device_count == 1 or config.disable_jit.value:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
|
||||
@jax.jit
|
||||
@ -1872,7 +1870,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testJitOfPmapLowerHasReplicaAttributes(self):
|
||||
device_count = jax.device_count()
|
||||
|
||||
if device_count == 1 or config.jax_disable_jit:
|
||||
if device_count == 1 or config.disable_jit.value:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
|
||||
@jax.jit
|
||||
@ -2037,7 +2035,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def test_grad_of_pmap_compilation_caching(self, axis_size):
|
||||
if len(jax.local_devices()) < axis_size:
|
||||
raise SkipTest("too few devices for test")
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
raise SkipTest("caching doesn't apply with jit disabled")
|
||||
|
||||
@jax.pmap
|
||||
@ -2059,7 +2057,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
|
||||
|
||||
def testSizeOverflow(self):
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
# TODO(sharadmv, mattjj): investigate and fix this issue
|
||||
raise SkipTest("OOMs in eager mode")
|
||||
x = jnp.arange(1)
|
||||
@ -2171,7 +2169,7 @@ class CppPmapTest(PythonPmapTest):
|
||||
|
||||
@property
|
||||
def pmap(self):
|
||||
if jax.config.jax_pmap_shmap_merge:
|
||||
if config.pmap_shmap_merge.value:
|
||||
return src_api.pmap
|
||||
return src_api._cpp_pmap
|
||||
|
||||
@ -2206,7 +2204,7 @@ class CppPmapTest(PythonPmapTest):
|
||||
pmaped_f(inputs)
|
||||
self.assertEqual(pmaped_f._cache_size, 1)
|
||||
|
||||
jax_config.update_thread_local_jit_state()
|
||||
config.update_thread_local_jit_state()
|
||||
|
||||
pmaped_f(inputs)
|
||||
self.assertEqual(pmaped_f._cache_size, 1)
|
||||
@ -2511,7 +2509,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
f(jnp.ones(jax.device_count() + 1))
|
||||
|
||||
def testBadAxisSizeErrorNested(self):
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
raise SkipTest("error doesn't apply when jit is disabled")
|
||||
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
|
||||
axis_name='j'),
|
||||
@ -2526,7 +2524,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def testNestedPmaps(self):
|
||||
if jax.device_count() % 2 != 0:
|
||||
raise SkipTest
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
raise SkipTest("disable_jit requires num devices to equal axis size")
|
||||
|
||||
# Devices specified in outer pmap are OK
|
||||
@ -2545,7 +2543,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def testNestedPmapsBools(self):
|
||||
if jax.device_count() % 2 != 0:
|
||||
raise SkipTest
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
raise SkipTest("disable_jit requires num devices to equal axis size")
|
||||
|
||||
# Devices specified in outer pmap are OK
|
||||
@ -3150,7 +3148,7 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(w, jnp.cos(jnp.sin(x) ** 2))
|
||||
|
||||
def test_same_out_sharding_id(self):
|
||||
if config.jax_disable_jit:
|
||||
if config.disable_jit.value:
|
||||
self.skipTest('Skip this under eager pmap mode.')
|
||||
shape = (jax.device_count(), 2)
|
||||
arr, inp_data = create_input_array_for_pmap(shape)
|
||||
@ -3215,8 +3213,8 @@ class EagerPmapMixin:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.eager_pmap_enabled = config.jax_eager_pmap
|
||||
self.jit_disabled = config.jax_disable_jit
|
||||
self.eager_pmap_enabled = config.eager_pmap.value
|
||||
self.jit_disabled = config.disable_jit.value
|
||||
config.update('jax_disable_jit', True)
|
||||
config.update('jax_eager_pmap', True)
|
||||
|
||||
|
@ -26,13 +26,13 @@ from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import maps
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import io_callback
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -40,7 +40,7 @@ import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
import numpy as np
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def _format_multiline(text):
|
||||
@ -632,8 +632,10 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
|
||||
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
|
||||
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
jtu.set_spmd_manual_lowering_flag(True)
|
||||
spmd_lowering = maps._SPMD_LOWERING.value
|
||||
spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', True)
|
||||
try:
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
||||
|
||||
@ -659,8 +661,11 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
out, np.sin(np.arange(jax.local_device_count()))
|
||||
)
|
||||
finally:
|
||||
jtu.restore_spmd_manual_lowering_flag()
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
config.update('experimental_xmap_spmd_lowering', spmd_lowering)
|
||||
config.update(
|
||||
'experimental_xmap_spmd_lowering_manual',
|
||||
spmd_manual_lowering,
|
||||
)
|
||||
|
||||
def test_cant_take_grad_of_pure_callback(self):
|
||||
|
||||
|
@ -17,13 +17,13 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
import jax.dlpack
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -69,8 +69,11 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
|
||||
def testJaxToTorch(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
|
||||
jnp.complex128]:
|
||||
if not config.enable_x64.value and dtype in [
|
||||
jnp.int64,
|
||||
jnp.float64,
|
||||
jnp.complex128,
|
||||
]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np = rng(shape, dtype)
|
||||
@ -89,7 +92,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
if xla_extension_version < 186:
|
||||
self.skipTest("Need xla_extension_version >= 186")
|
||||
|
||||
if not config.x64_enabled and dtype in [
|
||||
if not config.enable_x64.value and dtype in [
|
||||
jnp.int64,
|
||||
jnp.float64,
|
||||
jnp.complex128,
|
||||
@ -113,13 +116,16 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
# See https://github.com/google/jax/issues/11895
|
||||
x = jax.dlpack.from_dlpack(
|
||||
torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64)))
|
||||
dtype_expected = jnp.int64 if config.x64_enabled else jnp.int32
|
||||
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
|
||||
self.assertEqual(x.dtype, dtype_expected)
|
||||
|
||||
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
|
||||
def testTorchToJax(self, shape, dtype):
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
|
||||
jnp.complex128]:
|
||||
if not config.enable_x64.value and dtype in [
|
||||
jnp.int64,
|
||||
jnp.float64,
|
||||
jnp.complex128,
|
||||
]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -142,8 +148,11 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
if xla_extension_version < 191:
|
||||
self.skipTest("Need xla_extension_version >= 191")
|
||||
|
||||
if not config.x64_enabled and dtype in [jnp.int64, jnp.float64,
|
||||
jnp.complex128]:
|
||||
if not config.enable_x64.value and dtype in [
|
||||
jnp.int64,
|
||||
jnp.float64,
|
||||
jnp.complex128,
|
||||
]:
|
||||
self.skipTest("x64 types are disabled by jax_enable_x64")
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -16,18 +16,18 @@
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.linalg as osp_linalg
|
||||
from jax._src.lax import qdwh
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import qdwh
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
_JAX_ENABLE_X64_QDWH = config.x64_enabled
|
||||
_JAX_ENABLE_X64_QDWH = config.enable_x64.value
|
||||
|
||||
# Input matrix data type for QdwhTest.
|
||||
_QDWH_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64_QDWH else np.float32
|
||||
|
@ -37,7 +37,7 @@ from jax import vmap
|
||||
|
||||
from jax._src import prng as prng_internal
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
float_dtypes = jtu.dtypes.all_floating
|
||||
complex_dtypes = jtu.dtypes.complex
|
||||
|
@ -40,7 +40,7 @@ from jax.interpreters import xla
|
||||
from jax._src import random as jax_random
|
||||
from jax._src import prng as prng_internal
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
PRNG_IMPLS = list(prng_internal.prngs.items())
|
||||
|
@ -43,7 +43,7 @@ import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -1382,7 +1382,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
|
||||
def test_eager_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
@ -1391,7 +1391,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
|
||||
def test_jit_against_ref(self, fun, mesh, _, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
@ -1401,7 +1401,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(name + f'_check_rep={check_rep}', *params, check_rep)
|
||||
for (name, *params) in sample(config.FLAGS.jax_num_generated_cases, sample_shmap)
|
||||
for (name, *params) in sample(jtu._NUM_GENERATED_CASES.value, sample_shmap)
|
||||
for check_rep in [True, False]
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
@ -1414,7 +1414,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
sample(jtu._NUM_GENERATED_CASES.value, sample_shmap))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads_closure(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
mesh = self.make_mesh(mesh)
|
||||
@ -1433,7 +1433,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (0.2, *closed_over_args), order=2, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases,
|
||||
sample(jtu._NUM_GENERATED_CASES.value,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap(self, bdims, fun, mesh, jit, in_specs, out_specs, args, ref):
|
||||
mesh = self.make_mesh(mesh)
|
||||
@ -1456,7 +1456,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases,
|
||||
sample(jtu._NUM_GENERATED_CASES.value,
|
||||
partial(sample_shmap_batched, 5)))
|
||||
def test_vmap_closure(self, bdims, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
mesh = self.make_mesh(mesh)
|
||||
|
@ -41,7 +41,6 @@ from jax.util import split_list
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
COMPATIBLE_SHAPE_PAIRS = [
|
||||
[(), ()],
|
||||
|
@ -46,7 +46,6 @@ import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
||||
|
||||
|
@ -48,7 +48,7 @@ from jax._src.state.primitives import (get_p, swap_p, addupdate_p,
|
||||
from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
|
||||
AccumEffect, AbstractRef)
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
@ -831,7 +831,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@hp.given(get_vmap_params())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=config.FLAGS.jax_num_generated_cases)
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
def test_get_vmap(self, get_vmap_param: GetVmapParams):
|
||||
|
||||
indexed_dims = get_vmap_param.vmap_index_param.index_param.indexed_dims
|
||||
@ -870,7 +870,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@hp.given(set_vmap_params())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=config.FLAGS.jax_num_generated_cases)
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
def test_set_vmap(self, set_vmap_param: SetVmapParams):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Scatter is nondeterministic on GPU")
|
||||
@ -915,7 +915,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
|
||||
@hp.given(set_vmap_params())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=config.FLAGS.jax_num_generated_cases)
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
def test_addupdate_vmap(self, set_vmap_param: SetVmapParams):
|
||||
|
||||
indexed_dims = set_vmap_param.vmap_index_param.index_param.indexed_dims
|
||||
@ -1538,7 +1538,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
@jax.legacy_prng_key('allow')
|
||||
@hp.given(hps.data())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=config.FLAGS.jax_num_generated_cases)
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
def test_jvp(self, data):
|
||||
|
||||
spec = data.draw(func_spec())
|
||||
@ -1563,7 +1563,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
@jax.legacy_prng_key('allow')
|
||||
@hp.given(hps.data())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=config.FLAGS.jax_num_generated_cases)
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
def test_linearize(self, data):
|
||||
|
||||
spec = data.draw(func_spec())
|
||||
@ -1589,7 +1589,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
@jax.legacy_prng_key('allow')
|
||||
@hp.given(hps.data())
|
||||
@hp.settings(deadline=None, print_blob=True,
|
||||
max_examples=config.FLAGS.jax_num_generated_cases)
|
||||
max_examples=jtu._NUM_GENERATED_CASES.value)
|
||||
def test_vjp(self, data):
|
||||
|
||||
spec = data.draw(func_spec())
|
||||
|
@ -16,18 +16,18 @@
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import scipy.linalg as osp_linalg
|
||||
from jax._src.lax import svd
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import svd
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
_JAX_ENABLE_X64 = config.x64_enabled
|
||||
_JAX_ENABLE_X64 = config.enable_x64.value
|
||||
|
||||
# Input matrix data type for SvdTest.
|
||||
_SVD_TEST_DTYPE = np.float64 if _JAX_ENABLE_X64 else np.float32
|
||||
|
@ -23,7 +23,6 @@ from jax._src import util
|
||||
from jax import config
|
||||
from jax._src.util import weakref_lru_cache
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
try:
|
||||
from jax._src.lib import utils as jaxlib_utils
|
||||
|
@ -20,17 +20,16 @@ import warnings
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax._src import compiler
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
mock = absltest.mock
|
||||
|
||||
@ -65,7 +64,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
# --jax_xla_profile_version takes precedence.
|
||||
jax_flag_profile = 1
|
||||
another_profile = 2
|
||||
with jax_config.jax_xla_profile_version(jax_flag_profile):
|
||||
with config.jax_xla_profile_version(jax_flag_profile):
|
||||
with mock.patch.object(compiler, "get_latest_profile_version",
|
||||
side_effect=lambda: another_profile):
|
||||
self.assertEqual(
|
||||
@ -283,10 +282,10 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
fail_quietly=False, experimental=experimental)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._orig_factories = xb._backend_factories
|
||||
xb._backend_factories = {}
|
||||
self._orig_jax_platforms = config._read("jax_platforms")
|
||||
config.FLAGS.jax_platforms = ""
|
||||
self.enter_context(config.jax_platforms(""))
|
||||
self._save_backend_state()
|
||||
self._reset_backend_state()
|
||||
|
||||
@ -294,8 +293,8 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
self._register_factory("cpu", 0)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
xb._backend_factories = self._orig_factories
|
||||
config.FLAGS.jax_platforms = self._orig_jax_platforms
|
||||
self._restore_backend_state()
|
||||
|
||||
def _save_backend_state(self):
|
||||
@ -380,10 +379,7 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
self._register_factory("platform_A", 20, assert_used_at_most_once=True)
|
||||
self._register_factory("platform_B", 10, assert_used_at_most_once=True)
|
||||
|
||||
orig_jax_platforms = config._read("jax_platforms")
|
||||
try:
|
||||
config.FLAGS.jax_platforms = "cpu,platform_A"
|
||||
|
||||
with config.jax_platforms("cpu,platform_A"):
|
||||
backend = xb.get_backend()
|
||||
self.assertEqual(backend.platform, "cpu")
|
||||
# Only specified backends initialized.
|
||||
@ -395,10 +391,6 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Unknown backend platform_B"):
|
||||
backend = xb.get_backend("platform_B")
|
||||
|
||||
finally:
|
||||
config.FLAGS.jax_platforms = orig_jax_platforms
|
||||
|
||||
|
||||
def test_experimental_warning(self):
|
||||
self._register_factory("platform_A", 20, experimental=True)
|
||||
|
||||
|
@ -33,24 +33,24 @@ import jax.scipy as jscipy
|
||||
from jax._src import test_util as jtu
|
||||
from jax import vmap
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src.core import NamedShape
|
||||
from jax.experimental import maps
|
||||
from jax._src import array
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap, serial_loop, SerialLoop
|
||||
from jax.ad_checkpoint import checkpoint
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src.nn import initializers as nn_initializers
|
||||
from jax.experimental.maps import xmap, serial_loop, SerialLoop
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.interpreters import batching
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import maps
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import unzip2
|
||||
from jax._src.core import NamedShape
|
||||
from jax._src.lax import parallel as lax_parallel
|
||||
from jax._src.lax.parallel import pgather
|
||||
from jax.interpreters import batching
|
||||
from jax.ad_checkpoint import checkpoint
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.nn import initializers as nn_initializers
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.util import unzip2
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -246,10 +246,11 @@ class XMapTestCase(jtu.BufferDonationTestCase):
|
||||
class SPMDTestMixin:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
self.spmd_lowering = maps._SPMD_LOWERING.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
|
||||
def tearDown(self):
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
|
||||
|
||||
|
||||
class ManualSPMDTestMixin:
|
||||
@ -257,12 +258,14 @@ class ManualSPMDTestMixin:
|
||||
if not hasattr(xla_client.OpSharding.Type, "MANUAL"):
|
||||
raise SkipTest
|
||||
super().setUp()
|
||||
jtu.set_spmd_lowering_flag(True)
|
||||
jtu.set_spmd_manual_lowering_flag(True)
|
||||
self.spmd_lowering = maps._SPMD_LOWERING.value
|
||||
self.spmd_manual_lowering = maps._SPMD_LOWERING_MANUAL.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', True)
|
||||
|
||||
def tearDown(self):
|
||||
jtu.restore_spmd_manual_lowering_flag()
|
||||
jtu.restore_spmd_lowering_flag()
|
||||
config.update('experimental_xmap_spmd_lowering', self.spmd_lowering)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', self.spmd_manual_lowering)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
@ -433,7 +436,7 @@ class XMapTest(XMapTestCase):
|
||||
m_size = math.prod([2] + [2] * (len(mesh) - 2))
|
||||
self.assertListEqual(y_op_sharding.tile_assignment_dimensions(),
|
||||
[2, 1, 1, m_size])
|
||||
if config.experimental_xmap_spmd_lowering:
|
||||
if maps._SPMD_LOWERING.value:
|
||||
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
# Make sure that there are non-partial sharding specs in the HLO
|
||||
if xla_extension_version >= 180:
|
||||
@ -746,7 +749,7 @@ class XMapTest(XMapTestCase):
|
||||
axis_resources={'i': 'x'})
|
||||
x = jnp.arange(4, dtype=jnp.float32).reshape((2, 2))
|
||||
hlo = f.lower(x).as_text(dialect='stablehlo')
|
||||
if config.experimental_xmap_spmd_lowering:
|
||||
if maps._SPMD_LOWERING.value:
|
||||
self.assertIn("mhlo.num_partitions = 2", hlo)
|
||||
self.assertIn("mhlo.num_replicas = 1", hlo)
|
||||
else:
|
||||
@ -1201,7 +1204,7 @@ class NewPrimitiveTest(XMapTestCase):
|
||||
|
||||
@jtu.with_and_without_mesh
|
||||
def testGather(self, mesh, axis_resources):
|
||||
if axis_resources and not config.experimental_xmap_spmd_lowering:
|
||||
if axis_resources and not maps._SPMD_LOWERING.value:
|
||||
raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
|
||||
x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
|
||||
y = jnp.arange(35).reshape((5, 7)) % 3
|
||||
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -20,8 +20,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "50baaece6db8e61ada5db1f49f19a9d7dc1fa297"
|
||||
XLA_SHA256 = "c09aefa72d931909ddac059c6313008e209b9a363050dddbc2d3d348fda881a7"
|
||||
XLA_COMMIT = "500c965b04709f15008c26b46df2f8406279d730"
|
||||
XLA_SHA256 = "c0f98502d28e30f2905b56e8893b635e2081ad4ddb57009117fa0dc83bdf6d8e"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user