Move contents of jax.lib to jax._src.lib.

Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
This commit is contained in:
Peter Hawkins 2021-09-23 06:33:25 -07:00 committed by jax authors
parent 94cd1ea0a2
commit 2c2f4033cc
69 changed files with 1035 additions and 975 deletions

View File

@ -505,7 +505,7 @@
"colab": {}
},
"source": [
"from jax.lib import xla_bridge\n",
"from jax._src.lib import xla_bridge\n",
"device_count = xla_bridge.device_count()\n",
"\n",
"def send_right(x, axis_name):\n",

View File

@ -2087,8 +2087,8 @@
},
"outputs": [],
"source": [
"from jax.lib import xla_bridge as xb\n",
"from jax.lib import xla_client as xc\n",
"from jax._src.lib import xla_bridge as xb\n",
"from jax._src.lib import xla_client as xc\n",
"xe = xc._xla\n",
"xops = xc._xla.ops\n",
"\n",

View File

@ -1544,8 +1544,8 @@ class IDHashable:
Next, we'll define the evaluation rule for `xla_call`:
```{code-cell}
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops

View File

@ -1481,8 +1481,8 @@ class IDHashable:
# Next, we'll define the evaluation rule for `xla_call`:
# +
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops

View File

@ -149,7 +149,7 @@
" def pp(v):\n",
" \"\"\"Print certain values more succinctly\"\"\"\n",
" vtype = str(type(v))\n",
" if \"jax.lib.xla_bridge._JaxComputationBuilder\" in vtype:\n",
" if \"jax._src.lib.xla_bridge._JaxComputationBuilder\" in vtype:\n",
" return \"<JaxComputationBuilder>\"\n",
" elif \"jaxlib.xla_extension.XlaOp\" in vtype:\n",
" return \"<XlaOp at 0x{:x}>\".format(id(v))\n",
@ -614,7 +614,7 @@
},
"outputs": [],
"source": [
"from jax.lib import xla_client\n",
"from jax._src.lib import xla_client\n",
"@trace(\"multiply_add_xla_translation\")\n",
"def multiply_add_xla_translation(c, xc, yc, zc):\n",
" \"\"\"The compilation to XLA of the primitive.\n",

View File

@ -118,7 +118,7 @@ def trace(name):
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax.lib.xla_bridge._JaxComputationBuilder" in vtype:
if "jax._src.lib.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
@ -354,7 +354,7 @@ for most of them. However, XLA includes a `CustomCall` operation that can be use
```{code-cell} ipython3
:id: FYQWSSjKJaWP
from jax.lib import xla_client
from jax._src.lib import xla_client
@trace("multiply_add_xla_translation")
def multiply_add_xla_translation(c, xc, yc, zc):
"""The compilation to XLA of the primitive.

View File

@ -74,7 +74,7 @@
"import numpy as np\n",
"\n",
"# We only need to import JAX's xla_client, not all of JAX.\n",
"from jax.lib import xla_client as xc\n",
"from jax._src.lib import xla_client as xc\n",
"xops = xc.ops\n",
"\n",
"# Plotting\n",

View File

@ -66,7 +66,7 @@ https://github.com/google/jax/blob/main/jax/interpreters/xla.py
import numpy as np
# We only need to import JAX's xla_client, not all of JAX.
from jax.lib import xla_client as xc
from jax._src.lib import xla_client as xc
xops = xc.ops
# Plotting

View File

@ -27,9 +27,9 @@ import time
import numpy as np
import numpy.random as npr
import jax
from jax import jit, grad, pmap
from jax.scipy.special import logsumexp
from jax.lib import xla_bridge
from jax.tree_util import tree_map
from jax import lax
import jax.numpy as jnp
@ -77,7 +77,7 @@ if __name__ == "__main__":
# For this manual SPMD example, we get the number of devices (e.g. GPUs or
# TPU cores) that we're using, and use it to reshape data minibatches.
num_devices = xla_bridge.device_count()
num_devices = jax.device_count()
def data_stream():
rng = npr.RandomState(0)
while True:

View File

@ -130,6 +130,8 @@ from . import random as random
from . import tree_util as tree_util
from . import util as util
import jax.lib # TODO(phawkins): remove this export.
def _init():
from . import numpy as numpy # side-effecting import sets up operator overloads

View File

@ -40,7 +40,6 @@ from contextlib import contextmanager, ExitStack
import jax
from .. import core
from .. import lib
from .. import linear_util as lu
from . import dtypes
from ..core import eval_jaxpr
@ -57,13 +56,13 @@ from ..tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
treedef_is_leaf, treedef_children, Partial)
from .util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, wrap_name, cache, wraps, HashableFunction)
from ..lib import jax_jit
from ..lib import version
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..lib import pmap_lib
from jax._src.lib import jax_jit
from jax._src.lib import version
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
# Unused imports to be exported
from ..lib.xla_bridge import (device_count, local_device_count, devices,
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index, process_count,
host_id, host_ids, host_count, default_backend)
from ..core import ConcreteArray, ShapedArray, raise_to_shaped

View File

@ -24,8 +24,8 @@ import threading
from typing import Any, List, Callable, NamedTuple, Optional
import warnings
from jax import lib
from jax.lib import jax_jit
from jax._src import lib
from jax._src.lib import jax_jit
def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.

View File

@ -15,8 +15,8 @@
from jax import core
from jax import numpy as jnp
from jax.interpreters import xla
from jax.lib import xla_client
from jax.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_bridge
SUPPORTED_DTYPES = set([jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,

View File

@ -27,7 +27,7 @@ import numpy as np
from jax._src import util
from jax._src.config import flags, config
from jax.lib import xla_client
from jax._src.lib import xla_client
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)

View File

@ -43,8 +43,8 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import masking
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
split_list, cache, extend_name_stack)

View File

@ -23,10 +23,10 @@ from jax.interpreters import xla
from jax._src.util import prod
from jax._src import dtypes
from jax import lax
from jax.lib import xla_client
from jax.interpreters import ad
from jax.interpreters import batching
from jax.lib import pocketfft
from jax._src.lib import xla_client
from jax._src.lib import pocketfft
xops = xla_client.ops

View File

@ -50,9 +50,10 @@ from jax.interpreters import masking
from jax._src.util import (cache, safe_zip, prod, safe_map, canonicalize_axis,
split_list)
from jax.tree_util import tree_map
from jax.lib import pytree
from jax.lib import xla_bridge
from jax.lib import xla_client
import jax._src.lib
from jax._src.lib import pytree
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
xb = xla_bridge
xc = xla_client
@ -2631,7 +2632,7 @@ ad.defjvp2(rsqrt_p,
# TODO(phawkins): remove the fallback translation rule after the minimum jaxlib
# is 0.1.70 or newer.
if jax.lib._xla_extension_version >= 28:
if jax._src.lib._xla_extension_version >= 28:
_cbrt_translation_rule = None
else:
def _cbrt_translation_rule(c, x):

View File

@ -33,15 +33,15 @@ from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype, _broadcasting_select)
from jax._src.lax import lax as lax_internal
from jax.lib import lapack
from jax._src.lib import lapack
from jax.lib import cuda_linalg
from jax.lib import cusolver
from jax.lib import cusparse
from jax.lib import rocsolver
from jax._src.lib import cuda_linalg
from jax._src.lib import cusolver
from jax._src.lib import cusparse
from jax._src.lib import rocsolver
from jax.lib import xla_client
from jax.lib import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_bridge as xb
xops = xla_client.ops

View File

@ -23,7 +23,6 @@ import warnings
import numpy as np
from jax import core
from jax._src import dtypes
from jax import tree_util
from . import lax
from jax.core import ShapedArray, AxisName, raise_to_shaped
@ -31,10 +30,11 @@ from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
from jax.lib import xla_client as xc
from jax.lib import xla_bridge as xb
from jax._src import dtypes
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.numpy import lax_numpy
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
xops = xc.ops

122
jax/_src/lib/__init__.py Normal file
View File

@ -0,0 +1,122 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This module is largely a wrapper around `jaxlib` that performs version
# checking on import.
import platform
import os
import warnings
from typing import Optional
__all__ = [
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
'pocketfft', 'pytree', 'tpu_driver_client', 'version', 'xla_client'
]
# First, before attempting to import jaxlib, warn about experimental machine
# configurations.
if platform.system() == "Darwin" and platform.machine() == "arm64":
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
"Please see https://github.com/google/jax/issues/5501 in the "
"event of problems.")
try:
import jaxlib
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
'jax requires jaxlib to be installed. See '
'https://github.com/google/jax#installation for installation instructions.'
) from err
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
try:
from jaxlib import version as jaxlib_version
except Exception as err:
# jaxlib is too old to have version number.
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
raise ImportError(msg) from err
version = tuple(int(x) for x in jaxlib_version.__version__.split('.'))
_minimum_jaxlib_version = tuple(int(x) for x in _minimum_jaxlib_version_str.split('.'))
# Check the jaxlib version before importing anything else from jaxlib.
def _check_jaxlib_version():
if version < _minimum_jaxlib_version:
msg = (f'jaxlib is version {jaxlib_version.__version__}, '
f'but this version of jax requires version {_minimum_jaxlib_version_str}.')
if version == (0, 1, 23):
msg += ('\n\nA common cause of this error is that you installed jaxlib '
'using pip, but your version of pip is too old to support '
'manylinux2010 wheels. Try running:\n\n'
'pip install --upgrade pip\n'
'pip install --upgrade jax jaxlib\n')
raise ValueError(msg)
_check_jaxlib_version()
from jaxlib import cpu_feature_guard
cpu_feature_guard.check_cpu_features()
from jaxlib import xla_client
from jaxlib import lapack
from jaxlib import pocketfft
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
try:
from jaxlib import cusolver # pytype: disable=import-error
except ImportError:
cusolver = None
try:
from jaxlib import cusparse # pytype: disable=import-error
except ImportError:
cusparse = None
try:
from jaxlib import rocsolver # pytype: disable=import-error
except ImportError:
rocsolver = None
try:
from jaxlib import cuda_prng # pytype: disable=import-error
except ImportError:
cuda_prng = None
try:
from jaxlib import cuda_linalg # pytype: disable=import-error
except ImportError:
cuda_linalg = None
# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version
# number that can be used to perform changes without breaking the main
# branch on the Jax github.
_xla_extension_version = getattr(xla_client, '_version', 0)
try:
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
except:
tpu_driver_client = None # type: ignore
cuda_path: Optional[str]
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
if not os.path.isdir(cuda_path):
cuda_path = None

595
jax/_src/lib/xla_bridge.py Normal file
View File

@ -0,0 +1,595 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface and utility functions to XLA.
This module wraps the XLA client(s) and builders to standardize their interfaces
and provide some automatic type mapping logic for converting between Numpy and
XLA. There are also a handful of related casting utilities.
"""
from functools import partial, lru_cache
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
from absl import logging
# Disable "WARNING: Logging before flag parsing goes to stderr." message
logging._warn_preinit_stderr = 0
import jax._src.lib
from jax._src.config import flags, bool_env
from . import tpu_driver_client
from . import xla_client
from jax._src import util, traceback_util
from jax._src import dtypes
import numpy as np
import threading
traceback_util.register_exclusion(__file__)
xops = xla_client.ops
FLAGS = flags.FLAGS
# TODO(phawkins): Remove jax_xla_backend.
flags.DEFINE_string(
'jax_xla_backend', '',
'jax_xla_backend is an alias for jax_platform_name. If both are '
'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.')
flags.DEFINE_string(
'jax_backend_target', '',
'Either "local" or "rpc:address" to connect to a remote service target.')
flags.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Platform name for XLA. The default is to attempt to use a GPU or TPU if '
'available, but fall back to CPU otherwise. To set the platform manually, '
'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, '
'setting the platform name to "cpu" can silence warnings that appear with '
'the default setting.')
flags.DEFINE_bool(
'jax_disable_most_optimizations',
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
def get_compile_options(
num_replicas: int,
num_partitions: int,
device_assignment=None,
use_spmd_partitioning: bool = True,
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional tuple of integers indicating the assignment of
logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
if device_assignment is not None:
logging.vlog(
2,
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
num_replicas, num_partitions, device_assignment)
device_assignment = np.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
device_assignment = device_assignment[:, None]
if num_replicas != device_assignment.shape[0]:
msg = 'device_assignment does not match num_replicas: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_replicas))
if num_partitions != device_assignment.shape[1]:
msg = 'device_assignment does not match num_partitions: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_partitions))
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
assert device_assignment.replica_count() == num_replicas
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
debug_options = compile_options.executable_build_options.debug_options
if jax._src.lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = jax._src.lib.cuda_path
if FLAGS.jax_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
return compile_options
# Backends
def _make_tpu_driver_client():
if tpu_driver_client is None:
logging.info("Remote TPU is not linked into jax; skipping remote TPU.")
return None
if FLAGS.jax_backend_target is None:
logging.info("No --jax_backend_target was provided; skipping remote TPU.")
return None
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)
def tpu_client_timer_callback(timer_secs: float):
def _log_warning():
warnings.warn(
f'TPU backend initialization is taking more than {timer_secs} seconds. '
'Did you run your code on all TPU hosts? '
'See https://jax.readthedocs.io/en/latest/multi_process.html '
'for more information.')
# Will log a warning after `timer_secs`.
t = threading.Timer(timer_secs, _log_warning)
t.start()
try:
client = xla_client.make_tpu_client()
finally:
t.cancel()
return client
# Backends, in increasing order of preference.
# We have no particular opinion about how "backends" relate to "devices". For
# example, there could be multiple backends that provide the same kind of
# device.
_backend_factories = {}
def register_backend_factory(name, factory, *, priority=0):
_backend_factories[name] = (factory, priority)
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
_default_backend = None
_backends_initialized = False
_backends : Dict[str, Any] = {}
_backends_errors : Dict[str, str] = {}
_backend_lock = threading.Lock()
def backends():
global _backends_initialized
global _backends
global _backends_errors
global _default_backend
with _backend_lock:
if _backends_initialized:
return _backends
_backends_initialized = True
default_priority = -1000
for name, (factory, priority) in _backend_factories.items():
logging.vlog(1, "Initializing backend '%s'" % name)
try:
backend = factory()
if backend is not None:
if backend.device_count() > 0:
_backends[name] = backend
util.distributed_debug_log(("Initialized backend", backend.platform),
("process_index", backend.process_index()),
("device_count", backend.device_count()),
("local_devices", backend.local_devices()))
logging.vlog(1, "Backend '%s' initialized" % name)
if priority > default_priority:
_default_backend = backend
default_priority = priority
except Exception as err:
if name in ('cpu', 'interpreter'):
# We always expect the CPU and interpreter backends to initialize
# successfully.
raise
else:
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
logging.info("Unable to initialize backend '%s': %s" % (name, err))
_backends_errors[name] = str(err)
continue
if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
logging.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
return _backends
def _get_backend_uncached(platform=None):
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
# 'backend' values are handled
if not isinstance(platform, (type(None), str)):
return platform
bs = backends()
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
or None)
if platform is not None:
backend = bs.get(platform, None)
if backend is None:
if platform in _backends_errors:
raise RuntimeError(f"Requested backend {platform}, but it failed "
f"to initialize: {_backends_errors[platform]}")
raise RuntimeError(f"Unknown backend {platform}")
return backend
else:
return _default_backend
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
def get_backend(platform=None):
return _get_backend_uncached(platform)
def get_device_backend(device=None):
"""Returns the Backend associated with `device`, or the default Backend."""
if device is not None:
return device.client
return get_backend()
def device_count(backend: Optional[str] = None) -> int:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Number of devices.
"""
return int(get_backend(backend).device_count())
def local_device_count(backend: Optional[str] = None) -> int:
"""Returns the number of devices addressable by this process."""
return int(get_backend(backend).local_device_count())
def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
"""Returns a list of all devices for a given backend.
Each device is represented by a subclass of :class:`Device` (e.g.
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :meth:`Device.process_index` to the value returned by
:py:func:`jax.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
otherwise ``'cpu'``.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
return get_backend(backend).devices()
def default_backend() -> str:
"""Returns the platform name of the default XLA backend."""
return get_backend(None).platform
def local_devices(process_index: Optional[int] = None,
backend: Optional[str] = None,
host_id: Optional[int] = None) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
if host_id is not None:
warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code.")
process_index = host_id
if process_index is None:
process_index = get_backend(backend).process_index()
if not (0 <= process_index < process_count()):
raise ValueError(f"Unknown process_index {process_index}")
return [d for d in devices(backend) if d.process_index == process_index]
def process_index(backend: Optional[str] = None) -> int:
"""Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multi-process
platforms though.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Integer process index.
"""
return get_backend(backend).process_index()
# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend=None):
warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias "
"will eventually be removed; please update your code.")
return process_index(backend)
def process_count(backend: Optional[str] = None) -> int:
"""Returns the number of JAX processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None):
warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias "
"will eventually be removed; please update your code.")
return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None):
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "
"code.")
return list(range(process_count(backend)))
### utility functions
@util.memoize
def dtype_to_etype(dtype):
"""Convert from dtype to canonical etype (reading config.x64_enabled)."""
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
@util.memoize
def supported_numpy_dtypes():
return {dtypes.canonicalize_dtype(dtype)
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}
# TODO(mattjj,frostig): try to remove this function
def normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
if hasattr(val, '__array__') or np.isscalar(val):
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
elif isinstance(val, (tuple, list)):
return tuple(normalize_to_xla_dtypes(x) for x in val)
raise TypeError('Can\'t convert to XLA: {}'.format(val))
def _numpy_array_constant(builder, value, canonicalize_types=True):
if canonicalize_types:
value = normalize_to_xla_dtypes(value)
return [xops.ConstantLiteral(builder, value)]
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
def constant_general(builder, py_val, canonicalize_types=True):
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant as a list of xla ops.
"""
for t in type(py_val).mro():
handler = _constant_handlers.get(t)
if handler: return handler(builder, py_val, canonicalize_types)
if hasattr(py_val, '__jax_array__'):
return constant(builder, py_val.__jax_array__(), canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(py_val)))
def constant(builder, py_val, canonicalize_types=True):
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
const = constant_general(builder, py_val, canonicalize_types=canonicalize_types)
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
return const[0]
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# _sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Tuple[int, ...],
None,
Tuple[Union[Tuple[int, ...], None], ...]]
def _sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xla_client.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
if sharding is None:
proto.type = xla_client.OpSharding.Type.REPLICATED
else:
proto.type = xla_client.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_devices = list(range(np.product(sharding)))
return proto
def tuple_sharding_proto(elems):
proto = xla_client.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xla_client.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
def set_sharding_proto(builder, op, sharding_proto):
"""Uses CustomCall to annotate a value as sharded."""
# "Sharding" is a built-in custom call target that acts like an identity
# function, and is used to attach an OpSharding to.
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
builder, b"Sharding", [op], builder.get_shape(op))
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def set_sharding(builder, op, sharding: SpatialSharding):
"""Uses CustomCall to annotate a value as sharded."""
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
def make_computation_builder(name):
return xla_client.XlaBuilder(name)
def register_constant_handler(type_, handler_fun):
_constant_handlers[type_] = handler_fun
_constant_handlers: Dict[type, Callable] = {}
def _ndarray_constant_handler(c, val, canonicalize_types=True):
"""Constant handler for ndarray literals, handling zero-size strides.
This function essentially calls _numpy_array_constant(val) except it has
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses np.broadcast_to which in turn
uses size-zero strides).
Args:
c: an XlaBuilder
val: an ndarray.
Returns:
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
if dtypes.result_type(val) == dtypes.float0:
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim))]
xla_val = xops.Broadcast(
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
np.take(val.shape, zero_stride_axes))
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
return [xops.Transpose(xla_val, permutation)]
else:
return _numpy_array_constant(c, val, canonicalize_types)
register_constant_handler(np.ndarray, _ndarray_constant_handler)
def _scalar_constant_handler(c, val, canonicalize_types=True):
return _numpy_array_constant(c, val, canonicalize_types)
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.bool_, np.longlong,
xla_client.bfloat16]:
register_constant_handler(scalar_type, _scalar_constant_handler)
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
if hasattr(np, "float128"):
register_constant_handler(np.float128, _scalar_constant_handler)
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
return _numpy_array_constant(c, dtype.type(val))
for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))

View File

@ -17,7 +17,7 @@ import operator
import numpy as np
from jax import lax
from jax.lib import xla_client
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from .util import _wraps
from . import lax_numpy as jnp

View File

@ -23,13 +23,13 @@ from jax import lax
from jax import core
from jax import numpy as jnp
from jax import tree_util
from jax._src.api import jit, vmap
from jax.config import config
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax.lib import cuda_prng
from jax.interpreters import batching
from jax.interpreters import xla
from jax._src.api import jit, vmap
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import cuda_prng
from jax._src.pprint_util import pp, vcat
from jax._src.util import prod

View File

@ -18,8 +18,8 @@ import threading
from typing import Callable, Optional
import warnings
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
def start_server(port: int):

View File

@ -29,7 +29,7 @@ from jax.core import NamedShape
from jax._src.api import jit, vmap
from jax._src.numpy.lax_numpy import (_arraylike, _check_arraylike,
_constant_like, _convert_and_clip_integer)
from jax.lib import xla_bridge
from jax._src.lib import xla_bridge
from jax.numpy.linalg import cholesky, svd, eigh
from jax.interpreters import ad
from jax.interpreters import batching

View File

@ -19,7 +19,7 @@ import threading
from typing import Optional, Iterator
import jax.version
from jax.lib import xla_client
from jax._src.lib import xla_client
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)

View File

@ -17,7 +17,7 @@ import traceback
import types
import jax
from jax.lib import xla_extension
from jax._src.lib import xla_extension
from jax._src import util
_exclude_paths = [__file__, util.__file__]

View File

@ -19,7 +19,7 @@ import operator as op
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, Type,
TypeVar, overload, TYPE_CHECKING)
from ..lib import pytree
from jax._src.lib import pytree
from .._src.util import safe_zip, unzip2

View File

@ -17,7 +17,8 @@ import re
import jax
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
from jax.lib import xla_client
import jax._src.lib
from jax._src.lib import xla_client
from absl import logging
from typing import Optional
@ -85,7 +86,7 @@ def get_cache_key(xla_computation, compile_options, backend) -> str:
_hash_compile_options(hash_obj, compile_options)
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing compile_options: {hash_obj.digest().hex()}")
hash_obj.update(bytes(jax.lib.version))
hash_obj.update(bytes(jax._src.lib.version))
if logging.vlog_is_on(1):
logging.vlog(1, f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}")
_hash_platform(hash_obj, backend)
@ -110,7 +111,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
hash_obj.update(compile_options_obj.device_assignment.serialize())
def _hash_executable_build_options(hash_obj, executable_obj):
if jax.lib.version >= (0, 1, 72):
if jax._src.lib.version >= (0, 1, 72):
expected_options = 31
else:
expected_options = 30
@ -126,7 +127,7 @@ def _hash_executable_build_options(hash_obj, executable_obj):
if executable_obj.device_assignment is not None:
hash_obj.update(executable_obj.device_assignment.serialize())
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
if jax.lib.version >= (0, 1, 72):
if jax._src.lib.version >= (0, 1, 72):
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)
def _hash_debug_options(hash_obj, debug_obj):

View File

@ -600,8 +600,8 @@ core.pytype_aval_mappings[BoundedInt] = _abstractify_bdint
# XLA lowering
from jax.interpreters import xla
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops

View File

@ -456,15 +456,15 @@ from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax.lib import pytree
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax.lib import xla_extension
from jax.interpreters import ad, xla, batching, masking, pxla
from jax.interpreters import partial_eval as pe
from jax._src import pprint_util as ppu
from jax._src import source_info_util
from jax._src import util
from jax._src.lib import pytree
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
import numpy as np

View File

@ -35,7 +35,7 @@ from jax._src import util
from jax._src import ad_util
from jax._src.lax.lax import _device_put_raw
from jax.interpreters import xla
from jax.lib import xla_client
from jax._src.lib import xla_client
from . import jax2tf as jax2tf_internal
import numpy as np

View File

@ -44,7 +44,7 @@ from jax.interpreters import ad
from jax.interpreters import pxla
from jax.interpreters import sharded_jit
from jax.interpreters import xla
from jax.lib import xla_client
from jax._src.lib import xla_client
from . import shape_poly

View File

@ -29,6 +29,7 @@ from jax.config import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
from jax._src import source_info_util
import jax._src.lib.xla_bridge
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -972,7 +973,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
jax_comp = jax.xla_computation(f_while)(x)
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()
print(f"JAX OPT HLO = {jax_opt_hlo}")

View File

@ -54,7 +54,7 @@ from jax import numpy as jnp
from jax._src.lax import control_flow as lax_control_flow
from jax.interpreters import xla
from jax.lib import xla_client
from jax._src.lib import xla_client
import numpy as np
@ -1515,7 +1515,7 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
lax.linalg.qr,
[RandArg(shape, dtype),
StaticArg(full_matrices)],
# See jax.lib.lapack.geqrf for the list of compatible types
# See jax._src.lib.lapack.geqrf for the list of compatible types
jax_unimplemented=[
Limitation(
"unimplemented",

View File

@ -31,6 +31,7 @@ from jax.experimental.jax2tf.tests import tf_test_util
from jax.interpreters import sharded_jit
from jax.interpreters.sharded_jit import PartitionSpec as P
import jax.numpy as jnp
import jax._src.lib.xla_bridge
import numpy as np
@ -80,12 +81,12 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
self.AssertShardingAnnotations("JAX before optimizations", jax_hlo, expected)
if jtu.device_under_test() == "tpu":
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
num_replicas = 1
device_assignment = np.arange(num_partitions * num_replicas)
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
use_spmd_partitioning = num_partitions > 1
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,

View File

@ -30,6 +30,7 @@ from jax.config import config
from jax.experimental import jax2tf
from jax.interpreters import masking
from jax._src import util
import jax._src.lib.xla_bridge
import numpy as np
import tensorflow as tf # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
@ -307,7 +308,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
stage="hlo")
logging.info(f"[{self._testMethodName}] TF NON OPT HLO\n{tf_hlo}")
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()
logging.info(f"[{self._testMethodName}] "

View File

@ -39,8 +39,8 @@ from ..interpreters import pxla
from ..interpreters import xla
from ..interpreters import batching
from ..interpreters import ad
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from .._src.util import (safe_map, safe_zip, HashableFunction,
as_hashable_function, unzip2, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name)

View File

@ -35,8 +35,8 @@ from ..interpreters import xla
from ..interpreters import batching
from ..interpreters import partial_eval as pe
from ..interpreters.sharded_jit import PartitionSpec
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from ..tree_util import tree_map, tree_flatten, tree_unflatten
from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
wrap_name, wraps, distributed_debug_log,

View File

@ -44,9 +44,9 @@ from jax import vmap
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.lib import cusparse
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax._src.lib import cusparse
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
import jax.numpy as jnp
import numpy as np
from jax.interpreters import ad

View File

@ -51,10 +51,10 @@ from .._src.util import (unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
from ..errors import JAXTypeError
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..lib import pmap_lib
from ..lib import _xla_extension_version
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import _xla_extension_version
from ..tree_util import tree_flatten, tree_map
from . import batching
from . import partial_eval as pe

View File

@ -25,8 +25,8 @@ from . import partial_eval as pe
from . import pxla
from . import xla
from .. import linear_util as lu
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from .._src.api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple
from ..tree_util import tree_flatten, tree_unflatten
from .._src.util import (extend_name_stack, wrap_name, wraps, safe_zip,

View File

@ -40,8 +40,8 @@ from jax._src.pprint_util import pp
from .._src.util import (partialmethod, cache, prod, unzip2,
extend_name_stack, wrap_name, safe_zip, safe_map,
partition_list)
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from . import partial_eval as pe
from . import ad
from . import masking

View File

@ -12,111 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# This module is largely a wrapper around `jaxlib` that performs version
# checking on import.
import platform
import os
import warnings
from typing import Optional
__all__ = [
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
'pocketfft', 'pytree', 'tpu_driver_client', 'version', 'xla_client'
]
# First, before attempting to import jaxlib, warn about experimental machine
# configurations.
if platform.system() == "Darwin" and platform.machine() == "arm64":
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
"Please see https://github.com/google/jax/issues/5501 in the "
"event of problems.")
try:
import jaxlib
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
'jax requires jaxlib to be installed. See '
'https://github.com/google/jax#installation for installation instructions.'
) from err
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
try:
from jaxlib import version as jaxlib_version
except Exception as err:
# jaxlib is too old to have version number.
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
raise ImportError(msg) from err
version = tuple(int(x) for x in jaxlib_version.__version__.split('.'))
_minimum_jaxlib_version = tuple(int(x) for x in _minimum_jaxlib_version_str.split('.'))
# Check the jaxlib version before importing anything else from jaxlib.
def _check_jaxlib_version():
if version < _minimum_jaxlib_version:
msg = (f'jaxlib is version {jaxlib_version.__version__}, '
f'but this version of jax requires version {_minimum_jaxlib_version_str}.')
if version == (0, 1, 23):
msg += ('\n\nA common cause of this error is that you installed jaxlib '
'using pip, but your version of pip is too old to support '
'manylinux2010 wheels. Try running:\n\n'
'pip install --upgrade pip\n'
'pip install --upgrade jax jaxlib\n')
raise ValueError(msg)
_check_jaxlib_version()
from jaxlib import cpu_feature_guard
cpu_feature_guard.check_cpu_features()
from jaxlib import xla_client
from jaxlib import lapack
from jaxlib import pocketfft
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
try:
from jaxlib import cusolver # pytype: disable=import-error
except ImportError:
cusolver = None
try:
from jaxlib import cusparse # pytype: disable=import-error
except ImportError:
cusparse = None
try:
from jaxlib import rocsolver # pytype: disable=import-error
except ImportError:
rocsolver = None
try:
from jaxlib import cuda_prng # pytype: disable=import-error
except ImportError:
cuda_prng = None
try:
from jaxlib import cuda_linalg # pytype: disable=import-error
except ImportError:
cuda_linalg = None
# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version
# number that can be used to perform changes without breaking the main
# branch on the Jax github.
_xla_extension_version = getattr(xla_client, '_version', 0)
try:
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
except:
tpu_driver_client = None # type: ignore
cuda_path: Optional[str]
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
if not os.path.isdir(cuda_path):
cuda_path = None
# flake8: noqa: F401
from jax._src.lib import (
xla_client as xla_client,
xla_extension as xla_extension,
)
from . import xla_bridge as xla_bridge

View File

@ -12,583 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface and utility functions to XLA.
This module wraps the XLA client(s) and builders to standardize their interfaces
and provide some automatic type mapping logic for converting between Numpy and
XLA. There are also a handful of related casting utilities.
"""
from functools import partial, lru_cache
import os
from typing import Callable, Dict, List, Optional, Tuple, Union
import warnings
from absl import logging
# Disable "WARNING: Logging before flag parsing goes to stderr." message
logging._warn_preinit_stderr = 0
import jax.lib
from .._src.config import flags, bool_env
from . import tpu_driver_client
from . import xla_client
from jax._src import util, traceback_util
from jax._src import dtypes
import numpy as np
import threading
traceback_util.register_exclusion(__file__)
xops = xla_client.ops
FLAGS = flags.FLAGS
# TODO(phawkins): Remove jax_xla_backend.
flags.DEFINE_string(
'jax_xla_backend', '',
'jax_xla_backend is an alias for jax_platform_name. If both are '
'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.')
flags.DEFINE_string(
'jax_backend_target', '',
'Either "local" or "rpc:address" to connect to a remote service target.')
flags.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Platform name for XLA. The default is to attempt to use a GPU or TPU if '
'available, but fall back to CPU otherwise. To set the platform manually, '
'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, '
'setting the platform name to "cpu" can silence warnings that appear with '
'the default setting.')
flags.DEFINE_bool(
'jax_disable_most_optimizations',
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
def get_compile_options(
num_replicas: int,
num_partitions: int,
device_assignment=None,
use_spmd_partitioning: bool = True,
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional tuple of integers indicating the assignment of
logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
if device_assignment is not None:
logging.vlog(
2,
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
num_replicas, num_partitions, device_assignment)
device_assignment = np.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
device_assignment = device_assignment[:, None]
if num_replicas != device_assignment.shape[0]:
msg = 'device_assignment does not match num_replicas: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_replicas))
if num_partitions != device_assignment.shape[1]:
msg = 'device_assignment does not match num_partitions: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_partitions))
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
assert device_assignment.replica_count() == num_replicas
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
debug_options = compile_options.executable_build_options.debug_options
if jax.lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = jax.lib.cuda_path
if FLAGS.jax_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
return compile_options
# Backends
def _make_tpu_driver_client():
if tpu_driver_client is None:
logging.info("Remote TPU is not linked into jax; skipping remote TPU.")
return None
if FLAGS.jax_backend_target is None:
logging.info("No --jax_backend_target was provided; skipping remote TPU.")
return None
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)
def tpu_client_timer_callback(timer_secs: float):
def _log_warning():
warnings.warn(
f'TPU backend initialization is taking more than {timer_secs} seconds. '
'Did you run your code on all TPU hosts? '
'See https://jax.readthedocs.io/en/latest/multi_process.html '
'for more information.')
# Will log a warning after `timer_secs`.
t = threading.Timer(timer_secs, _log_warning)
t.start()
try:
client = xla_client.make_tpu_client()
finally:
t.cancel()
return client
# Backends, in increasing order of preference.
# We have no particular opinion about how "backends" relate to "devices". For
# example, there could be multiple backends that provide the same kind of
# device.
_backend_factories = {}
def register_backend_factory(name, factory, *, priority=0):
_backend_factories[name] = (factory, priority)
register_backend_factory('interpreter', xla_client.make_interpreter_client,
priority=-100)
register_backend_factory('cpu',
partial(xla_client.make_cpu_client, use_tfrt=True),
priority=0)
register_backend_factory('tpu_driver', _make_tpu_driver_client,
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
_default_backend = None
_backends = None
_backends_errors = None
_backend_lock = threading.Lock()
def backends():
global _backends
global _backends_errors
global _default_backend
with _backend_lock:
if _backends is not None:
return _backends
default_priority = -1000
_backends = {}
_backends_errors = {}
for name, (factory, priority) in _backend_factories.items():
logging.vlog(1, "Initializing backend '%s'" % name)
try:
backend = factory()
if backend is not None:
if backend.device_count() > 0:
_backends[name] = backend
util.distributed_debug_log(("Initialized backend", backend.platform),
("process_index", backend.process_index()),
("device_count", backend.device_count()),
("local_devices", backend.local_devices()))
logging.vlog(1, "Backend '%s' initialized" % name)
if priority > default_priority:
_default_backend = backend
default_priority = priority
except Exception as err:
if name in ('cpu', 'interpreter'):
# We always expect the CPU and interpreter backends to initialize
# successfully.
raise
else:
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
logging.info("Unable to initialize backend '%s': %s" % (name, err))
_backends_errors[name] = str(err)
continue
if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
logging.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
return _backends
def _get_backend_uncached(platform=None):
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
# 'backend' values are handled
if not isinstance(platform, (type(None), str)):
return platform
bs = backends()
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
or None)
if platform is not None:
backend = bs.get(platform, None)
if backend is None:
if platform in _backends_errors:
raise RuntimeError(f"Requested backend {platform}, but it failed "
f"to initialize: {_backends_errors[platform]}")
raise RuntimeError(f"Unknown backend {platform}")
return backend
else:
return _default_backend
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
def get_backend(platform=None):
return _get_backend_uncached(platform)
def get_device_backend(device=None):
"""Returns the Backend associated with `device`, or the default Backend."""
if device is not None:
return device.client
return get_backend()
def device_count(backend: Optional[str] = None) -> int:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Number of devices.
"""
return int(get_backend(backend).device_count())
def local_device_count(backend: Optional[str] = None) -> int:
"""Returns the number of devices addressable by this process."""
return int(get_backend(backend).local_device_count())
def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
"""Returns a list of all devices for a given backend.
Each device is represented by a subclass of :class:`Device` (e.g.
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :meth:`Device.process_index` to the value returned by
:py:func:`jax.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
otherwise ``'cpu'``.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
return get_backend(backend).devices()
def default_backend() -> str:
"""Returns the platform name of the default XLA backend."""
return get_backend(None).platform
def local_devices(process_index: Optional[int] = None,
backend: Optional[str] = None,
host_id: Optional[int] = None) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
List of Device subclasses.
"""
if host_id is not None:
warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code.")
process_index = host_id
if process_index is None:
process_index = get_backend(backend).process_index()
if not (0 <= process_index < process_count()):
raise ValueError(f"Unknown process_index {process_index}")
return [d for d in devices(backend) if d.process_index == process_index]
def process_index(backend: Optional[str] = None) -> int:
"""Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multi-process
platforms though.
Args:
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
Returns:
Integer process index.
"""
return get_backend(backend).process_index()
# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend=None):
warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias "
"will eventually be removed; please update your code.")
return process_index(backend)
def process_count(backend: Optional[str] = None) -> int:
"""Returns the number of JAX processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None):
warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias "
"will eventually be removed; please update your code.")
return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None):
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "
"code.")
return list(range(process_count(backend)))
### utility functions
@util.memoize
def dtype_to_etype(dtype):
"""Convert from dtype to canonical etype (reading config.x64_enabled)."""
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
@util.memoize
def supported_numpy_dtypes():
return {dtypes.canonicalize_dtype(dtype)
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}
# TODO(mattjj,frostig): try to remove this function
def normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
if hasattr(val, '__array__') or np.isscalar(val):
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
elif isinstance(val, (tuple, list)):
return tuple(normalize_to_xla_dtypes(x) for x in val)
raise TypeError('Can\'t convert to XLA: {}'.format(val))
def _numpy_array_constant(builder, value, canonicalize_types=True):
if canonicalize_types:
value = normalize_to_xla_dtypes(value)
return [xops.ConstantLiteral(builder, value)]
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
def constant_general(builder, py_val, canonicalize_types=True):
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant as a list of xla ops.
"""
for t in type(py_val).mro():
handler = _constant_handlers.get(t)
if handler: return handler(builder, py_val, canonicalize_types)
if hasattr(py_val, '__jax_array__'):
return constant(builder, py_val.__jax_array__(), canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(py_val)))
def constant(builder, py_val, canonicalize_types=True):
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
const = constant_general(builder, py_val, canonicalize_types=canonicalize_types)
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
return const[0]
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# _sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Tuple[int, ...],
None,
Tuple[Union[Tuple[int, ...], None], ...]]
def _sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xla_client.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
if sharding is None:
proto.type = xla_client.OpSharding.Type.REPLICATED
else:
proto.type = xla_client.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_devices = list(range(np.product(sharding)))
return proto
def tuple_sharding_proto(elems):
proto = xla_client.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xla_client.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
def set_sharding_proto(builder, op, sharding_proto):
"""Uses CustomCall to annotate a value as sharded."""
# "Sharding" is a built-in custom call target that acts like an identity
# function, and is used to attach an OpSharding to.
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
builder, b"Sharding", [op], builder.get_shape(op))
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def set_sharding(builder, op, sharding: SpatialSharding):
"""Uses CustomCall to annotate a value as sharded."""
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
def make_computation_builder(name):
return xla_client.XlaBuilder(name)
def register_constant_handler(type_, handler_fun):
_constant_handlers[type_] = handler_fun
_constant_handlers: Dict[type, Callable] = {}
def _ndarray_constant_handler(c, val, canonicalize_types=True):
"""Constant handler for ndarray literals, handling zero-size strides.
This function essentially calls _numpy_array_constant(val) except it has
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses np.broadcast_to which in turn
uses size-zero strides).
Args:
c: an XlaBuilder
val: an ndarray.
Returns:
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
if dtypes.result_type(val) == dtypes.float0:
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim))]
xla_val = xops.Broadcast(
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
np.take(val.shape, zero_stride_axes))
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
return [xops.Transpose(xla_val, permutation)]
else:
return _numpy_array_constant(c, val, canonicalize_types)
register_constant_handler(np.ndarray, _ndarray_constant_handler)
def _scalar_constant_handler(c, val, canonicalize_types=True):
return _numpy_array_constant(c, val, canonicalize_types)
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.bool_, np.longlong,
xla_client.bfloat16]:
register_constant_handler(scalar_type, _scalar_constant_handler)
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
if hasattr(np, "float128"):
register_constant_handler(np.float128, _scalar_constant_handler)
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
return _numpy_array_constant(c, dtype.type(val))
for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
# flake8: noqa: F401
from jax._src.lib.xla_bridge import (
constant as constant,
default_backend as default_backend,
device_count as device_count,
get_backend as get_backend,
get_compile_options as get_compile_options,
local_device_count as local_device_count,
process_index as process_index,
register_constant_handler as register_constant_handler,
xla_client as xla_client,
_backends as _backends,
_python_scalar_handler as _python_scalar_handler,
)

View File

@ -37,7 +37,7 @@ from . import lax
from ._src.config import flags, bool_env, config
from ._src.util import prod, unzip2
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge
from jax._src.lib import xla_bridge
from .interpreters import xla
from .experimental.maps import mesh

View File

@ -72,7 +72,7 @@ from absl import app
from absl import flags
import jax
import jax.numpy as jnp
from jax.lib import xla_client
from jax._src.lib import xla_client
FLAGS = flags.FLAGS

View File

@ -17,7 +17,7 @@ cusparse wrappers for performing sparse matrix computations in JAX
import numpy as np
from jax.lib import xla_client
from jax._src.lib import xla_client
try:
from . import _cusparse

View File

@ -47,7 +47,8 @@ from jax.interpreters import ad
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters.sharded_jit import PartitionSpec as P
from jax.lib import xla_bridge as xb
import jax._src.lib
from jax._src.lib import xla_client
from jax import test_util as jtu
from jax import tree_util
from jax import linear_util as lu
@ -199,7 +200,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
assert len(side) == 3
def test_jit_device(self):
device = xb.devices()[-1]
device = jax.devices()[-1]
x = self.jit(lambda x: x, device=device)(3.)
self.assertIsInstance(x, xla.DeviceArray)
self.assertEqual(x.device_buffer.device(), device)
@ -663,7 +664,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
np.testing.assert_allclose(f_pruned(*args), 3)
self.assertEqual(count[0], 1)
@unittest.skipIf(jax.lib._xla_extension_version <= 36,
@unittest.skipIf(jax._src.lib._xla_extension_version <= 36,
"Test requires jaxlib 0.1.71")
def testBuffersAreFreedPromptly(self):
# Regression test for a bug where garbage collection was delayed too long
@ -1757,7 +1758,7 @@ class APITest(jtu.JaxTestCase):
param_shapes = c.program_shape().parameter_shapes()
self.assertEqual(len(param_shapes), 1)
self.assertEqual(param_shapes[0].xla_element_type(),
xb.xla_client.PrimitiveType.TUPLE)
xla_client.PrimitiveType.TUPLE)
def test_xla_computation_duck_typing(self):
def foo(x, y, z):
@ -1774,7 +1775,7 @@ class APITest(jtu.JaxTestCase):
param_shapes = c.program_shape().parameter_shapes()
self.assertEqual(len(param_shapes), 1)
self.assertEqual(param_shapes[0].xla_element_type(),
xb.xla_client.PrimitiveType.TUPLE)
xla_client.PrimitiveType.TUPLE)
def test_staging_out_multi_replica(self):
def f(x):

View File

@ -19,7 +19,7 @@ from absl.testing import absltest, parameterized
import jax
from jax.config import config
import jax.dlpack
from jax.lib import xla_bridge, xla_client
from jax._src.lib import xla_bridge, xla_client
import jax.numpy as jnp
from jax import test_util as jtu

View File

@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from functools import partial
import hashlib
import os
import random
import tempfile
import unittest
from unittest import SkipTest
from absl.testing import absltest
from jax.experimental import PartitionSpec as P
from jax.experimental.compilation_cache import compilation_cache as cc
from jax.experimental.maps import xmap
@ -23,12 +29,8 @@ import jax
from jax import jit, lax, pmap
from jax._src.util import prod
import jax.test_util as jtu
import jax._src.lib
import numpy as np
import os
import random
import tempfile
import unittest
from unittest import SkipTest
from jax.config import config
config.parse_flags_with_absl()
@ -40,16 +42,16 @@ class CompilationCacheTest(jtu.JaxTestCase):
super().setUp()
if jtu.device_under_test() != "tpu":
raise SkipTest("serialize executable only works on TPU")
if jax.lib.xla_bridge.get_backend().runtime_type == "tfrt":
if jax._src.lib.xla_bridge.get_backend().runtime_type == "tfrt":
raise SkipTest("the new TFRT runtime does not support serialization")
def tearDown(self):
super().tearDown()
cc._cache = None
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
@unittest.skipIf(jax._src.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
def test_compile_options(self):
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
compile_options_not_filled = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
compile_options_filled = self.filled_compile_options()
filled_hash1 = self.get_hashed_value(cc._hash_compile_options, compile_options_filled)
@ -59,7 +61,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(filled_hash1, not_filled_hash3)
def test_executable_build_options(self):
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
compile_options_not_filled = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
compile_options_filled = self.filled_compile_options()
filled_hash1 = self.get_hashed_value(cc._hash_executable_build_options,
@ -72,7 +74,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(filled_hash1, not_filled_hash3)
def test_debug_options(self):
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
hash1 = self.get_hashed_value(cc._hash_debug_options,
compile_options.executable_build_options.debug_options)
@ -84,11 +86,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
self.assertNotEqual(hash1, hash3)
def test_hash_platform(self):
hash1 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
hash2 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
hash1 = self.get_hashed_value(cc._hash_platform, jax._src.lib.xla_bridge.get_backend())
hash2 = self.get_hashed_value(cc._hash_platform, jax._src.lib.xla_bridge.get_backend())
self.assertEqual(hash1, hash2)
if jax.lib.xla_bridge.get_backend().platform != "cpu":
cpu_backend = jax.lib.xla_bridge.get_backend("cpu")
if jax._src.lib.xla_bridge.get_backend().platform != "cpu":
cpu_backend = jax._src.lib.xla_bridge.get_backend("cpu")
hash3 = self.get_hashed_value(cc._hash_platform, cpu_backend)
self.assertNotEqual(hash1, hash3)
@ -115,27 +117,27 @@ class CompilationCacheTest(jtu.JaxTestCase):
def test_same_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
self.assertEqual(cc.get_cache_key(computation, compile_options, backend),
cc.get_cache_key(computation, compile_options, backend))
def test_different_hash_key(self):
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
compile_options_not_filled = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
compile_options_filled = self.filled_compile_options()
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled, backend),
cc.get_cache_key(computation, compile_options_filled, backend))
def test_different_computations(self):
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend),
cc.get_cache_key(computation2, compile_options, backend))
@ -143,9 +145,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
self.assertEqual(cc.get_executable(computation, compile_options, backend), None)
def test_diff_executables(self):
@ -153,9 +155,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.initialize_cache(tmpdir)
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
executable1 = backend.compile(computation1, compile_options)
executable2 = backend.compile(computation2, compile_options)
cc.put_executable(computation1, compile_options, executable1, backend)
@ -167,15 +169,15 @@ class CompilationCacheTest(jtu.JaxTestCase):
with tempfile.TemporaryDirectory() as tmpdir:
cc.initialize_cache(tmpdir)
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = jax._src.lib.xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = jax.lib.xla_bridge.get_backend()
backend = jax._src.lib.xla_bridge.get_backend()
executable = backend.compile(computation, compile_options)
cc.put_executable(computation, compile_options, executable, backend)
deserialized_executable = cc.get_executable(computation, compile_options, backend)
inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32))
expected = jax.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
actual = jax.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
self.assertEqual(expected, actual)
def test_pmap(self):
@ -257,15 +259,15 @@ class CompilationCacheTest(jtu.JaxTestCase):
return debug_options_obj
def filled_compile_options(self):
compile_options = jax.lib.xla_client.CompileOptions()
compile_options = jax._src.lib.xla_client.CompileOptions()
compile_options.num_replicas = 1
compile_options.num_partitions = 1
shape = jax.lib.xla_client.Shape.array_shape(np.dtype(np.float32), [2])
shape = jax._src.lib.xla_client.Shape.array_shape(np.dtype(np.float32), [2])
shape_array = [shape, shape]
compile_options.argument_layouts = shape_array
compile_options.executable_build_options.result_layout = shape
device_assignment = jax.lib.xla_client.DeviceAssignment.create(np.ndarray(shape=(2,2)))
device_assignment = jax._src.lib.xla_client.DeviceAssignment.create(np.ndarray(shape=(2,2)))
compile_options.device_assignment = device_assignment
compile_options.executable_build_options.device_assignment = device_assignment
return compile_options

View File

@ -20,7 +20,7 @@ from jax import test_util as jtu
import jax.numpy as jnp
from jax import core, jit, lax, make_jaxpr
from jax.interpreters import xla
from jax.lib import xla_bridge, xla_client
from jax._src.lib import xla_bridge, xla_client
xops = xla_client.ops
xb = xla_bridge

View File

@ -24,6 +24,7 @@ from jax._src import api
from jax import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit
import jax._src.lib
from jax.config import config
config.parse_flags_with_absl()
@ -97,7 +98,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
def testPmap(self):
pmap_funcs = [api._python_pmap]
if jax.lib._xla_extension_version >= 36:
if jax._src.lib._xla_extension_version >= 36:
pmap_funcs.append(api._cpp_pmap)
for pmap in pmap_funcs:

View File

@ -16,6 +16,7 @@ import unittest
from absl.testing import absltest
import jax
import jax._src.lib.xla_bridge
from jax.config import config
import jax.test_util as jtu
@ -28,7 +29,7 @@ class HeapProfilerTest(unittest.TestCase):
# not check functional correctness.
def testBasics(self):
client = jax.lib.xla_bridge.get_backend()
client = jax._src.lib.xla_bridge.get_backend()
_ = client.heap_profile()
a = jax.device_put(1)

View File

@ -37,7 +37,7 @@ from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax import tree_util
from jax.lib import xla_bridge
from jax._src.lib import xla_bridge
import numpy as np
@ -162,7 +162,7 @@ def helper_set_hlo_dump():
def helper_print_optimized_hlo(fun, *args):
backend = jax.lib.xla_bridge.get_backend()
backend = xla_bridge.get_backend()
c = jax.xla_computation(fun)(*args)
print(re.sub(r", metadata.*", "",
backend.compile(c).hlo_modules()[0].to_string()))
@ -177,13 +177,13 @@ def helper_log_ir(name,
jax_comp = jax.xla_computation(f_jax)(*args)
print(f"HLO[{name}]: {jax_comp.as_hlo_text()}")
backend = jax.lib.xla_bridge.get_backend()
backend = xla_bridge.get_backend()
if num_partitions is not None:
num_replicas = 1
device_assignment = np.arange(num_partitions * num_replicas)
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
use_spmd_partitioning = num_partitions > 1
compile_options = jax.lib.xla_bridge.get_compile_options(
compile_options = xla_bridge.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,

View File

@ -20,7 +20,7 @@ import jax
from jax import lax, numpy as jnp
from jax import config
from jax.experimental import host_callback as hcb
from jax.lib import xla_client
from jax._src.lib import xla_client
import jax.test_util as jtu
import numpy as np

View File

@ -19,7 +19,7 @@ from absl.testing import parameterized
import jax
from jax._src import api
from jax import dtypes
from jax import lib as jaxlib
from jax._src import lib as jaxlib
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config

View File

@ -13,7 +13,7 @@
# limitations under the License.
from absl.testing import absltest
from jax.lib import xla_client
from jax._src.lib import xla_client
import jax.numpy as jnp
from jax.tools.jax_to_hlo import jax_to_hlo
from jax import test_util as jtu

View File

@ -36,7 +36,6 @@ from jax import test_util as jtu
from jax import tree_util
from jax._src.util import unzip2
from jax.experimental import maps
from jax.lib import xla_bridge
from jax.interpreters import xla
import jax.numpy as jnp # scan tests use numpy
import jax.scipy as jsp
@ -1744,7 +1743,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
def testIssue804(self):
num_devices = xla_bridge.device_count()
num_devices = jax.device_count()
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash

View File

@ -27,7 +27,7 @@ import jax
from jax import dtypes
from jax import lax
from jax import test_util as jtu
from jax.lib import xla_client
from jax._src.lib import xla_client
from jax._src.util import safe_map, safe_zip
from lax_test import LAX_OPS

View File

@ -25,7 +25,6 @@ from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.lib
from jax import jit, grad, jvp, vmap
from jax import lax
from jax import numpy as jnp

View File

@ -22,7 +22,7 @@ import jax
import jax.numpy as jnp
from jax import lax
from jax import test_util as jtu
from jax.lib import xla_bridge
from jax._src.lib import xla_bridge
from jax.interpreters import xla
from jax.config import config

View File

@ -27,6 +27,7 @@ import jax
from jax import numpy as jnp
from jax.config import config
from jax import test_util as jtu
import jax._src.lib
config.parse_flags_with_absl()
@ -34,7 +35,7 @@ config.parse_flags_with_absl()
class CloudpickleTest(jtu.JaxTestCase):
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
@unittest.skipIf(jax.lib._xla_extension_version < 31,
@unittest.skipIf(jax._src.lib._xla_extension_version < 31,
"Requires jaxlib 0.1.71")
def testPickleOfJittedFunctions(self):
@ -55,7 +56,7 @@ class CloudpickleTest(jtu.JaxTestCase):
self.assertEqual(expected, actual)
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
@unittest.skipIf(jax.lib._xla_extension_version < 39,
@unittest.skipIf(jax._src.lib._xla_extension_version < 39,
"Requires jaxlib 0.1.72")
def testPickleOfPmappedFunctions(self):

View File

@ -33,7 +33,7 @@ from jax.experimental.maps import xmap, mesh
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.lib import xla_client
from jax._src.lib import xla_client
from jax._src.util import prod, curry
from jax.config import config

View File

@ -39,7 +39,8 @@ from jax import random
from jax.core import ShapedArray
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax.lib import xla_bridge
import jax._src.lib
from jax._src.lib import xla_bridge
from jax._src.util import prod, safe_map
from jax.interpreters import pxla
from jax.interpreters import xla
@ -113,7 +114,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return src_api._python_pmap
def _getMeshShape(self, device_mesh_shape):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
if any(size == -1 for size in device_mesh_shape):
try:
return np.arange(device_count).reshape(device_mesh_shape).shape
@ -130,7 +131,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testBasic(self):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
@ -140,7 +141,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testMean(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.broadcast_to(np.mean(x, 0), x.shape)
@ -150,16 +151,16 @@ class PythonPmapTest(jtu.JaxTestCase):
def testGather(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * xla_bridge.device_count())
expected = np.array([x] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherTiled(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i', tiled=True), axis_name='i')
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shape = (device_count, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * device_count).reshape(device_count, -1)
@ -169,7 +170,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testReduceScatter(self):
f = self.pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i')
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shape = (device_count, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
@ -180,7 +181,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testReduceScatterTiled(self):
f = self.pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i')
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shape = (device_count, 4 * device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
@ -191,7 +192,7 @@ class PythonPmapTest(jtu.JaxTestCase):
expected[i * scatter_len:(i + 1) * scatter_len])
def testReduceScatterReplicaGroupsTiled(self):
replicas = xla_bridge.device_count()
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = [[i for i in range(jax.device_count()) if i % 2 == 0],
@ -227,7 +228,7 @@ class PythonPmapTest(jtu.JaxTestCase):
np_transpose = tree_f(np.transpose)
np_rotate = tree_f(lambda x: np.concatenate([x[-1:], x[:-1]]))
n = xla_bridge.device_count()
n = jax.device_count()
x = {'a': np.arange(1 * n * n, 2 * n * n).reshape([n, n]),
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
@ -260,7 +261,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testComplexPsum(self):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (xla_bridge.device_count(), 4 * 2)
shape = (jax.device_count(), 4 * 2)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
expected = x - np.sum(x, 0)
@ -274,7 +275,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testAllToAll(self, split_axis, concat_axis):
pmap_in_axis = 0
shape = (xla_bridge.device_count(),) * 3
shape = (jax.device_count(),) * 3
x = np.arange(np.prod(shape)).reshape(shape)
@partial(self.pmap, axis_name='i')
@ -293,7 +294,7 @@ class PythonPmapTest(jtu.JaxTestCase):
for split_axis, concat_axis in it.product(range(2), range(2)))
@ignore_slow_all_to_all_warning()
def testAllToAllSplitAxis(self, split_axis, concat_axis):
if xla_bridge.device_count() < 4:
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
pmap_in_axis = 0
shape = (4, 4, 4)
@ -322,7 +323,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def sum_and_broadcast(x, axis):
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
shape = (xla_bridge.device_count(), 1, 4)
shape = (jax.device_count(), 1, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
@ -330,7 +331,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testMismatchedAxisSizes(self):
n = xla_bridge.device_count()
n = jax.device_count()
f = self.pmap(lambda x, y: x + y)
self.assertRaisesRegex(
ValueError,
@ -359,7 +360,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x, y: x, in_axes=(None, 0))
g = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
mesh_shape = (xla_bridge.device_count(),)
mesh_shape = (jax.device_count(),)
shape = mesh_shape + (4,)
x = np.array(3., dtype=np.float32)
y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
@ -383,7 +384,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testReplicate(self):
base = np.array([3.,4.], dtype=np.float32)
num_devices = xla_bridge.device_count()
num_devices = jax.device_count()
replicated = pxla.replicate(base, num_devices, num_devices, in_axis=None)
self.assertAllClose(base, replicated)
self.assertEmpty([a for a in replicated.sharding_spec.mesh_mapping
@ -415,7 +416,7 @@ class PythonPmapTest(jtu.JaxTestCase):
_, jvp = linearize(f, x)
return jvp(jnp.ones_like(x))
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.cos(x)
@ -429,7 +430,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def f(x):
return jnp.sin(x)
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
@ -456,7 +457,7 @@ class PythonPmapTest(jtu.JaxTestCase):
fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1])
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
@ -472,7 +473,7 @@ class PythonPmapTest(jtu.JaxTestCase):
tot = jnp.sum(5. * jnp.cos(x) * jnp.sin(y))
return tot * jnp.ones_like(x) # broadcast to map like pjit does
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = 4 + x
ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
@ -517,7 +518,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
# test that we can pass in and out ShardedDeviceArrays
@ -560,7 +561,7 @@ class PythonPmapTest(jtu.JaxTestCase):
[(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)]
])
def testShardedDeviceArrayReshape(self, in_shape, out_shape):
if xla_bridge.device_count() < max(in_shape[:1] + out_shape[:1]):
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
raise SkipTest("not enough devices")
x = np.arange(prod(in_shape)).reshape(in_shape)
@ -575,7 +576,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def sum_and_broadcast(x, axis):
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
device_count = xla_bridge.device_count()
device_count = jax.device_count()
num_pairs, ragged = divmod(device_count, 2)
if num_pairs > 1 and not ragged:
shape = (num_pairs, 2, 4)
@ -588,7 +589,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsumConstantReplicaGroups(self):
replicas = xla_bridge.device_count()
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas).reshape(
@ -606,7 +607,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def testPsumUnevenReplicaGroups(self):
replicas = xla_bridge.device_count()
replicas = jax.device_count()
if replicas <= 2:
raise SkipTest("Test expected devices greater than 2.")
axis_index_groups = [[0,1], np.arange(2,replicas)]
@ -627,7 +628,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsumReplicaGroups(self):
replicas = xla_bridge.device_count()
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas).reshape(
@ -649,7 +650,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherReplicaGroups(self):
replicas = xla_bridge.device_count()
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest("Test expected an even number of devices greater than 1.")
@ -674,7 +675,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherReplicaGroupsInterleaved(self):
replicas = xla_bridge.device_count()
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest("Test expected an even number of devices greater than 1.")
@ -707,7 +708,7 @@ class PythonPmapTest(jtu.JaxTestCase):
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
def testNestedPmapReplicaGroups(self):
replicas = xla_bridge.device_count()
replicas = jax.device_count()
if replicas % 4 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas // 2).reshape(
@ -762,7 +763,7 @@ class PythonPmapTest(jtu.JaxTestCase):
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter
def testCollectivePermute(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
rotation = [(i, (i + 1) % device_count) for i in range(device_count)]
f = lambda x: lax.ppermute(x, perm=rotation, axis_name='i')
f = self.pmap(f, 'i')
@ -774,7 +775,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@jtu.skip_on_devices("cpu")
def testCollectivePermuteGrad(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shift_right = [(i, (i + 1)) for i in range(device_count - 1)]
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
y = np.pi + np.arange(device_count, dtype=np.float32)
@ -786,7 +787,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testCollectivePermuteCyclicGrad(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shift_right = [(i, (i + 1) % device_count) for i in range(device_count)]
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
y = np.pi + np.arange(device_count, dtype=np.float32)
@ -801,7 +802,7 @@ class PythonPmapTest(jtu.JaxTestCase):
jtu.check_grads(g, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2)
def testCollectivePermuteCyclicWithPShuffle(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
values = np.arange(device_count)
shift_right = [(i - 1) % device_count for i in range(device_count)]
f = lambda x: lax.pshuffle(x, perm=shift_right, axis_name='i')
@ -810,7 +811,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testPShuffleWithBadPerm(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
bad_perm = list(range(device_count))
bad_perm[0] = 1
f = lambda x: lax.pshuffle(x, perm=bad_perm, axis_name='i')
@ -821,7 +822,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testPpermuteWithZipObject(self):
# https://github.com/google/jax/issues/1703
num_devices = xla_bridge.device_count()
num_devices = jax.device_count()
perm = [num_devices - 1] + list(range(num_devices - 1))
f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i")
result = f(jnp.arange(num_devices, dtype=jnp.float32))
@ -833,7 +834,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30
# Halo exchange should be useful in spatially-sharded convolutions and in
# other simulations.
device_count = xla_bridge.device_count()
device_count = jax.device_count()
def send_right(x, axis_name):
left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]
@ -900,7 +901,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testReduceMax(self):
f = self.pmap(lambda x: x - lax.pmax(x, 'i'), axis_name='i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.max(x, 0)
@ -910,7 +911,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testReduceMin(self):
f = self.pmap(lambda x: x - lax.pmin(x, 'i'), axis_name='i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.min(x, 0)
@ -918,7 +919,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testDeviceCountError(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f = self.pmap(lambda x: x)
x = jnp.arange(device_count + 1)
@ -933,7 +934,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
def testPmapConstant(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f = self.pmap(lambda x: 3)
x = jnp.arange(device_count)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
@ -949,10 +950,10 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testPmapConstantDevices(self):
if xla_bridge.device_count() == 1:
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
devices = xla_bridge.devices()[:-1]
devices = jax.devices()[:-1]
shuffle(devices)
f = self.pmap(lambda x: 3, devices=devices)
x = jnp.arange(len(devices))
@ -966,7 +967,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual([b.device() for b in ans.device_buffers], devices)
def testPmapConstantError(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f = self.pmap(lambda x: 3)
x = jnp.arange(device_count + 1)
self.assertRaisesRegex(
@ -976,18 +977,18 @@ class PythonPmapTest(jtu.JaxTestCase):
lambda: f(x))
# TODO(mattjj): test error message with explicit devices
# f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
# f = pmap(lambda x: 3, devices=[jax.devices()[0]])
# x = jnp.arange(2)
# self.assertRaisesRegex(
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
# r"local devices are available.", lambda: f(x))
def testNestedPmapConstant(self):
if xla_bridge.device_count() == 1:
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, xla_bridge.device_count() // 2, 3)
shape = (2, jax.device_count() // 2, 3)
x = jnp.arange(prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = f(x)
@ -1009,10 +1010,10 @@ class PythonPmapTest(jtu.JaxTestCase):
def testNestedPmapConstantDevices(self):
raise SkipTest("Nested pmaps with devices not yet implemented")
if xla_bridge.device_count() < 6:
if jax.device_count() < 6:
raise SkipTest("this test requires >= 6 devices")
devices = xla_bridge.devices()[:-2]
devices = jax.devices()[:-2]
shuffle(devices)
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
shape = (2, len(devices) // 2, 3)
@ -1030,7 +1031,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testNestedPmapConstantError(self):
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, xla_bridge.device_count() // 2 + 1, 3)
shape = (2, jax.device_count() // 2 + 1, 3)
x = jnp.arange(prod(shape)).reshape(shape)
self.assertRaisesRegex(
ValueError,
@ -1039,9 +1040,9 @@ class PythonPmapTest(jtu.JaxTestCase):
lambda: f(x))
# TODO(mattjj): check error message with explicit devices
# if xla_bridge.device_count() > 1:
# f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
# shape = (2, xla_bridge.device_count() // 2, 3)
# if jax.device_count() > 1:
# f = pmap(pmap(lambda x: 3), devices=jax.devices()[:-1])
# shape = (2, jax.device_count() // 2, 3)
# x = jnp.arange(prod(shape)).reshape(shape)
# self.assertRaisesRegex(
# ValueError,
@ -1050,7 +1051,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# lambda: f(x))
def testCollectiveConstant(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f = self.pmap(lambda x: lax.psum(1, 'i'), 'i')
x = jnp.arange(device_count)
ans = f(x)
@ -1058,7 +1059,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testCollectiveConstantNested(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
@partial(self.pmap, axis_name='i')
def f(x):
@ -1083,7 +1084,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertEqual(c.ravel()[0], device_count * 1)
def testAxisIndex(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f = self.pmap(lambda x: x + lax.axis_index('i'), 'i')
x = jnp.ones(device_count)
ans = f(x)
@ -1091,7 +1092,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testAxisIndexNestedPmap(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
if device_count < 4:
raise SkipTest("test requires at least four devices")
f = lambda axis: self.pmap(self.pmap(lambda x: x + lax.axis_index(axis), 'j'), 'i')
@ -1101,7 +1102,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(f('i')(x), expected_j.T, check_dtypes=False)
def testAxisIndexNd(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
if device_count < 4:
raise SkipTest("test requires at least four devices")
f = lambda axes: self.pmap(self.pmap(lambda x: x + lax.axis_index(axes), 'j'), 'i')
@ -1116,13 +1117,13 @@ class PythonPmapTest(jtu.JaxTestCase):
def body(carry, i):
return carry + i + lax.axis_index('i'), None
return lax.scan(body, 0, x)[0]
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shape = (device_count, 10)
self.assertAllClose(f(jnp.ones(shape, dtype=int)),
(np.arange(device_count) + 1) * 10)
def testVmapOfPmap(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f0 = lambda x: x
f1 = self.pmap(f0, axis_name='i')
ax = np.random.randn(2, device_count, 50, 60)
@ -1130,7 +1131,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ax, bx, check_dtypes=False)
def testVmapOfPmap2(self):
N_DEVICES = xla_bridge.device_count()
N_DEVICES = jax.device_count()
keys = random.split(random.PRNGKey(1), 13) # [13, 2]
@self.pmap
@ -1150,7 +1151,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testVmapOfPmap3(self):
# https://github.com/google/jax/issues/3399
device_count = xla_bridge.device_count()
device_count = jax.device_count()
if device_count < 2:
raise SkipTest("test requires at least two devices")
@ -1172,7 +1173,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testVmapOfPmapNonLeadingAxis(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f0 = lambda x: x
f1 = self.pmap(f0, axis_name='i')
ax = np.random.randn(device_count, 2, 50, 60)
@ -1180,7 +1181,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ax, bx, check_dtypes=False)
def testVmapOfPmapTuple(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
f0 = lambda *x: x
f1 = self.pmap(f0, axis_name='i')
@ -1201,7 +1202,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testPswapaxes(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shape = (device_count, 3, device_count, 5)
x = np.arange(prod(shape)).reshape(shape)
@ -1211,7 +1212,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testGradOfPswapaxes(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shape = (device_count, 1, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
w = np.arange(device_count, dtype=np.float32)
@ -1235,7 +1236,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# This is essentially like splitting the number of rows in the input in two
# groups of rows, and swaping the two inner axes (axis=1 and axis=2), which
# is exactly what the test case checks.
device_count = xla_bridge.device_count()
device_count = jax.device_count()
if device_count % 2 != 0:
raise SkipTest('test requires an even number of devices')
shape = (device_count, device_count // 2)
@ -1256,7 +1257,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testGradOfAllToAllReplicaGroups(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
if device_count % 2 != 0:
raise SkipTest('test requires an even number of devices')
shape = (device_count, device_count // 2, 1)
@ -1279,13 +1280,13 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(fn(x, w), expected, check_dtypes=False)
def testReshardInput(self):
if xla_bridge.device_count() < 6:
if jax.device_count() < 6:
raise SkipTest("testReshardInput requires 6 devices")
# Manually construct a ShardedDeviceArray with the wrong sharding for the
# subsequent pmap
shard_shape = (3,2)
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
bufs = pxla.device_put(shard, xla_bridge.devices()[:4], replicate=True)
bufs = pxla.device_put(shard, jax.devices()[:4], replicate=True)
aval = ShapedArray((6,4), shard.dtype)
sharding_spec = pxla.ShardingSpec(
sharding=map(pxla.Chunked, ([2], [2])),
@ -1298,7 +1299,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapBatchMatmul(self):
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
ans = soft_pmap(jnp.dot, 'i')(xs, ys)
@ -1307,7 +1308,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapBatchMatmulJit(self):
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
ans = soft_pmap(jit(jnp.dot), 'i')(xs, ys)
@ -1316,7 +1317,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapPsumConstant(self):
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
def f(_):
return lax.psum(1, 'i')
ans = soft_pmap(f, 'i')(jnp.ones(n))
@ -1325,7 +1326,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapPsum(self):
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
def f(x):
return x / lax.psum(x, 'i')
ans = soft_pmap(f, 'i')(jnp.ones(n))
@ -1334,7 +1335,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapAxisIndex(self):
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
def f(x):
return x * lax.axis_index('i')
ans = soft_pmap(f, 'i')(2 * jnp.ones(n))
@ -1343,7 +1344,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapOfJit(self):
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
def f(x):
return 3 * x
ans = soft_pmap(jit(f), 'i')(np.arange(n))
@ -1353,7 +1354,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapNested(self):
raise SkipTest("not implemented") # TODO(mattjj): re-implement
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
@partial(soft_pmap, axis_name='i')
@partial(soft_pmap, axis_name='j')
@ -1368,7 +1369,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testGradOfSoftPmap(self):
raise SkipTest("not implemented") # TODO(mattjj): re-implement
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
@partial(soft_pmap, axis_name='i')
def f(x):
@ -1380,7 +1381,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_xmap_warning()
def testSoftPmapDevicePersistence(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
shape = (2 * 2 * device_count, 2, 3)
# check that we can maintain device persistence across calls
@ -1393,7 +1394,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testSoftPmapAllToAll(self):
raise SkipTest("the underlying code here is broken") # TODO(mattjj)
n = 4 * xla_bridge.device_count()
n = 4 * jax.device_count()
def f(x):
return lax.all_to_all(x, 'i', 0, 0)
ans = soft_pmap(f, 'i')(jnp.arange(n ** 2).reshape(n, n))
@ -1401,7 +1402,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
def testShardedDeviceArrayBlockUntilReady(self):
x = np.arange(xla_bridge.device_count())
x = np.arange(jax.device_count())
x = self.pmap(lambda x: x)(x)
x.block_until_ready() # doesn't crash
@ -1409,7 +1410,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testJitPmapComposition(self):
f = lambda x: x - lax.psum(x, 'i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
@ -1435,7 +1436,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@ignore_jit_of_pmap_warning()
def testIssue1065(self):
# from https://github.com/google/jax/issues/1065
device_count = xla_bridge.device_count()
device_count = jax.device_count()
def multi_step_pmap(state, count):
@partial(self.pmap, axis_name='x')
@ -1455,7 +1456,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = f(x)
@ -1471,7 +1472,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# replica.
raise SkipTest("need eager multi-replica support")
# test came from https://github.com/google/jax/issues/1369
nrep = xla_bridge.device_count()
nrep = jax.device_count()
def pmvm(a, b):
a = a.reshape((nrep, -1, a.shape[1]))
@ -1497,7 +1498,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return sum(args_list)
vals = list(range(500))
ndevices = xla_bridge.device_count()
ndevices = jax.device_count()
self.assertAllClose(f(jnp.array([vals] * ndevices)),
jnp.array([sum(vals)] * ndevices))
@ -1577,7 +1578,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testPsumOnBooleanDtype(self):
# https://github.com/google/jax/issues/3123
n = xla_bridge.device_count()
n = jax.device_count()
if n > 1:
x = jnp.array([True, False])
@ -1608,7 +1609,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertTrue(w() is None)
def testJitOfPmapWarningMessage(self):
device_count = xla_bridge.device_count()
device_count = jax.device_count()
if device_count == 1:
raise SkipTest("test requires at least two devices")
@ -1651,7 +1652,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def test_issue_1062(self):
# code from https://github.com/google/jax/issues/1062 @shoyer
# this tests, among other things, whether ShardedDeviceTuple constants work
device_count = xla_bridge.device_count()
device_count = jax.device_count()
@jit
def multi_step(state, count):
@ -1708,7 +1709,7 @@ class PythonPmapTest(jtu.JaxTestCase):
for axis in range(len(shape))
)
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
if xla_bridge.device_count() < shape[axis]:
if jax.device_count() < shape[axis]:
raise SkipTest(f"test requires at least {shape[axis]} devices")
if (jtu.device_under_test() == 'cpu' and
np.issubdtype(dtype, np.floating) and
@ -1733,7 +1734,7 @@ class PythonPmapTest(jtu.JaxTestCase):
@partial(self.pmap, axis_name='i')
def func(_):
return jax.lax.psum(dtype(0), axis_name='i')
unused_arg = jnp.arange(xla_bridge.device_count())
unused_arg = jnp.arange(jax.device_count())
out_dtype = func(unused_arg).dtype
self.assertEqual(out_dtype, dtype)
@ -1746,7 +1747,7 @@ class PythonPmapTest(jtu.JaxTestCase):
y = lax.cond(True, jax.pmap(identity), jax.pmap(identity), x)
return y
cond_of_pmap(jnp.zeros((xla_bridge.device_count(), 2)))
cond_of_pmap(jnp.zeros((jax.device_count(), 2)))
def test_static_argnum_on_method(self):
@ -1763,7 +1764,7 @@ class CppPmapTest(PythonPmapTest):
@property
def pmap(self):
if jax.lib._xla_extension_version >= 38:
if jax._src.lib._xla_extension_version >= 38:
return src_api._cpp_pmap
else:
return src_api._python_pmap
@ -1787,7 +1788,7 @@ class VmapOfPmapTest(jtu.JaxTestCase):
)))
def testVmapOfPmap(self, shapes, vmap_in_axes, pmap_in_axes, vmap_out_axes, pmap_out_axes):
vmapped_size = 3
pmapped_size = xla_bridge.device_count()
pmapped_size = jax.device_count()
rng = jtu.rand_default(self.rng())
@ -1824,7 +1825,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
return x + collective(x.dot(y), ('i', 'j'))
return f
if xla_bridge.device_count() < 4:
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
x = jnp.ones((2, 2, 64, 64))
y = f(jax.pmap, jax.pmap)(x, x)
@ -1842,7 +1843,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
return x + jax.lax.ppermute(x.dot(y), 'i', perm)
return f
if xla_bridge.device_count() < 4:
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
x = jnp.ones((2, 2, 64, 64))
self.assertAllClose(f(jax.pmap)(x, x), f(jax.vmap)(x, x))
@ -1938,7 +1939,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
@ignore_slow_all_to_all_warning()
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
raise SkipTest("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
if xla_bridge.device_count() < 4:
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
def f(x):
@ -1957,7 +1958,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
return jax.lax.all_gather(x, 'i')
return f
if xla_bridge.device_count() < 4:
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
x = jnp.ones((2, 2, 64, 64))
self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x))
@ -1967,19 +1968,19 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testAllDevices(self):
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i',
devices=xla_bridge.devices())
shape = (xla_bridge.device_count(), 4)
devices=jax.devices())
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
ans = f(x)
self.assertAllClose(ans, expected)
def testOneDevice(self):
if xla_bridge.device_count() == 1:
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
d0 = xla_bridge.devices()[0]
d1 = xla_bridge.devices()[1]
d0 = jax.devices()[0]
d1 = jax.devices()[1]
f = lambda x: jnp.dot(x, x.T)
f0 = pmap(f, devices=[d0])
f1 = pmap(f, devices=[d1])
@ -1992,18 +1993,18 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testNoDevicesError(self):
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
with self.assertRaisesRegex(
ValueError, "'devices' argument to pmap must be non-empty, or None."):
f(x)
def testBadAxisSizeError(self):
if xla_bridge.device_count() == 1:
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
devices=xla_bridge.devices())
devices=jax.devices())
with self.assertRaisesRegex(
ValueError, r"Leading axis size of input to pmapped function must "
r"equal the number of local devices passed to pmap. Got axis_size=1, "
@ -2014,7 +2015,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
ValueError, r"Leading axis size of input to pmapped function must "
r"equal the number of local devices passed to pmap. Got axis_size=\d, "
r"num_local_devices=\d."):
f(jnp.ones(xla_bridge.device_count() + 1))
f(jnp.ones(jax.device_count() + 1))
def testBadAxisSizeErrorNested(self):
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
@ -2028,18 +2029,18 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
f(jnp.ones((1, 4)))
def testNestedPmaps(self):
if xla_bridge.device_count() % 2 != 0:
if jax.device_count() % 2 != 0:
raise SkipTest
# Devices specified in outer pmap are OK
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
@partial(pmap, axis_name='i', devices=jax.devices())
def foo(x):
@partial(pmap, axis_name='j')
def bar(y):
return lax.psum(y, 'j')
return bar(x)
x = jnp.ones((xla_bridge.device_count() // 2, 2))
x = jnp.ones((jax.device_count() // 2, 2))
ans = foo(x)
expected = x * 2
self.assertAllClose(ans, expected)
@ -2048,7 +2049,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
# Devices specified in inner pmap not OK
@partial(pmap, axis_name='i')
def foo(x):
@partial(pmap, axis_name='j', devices=xla_bridge.devices())
@partial(pmap, axis_name='j', devices=jax.devices())
def bar(y):
return lax.psum(y, 'j')
return bar(x)
@ -2056,17 +2057,17 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
ValueError,
"Nested pmap with explicit devices argument."):
foo(jnp.ones((xla_bridge.device_count(), 1)))
foo(jnp.ones((jax.device_count(), 1)))
def testJitInPmap(self):
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
@partial(pmap, axis_name='i', devices=jax.devices())
def foo(x):
@jit
def bar(y):
return y + 1
return lax.psum(bar(x), 'i')
ndevices = xla_bridge.device_count()
ndevices = jax.device_count()
ans = foo(jnp.ones((ndevices, 1)))
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices * 2
self.assertAllClose(ans, expected)
@ -2075,22 +2076,22 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testPmapInJit(self):
@jit
def foo(x):
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
@partial(pmap, axis_name='i', devices=jax.devices())
def bar(y):
return lax.psum(y, 'i')
return bar(x)
ndevices = xla_bridge.device_count()
ndevices = jax.device_count()
ans = foo(jnp.ones((ndevices, 1)))
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices
self.assertAllClose(ans, expected)
def testGradBasic(self):
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
@partial(pmap, axis_name='i', devices=jax.devices())
def f(x):
return jnp.sin(x)
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
@ -2101,7 +2102,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
@partial(pmap, axis_name='i', static_broadcasted_argnums=1)
def f(x, y):
return jnp.sin(x + y())
shape = (xla_bridge.device_count(), 4)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = lambda: 3.
@ -2113,9 +2114,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
@partial(pmap, in_axes=(1, 2))
def f(x, y):
return jnp.sin(x + y)
xshape = (2, xla_bridge.device_count(), 4)
xshape = (2, jax.device_count(), 4)
x = np.arange(prod(xshape)).reshape(xshape)
yshape = (2, 4, xla_bridge.device_count())
yshape = (2, 4, jax.device_count())
y = np.arange(prod(yshape)).reshape(yshape)
self.assertAllClose(f(x, y),
@ -2126,9 +2127,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
return jnp.sin(x + y + z)
fp = pmap(f, in_axes=(1, 2, None))
fv = vmap(f, in_axes=(1, 2, None))
xshape = (5, xla_bridge.device_count(), 7)
xshape = (5, jax.device_count(), 7)
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
yshape = (5, 7, xla_bridge.device_count())
yshape = (5, 7, jax.device_count())
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
zshape = (5, 7)
z = np.arange(prod(zshape), dtype=np.float32).reshape(zshape)
@ -2145,7 +2146,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
@partial(pmap, in_axes=(1, None), out_axes=(2, None))
def f(x, y):
return jnp.sin(x + y), y * 2
xshape = (2, xla_bridge.device_count(), 4)
xshape = (2, jax.device_count(), 4)
x = np.arange(prod(xshape)).reshape(xshape)
yshape = (2, 4)
y = np.arange(prod(yshape)).reshape(yshape)
@ -2158,7 +2159,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
@partial(pmap, out_axes={'a': 0})
def f(x):
return {'a': x}
device_count = xla_bridge.device_count()
device_count = jax.device_count()
x = jnp.arange(device_count)
tree_util.tree_multimap(self.assertAllClose, f(x), {'a': x})
@ -2172,7 +2173,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def f(x, y, z):
return jnp.sin(x + y) * z
pmapped_size = xla_bridge.device_count()
pmapped_size = jax.device_count()
mapped_shapes = [(3, 4), (3, 1), (1, 4)]
arg_shapes = map(partial(add_bdim, pmapped_size), in_axes, mapped_shapes)
rng = jtu.rand_default(self.rng())
@ -2192,7 +2193,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
xshape = (5, 7)
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
yshape = (5, xla_bridge.device_count(), 7)
yshape = (5, jax.device_count(), 7)
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
self.assertAllClose(jax.grad(mk_case(pmap))(x, y),
jax.grad(mk_case(vmap))(x, y))

View File

@ -25,8 +25,8 @@ from jax import dtypes
from jax.experimental import sparse
from jax.experimental.sparse.ops import _bcoo_nse
from jax import lax
from jax.lib import cusparse
from jax.lib import xla_bridge
from jax._src.lib import cusparse
from jax._src.lib import xla_bridge
from jax import jit
from jax import test_util as jtu
from jax import xla

View File

@ -25,7 +25,7 @@ from jax import tree_util
from jax._src.tree_util import _process_pytree
from jax import flatten_util
import jax.numpy as jnp
import jax._src.lib
def _dummy_func(*args, **kwargs):
return
@ -206,7 +206,7 @@ class TreeTest(jtu.JaxTestCase):
def testTreedefTupleFromChildren(self):
# https://github.com/google/jax/issues/7377
# TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
if jax.lib._xla_extension_version < 29:
if jax._src.lib._xla_extension_version < 29:
self.skipTest("fixed in future jaxlib")
tree = ((1, 2, (3, 4)), (5,))
leaves, treedef1 = tree_util.tree_flatten(tree)
@ -324,7 +324,7 @@ class TreeTest(jtu.JaxTestCase):
self.assertRegex(str(treedef), correct_string)
def testTreeDefWithEmptyDictStringRepresentation(self):
if jax.lib._xla_extension_version < 35:
if jax._src.lib._xla_extension_version < 35:
self.skipTest("fixed in future jaxlib")
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")

View File

@ -17,8 +17,8 @@ import warnings
from absl.testing import absltest
from jax import test_util as jtu
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
mock = absltest.mock
@ -79,8 +79,8 @@ class XlaBridgeTest(absltest.TestCase):
msg = str(w[-1].message)
self.assertIn("Did you run your code on all TPU hosts?", msg)
with mock.patch(
"jax.lib.xla_client.make_tpu_client", side_effect=_mock_tpu_client):
with mock.patch("jax._src.lib.xla_client.make_tpu_client",
side_effect=_mock_tpu_client):
xb.tpu_client_timer_callback(0.01)

View File

@ -41,7 +41,7 @@ from jax.core import NamedShape, JaxprTypeError
from jax.experimental import maps
from jax.experimental.maps import Mesh, mesh, xmap, serial_loop, SerialLoop
from jax.errors import JAXTypeError
from jax.lib import xla_bridge
from jax._src.lib import xla_bridge
from jax._src.util import curry, unzip2, split_list, prod
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lax.parallel import pgather