mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice. PiperOrigin-RevId: 398471863
This commit is contained in:
parent
94cd1ea0a2
commit
2c2f4033cc
@ -505,7 +505,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"from jax.lib import xla_bridge\n",
|
||||
"from jax._src.lib import xla_bridge\n",
|
||||
"device_count = xla_bridge.device_count()\n",
|
||||
"\n",
|
||||
"def send_right(x, axis_name):\n",
|
||||
|
@ -2087,8 +2087,8 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.lib import xla_bridge as xb\n",
|
||||
"from jax.lib import xla_client as xc\n",
|
||||
"from jax._src.lib import xla_bridge as xb\n",
|
||||
"from jax._src.lib import xla_client as xc\n",
|
||||
"xe = xc._xla\n",
|
||||
"xops = xc._xla.ops\n",
|
||||
"\n",
|
||||
|
@ -1544,8 +1544,8 @@ class IDHashable:
|
||||
Next, we'll define the evaluation rule for `xla_call`:
|
||||
|
||||
```{code-cell}
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
xe = xc._xla
|
||||
xops = xc._xla.ops
|
||||
|
||||
|
@ -1481,8 +1481,8 @@ class IDHashable:
|
||||
# Next, we'll define the evaluation rule for `xla_call`:
|
||||
|
||||
# +
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
xe = xc._xla
|
||||
xops = xc._xla.ops
|
||||
|
||||
|
@ -149,7 +149,7 @@
|
||||
" def pp(v):\n",
|
||||
" \"\"\"Print certain values more succinctly\"\"\"\n",
|
||||
" vtype = str(type(v))\n",
|
||||
" if \"jax.lib.xla_bridge._JaxComputationBuilder\" in vtype:\n",
|
||||
" if \"jax._src.lib.xla_bridge._JaxComputationBuilder\" in vtype:\n",
|
||||
" return \"<JaxComputationBuilder>\"\n",
|
||||
" elif \"jaxlib.xla_extension.XlaOp\" in vtype:\n",
|
||||
" return \"<XlaOp at 0x{:x}>\".format(id(v))\n",
|
||||
@ -614,7 +614,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.lib import xla_client\n",
|
||||
"from jax._src.lib import xla_client\n",
|
||||
"@trace(\"multiply_add_xla_translation\")\n",
|
||||
"def multiply_add_xla_translation(c, xc, yc, zc):\n",
|
||||
" \"\"\"The compilation to XLA of the primitive.\n",
|
||||
|
@ -118,7 +118,7 @@ def trace(name):
|
||||
def pp(v):
|
||||
"""Print certain values more succinctly"""
|
||||
vtype = str(type(v))
|
||||
if "jax.lib.xla_bridge._JaxComputationBuilder" in vtype:
|
||||
if "jax._src.lib.xla_bridge._JaxComputationBuilder" in vtype:
|
||||
return "<JaxComputationBuilder>"
|
||||
elif "jaxlib.xla_extension.XlaOp" in vtype:
|
||||
return "<XlaOp at 0x{:x}>".format(id(v))
|
||||
@ -354,7 +354,7 @@ for most of them. However, XLA includes a `CustomCall` operation that can be use
|
||||
```{code-cell} ipython3
|
||||
:id: FYQWSSjKJaWP
|
||||
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
@trace("multiply_add_xla_translation")
|
||||
def multiply_add_xla_translation(c, xc, yc, zc):
|
||||
"""The compilation to XLA of the primitive.
|
||||
|
@ -74,7 +74,7 @@
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"# We only need to import JAX's xla_client, not all of JAX.\n",
|
||||
"from jax.lib import xla_client as xc\n",
|
||||
"from jax._src.lib import xla_client as xc\n",
|
||||
"xops = xc.ops\n",
|
||||
"\n",
|
||||
"# Plotting\n",
|
||||
|
@ -66,7 +66,7 @@ https://github.com/google/jax/blob/main/jax/interpreters/xla.py
|
||||
import numpy as np
|
||||
|
||||
# We only need to import JAX's xla_client, not all of JAX.
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src.lib import xla_client as xc
|
||||
xops = xc.ops
|
||||
|
||||
# Plotting
|
||||
|
@ -27,9 +27,9 @@ import time
|
||||
import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
import jax
|
||||
from jax import jit, grad, pmap
|
||||
from jax.scipy.special import logsumexp
|
||||
from jax.lib import xla_bridge
|
||||
from jax.tree_util import tree_map
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
@ -77,7 +77,7 @@ if __name__ == "__main__":
|
||||
|
||||
# For this manual SPMD example, we get the number of devices (e.g. GPUs or
|
||||
# TPU cores) that we're using, and use it to reshape data minibatches.
|
||||
num_devices = xla_bridge.device_count()
|
||||
num_devices = jax.device_count()
|
||||
def data_stream():
|
||||
rng = npr.RandomState(0)
|
||||
while True:
|
||||
|
@ -130,6 +130,8 @@ from . import random as random
|
||||
from . import tree_util as tree_util
|
||||
from . import util as util
|
||||
|
||||
import jax.lib # TODO(phawkins): remove this export.
|
||||
|
||||
def _init():
|
||||
from . import numpy as numpy # side-effecting import sets up operator overloads
|
||||
|
||||
|
@ -40,7 +40,6 @@ from contextlib import contextmanager, ExitStack
|
||||
|
||||
import jax
|
||||
from .. import core
|
||||
from .. import lib
|
||||
from .. import linear_util as lu
|
||||
from . import dtypes
|
||||
from ..core import eval_jaxpr
|
||||
@ -57,13 +56,13 @@ from ..tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
treedef_is_leaf, treedef_children, Partial)
|
||||
from .util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, wrap_name, cache, wraps, HashableFunction)
|
||||
from ..lib import jax_jit
|
||||
from ..lib import version
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from ..lib import pmap_lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import version
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
# Unused imports to be exported
|
||||
from ..lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index, process_count,
|
||||
host_id, host_ids, host_count, default_backend)
|
||||
from ..core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
|
@ -24,8 +24,8 @@ import threading
|
||||
from typing import Any, List, Callable, NamedTuple, Optional
|
||||
import warnings
|
||||
|
||||
from jax import lib
|
||||
from jax.lib import jax_jit
|
||||
from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
"""Read an environment variable and interpret it as a boolean.
|
||||
|
@ -15,8 +15,8 @@
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_bridge
|
||||
|
||||
SUPPORTED_DTYPES = set([jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
|
||||
|
@ -27,7 +27,7 @@ import numpy as np
|
||||
|
||||
from jax._src import util
|
||||
from jax._src.config import flags, config
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
@ -43,8 +43,8 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import masking
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip,
|
||||
split_list, cache, extend_name_stack)
|
||||
|
@ -23,10 +23,10 @@ from jax.interpreters import xla
|
||||
from jax._src.util import prod
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax.lib import xla_client
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.lib import pocketfft
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import pocketfft
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
|
@ -50,9 +50,10 @@ from jax.interpreters import masking
|
||||
from jax._src.util import (cache, safe_zip, prod, safe_map, canonicalize_axis,
|
||||
split_list)
|
||||
from jax.tree_util import tree_map
|
||||
from jax.lib import pytree
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
import jax._src.lib
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
xb = xla_bridge
|
||||
xc = xla_client
|
||||
@ -2631,7 +2632,7 @@ ad.defjvp2(rsqrt_p,
|
||||
|
||||
# TODO(phawkins): remove the fallback translation rule after the minimum jaxlib
|
||||
# is 0.1.70 or newer.
|
||||
if jax.lib._xla_extension_version >= 28:
|
||||
if jax._src.lib._xla_extension_version >= 28:
|
||||
_cbrt_translation_rule = None
|
||||
else:
|
||||
def _cbrt_translation_rule(c, x):
|
||||
|
@ -33,15 +33,15 @@ from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
_input_dtype, _broadcasting_select)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax.lib import lapack
|
||||
from jax._src.lib import lapack
|
||||
|
||||
from jax.lib import cuda_linalg
|
||||
from jax.lib import cusolver
|
||||
from jax.lib import cusparse
|
||||
from jax.lib import rocsolver
|
||||
from jax._src.lib import cuda_linalg
|
||||
from jax._src.lib import cusolver
|
||||
from jax._src.lib import cusparse
|
||||
from jax._src.lib import rocsolver
|
||||
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
|
@ -23,7 +23,6 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from . import lax
|
||||
from jax.core import ShapedArray, AxisName, raise_to_shaped
|
||||
@ -31,10 +30,11 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
|
||||
from jax.lib import xla_client as xc
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, moveaxis
|
||||
|
||||
xops = xc.ops
|
||||
|
||||
|
122
jax/_src/lib/__init__.py
Normal file
122
jax/_src/lib/__init__.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This module is largely a wrapper around `jaxlib` that performs version
|
||||
# checking on import.
|
||||
|
||||
import platform
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
|
||||
'pocketfft', 'pytree', 'tpu_driver_client', 'version', 'xla_client'
|
||||
]
|
||||
|
||||
# First, before attempting to import jaxlib, warn about experimental machine
|
||||
# configurations.
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
|
||||
"Please see https://github.com/google/jax/issues/5501 in the "
|
||||
"event of problems.")
|
||||
|
||||
try:
|
||||
import jaxlib
|
||||
except ModuleNotFoundError as err:
|
||||
raise ModuleNotFoundError(
|
||||
'jax requires jaxlib to be installed. See '
|
||||
'https://github.com/google/jax#installation for installation instructions.'
|
||||
) from err
|
||||
|
||||
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
|
||||
try:
|
||||
from jaxlib import version as jaxlib_version
|
||||
except Exception as err:
|
||||
# jaxlib is too old to have version number.
|
||||
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
|
||||
raise ImportError(msg) from err
|
||||
|
||||
version = tuple(int(x) for x in jaxlib_version.__version__.split('.'))
|
||||
_minimum_jaxlib_version = tuple(int(x) for x in _minimum_jaxlib_version_str.split('.'))
|
||||
|
||||
# Check the jaxlib version before importing anything else from jaxlib.
|
||||
def _check_jaxlib_version():
|
||||
if version < _minimum_jaxlib_version:
|
||||
msg = (f'jaxlib is version {jaxlib_version.__version__}, '
|
||||
f'but this version of jax requires version {_minimum_jaxlib_version_str}.')
|
||||
|
||||
if version == (0, 1, 23):
|
||||
msg += ('\n\nA common cause of this error is that you installed jaxlib '
|
||||
'using pip, but your version of pip is too old to support '
|
||||
'manylinux2010 wheels. Try running:\n\n'
|
||||
'pip install --upgrade pip\n'
|
||||
'pip install --upgrade jax jaxlib\n')
|
||||
raise ValueError(msg)
|
||||
|
||||
_check_jaxlib_version()
|
||||
|
||||
from jaxlib import cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import lapack
|
||||
from jaxlib import pocketfft
|
||||
|
||||
xla_extension = xla_client._xla
|
||||
pytree = xla_client._xla.pytree
|
||||
jax_jit = xla_client._xla.jax_jit
|
||||
pmap_lib = xla_client._xla.pmap_lib
|
||||
|
||||
try:
|
||||
from jaxlib import cusolver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cusolver = None
|
||||
|
||||
try:
|
||||
from jaxlib import cusparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cusparse = None
|
||||
|
||||
try:
|
||||
from jaxlib import rocsolver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
rocsolver = None
|
||||
|
||||
try:
|
||||
from jaxlib import cuda_prng # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cuda_prng = None
|
||||
|
||||
try:
|
||||
from jaxlib import cuda_linalg # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cuda_linalg = None
|
||||
|
||||
# Jaxlib code is split between the Jax and the Tensorflow repositories.
|
||||
# Only for the internal usage of the JAX developers, we expose a version
|
||||
# number that can be used to perform changes without breaking the main
|
||||
# branch on the Jax github.
|
||||
_xla_extension_version = getattr(xla_client, '_version', 0)
|
||||
|
||||
try:
|
||||
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
|
||||
except:
|
||||
tpu_driver_client = None # type: ignore
|
||||
|
||||
|
||||
cuda_path: Optional[str]
|
||||
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
|
||||
if not os.path.isdir(cuda_path):
|
||||
cuda_path = None
|
595
jax/_src/lib/xla_bridge.py
Normal file
595
jax/_src/lib/xla_bridge.py
Normal file
@ -0,0 +1,595 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Interface and utility functions to XLA.
|
||||
|
||||
This module wraps the XLA client(s) and builders to standardize their interfaces
|
||||
and provide some automatic type mapping logic for converting between Numpy and
|
||||
XLA. There are also a handful of related casting utilities.
|
||||
"""
|
||||
|
||||
|
||||
from functools import partial, lru_cache
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
# Disable "WARNING: Logging before flag parsing goes to stderr." message
|
||||
logging._warn_preinit_stderr = 0
|
||||
|
||||
import jax._src.lib
|
||||
from jax._src.config import flags, bool_env
|
||||
from . import tpu_driver_client
|
||||
from . import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
from jax._src import dtypes
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# TODO(phawkins): Remove jax_xla_backend.
|
||||
flags.DEFINE_string(
|
||||
'jax_xla_backend', '',
|
||||
'jax_xla_backend is an alias for jax_platform_name. If both are '
|
||||
'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.')
|
||||
flags.DEFINE_string(
|
||||
'jax_backend_target', '',
|
||||
'Either "local" or "rpc:address" to connect to a remote service target.')
|
||||
flags.DEFINE_string(
|
||||
'jax_platform_name',
|
||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||
'Platform name for XLA. The default is to attempt to use a GPU or TPU if '
|
||||
'available, but fall back to CPU otherwise. To set the platform manually, '
|
||||
'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, '
|
||||
'setting the platform name to "cpu" can silence warnings that appear with '
|
||||
'the default setting.')
|
||||
flags.DEFINE_bool(
|
||||
'jax_disable_most_optimizations',
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
num_partitions: int,
|
||||
device_assignment=None,
|
||||
use_spmd_partitioning: bool = True,
|
||||
) -> xla_client.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
Args:
|
||||
num_replicas: Number of replicas for which to compile.
|
||||
num_partitions: Number of partitions for which to compile.
|
||||
device_assignment: Optional tuple of integers indicating the assignment of
|
||||
logical replicas to physical devices (default inherited from
|
||||
xla_client.CompileOptions). Must be consistent with `num_replicas` and
|
||||
`num_partitions`.
|
||||
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
||||
partitioning in XLA.
|
||||
"""
|
||||
compile_options = xla_client.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
compile_options.num_partitions = num_partitions
|
||||
build_options = compile_options.executable_build_options
|
||||
build_options.use_spmd_partitioning = use_spmd_partitioning
|
||||
if device_assignment is not None:
|
||||
logging.vlog(
|
||||
2,
|
||||
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
||||
num_replicas, num_partitions, device_assignment)
|
||||
device_assignment = np.array(device_assignment)
|
||||
|
||||
# Allow 1D device assignment if num_partitions is 1.
|
||||
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
||||
device_assignment = device_assignment[:, None]
|
||||
|
||||
if num_replicas != device_assignment.shape[0]:
|
||||
msg = 'device_assignment does not match num_replicas: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_replicas))
|
||||
|
||||
if num_partitions != device_assignment.shape[1]:
|
||||
msg = 'device_assignment does not match num_partitions: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_partitions))
|
||||
|
||||
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
|
||||
assert device_assignment.replica_count() == num_replicas
|
||||
assert device_assignment.computation_count() == num_partitions
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
if jax._src.lib.cuda_path is not None:
|
||||
debug_options.xla_gpu_cuda_data_dir = jax._src.lib.cuda_path
|
||||
|
||||
if FLAGS.jax_disable_most_optimizations:
|
||||
|
||||
debug_options.xla_backend_optimization_level = 0
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
return compile_options
|
||||
|
||||
|
||||
# Backends
|
||||
|
||||
def _make_tpu_driver_client():
|
||||
if tpu_driver_client is None:
|
||||
logging.info("Remote TPU is not linked into jax; skipping remote TPU.")
|
||||
return None
|
||||
if FLAGS.jax_backend_target is None:
|
||||
logging.info("No --jax_backend_target was provided; skipping remote TPU.")
|
||||
return None
|
||||
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)
|
||||
|
||||
|
||||
def tpu_client_timer_callback(timer_secs: float):
|
||||
def _log_warning():
|
||||
warnings.warn(
|
||||
f'TPU backend initialization is taking more than {timer_secs} seconds. '
|
||||
'Did you run your code on all TPU hosts? '
|
||||
'See https://jax.readthedocs.io/en/latest/multi_process.html '
|
||||
'for more information.')
|
||||
|
||||
# Will log a warning after `timer_secs`.
|
||||
t = threading.Timer(timer_secs, _log_warning)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
client = xla_client.make_tpu_client()
|
||||
finally:
|
||||
t.cancel()
|
||||
|
||||
return client
|
||||
|
||||
|
||||
# Backends, in increasing order of preference.
|
||||
# We have no particular opinion about how "backends" relate to "devices". For
|
||||
# example, there could be multiple backends that provide the same kind of
|
||||
# device.
|
||||
_backend_factories = {}
|
||||
|
||||
def register_backend_factory(name, factory, *, priority=0):
|
||||
_backend_factories[name] = (factory, priority)
|
||||
|
||||
|
||||
register_backend_factory('interpreter', xla_client.make_interpreter_client,
|
||||
priority=-100)
|
||||
register_backend_factory('cpu',
|
||||
partial(xla_client.make_cpu_client, use_tfrt=True),
|
||||
priority=0)
|
||||
register_backend_factory('tpu_driver', _make_tpu_driver_client,
|
||||
priority=100)
|
||||
register_backend_factory('gpu', xla_client.make_gpu_client,
|
||||
priority=200)
|
||||
register_backend_factory(
|
||||
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
|
||||
|
||||
_default_backend = None
|
||||
_backends_initialized = False
|
||||
_backends : Dict[str, Any] = {}
|
||||
_backends_errors : Dict[str, str] = {}
|
||||
_backend_lock = threading.Lock()
|
||||
|
||||
|
||||
def backends():
|
||||
global _backends_initialized
|
||||
global _backends
|
||||
global _backends_errors
|
||||
global _default_backend
|
||||
|
||||
with _backend_lock:
|
||||
if _backends_initialized:
|
||||
return _backends
|
||||
|
||||
_backends_initialized = True
|
||||
default_priority = -1000
|
||||
for name, (factory, priority) in _backend_factories.items():
|
||||
logging.vlog(1, "Initializing backend '%s'" % name)
|
||||
try:
|
||||
backend = factory()
|
||||
if backend is not None:
|
||||
if backend.device_count() > 0:
|
||||
_backends[name] = backend
|
||||
util.distributed_debug_log(("Initialized backend", backend.platform),
|
||||
("process_index", backend.process_index()),
|
||||
("device_count", backend.device_count()),
|
||||
("local_devices", backend.local_devices()))
|
||||
logging.vlog(1, "Backend '%s' initialized" % name)
|
||||
if priority > default_priority:
|
||||
_default_backend = backend
|
||||
default_priority = priority
|
||||
except Exception as err:
|
||||
if name in ('cpu', 'interpreter'):
|
||||
# We always expect the CPU and interpreter backends to initialize
|
||||
# successfully.
|
||||
raise
|
||||
else:
|
||||
# If the backend isn't built into the binary, or if it has no devices,
|
||||
# we expect a RuntimeError.
|
||||
logging.info("Unable to initialize backend '%s': %s" % (name, err))
|
||||
_backends_errors[name] = str(err)
|
||||
continue
|
||||
if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
|
||||
logging.warning('No GPU/TPU found, falling back to CPU. '
|
||||
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
||||
return _backends
|
||||
|
||||
|
||||
def _get_backend_uncached(platform=None):
|
||||
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
|
||||
# 'backend' values are handled
|
||||
if not isinstance(platform, (type(None), str)):
|
||||
return platform
|
||||
|
||||
bs = backends()
|
||||
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
|
||||
or None)
|
||||
if platform is not None:
|
||||
backend = bs.get(platform, None)
|
||||
if backend is None:
|
||||
if platform in _backends_errors:
|
||||
raise RuntimeError(f"Requested backend {platform}, but it failed "
|
||||
f"to initialize: {_backends_errors[platform]}")
|
||||
raise RuntimeError(f"Unknown backend {platform}")
|
||||
return backend
|
||||
else:
|
||||
return _default_backend
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
|
||||
def get_backend(platform=None):
|
||||
return _get_backend_uncached(platform)
|
||||
|
||||
|
||||
def get_device_backend(device=None):
|
||||
"""Returns the Backend associated with `device`, or the default Backend."""
|
||||
if device is not None:
|
||||
return device.client
|
||||
return get_backend()
|
||||
|
||||
|
||||
def device_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the total number of devices.
|
||||
|
||||
On most platforms, this is the same as :py:func:`jax.local_device_count`.
|
||||
However, on multi-process platforms where different devices are associated
|
||||
with different processes, this will return the total number of devices across
|
||||
all processes.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
Number of devices.
|
||||
|
||||
"""
|
||||
return int(get_backend(backend).device_count())
|
||||
|
||||
|
||||
def local_device_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the number of devices addressable by this process."""
|
||||
return int(get_backend(backend).local_device_count())
|
||||
|
||||
|
||||
def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
|
||||
"""Returns a list of all devices for a given backend.
|
||||
|
||||
Each device is represented by a subclass of :class:`Device` (e.g.
|
||||
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
|
||||
equal to ``device_count(backend)``. Local devices can be identified by
|
||||
comparing :meth:`Device.process_index` to the value returned by
|
||||
:py:func:`jax.process_index`.
|
||||
|
||||
If ``backend`` is ``None``, returns all the devices from the default backend.
|
||||
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
|
||||
otherwise ``'cpu'``.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
List of Device subclasses.
|
||||
"""
|
||||
return get_backend(backend).devices()
|
||||
|
||||
|
||||
def default_backend() -> str:
|
||||
"""Returns the platform name of the default XLA backend."""
|
||||
return get_backend(None).platform
|
||||
|
||||
|
||||
def local_devices(process_index: Optional[int] = None,
|
||||
backend: Optional[str] = None,
|
||||
host_id: Optional[int] = None) -> List[xla_client.Device]:
|
||||
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
|
||||
|
||||
If ``process_index`` is ``None``, returns devices local to this process.
|
||||
|
||||
Args:
|
||||
process_index: the integer index of the process. Process indices can be
|
||||
retrieved via ``len(jax.process_count())``.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
List of Device subclasses.
|
||||
"""
|
||||
if host_id is not None:
|
||||
warnings.warn(
|
||||
"The argument to jax.local_devices has been renamed from `host_id` to "
|
||||
"`process_index`. This alias will eventually be removed; please update "
|
||||
"your code.")
|
||||
process_index = host_id
|
||||
if process_index is None:
|
||||
process_index = get_backend(backend).process_index()
|
||||
if not (0 <= process_index < process_count()):
|
||||
raise ValueError(f"Unknown process_index {process_index}")
|
||||
return [d for d in devices(backend) if d.process_index == process_index]
|
||||
|
||||
|
||||
def process_index(backend: Optional[str] = None) -> int:
|
||||
"""Returns the integer process index of this process.
|
||||
|
||||
On most platforms, this will always be 0. This will vary on multi-process
|
||||
platforms though.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
Integer process index.
|
||||
"""
|
||||
return get_backend(backend).process_index()
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_id(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_id has been renamed to jax.process_index. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return process_index(backend)
|
||||
|
||||
|
||||
def process_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the number of JAX processes associated with the backend."""
|
||||
return max(d.process_index for d in devices(backend)) + 1
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_count(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_count has been renamed to jax.process_count. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return process_count(backend)
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_ids(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
|
||||
"instead. jax.host_ids will eventually be removed; please update your "
|
||||
"code.")
|
||||
return list(range(process_count(backend)))
|
||||
|
||||
|
||||
### utility functions
|
||||
|
||||
@util.memoize
|
||||
def dtype_to_etype(dtype):
|
||||
"""Convert from dtype to canonical etype (reading config.x64_enabled)."""
|
||||
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
|
||||
@util.memoize
|
||||
def supported_numpy_dtypes():
|
||||
return {dtypes.canonicalize_dtype(dtype)
|
||||
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}
|
||||
|
||||
|
||||
# TODO(mattjj,frostig): try to remove this function
|
||||
def normalize_to_xla_dtypes(val):
|
||||
"""Normalize dtypes in a value."""
|
||||
if hasattr(val, '__array__') or np.isscalar(val):
|
||||
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
||||
elif isinstance(val, (tuple, list)):
|
||||
return tuple(normalize_to_xla_dtypes(x) for x in val)
|
||||
raise TypeError('Can\'t convert to XLA: {}'.format(val))
|
||||
|
||||
def _numpy_array_constant(builder, value, canonicalize_types=True):
|
||||
if canonicalize_types:
|
||||
value = normalize_to_xla_dtypes(value)
|
||||
return [xops.ConstantLiteral(builder, value)]
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
if name is None:
|
||||
name = ''
|
||||
if replicated is None:
|
||||
replicated = []
|
||||
elif isinstance(replicated, bool):
|
||||
replicated = [replicated] * shape.leaf_count()
|
||||
|
||||
return xops.Parameter(builder, num,
|
||||
shape.with_major_to_minor_layout_if_absent(), name,
|
||||
replicated)
|
||||
|
||||
|
||||
def constant_general(builder, py_val, canonicalize_types=True):
|
||||
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant as a list of xla ops.
|
||||
"""
|
||||
for t in type(py_val).mro():
|
||||
handler = _constant_handlers.get(t)
|
||||
if handler: return handler(builder, py_val, canonicalize_types)
|
||||
if hasattr(py_val, '__jax_array__'):
|
||||
return constant(builder, py_val.__jax_array__(), canonicalize_types)
|
||||
raise TypeError("No constant handler for type: {}".format(type(py_val)))
|
||||
|
||||
def constant(builder, py_val, canonicalize_types=True):
|
||||
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant, either a ComputationDataHandle or None
|
||||
"""
|
||||
const = constant_general(builder, py_val, canonicalize_types=canonicalize_types)
|
||||
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
|
||||
return const[0]
|
||||
|
||||
# HLO instructions optionally can be annotated to say how the output should be
|
||||
# spatially partitioned (represented in XLA as OpSharding protos, see
|
||||
# _sharding_to_proto). For array outputs, the annotation is either an int per
|
||||
# dimension specifying the number of ways that dimension divided (i.e. the total
|
||||
# number of shards is the product), or None to indicate the array should be
|
||||
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
|
||||
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
||||
# checkers don't support recursive types), so we only represent one level of
|
||||
# nesting in this type definition.
|
||||
SpatialSharding = Union[Tuple[int, ...],
|
||||
None,
|
||||
Tuple[Union[Tuple[int, ...], None], ...]]
|
||||
|
||||
def _sharding_to_proto(sharding: SpatialSharding):
|
||||
"""Converts a SpatialSharding to an OpSharding.
|
||||
|
||||
See
|
||||
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
|
||||
for details on the OpSharding proto.
|
||||
"""
|
||||
proto = xla_client.OpSharding()
|
||||
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
|
||||
assert all(s is None or isinstance(s, tuple) for s in sharding)
|
||||
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
|
||||
|
||||
if sharding is None:
|
||||
proto.type = xla_client.OpSharding.Type.REPLICATED
|
||||
else:
|
||||
proto.type = xla_client.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = list(sharding)
|
||||
proto.tile_assignment_devices = list(range(np.product(sharding)))
|
||||
return proto
|
||||
|
||||
def tuple_sharding_proto(elems):
|
||||
proto = xla_client.OpSharding()
|
||||
assert all(isinstance(e, type(proto)) for e in elems)
|
||||
proto.type = xla_client.OpSharding.Type.TUPLE
|
||||
proto.tuple_shardings = elems
|
||||
return proto
|
||||
|
||||
def set_sharding_proto(builder, op, sharding_proto):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
# "Sharding" is a built-in custom call target that acts like an identity
|
||||
# function, and is used to attach an OpSharding to.
|
||||
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
|
||||
builder, b"Sharding", [op], builder.get_shape(op))
|
||||
|
||||
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
builder.set_sharding(sharding_proto)
|
||||
try:
|
||||
return op_fn(*args, **kwargs)
|
||||
finally:
|
||||
builder.clear_sharding()
|
||||
|
||||
def set_sharding(builder, op, sharding: SpatialSharding):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
|
||||
|
||||
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
|
||||
|
||||
def make_computation_builder(name):
|
||||
return xla_client.XlaBuilder(name)
|
||||
|
||||
|
||||
def register_constant_handler(type_, handler_fun):
|
||||
_constant_handlers[type_] = handler_fun
|
||||
_constant_handlers: Dict[type, Callable] = {}
|
||||
|
||||
|
||||
def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
||||
"""Constant handler for ndarray literals, handling zero-size strides.
|
||||
|
||||
This function essentially calls _numpy_array_constant(val) except it has
|
||||
special handling of arrays with any strides of size zero: for those, it
|
||||
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
||||
to avoid staging in large literals that might arise from np.zeros or np.ones
|
||||
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
||||
uses size-zero strides).
|
||||
|
||||
Args:
|
||||
c: an XlaBuilder
|
||||
val: an ndarray.
|
||||
|
||||
Returns:
|
||||
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
||||
staged into the XLA Computation.
|
||||
"""
|
||||
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
|
||||
if dtypes.result_type(val) == dtypes.float0:
|
||||
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
|
||||
elif np.any(np.equal(0, val.strides)) and val.size > 0:
|
||||
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
||||
other_axes, = np.where(np.not_equal(0, val.strides))
|
||||
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
|
||||
for ax in range(val.ndim))]
|
||||
xla_val = xops.Broadcast(
|
||||
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
|
||||
np.take(val.shape, zero_stride_axes))
|
||||
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
||||
return [xops.Transpose(xla_val, permutation)]
|
||||
else:
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
||||
|
||||
|
||||
def _scalar_constant_handler(c, val, canonicalize_types=True):
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
|
||||
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64,
|
||||
np.bool_, np.longlong,
|
||||
xla_client.bfloat16]:
|
||||
register_constant_handler(scalar_type, _scalar_constant_handler)
|
||||
|
||||
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
|
||||
if hasattr(np, "float128"):
|
||||
register_constant_handler(np.float128, _scalar_constant_handler)
|
||||
|
||||
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
|
||||
return _numpy_array_constant(c, dtype.type(val))
|
||||
|
||||
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
@ -17,7 +17,7 @@ import operator
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_zip
|
||||
from .util import _wraps
|
||||
from . import lax_numpy as jnp
|
||||
|
@ -23,13 +23,13 @@ from jax import lax
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax._src.api import jit, vmap
|
||||
from jax.config import config
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import cuda_prng
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import xla
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import cuda_prng
|
||||
from jax._src.pprint_util import pp, vcat
|
||||
from jax._src.util import prod
|
||||
|
||||
|
@ -18,8 +18,8 @@ import threading
|
||||
from typing import Callable, Optional
|
||||
import warnings
|
||||
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
|
||||
def start_server(port: int):
|
||||
|
@ -29,7 +29,7 @@ from jax.core import NamedShape
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.numpy.lax_numpy import (_arraylike, _check_arraylike,
|
||||
_constant_like, _convert_and_clip_integer)
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax.numpy.linalg import cholesky, svd, eigh
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -19,7 +19,7 @@ import threading
|
||||
from typing import Optional, Iterator
|
||||
|
||||
import jax.version
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
@ -17,7 +17,7 @@ import traceback
|
||||
import types
|
||||
|
||||
import jax
|
||||
from jax.lib import xla_extension
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src import util
|
||||
|
||||
_exclude_paths = [__file__, util.__file__]
|
||||
|
@ -19,7 +19,7 @@ import operator as op
|
||||
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, Type,
|
||||
TypeVar, overload, TYPE_CHECKING)
|
||||
|
||||
from ..lib import pytree
|
||||
from jax._src.lib import pytree
|
||||
|
||||
from .._src.util import safe_zip, unzip2
|
||||
|
||||
|
@ -17,7 +17,8 @@ import re
|
||||
|
||||
import jax
|
||||
from jax.experimental.compilation_cache.file_system_cache import FileSystemCache
|
||||
from jax.lib import xla_client
|
||||
import jax._src.lib
|
||||
from jax._src.lib import xla_client
|
||||
from absl import logging
|
||||
from typing import Optional
|
||||
|
||||
@ -85,7 +86,7 @@ def get_cache_key(xla_computation, compile_options, backend) -> str:
|
||||
_hash_compile_options(hash_obj, compile_options)
|
||||
if logging.vlog_is_on(1):
|
||||
logging.vlog(1, f"get_cache_key hash after serializing compile_options: {hash_obj.digest().hex()}")
|
||||
hash_obj.update(bytes(jax.lib.version))
|
||||
hash_obj.update(bytes(jax._src.lib.version))
|
||||
if logging.vlog_is_on(1):
|
||||
logging.vlog(1, f"get_cache_key hash after serializing jax_lib version: {hash_obj.digest().hex()}")
|
||||
_hash_platform(hash_obj, backend)
|
||||
@ -110,7 +111,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
hash_obj.update(compile_options_obj.device_assignment.serialize())
|
||||
|
||||
def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
if jax.lib.version >= (0, 1, 72):
|
||||
if jax._src.lib.version >= (0, 1, 72):
|
||||
expected_options = 31
|
||||
else:
|
||||
expected_options = 30
|
||||
@ -126,7 +127,7 @@ def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
if executable_obj.device_assignment is not None:
|
||||
hash_obj.update(executable_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
|
||||
if jax.lib.version >= (0, 1, 72):
|
||||
if jax._src.lib.version >= (0, 1, 72):
|
||||
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)
|
||||
|
||||
def _hash_debug_options(hash_obj, debug_obj):
|
||||
|
@ -600,8 +600,8 @@ core.pytype_aval_mappings[BoundedInt] = _abstractify_bdint
|
||||
# XLA lowering
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
xe = xc._xla
|
||||
xops = xc._xla.ops
|
||||
|
||||
|
@ -456,15 +456,15 @@ from jax import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax.experimental import pjit
|
||||
from jax.lib import pytree
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client
|
||||
from jax.lib import xla_extension
|
||||
from jax.interpreters import ad, xla, batching, masking, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import pprint_util as ppu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -35,7 +35,7 @@ from jax._src import util
|
||||
from jax._src import ad_util
|
||||
from jax._src.lax.lax import _device_put_raw
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
from . import jax2tf as jax2tf_internal
|
||||
|
||||
import numpy as np
|
||||
|
@ -44,7 +44,7 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import sharded_jit
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from . import shape_poly
|
||||
|
||||
|
@ -29,6 +29,7 @@ from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax._src import source_info_util
|
||||
import jax._src.lib.xla_bridge
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
@ -972,7 +973,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
jax_comp = jax.xla_computation(f_while)(x)
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
modules = backend.compile(jax_comp).hlo_modules()
|
||||
jax_opt_hlo = modules[0].to_string()
|
||||
print(f"JAX OPT HLO = {jax_opt_hlo}")
|
||||
|
@ -54,7 +54,7 @@ from jax import numpy as jnp
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1515,7 +1515,7 @@ for dtype in jtu.dtypes.all_floating + jtu.dtypes.complex:
|
||||
lax.linalg.qr,
|
||||
[RandArg(shape, dtype),
|
||||
StaticArg(full_matrices)],
|
||||
# See jax.lib.lapack.geqrf for the list of compatible types
|
||||
# See jax._src.lib.lapack.geqrf for the list of compatible types
|
||||
jax_unimplemented=[
|
||||
Limitation(
|
||||
"unimplemented",
|
||||
|
@ -31,6 +31,7 @@ from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.interpreters import sharded_jit
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
import jax.numpy as jnp
|
||||
import jax._src.lib.xla_bridge
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -80,12 +81,12 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
||||
self.AssertShardingAnnotations("JAX before optimizations", jax_hlo, expected)
|
||||
|
||||
if jtu.device_under_test() == "tpu":
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
num_replicas = 1
|
||||
device_assignment = np.arange(num_partitions * num_replicas)
|
||||
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
|
||||
use_spmd_partitioning = num_partitions > 1
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=device_assignment,
|
||||
|
@ -30,6 +30,7 @@ from jax.config import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.interpreters import masking
|
||||
from jax._src import util
|
||||
import jax._src.lib.xla_bridge
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
|
||||
@ -307,7 +308,7 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
stage="hlo")
|
||||
logging.info(f"[{self._testMethodName}] TF NON OPT HLO\n{tf_hlo}")
|
||||
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
modules = backend.compile(jax_comp).hlo_modules()
|
||||
jax_opt_hlo = modules[0].to_string()
|
||||
logging.info(f"[{self._testMethodName}] "
|
||||
|
@ -39,8 +39,8 @@ from ..interpreters import pxla
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import ad
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from .._src.util import (safe_map, safe_zip, HashableFunction,
|
||||
as_hashable_function, unzip2, distributed_debug_log,
|
||||
tuple_insert, moveaxis, split_list, wrap_name)
|
||||
|
@ -35,8 +35,8 @@ from ..interpreters import xla
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters.sharded_jit import PartitionSpec
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from ..tree_util import tree_map, tree_flatten, tree_unflatten
|
||||
from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
||||
wrap_name, wraps, distributed_debug_log,
|
||||
|
@ -44,9 +44,9 @@ from jax import vmap
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import cusparse
|
||||
from jax.lib import xla_bridge
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import cusparse
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.interpreters import ad
|
||||
|
@ -51,10 +51,10 @@ from .._src.util import (unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log)
|
||||
from ..errors import JAXTypeError
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from ..lib import pmap_lib
|
||||
from ..lib import _xla_extension_version
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import _xla_extension_version
|
||||
from ..tree_util import tree_flatten, tree_map
|
||||
from . import batching
|
||||
from . import partial_eval as pe
|
||||
|
@ -25,8 +25,8 @@ from . import partial_eval as pe
|
||||
from . import pxla
|
||||
from . import xla
|
||||
from .. import linear_util as lu
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from .._src.api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple
|
||||
from ..tree_util import tree_flatten, tree_unflatten
|
||||
from .._src.util import (extend_name_stack, wrap_name, wraps, safe_zip,
|
||||
|
@ -40,8 +40,8 @@ from jax._src.pprint_util import pp
|
||||
from .._src.util import (partialmethod, cache, prod, unzip2,
|
||||
extend_name_stack, wrap_name, safe_zip, safe_map,
|
||||
partition_list)
|
||||
from ..lib import xla_bridge as xb
|
||||
from ..lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from . import partial_eval as pe
|
||||
from . import ad
|
||||
from . import masking
|
||||
|
@ -12,111 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This module is largely a wrapper around `jaxlib` that performs version
|
||||
# checking on import.
|
||||
|
||||
import platform
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
__all__ = [
|
||||
'cuda_linalg', 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
|
||||
'pocketfft', 'pytree', 'tpu_driver_client', 'version', 'xla_client'
|
||||
]
|
||||
|
||||
# First, before attempting to import jaxlib, warn about experimental machine
|
||||
# configurations.
|
||||
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||
warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
|
||||
"Please see https://github.com/google/jax/issues/5501 in the "
|
||||
"event of problems.")
|
||||
|
||||
try:
|
||||
import jaxlib
|
||||
except ModuleNotFoundError as err:
|
||||
raise ModuleNotFoundError(
|
||||
'jax requires jaxlib to be installed. See '
|
||||
'https://github.com/google/jax#installation for installation instructions.'
|
||||
) from err
|
||||
|
||||
from jax.version import _minimum_jaxlib_version as _minimum_jaxlib_version_str
|
||||
try:
|
||||
from jaxlib import version as jaxlib_version
|
||||
except Exception as err:
|
||||
# jaxlib is too old to have version number.
|
||||
msg = f'This version of jax requires jaxlib version >= {_minimum_jaxlib_version_str}.'
|
||||
raise ImportError(msg) from err
|
||||
|
||||
version = tuple(int(x) for x in jaxlib_version.__version__.split('.'))
|
||||
_minimum_jaxlib_version = tuple(int(x) for x in _minimum_jaxlib_version_str.split('.'))
|
||||
|
||||
# Check the jaxlib version before importing anything else from jaxlib.
|
||||
def _check_jaxlib_version():
|
||||
if version < _minimum_jaxlib_version:
|
||||
msg = (f'jaxlib is version {jaxlib_version.__version__}, '
|
||||
f'but this version of jax requires version {_minimum_jaxlib_version_str}.')
|
||||
|
||||
if version == (0, 1, 23):
|
||||
msg += ('\n\nA common cause of this error is that you installed jaxlib '
|
||||
'using pip, but your version of pip is too old to support '
|
||||
'manylinux2010 wheels. Try running:\n\n'
|
||||
'pip install --upgrade pip\n'
|
||||
'pip install --upgrade jax jaxlib\n')
|
||||
raise ValueError(msg)
|
||||
|
||||
_check_jaxlib_version()
|
||||
|
||||
from jaxlib import cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import lapack
|
||||
from jaxlib import pocketfft
|
||||
|
||||
xla_extension = xla_client._xla
|
||||
pytree = xla_client._xla.pytree
|
||||
jax_jit = xla_client._xla.jax_jit
|
||||
pmap_lib = xla_client._xla.pmap_lib
|
||||
|
||||
try:
|
||||
from jaxlib import cusolver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cusolver = None
|
||||
|
||||
try:
|
||||
from jaxlib import cusparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cusparse = None
|
||||
|
||||
try:
|
||||
from jaxlib import rocsolver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
rocsolver = None
|
||||
|
||||
try:
|
||||
from jaxlib import cuda_prng # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cuda_prng = None
|
||||
|
||||
try:
|
||||
from jaxlib import cuda_linalg # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cuda_linalg = None
|
||||
|
||||
# Jaxlib code is split between the Jax and the Tensorflow repositories.
|
||||
# Only for the internal usage of the JAX developers, we expose a version
|
||||
# number that can be used to perform changes without breaking the main
|
||||
# branch on the Jax github.
|
||||
_xla_extension_version = getattr(xla_client, '_version', 0)
|
||||
|
||||
try:
|
||||
from jaxlib import tpu_client as tpu_driver_client # pytype: disable=import-error
|
||||
except:
|
||||
tpu_driver_client = None # type: ignore
|
||||
|
||||
|
||||
cuda_path: Optional[str]
|
||||
cuda_path = os.path.join(os.path.dirname(jaxlib.__file__), "cuda")
|
||||
if not os.path.isdir(cuda_path):
|
||||
cuda_path = None
|
||||
# flake8: noqa: F401
|
||||
from jax._src.lib import (
|
||||
xla_client as xla_client,
|
||||
xla_extension as xla_extension,
|
||||
)
|
||||
from . import xla_bridge as xla_bridge
|
||||
|
@ -12,583 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Interface and utility functions to XLA.
|
||||
|
||||
This module wraps the XLA client(s) and builders to standardize their interfaces
|
||||
and provide some automatic type mapping logic for converting between Numpy and
|
||||
XLA. There are also a handful of related casting utilities.
|
||||
"""
|
||||
|
||||
|
||||
from functools import partial, lru_cache
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
# Disable "WARNING: Logging before flag parsing goes to stderr." message
|
||||
logging._warn_preinit_stderr = 0
|
||||
|
||||
import jax.lib
|
||||
from .._src.config import flags, bool_env
|
||||
from . import tpu_driver_client
|
||||
from . import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
from jax._src import dtypes
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# TODO(phawkins): Remove jax_xla_backend.
|
||||
flags.DEFINE_string(
|
||||
'jax_xla_backend', '',
|
||||
'jax_xla_backend is an alias for jax_platform_name. If both are '
|
||||
'provided, --jax_xla_backend takes priority. Prefer --jax_platform_name.')
|
||||
flags.DEFINE_string(
|
||||
'jax_backend_target', '',
|
||||
'Either "local" or "rpc:address" to connect to a remote service target.')
|
||||
flags.DEFINE_string(
|
||||
'jax_platform_name',
|
||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||
'Platform name for XLA. The default is to attempt to use a GPU or TPU if '
|
||||
'available, but fall back to CPU otherwise. To set the platform manually, '
|
||||
'pass "cpu" for CPU, "gpu" for GPU, etc. If intending to use CPU, '
|
||||
'setting the platform name to "cpu" can silence warnings that appear with '
|
||||
'the default setting.')
|
||||
flags.DEFINE_bool(
|
||||
'jax_disable_most_optimizations',
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
num_partitions: int,
|
||||
device_assignment=None,
|
||||
use_spmd_partitioning: bool = True,
|
||||
) -> xla_client.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
Args:
|
||||
num_replicas: Number of replicas for which to compile.
|
||||
num_partitions: Number of partitions for which to compile.
|
||||
device_assignment: Optional tuple of integers indicating the assignment of
|
||||
logical replicas to physical devices (default inherited from
|
||||
xla_client.CompileOptions). Must be consistent with `num_replicas` and
|
||||
`num_partitions`.
|
||||
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
||||
partitioning in XLA.
|
||||
"""
|
||||
compile_options = xla_client.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
compile_options.num_partitions = num_partitions
|
||||
build_options = compile_options.executable_build_options
|
||||
build_options.use_spmd_partitioning = use_spmd_partitioning
|
||||
if device_assignment is not None:
|
||||
logging.vlog(
|
||||
2,
|
||||
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
||||
num_replicas, num_partitions, device_assignment)
|
||||
device_assignment = np.array(device_assignment)
|
||||
|
||||
# Allow 1D device assignment if num_partitions is 1.
|
||||
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
||||
device_assignment = device_assignment[:, None]
|
||||
|
||||
if num_replicas != device_assignment.shape[0]:
|
||||
msg = 'device_assignment does not match num_replicas: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_replicas))
|
||||
|
||||
if num_partitions != device_assignment.shape[1]:
|
||||
msg = 'device_assignment does not match num_partitions: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_partitions))
|
||||
|
||||
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
|
||||
assert device_assignment.replica_count() == num_replicas
|
||||
assert device_assignment.computation_count() == num_partitions
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
if jax.lib.cuda_path is not None:
|
||||
debug_options.xla_gpu_cuda_data_dir = jax.lib.cuda_path
|
||||
|
||||
if FLAGS.jax_disable_most_optimizations:
|
||||
|
||||
debug_options.xla_backend_optimization_level = 0
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
return compile_options
|
||||
|
||||
|
||||
# Backends
|
||||
|
||||
def _make_tpu_driver_client():
|
||||
if tpu_driver_client is None:
|
||||
logging.info("Remote TPU is not linked into jax; skipping remote TPU.")
|
||||
return None
|
||||
if FLAGS.jax_backend_target is None:
|
||||
logging.info("No --jax_backend_target was provided; skipping remote TPU.")
|
||||
return None
|
||||
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)
|
||||
|
||||
|
||||
def tpu_client_timer_callback(timer_secs: float):
|
||||
def _log_warning():
|
||||
warnings.warn(
|
||||
f'TPU backend initialization is taking more than {timer_secs} seconds. '
|
||||
'Did you run your code on all TPU hosts? '
|
||||
'See https://jax.readthedocs.io/en/latest/multi_process.html '
|
||||
'for more information.')
|
||||
|
||||
# Will log a warning after `timer_secs`.
|
||||
t = threading.Timer(timer_secs, _log_warning)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
client = xla_client.make_tpu_client()
|
||||
finally:
|
||||
t.cancel()
|
||||
|
||||
return client
|
||||
|
||||
|
||||
# Backends, in increasing order of preference.
|
||||
# We have no particular opinion about how "backends" relate to "devices". For
|
||||
# example, there could be multiple backends that provide the same kind of
|
||||
# device.
|
||||
_backend_factories = {}
|
||||
|
||||
def register_backend_factory(name, factory, *, priority=0):
|
||||
_backend_factories[name] = (factory, priority)
|
||||
|
||||
|
||||
register_backend_factory('interpreter', xla_client.make_interpreter_client,
|
||||
priority=-100)
|
||||
register_backend_factory('cpu',
|
||||
partial(xla_client.make_cpu_client, use_tfrt=True),
|
||||
priority=0)
|
||||
register_backend_factory('tpu_driver', _make_tpu_driver_client,
|
||||
priority=100)
|
||||
register_backend_factory('gpu', xla_client.make_gpu_client,
|
||||
priority=200)
|
||||
register_backend_factory(
|
||||
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
|
||||
|
||||
_default_backend = None
|
||||
_backends = None
|
||||
_backends_errors = None
|
||||
_backend_lock = threading.Lock()
|
||||
|
||||
|
||||
def backends():
|
||||
global _backends
|
||||
global _backends_errors
|
||||
global _default_backend
|
||||
|
||||
with _backend_lock:
|
||||
if _backends is not None:
|
||||
return _backends
|
||||
|
||||
default_priority = -1000
|
||||
_backends = {}
|
||||
_backends_errors = {}
|
||||
for name, (factory, priority) in _backend_factories.items():
|
||||
logging.vlog(1, "Initializing backend '%s'" % name)
|
||||
try:
|
||||
backend = factory()
|
||||
if backend is not None:
|
||||
if backend.device_count() > 0:
|
||||
_backends[name] = backend
|
||||
util.distributed_debug_log(("Initialized backend", backend.platform),
|
||||
("process_index", backend.process_index()),
|
||||
("device_count", backend.device_count()),
|
||||
("local_devices", backend.local_devices()))
|
||||
logging.vlog(1, "Backend '%s' initialized" % name)
|
||||
if priority > default_priority:
|
||||
_default_backend = backend
|
||||
default_priority = priority
|
||||
except Exception as err:
|
||||
if name in ('cpu', 'interpreter'):
|
||||
# We always expect the CPU and interpreter backends to initialize
|
||||
# successfully.
|
||||
raise
|
||||
else:
|
||||
# If the backend isn't built into the binary, or if it has no devices,
|
||||
# we expect a RuntimeError.
|
||||
logging.info("Unable to initialize backend '%s': %s" % (name, err))
|
||||
_backends_errors[name] = str(err)
|
||||
continue
|
||||
if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
|
||||
logging.warning('No GPU/TPU found, falling back to CPU. '
|
||||
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
||||
return _backends
|
||||
|
||||
|
||||
def _get_backend_uncached(platform=None):
|
||||
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
|
||||
# 'backend' values are handled
|
||||
if not isinstance(platform, (type(None), str)):
|
||||
return platform
|
||||
|
||||
bs = backends()
|
||||
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
|
||||
or None)
|
||||
if platform is not None:
|
||||
backend = bs.get(platform, None)
|
||||
if backend is None:
|
||||
if platform in _backends_errors:
|
||||
raise RuntimeError(f"Requested backend {platform}, but it failed "
|
||||
f"to initialize: {_backends_errors[platform]}")
|
||||
raise RuntimeError(f"Unknown backend {platform}")
|
||||
return backend
|
||||
else:
|
||||
return _default_backend
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
|
||||
def get_backend(platform=None):
|
||||
return _get_backend_uncached(platform)
|
||||
|
||||
|
||||
def get_device_backend(device=None):
|
||||
"""Returns the Backend associated with `device`, or the default Backend."""
|
||||
if device is not None:
|
||||
return device.client
|
||||
return get_backend()
|
||||
|
||||
|
||||
def device_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the total number of devices.
|
||||
|
||||
On most platforms, this is the same as :py:func:`jax.local_device_count`.
|
||||
However, on multi-process platforms where different devices are associated
|
||||
with different processes, this will return the total number of devices across
|
||||
all processes.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
Number of devices.
|
||||
|
||||
"""
|
||||
return int(get_backend(backend).device_count())
|
||||
|
||||
|
||||
def local_device_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the number of devices addressable by this process."""
|
||||
return int(get_backend(backend).local_device_count())
|
||||
|
||||
|
||||
def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
|
||||
"""Returns a list of all devices for a given backend.
|
||||
|
||||
Each device is represented by a subclass of :class:`Device` (e.g.
|
||||
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
|
||||
equal to ``device_count(backend)``. Local devices can be identified by
|
||||
comparing :meth:`Device.process_index` to the value returned by
|
||||
:py:func:`jax.process_index`.
|
||||
|
||||
If ``backend`` is ``None``, returns all the devices from the default backend.
|
||||
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
|
||||
otherwise ``'cpu'``.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
List of Device subclasses.
|
||||
"""
|
||||
return get_backend(backend).devices()
|
||||
|
||||
|
||||
def default_backend() -> str:
|
||||
"""Returns the platform name of the default XLA backend."""
|
||||
return get_backend(None).platform
|
||||
|
||||
|
||||
def local_devices(process_index: Optional[int] = None,
|
||||
backend: Optional[str] = None,
|
||||
host_id: Optional[int] = None) -> List[xla_client.Device]:
|
||||
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
|
||||
|
||||
If ``process_index`` is ``None``, returns devices local to this process.
|
||||
|
||||
Args:
|
||||
process_index: the integer index of the process. Process indices can be
|
||||
retrieved via ``len(jax.process_count())``.
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
List of Device subclasses.
|
||||
"""
|
||||
if host_id is not None:
|
||||
warnings.warn(
|
||||
"The argument to jax.local_devices has been renamed from `host_id` to "
|
||||
"`process_index`. This alias will eventually be removed; please update "
|
||||
"your code.")
|
||||
process_index = host_id
|
||||
if process_index is None:
|
||||
process_index = get_backend(backend).process_index()
|
||||
if not (0 <= process_index < process_count()):
|
||||
raise ValueError(f"Unknown process_index {process_index}")
|
||||
return [d for d in devices(backend) if d.process_index == process_index]
|
||||
|
||||
|
||||
def process_index(backend: Optional[str] = None) -> int:
|
||||
"""Returns the integer process index of this process.
|
||||
|
||||
On most platforms, this will always be 0. This will vary on multi-process
|
||||
platforms though.
|
||||
|
||||
Args:
|
||||
backend: This is an experimental feature and the API is likely to change.
|
||||
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
||||
``'tpu'``.
|
||||
|
||||
Returns:
|
||||
Integer process index.
|
||||
"""
|
||||
return get_backend(backend).process_index()
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_id(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_id has been renamed to jax.process_index. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return process_index(backend)
|
||||
|
||||
|
||||
def process_count(backend: Optional[str] = None) -> int:
|
||||
"""Returns the number of JAX processes associated with the backend."""
|
||||
return max(d.process_index for d in devices(backend)) + 1
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_count(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_count has been renamed to jax.process_count. This alias "
|
||||
"will eventually be removed; please update your code.")
|
||||
return process_count(backend)
|
||||
|
||||
|
||||
# TODO: remove this sometime after jax 0.2.13 is released
|
||||
def host_ids(backend=None):
|
||||
warnings.warn(
|
||||
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
|
||||
"instead. jax.host_ids will eventually be removed; please update your "
|
||||
"code.")
|
||||
return list(range(process_count(backend)))
|
||||
|
||||
|
||||
### utility functions
|
||||
|
||||
@util.memoize
|
||||
def dtype_to_etype(dtype):
|
||||
"""Convert from dtype to canonical etype (reading config.x64_enabled)."""
|
||||
return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
|
||||
|
||||
|
||||
@util.memoize
|
||||
def supported_numpy_dtypes():
|
||||
return {dtypes.canonicalize_dtype(dtype)
|
||||
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}
|
||||
|
||||
|
||||
# TODO(mattjj,frostig): try to remove this function
|
||||
def normalize_to_xla_dtypes(val):
|
||||
"""Normalize dtypes in a value."""
|
||||
if hasattr(val, '__array__') or np.isscalar(val):
|
||||
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
||||
elif isinstance(val, (tuple, list)):
|
||||
return tuple(normalize_to_xla_dtypes(x) for x in val)
|
||||
raise TypeError('Can\'t convert to XLA: {}'.format(val))
|
||||
|
||||
def _numpy_array_constant(builder, value, canonicalize_types=True):
|
||||
if canonicalize_types:
|
||||
value = normalize_to_xla_dtypes(value)
|
||||
return [xops.ConstantLiteral(builder, value)]
|
||||
|
||||
def parameter(builder, num, shape, name=None, replicated=None):
|
||||
if name is None:
|
||||
name = ''
|
||||
if replicated is None:
|
||||
replicated = []
|
||||
elif isinstance(replicated, bool):
|
||||
replicated = [replicated] * shape.leaf_count()
|
||||
|
||||
return xops.Parameter(builder, num,
|
||||
shape.with_major_to_minor_layout_if_absent(), name,
|
||||
replicated)
|
||||
|
||||
|
||||
def constant_general(builder, py_val, canonicalize_types=True):
|
||||
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant as a list of xla ops.
|
||||
"""
|
||||
for t in type(py_val).mro():
|
||||
handler = _constant_handlers.get(t)
|
||||
if handler: return handler(builder, py_val, canonicalize_types)
|
||||
if hasattr(py_val, '__jax_array__'):
|
||||
return constant(builder, py_val.__jax_array__(), canonicalize_types)
|
||||
raise TypeError("No constant handler for type: {}".format(type(py_val)))
|
||||
|
||||
def constant(builder, py_val, canonicalize_types=True):
|
||||
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
py_val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant, either a ComputationDataHandle or None
|
||||
"""
|
||||
const = constant_general(builder, py_val, canonicalize_types=canonicalize_types)
|
||||
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
|
||||
return const[0]
|
||||
|
||||
# HLO instructions optionally can be annotated to say how the output should be
|
||||
# spatially partitioned (represented in XLA as OpSharding protos, see
|
||||
# _sharding_to_proto). For array outputs, the annotation is either an int per
|
||||
# dimension specifying the number of ways that dimension divided (i.e. the total
|
||||
# number of shards is the product), or None to indicate the array should be
|
||||
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
|
||||
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
||||
# checkers don't support recursive types), so we only represent one level of
|
||||
# nesting in this type definition.
|
||||
SpatialSharding = Union[Tuple[int, ...],
|
||||
None,
|
||||
Tuple[Union[Tuple[int, ...], None], ...]]
|
||||
|
||||
def _sharding_to_proto(sharding: SpatialSharding):
|
||||
"""Converts a SpatialSharding to an OpSharding.
|
||||
|
||||
See
|
||||
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
|
||||
for details on the OpSharding proto.
|
||||
"""
|
||||
proto = xla_client.OpSharding()
|
||||
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
|
||||
assert all(s is None or isinstance(s, tuple) for s in sharding)
|
||||
return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore
|
||||
|
||||
if sharding is None:
|
||||
proto.type = xla_client.OpSharding.Type.REPLICATED
|
||||
else:
|
||||
proto.type = xla_client.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = list(sharding)
|
||||
proto.tile_assignment_devices = list(range(np.product(sharding)))
|
||||
return proto
|
||||
|
||||
def tuple_sharding_proto(elems):
|
||||
proto = xla_client.OpSharding()
|
||||
assert all(isinstance(e, type(proto)) for e in elems)
|
||||
proto.type = xla_client.OpSharding.Type.TUPLE
|
||||
proto.tuple_shardings = elems
|
||||
return proto
|
||||
|
||||
def set_sharding_proto(builder, op, sharding_proto):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
# "Sharding" is a built-in custom call target that acts like an identity
|
||||
# function, and is used to attach an OpSharding to.
|
||||
return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
|
||||
builder, b"Sharding", [op], builder.get_shape(op))
|
||||
|
||||
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
builder.set_sharding(sharding_proto)
|
||||
try:
|
||||
return op_fn(*args, **kwargs)
|
||||
finally:
|
||||
builder.clear_sharding()
|
||||
|
||||
def set_sharding(builder, op, sharding: SpatialSharding):
|
||||
"""Uses CustomCall to annotate a value as sharded."""
|
||||
return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
|
||||
|
||||
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
||||
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
||||
return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
|
||||
|
||||
def make_computation_builder(name):
|
||||
return xla_client.XlaBuilder(name)
|
||||
|
||||
|
||||
def register_constant_handler(type_, handler_fun):
|
||||
_constant_handlers[type_] = handler_fun
|
||||
_constant_handlers: Dict[type, Callable] = {}
|
||||
|
||||
|
||||
def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
||||
"""Constant handler for ndarray literals, handling zero-size strides.
|
||||
|
||||
This function essentially calls _numpy_array_constant(val) except it has
|
||||
special handling of arrays with any strides of size zero: for those, it
|
||||
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
||||
to avoid staging in large literals that might arise from np.zeros or np.ones
|
||||
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
||||
uses size-zero strides).
|
||||
|
||||
Args:
|
||||
c: an XlaBuilder
|
||||
val: an ndarray.
|
||||
|
||||
Returns:
|
||||
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
||||
staged into the XLA Computation.
|
||||
"""
|
||||
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
|
||||
if dtypes.result_type(val) == dtypes.float0:
|
||||
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
|
||||
elif np.any(np.equal(0, val.strides)) and val.size > 0:
|
||||
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
||||
other_axes, = np.where(np.not_equal(0, val.strides))
|
||||
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
|
||||
for ax in range(val.ndim))]
|
||||
xla_val = xops.Broadcast(
|
||||
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
|
||||
np.take(val.shape, zero_stride_axes))
|
||||
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
||||
return [xops.Transpose(xla_val, permutation)]
|
||||
else:
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
||||
|
||||
|
||||
def _scalar_constant_handler(c, val, canonicalize_types=True):
|
||||
return _numpy_array_constant(c, val, canonicalize_types)
|
||||
|
||||
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64,
|
||||
np.bool_, np.longlong,
|
||||
xla_client.bfloat16]:
|
||||
register_constant_handler(scalar_type, _scalar_constant_handler)
|
||||
|
||||
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
|
||||
if hasattr(np, "float128"):
|
||||
register_constant_handler(np.float128, _scalar_constant_handler)
|
||||
|
||||
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
|
||||
return _numpy_array_constant(c, dtype.type(val))
|
||||
|
||||
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
||||
# flake8: noqa: F401
|
||||
from jax._src.lib.xla_bridge import (
|
||||
constant as constant,
|
||||
default_backend as default_backend,
|
||||
device_count as device_count,
|
||||
get_backend as get_backend,
|
||||
get_compile_options as get_compile_options,
|
||||
local_device_count as local_device_count,
|
||||
process_index as process_index,
|
||||
register_constant_handler as register_constant_handler,
|
||||
xla_client as xla_client,
|
||||
_backends as _backends,
|
||||
_python_scalar_handler as _python_scalar_handler,
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ from . import lax
|
||||
from ._src.config import flags, bool_env, config
|
||||
from ._src.util import prod, unzip2
|
||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from .lib import xla_bridge
|
||||
from jax._src.lib import xla_bridge
|
||||
from .interpreters import xla
|
||||
from .experimental.maps import mesh
|
||||
|
||||
|
@ -72,7 +72,7 @@ from absl import app
|
||||
from absl import flags
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
@ -17,7 +17,7 @@ cusparse wrappers for performing sparse matrix computations in JAX
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
try:
|
||||
from . import _cusparse
|
||||
|
@ -47,7 +47,8 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
from jax.lib import xla_bridge as xb
|
||||
import jax._src.lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax import linear_util as lu
|
||||
@ -199,7 +200,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
assert len(side) == 3
|
||||
|
||||
def test_jit_device(self):
|
||||
device = xb.devices()[-1]
|
||||
device = jax.devices()[-1]
|
||||
x = self.jit(lambda x: x, device=device)(3.)
|
||||
self.assertIsInstance(x, xla.DeviceArray)
|
||||
self.assertEqual(x.device_buffer.device(), device)
|
||||
@ -663,7 +664,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
np.testing.assert_allclose(f_pruned(*args), 3)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@unittest.skipIf(jax.lib._xla_extension_version <= 36,
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version <= 36,
|
||||
"Test requires jaxlib 0.1.71")
|
||||
def testBuffersAreFreedPromptly(self):
|
||||
# Regression test for a bug where garbage collection was delayed too long
|
||||
@ -1757,7 +1758,7 @@ class APITest(jtu.JaxTestCase):
|
||||
param_shapes = c.program_shape().parameter_shapes()
|
||||
self.assertEqual(len(param_shapes), 1)
|
||||
self.assertEqual(param_shapes[0].xla_element_type(),
|
||||
xb.xla_client.PrimitiveType.TUPLE)
|
||||
xla_client.PrimitiveType.TUPLE)
|
||||
|
||||
def test_xla_computation_duck_typing(self):
|
||||
def foo(x, y, z):
|
||||
@ -1774,7 +1775,7 @@ class APITest(jtu.JaxTestCase):
|
||||
param_shapes = c.program_shape().parameter_shapes()
|
||||
self.assertEqual(len(param_shapes), 1)
|
||||
self.assertEqual(param_shapes[0].xla_element_type(),
|
||||
xb.xla_client.PrimitiveType.TUPLE)
|
||||
xla_client.PrimitiveType.TUPLE)
|
||||
|
||||
def test_staging_out_multi_replica(self):
|
||||
def f(x):
|
||||
|
@ -19,7 +19,7 @@ from absl.testing import absltest, parameterized
|
||||
import jax
|
||||
from jax.config import config
|
||||
import jax.dlpack
|
||||
from jax.lib import xla_bridge, xla_client
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
|
||||
|
@ -12,9 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from functools import partial
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
from jax.experimental.maps import xmap
|
||||
@ -23,12 +29,8 @@ import jax
|
||||
from jax import jit, lax, pmap
|
||||
from jax._src.util import prod
|
||||
import jax.test_util as jtu
|
||||
import jax._src.lib
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -40,16 +42,16 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
raise SkipTest("serialize executable only works on TPU")
|
||||
if jax.lib.xla_bridge.get_backend().runtime_type == "tfrt":
|
||||
if jax._src.lib.xla_bridge.get_backend().runtime_type == "tfrt":
|
||||
raise SkipTest("the new TFRT runtime does not support serialization")
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
cc._cache = None
|
||||
|
||||
@unittest.skipIf(jax.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
@unittest.skipIf(jax._src.lib.version < (0, 1, 68), "fails with earlier jaxlibs")
|
||||
def test_compile_options(self):
|
||||
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options_not_filled = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
filled_hash1 = self.get_hashed_value(cc._hash_compile_options, compile_options_filled)
|
||||
@ -59,7 +61,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(filled_hash1, not_filled_hash3)
|
||||
|
||||
def test_executable_build_options(self):
|
||||
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options_not_filled = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
filled_hash1 = self.get_hashed_value(cc._hash_executable_build_options,
|
||||
@ -72,7 +74,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(filled_hash1, not_filled_hash3)
|
||||
|
||||
def test_debug_options(self):
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
hash1 = self.get_hashed_value(cc._hash_debug_options,
|
||||
compile_options.executable_build_options.debug_options)
|
||||
@ -84,11 +86,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(hash1, hash3)
|
||||
|
||||
def test_hash_platform(self):
|
||||
hash1 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
|
||||
hash2 = self.get_hashed_value(cc._hash_platform, jax.lib.xla_bridge.get_backend())
|
||||
hash1 = self.get_hashed_value(cc._hash_platform, jax._src.lib.xla_bridge.get_backend())
|
||||
hash2 = self.get_hashed_value(cc._hash_platform, jax._src.lib.xla_bridge.get_backend())
|
||||
self.assertEqual(hash1, hash2)
|
||||
if jax.lib.xla_bridge.get_backend().platform != "cpu":
|
||||
cpu_backend = jax.lib.xla_bridge.get_backend("cpu")
|
||||
if jax._src.lib.xla_bridge.get_backend().platform != "cpu":
|
||||
cpu_backend = jax._src.lib.xla_bridge.get_backend("cpu")
|
||||
hash3 = self.get_hashed_value(cc._hash_platform, cpu_backend)
|
||||
self.assertNotEqual(hash1, hash3)
|
||||
|
||||
@ -115,27 +117,27 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def test_same_hash_key(self):
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
self.assertEqual(cc.get_cache_key(computation, compile_options, backend),
|
||||
cc.get_cache_key(computation, compile_options, backend))
|
||||
|
||||
def test_different_hash_key(self):
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options_not_filled = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
self.assertNotEqual(cc.get_cache_key(computation, compile_options_not_filled, backend),
|
||||
cc.get_cache_key(computation, compile_options_filled, backend))
|
||||
|
||||
def test_different_computations(self):
|
||||
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
self.assertNotEqual(cc.get_cache_key(computation1, compile_options, backend),
|
||||
cc.get_cache_key(computation2, compile_options, backend))
|
||||
|
||||
@ -143,9 +145,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
self.assertEqual(cc.get_executable(computation, compile_options, backend), None)
|
||||
|
||||
def test_diff_executables(self):
|
||||
@ -153,9 +155,9 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
executable1 = backend.compile(computation1, compile_options)
|
||||
executable2 = backend.compile(computation2, compile_options)
|
||||
cc.put_executable(computation1, compile_options, executable1, backend)
|
||||
@ -167,15 +169,15 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = jax._src.lib.xla_bridge.get_backend()
|
||||
executable = backend.compile(computation, compile_options)
|
||||
cc.put_executable(computation, compile_options, executable, backend)
|
||||
deserialized_executable = cc.get_executable(computation, compile_options, backend)
|
||||
inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32))
|
||||
expected = jax.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
|
||||
actual = jax.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
|
||||
expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
|
||||
actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_pmap(self):
|
||||
@ -257,15 +259,15 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
return debug_options_obj
|
||||
|
||||
def filled_compile_options(self):
|
||||
compile_options = jax.lib.xla_client.CompileOptions()
|
||||
compile_options = jax._src.lib.xla_client.CompileOptions()
|
||||
compile_options.num_replicas = 1
|
||||
compile_options.num_partitions = 1
|
||||
shape = jax.lib.xla_client.Shape.array_shape(np.dtype(np.float32), [2])
|
||||
shape = jax._src.lib.xla_client.Shape.array_shape(np.dtype(np.float32), [2])
|
||||
shape_array = [shape, shape]
|
||||
compile_options.argument_layouts = shape_array
|
||||
compile_options.executable_build_options.result_layout = shape
|
||||
|
||||
device_assignment = jax.lib.xla_client.DeviceAssignment.create(np.ndarray(shape=(2,2)))
|
||||
device_assignment = jax._src.lib.xla_client.DeviceAssignment.create(np.ndarray(shape=(2,2)))
|
||||
compile_options.device_assignment = device_assignment
|
||||
compile_options.executable_build_options.device_assignment = device_assignment
|
||||
return compile_options
|
||||
|
@ -20,7 +20,7 @@ from jax import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
from jax import core, jit, lax, make_jaxpr
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_bridge, xla_client
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
xops = xla_client.ops
|
||||
xb = xla_bridge
|
||||
|
||||
|
@ -24,6 +24,7 @@ from jax._src import api
|
||||
from jax import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
import jax._src.lib
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -97,7 +98,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
def testPmap(self):
|
||||
pmap_funcs = [api._python_pmap]
|
||||
if jax.lib._xla_extension_version >= 36:
|
||||
if jax._src.lib._xla_extension_version >= 36:
|
||||
pmap_funcs.append(api._cpp_pmap)
|
||||
|
||||
for pmap in pmap_funcs:
|
||||
|
@ -16,6 +16,7 @@ import unittest
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax._src.lib.xla_bridge
|
||||
from jax.config import config
|
||||
import jax.test_util as jtu
|
||||
|
||||
@ -28,7 +29,7 @@ class HeapProfilerTest(unittest.TestCase):
|
||||
# not check functional correctness.
|
||||
|
||||
def testBasics(self):
|
||||
client = jax.lib.xla_bridge.get_backend()
|
||||
client = jax._src.lib.xla_bridge.get_backend()
|
||||
_ = client.heap_profile()
|
||||
|
||||
a = jax.device_put(1)
|
||||
|
@ -37,7 +37,7 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.lib import xla_bridge
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -162,7 +162,7 @@ def helper_set_hlo_dump():
|
||||
|
||||
|
||||
def helper_print_optimized_hlo(fun, *args):
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = xla_bridge.get_backend()
|
||||
c = jax.xla_computation(fun)(*args)
|
||||
print(re.sub(r", metadata.*", "",
|
||||
backend.compile(c).hlo_modules()[0].to_string()))
|
||||
@ -177,13 +177,13 @@ def helper_log_ir(name,
|
||||
jax_comp = jax.xla_computation(f_jax)(*args)
|
||||
print(f"HLO[{name}]: {jax_comp.as_hlo_text()}")
|
||||
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
backend = xla_bridge.get_backend()
|
||||
if num_partitions is not None:
|
||||
num_replicas = 1
|
||||
device_assignment = np.arange(num_partitions * num_replicas)
|
||||
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
|
||||
use_spmd_partitioning = num_partitions > 1
|
||||
compile_options = jax.lib.xla_bridge.get_compile_options(
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=device_assignment,
|
||||
|
@ -20,7 +20,7 @@ import jax
|
||||
from jax import lax, numpy as jnp
|
||||
from jax import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
import jax.test_util as jtu
|
||||
import numpy as np
|
||||
|
||||
|
@ -19,7 +19,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax import dtypes
|
||||
from jax import lib as jaxlib
|
||||
from jax._src import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
import jax.numpy as jnp
|
||||
from jax.tools.jax_to_hlo import jax_to_hlo
|
||||
from jax import test_util as jtu
|
||||
|
@ -36,7 +36,6 @@ from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src.util import unzip2
|
||||
from jax.experimental import maps
|
||||
from jax.lib import xla_bridge
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp # scan tests use numpy
|
||||
import jax.scipy as jsp
|
||||
@ -1744,7 +1743,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jax.jit(jax.jacfwd(loop, argnums=(0,)))(arg) # doesn't crash
|
||||
|
||||
def testIssue804(self):
|
||||
num_devices = xla_bridge.device_count()
|
||||
num_devices = jax.device_count()
|
||||
f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
|
||||
jax.pmap(f, axis_name="i")(jnp.ones((num_devices, 4))) # doesn't crash
|
||||
|
||||
|
@ -27,7 +27,7 @@ import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
from lax_test import LAX_OPS
|
||||
|
@ -25,7 +25,6 @@ from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
import jax.lib
|
||||
from jax import jit, grad, jvp, vmap
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
|
@ -22,7 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax.config import config
|
||||
|
@ -27,6 +27,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax import test_util as jtu
|
||||
import jax._src.lib
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -34,7 +35,7 @@ config.parse_flags_with_absl()
|
||||
class CloudpickleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
||||
@unittest.skipIf(jax.lib._xla_extension_version < 31,
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version < 31,
|
||||
"Requires jaxlib 0.1.71")
|
||||
def testPickleOfJittedFunctions(self):
|
||||
|
||||
@ -55,7 +56,7 @@ class CloudpickleTest(jtu.JaxTestCase):
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
||||
@unittest.skipIf(jax.lib._xla_extension_version < 39,
|
||||
@unittest.skipIf(jax._src.lib._xla_extension_version < 39,
|
||||
"Requires jaxlib 0.1.72")
|
||||
def testPickleOfPmappedFunctions(self):
|
||||
|
||||
|
@ -33,7 +33,7 @@ from jax.experimental.maps import xmap, mesh
|
||||
from jax.experimental.pjit import pjit, pjit_p, with_sharding_constraint, SpecSync
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_client
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import prod, curry
|
||||
|
||||
from jax.config import config
|
||||
|
@ -39,7 +39,8 @@ from jax import random
|
||||
from jax.core import ShapedArray
|
||||
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax.lib import xla_bridge
|
||||
import jax._src.lib
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.util import prod, safe_map
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
@ -113,7 +114,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return src_api._python_pmap
|
||||
|
||||
def _getMeshShape(self, device_mesh_shape):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
if any(size == -1 for size in device_mesh_shape):
|
||||
try:
|
||||
return np.arange(device_count).reshape(device_mesh_shape).shape
|
||||
@ -130,7 +131,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testBasic(self):
|
||||
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.sum(x, 0)
|
||||
|
||||
@ -140,7 +141,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testMean(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.broadcast_to(np.mean(x, 0), x.shape)
|
||||
|
||||
@ -150,16 +151,16 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testGather(self):
|
||||
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.array([x] * xla_bridge.device_count())
|
||||
expected = np.array([x] * jax.device_count())
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testGatherTiled(self):
|
||||
f = self.pmap(lambda x: lax.all_gather(x, 'i', tiled=True), axis_name='i')
|
||||
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.array([x] * device_count).reshape(device_count, -1)
|
||||
@ -169,7 +170,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testReduceScatter(self):
|
||||
f = self.pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i')
|
||||
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.sum(x, axis=0)
|
||||
@ -180,7 +181,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testReduceScatterTiled(self):
|
||||
f = self.pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i')
|
||||
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 4 * device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.sum(x, axis=0)
|
||||
@ -191,7 +192,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
expected[i * scatter_len:(i + 1) * scatter_len])
|
||||
|
||||
def testReduceScatterReplicaGroupsTiled(self):
|
||||
replicas = xla_bridge.device_count()
|
||||
replicas = jax.device_count()
|
||||
if replicas % 2 != 0:
|
||||
raise SkipTest
|
||||
axis_index_groups = [[i for i in range(jax.device_count()) if i % 2 == 0],
|
||||
@ -227,7 +228,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
np_transpose = tree_f(np.transpose)
|
||||
np_rotate = tree_f(lambda x: np.concatenate([x[-1:], x[:-1]]))
|
||||
|
||||
n = xla_bridge.device_count()
|
||||
n = jax.device_count()
|
||||
x = {'a': np.arange(1 * n * n, 2 * n * n).reshape([n, n]),
|
||||
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
|
||||
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
|
||||
@ -260,7 +261,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testComplexPsum(self):
|
||||
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4 * 2)
|
||||
shape = (jax.device_count(), 4 * 2)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
|
||||
expected = x - np.sum(x, 0)
|
||||
|
||||
@ -274,7 +275,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAll(self, split_axis, concat_axis):
|
||||
pmap_in_axis = 0
|
||||
shape = (xla_bridge.device_count(),) * 3
|
||||
shape = (jax.device_count(),) * 3
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
|
||||
@partial(self.pmap, axis_name='i')
|
||||
@ -293,7 +294,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
for split_axis, concat_axis in it.product(range(2), range(2)))
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllSplitAxis(self, split_axis, concat_axis):
|
||||
if xla_bridge.device_count() < 4:
|
||||
if jax.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
pmap_in_axis = 0
|
||||
shape = (4, 4, 4)
|
||||
@ -322,7 +323,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def sum_and_broadcast(x, axis):
|
||||
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
|
||||
|
||||
shape = (xla_bridge.device_count(), 1, 4)
|
||||
shape = (jax.device_count(), 1, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = f(x)
|
||||
@ -330,7 +331,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testMismatchedAxisSizes(self):
|
||||
n = xla_bridge.device_count()
|
||||
n = jax.device_count()
|
||||
f = self.pmap(lambda x, y: x + y)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -359,7 +360,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x, y: x, in_axes=(None, 0))
|
||||
g = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
|
||||
|
||||
mesh_shape = (xla_bridge.device_count(),)
|
||||
mesh_shape = (jax.device_count(),)
|
||||
shape = mesh_shape + (4,)
|
||||
x = np.array(3., dtype=np.float32)
|
||||
y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
@ -383,7 +384,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testReplicate(self):
|
||||
base = np.array([3.,4.], dtype=np.float32)
|
||||
num_devices = xla_bridge.device_count()
|
||||
num_devices = jax.device_count()
|
||||
replicated = pxla.replicate(base, num_devices, num_devices, in_axis=None)
|
||||
self.assertAllClose(base, replicated)
|
||||
self.assertEmpty([a for a in replicated.sharding_spec.mesh_mapping
|
||||
@ -415,7 +416,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
_, jvp = linearize(f, x)
|
||||
return jvp(jnp.ones_like(x))
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.cos(x)
|
||||
|
||||
@ -429,7 +430,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
|
||||
@ -456,7 +457,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1])
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
|
||||
@ -472,7 +473,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
tot = jnp.sum(5. * jnp.cos(x) * jnp.sin(y))
|
||||
return tot * jnp.ones_like(x) # broadcast to map like pjit does
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = 4 + x
|
||||
ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
|
||||
@ -517,7 +518,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = lambda x: 2 * x
|
||||
f = self.pmap(f, axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
# test that we can pass in and out ShardedDeviceArrays
|
||||
@ -560,7 +561,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
[(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)]
|
||||
])
|
||||
def testShardedDeviceArrayReshape(self, in_shape, out_shape):
|
||||
if xla_bridge.device_count() < max(in_shape[:1] + out_shape[:1]):
|
||||
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
|
||||
raise SkipTest("not enough devices")
|
||||
|
||||
x = np.arange(prod(in_shape)).reshape(in_shape)
|
||||
@ -575,7 +576,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def sum_and_broadcast(x, axis):
|
||||
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
|
||||
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
num_pairs, ragged = divmod(device_count, 2)
|
||||
if num_pairs > 1 and not ragged:
|
||||
shape = (num_pairs, 2, 4)
|
||||
@ -588,7 +589,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsumConstantReplicaGroups(self):
|
||||
replicas = xla_bridge.device_count()
|
||||
replicas = jax.device_count()
|
||||
if replicas % 2 != 0:
|
||||
raise SkipTest
|
||||
axis_index_groups = np.arange(replicas).reshape(
|
||||
@ -606,7 +607,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testPsumUnevenReplicaGroups(self):
|
||||
replicas = xla_bridge.device_count()
|
||||
replicas = jax.device_count()
|
||||
if replicas <= 2:
|
||||
raise SkipTest("Test expected devices greater than 2.")
|
||||
axis_index_groups = [[0,1], np.arange(2,replicas)]
|
||||
@ -627,7 +628,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPsumReplicaGroups(self):
|
||||
replicas = xla_bridge.device_count()
|
||||
replicas = jax.device_count()
|
||||
if replicas % 2 != 0:
|
||||
raise SkipTest
|
||||
axis_index_groups = np.arange(replicas).reshape(
|
||||
@ -649,7 +650,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testGatherReplicaGroups(self):
|
||||
replicas = xla_bridge.device_count()
|
||||
replicas = jax.device_count()
|
||||
if replicas % 2 != 0:
|
||||
raise SkipTest("Test expected an even number of devices greater than 1.")
|
||||
|
||||
@ -674,7 +675,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testGatherReplicaGroupsInterleaved(self):
|
||||
replicas = xla_bridge.device_count()
|
||||
replicas = jax.device_count()
|
||||
if replicas % 2 != 0:
|
||||
raise SkipTest("Test expected an even number of devices greater than 1.")
|
||||
|
||||
@ -707,7 +708,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
|
||||
|
||||
def testNestedPmapReplicaGroups(self):
|
||||
replicas = xla_bridge.device_count()
|
||||
replicas = jax.device_count()
|
||||
if replicas % 4 != 0:
|
||||
raise SkipTest
|
||||
axis_index_groups = np.arange(replicas // 2).reshape(
|
||||
@ -762,7 +763,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter
|
||||
|
||||
def testCollectivePermute(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
rotation = [(i, (i + 1) % device_count) for i in range(device_count)]
|
||||
f = lambda x: lax.ppermute(x, perm=rotation, axis_name='i')
|
||||
f = self.pmap(f, 'i')
|
||||
@ -774,7 +775,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def testCollectivePermuteGrad(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shift_right = [(i, (i + 1)) for i in range(device_count - 1)]
|
||||
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
|
||||
y = np.pi + np.arange(device_count, dtype=np.float32)
|
||||
@ -786,7 +787,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testCollectivePermuteCyclicGrad(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shift_right = [(i, (i + 1) % device_count) for i in range(device_count)]
|
||||
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
|
||||
y = np.pi + np.arange(device_count, dtype=np.float32)
|
||||
@ -801,7 +802,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(g, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2)
|
||||
|
||||
def testCollectivePermuteCyclicWithPShuffle(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
values = np.arange(device_count)
|
||||
shift_right = [(i - 1) % device_count for i in range(device_count)]
|
||||
f = lambda x: lax.pshuffle(x, perm=shift_right, axis_name='i')
|
||||
@ -810,7 +811,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPShuffleWithBadPerm(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
bad_perm = list(range(device_count))
|
||||
bad_perm[0] = 1
|
||||
f = lambda x: lax.pshuffle(x, perm=bad_perm, axis_name='i')
|
||||
@ -821,7 +822,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testPpermuteWithZipObject(self):
|
||||
# https://github.com/google/jax/issues/1703
|
||||
num_devices = xla_bridge.device_count()
|
||||
num_devices = jax.device_count()
|
||||
perm = [num_devices - 1] + list(range(num_devices - 1))
|
||||
f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i")
|
||||
result = f(jnp.arange(num_devices, dtype=jnp.float32))
|
||||
@ -833,7 +834,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30
|
||||
# Halo exchange should be useful in spatially-sharded convolutions and in
|
||||
# other simulations.
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
|
||||
def send_right(x, axis_name):
|
||||
left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]
|
||||
@ -900,7 +901,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testReduceMax(self):
|
||||
f = self.pmap(lambda x: x - lax.pmax(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.max(x, 0)
|
||||
|
||||
@ -910,7 +911,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testReduceMin(self):
|
||||
f = self.pmap(lambda x: x - lax.pmin(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.min(x, 0)
|
||||
|
||||
@ -918,7 +919,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testDeviceCountError(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
|
||||
f = self.pmap(lambda x: x)
|
||||
x = jnp.arange(device_count + 1)
|
||||
@ -933,7 +934,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
|
||||
|
||||
def testPmapConstant(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
f = self.pmap(lambda x: 3)
|
||||
x = jnp.arange(device_count)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
@ -949,10 +950,10 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testPmapConstantDevices(self):
|
||||
if xla_bridge.device_count() == 1:
|
||||
if jax.device_count() == 1:
|
||||
raise SkipTest("this test requires multiple devices")
|
||||
|
||||
devices = xla_bridge.devices()[:-1]
|
||||
devices = jax.devices()[:-1]
|
||||
shuffle(devices)
|
||||
f = self.pmap(lambda x: 3, devices=devices)
|
||||
x = jnp.arange(len(devices))
|
||||
@ -966,7 +967,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual([b.device() for b in ans.device_buffers], devices)
|
||||
|
||||
def testPmapConstantError(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
f = self.pmap(lambda x: 3)
|
||||
x = jnp.arange(device_count + 1)
|
||||
self.assertRaisesRegex(
|
||||
@ -976,18 +977,18 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
lambda: f(x))
|
||||
|
||||
# TODO(mattjj): test error message with explicit devices
|
||||
# f = pmap(lambda x: 3, devices=[xla_bridge.devices()[0]])
|
||||
# f = pmap(lambda x: 3, devices=[jax.devices()[0]])
|
||||
# x = jnp.arange(2)
|
||||
# self.assertRaisesRegex(
|
||||
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
|
||||
# r"local devices are available.", lambda: f(x))
|
||||
|
||||
def testNestedPmapConstant(self):
|
||||
if xla_bridge.device_count() == 1:
|
||||
if jax.device_count() == 1:
|
||||
raise SkipTest("this test requires multiple devices")
|
||||
|
||||
f = self.pmap(self.pmap(lambda x: 3))
|
||||
shape = (2, xla_bridge.device_count() // 2, 3)
|
||||
shape = (2, jax.device_count() // 2, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
@ -1009,10 +1010,10 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testNestedPmapConstantDevices(self):
|
||||
raise SkipTest("Nested pmaps with devices not yet implemented")
|
||||
|
||||
if xla_bridge.device_count() < 6:
|
||||
if jax.device_count() < 6:
|
||||
raise SkipTest("this test requires >= 6 devices")
|
||||
|
||||
devices = xla_bridge.devices()[:-2]
|
||||
devices = jax.devices()[:-2]
|
||||
shuffle(devices)
|
||||
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
|
||||
shape = (2, len(devices) // 2, 3)
|
||||
@ -1030,7 +1031,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testNestedPmapConstantError(self):
|
||||
f = self.pmap(self.pmap(lambda x: 3))
|
||||
shape = (2, xla_bridge.device_count() // 2 + 1, 3)
|
||||
shape = (2, jax.device_count() // 2 + 1, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -1039,9 +1040,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
lambda: f(x))
|
||||
|
||||
# TODO(mattjj): check error message with explicit devices
|
||||
# if xla_bridge.device_count() > 1:
|
||||
# f = pmap(pmap(lambda x: 3), devices=xla_bridge.devices()[:-1])
|
||||
# shape = (2, xla_bridge.device_count() // 2, 3)
|
||||
# if jax.device_count() > 1:
|
||||
# f = pmap(pmap(lambda x: 3), devices=jax.devices()[:-1])
|
||||
# shape = (2, jax.device_count() // 2, 3)
|
||||
# x = jnp.arange(prod(shape)).reshape(shape)
|
||||
# self.assertRaisesRegex(
|
||||
# ValueError,
|
||||
@ -1050,7 +1051,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# lambda: f(x))
|
||||
|
||||
def testCollectiveConstant(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
f = self.pmap(lambda x: lax.psum(1, 'i'), 'i')
|
||||
x = jnp.arange(device_count)
|
||||
ans = f(x)
|
||||
@ -1058,7 +1059,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testCollectiveConstantNested(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
|
||||
@partial(self.pmap, axis_name='i')
|
||||
def f(x):
|
||||
@ -1083,7 +1084,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(c.ravel()[0], device_count * 1)
|
||||
|
||||
def testAxisIndex(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
f = self.pmap(lambda x: x + lax.axis_index('i'), 'i')
|
||||
x = jnp.ones(device_count)
|
||||
ans = f(x)
|
||||
@ -1091,7 +1092,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testAxisIndexNestedPmap(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
if device_count < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
f = lambda axis: self.pmap(self.pmap(lambda x: x + lax.axis_index(axis), 'j'), 'i')
|
||||
@ -1101,7 +1102,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f('i')(x), expected_j.T, check_dtypes=False)
|
||||
|
||||
def testAxisIndexNd(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
if device_count < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
f = lambda axes: self.pmap(self.pmap(lambda x: x + lax.axis_index(axes), 'j'), 'i')
|
||||
@ -1116,13 +1117,13 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def body(carry, i):
|
||||
return carry + i + lax.axis_index('i'), None
|
||||
return lax.scan(body, 0, x)[0]
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 10)
|
||||
self.assertAllClose(f(jnp.ones(shape, dtype=int)),
|
||||
(np.arange(device_count) + 1) * 10)
|
||||
|
||||
def testVmapOfPmap(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
f0 = lambda x: x
|
||||
f1 = self.pmap(f0, axis_name='i')
|
||||
ax = np.random.randn(2, device_count, 50, 60)
|
||||
@ -1130,7 +1131,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ax, bx, check_dtypes=False)
|
||||
|
||||
def testVmapOfPmap2(self):
|
||||
N_DEVICES = xla_bridge.device_count()
|
||||
N_DEVICES = jax.device_count()
|
||||
keys = random.split(random.PRNGKey(1), 13) # [13, 2]
|
||||
|
||||
@self.pmap
|
||||
@ -1150,7 +1151,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testVmapOfPmap3(self):
|
||||
# https://github.com/google/jax/issues/3399
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
if device_count < 2:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
|
||||
@ -1172,7 +1173,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testVmapOfPmapNonLeadingAxis(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
f0 = lambda x: x
|
||||
f1 = self.pmap(f0, axis_name='i')
|
||||
ax = np.random.randn(device_count, 2, 50, 60)
|
||||
@ -1180,7 +1181,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ax, bx, check_dtypes=False)
|
||||
|
||||
def testVmapOfPmapTuple(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
f0 = lambda *x: x
|
||||
f1 = self.pmap(f0, axis_name='i')
|
||||
|
||||
@ -1201,7 +1202,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testPswapaxes(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 3, device_count, 5)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
|
||||
@ -1211,7 +1212,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOfPswapaxes(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 1, device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
w = np.arange(device_count, dtype=np.float32)
|
||||
@ -1235,7 +1236,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# This is essentially like splitting the number of rows in the input in two
|
||||
# groups of rows, and swaping the two inner axes (axis=1 and axis=2), which
|
||||
# is exactly what the test case checks.
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
if device_count % 2 != 0:
|
||||
raise SkipTest('test requires an even number of devices')
|
||||
shape = (device_count, device_count // 2)
|
||||
@ -1256,7 +1257,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testGradOfAllToAllReplicaGroups(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
if device_count % 2 != 0:
|
||||
raise SkipTest('test requires an even number of devices')
|
||||
shape = (device_count, device_count // 2, 1)
|
||||
@ -1279,13 +1280,13 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(fn(x, w), expected, check_dtypes=False)
|
||||
|
||||
def testReshardInput(self):
|
||||
if xla_bridge.device_count() < 6:
|
||||
if jax.device_count() < 6:
|
||||
raise SkipTest("testReshardInput requires 6 devices")
|
||||
# Manually construct a ShardedDeviceArray with the wrong sharding for the
|
||||
# subsequent pmap
|
||||
shard_shape = (3,2)
|
||||
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
|
||||
bufs = pxla.device_put(shard, xla_bridge.devices()[:4], replicate=True)
|
||||
bufs = pxla.device_put(shard, jax.devices()[:4], replicate=True)
|
||||
aval = ShapedArray((6,4), shard.dtype)
|
||||
sharding_spec = pxla.ShardingSpec(
|
||||
sharding=map(pxla.Chunked, ([2], [2])),
|
||||
@ -1298,7 +1299,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmul(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
||||
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
||||
ans = soft_pmap(jnp.dot, 'i')(xs, ys)
|
||||
@ -1307,7 +1308,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmulJit(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
||||
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
||||
ans = soft_pmap(jit(jnp.dot), 'i')(xs, ys)
|
||||
@ -1316,7 +1317,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsumConstant(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
def f(_):
|
||||
return lax.psum(1, 'i')
|
||||
ans = soft_pmap(f, 'i')(jnp.ones(n))
|
||||
@ -1325,7 +1326,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsum(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return x / lax.psum(x, 'i')
|
||||
ans = soft_pmap(f, 'i')(jnp.ones(n))
|
||||
@ -1334,7 +1335,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapAxisIndex(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return x * lax.axis_index('i')
|
||||
ans = soft_pmap(f, 'i')(2 * jnp.ones(n))
|
||||
@ -1343,7 +1344,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapOfJit(self):
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return 3 * x
|
||||
ans = soft_pmap(jit(f), 'i')(np.arange(n))
|
||||
@ -1353,7 +1354,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapNested(self):
|
||||
raise SkipTest("not implemented") # TODO(mattjj): re-implement
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
|
||||
@partial(soft_pmap, axis_name='i')
|
||||
@partial(soft_pmap, axis_name='j')
|
||||
@ -1368,7 +1369,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
@ignore_xmap_warning()
|
||||
def testGradOfSoftPmap(self):
|
||||
raise SkipTest("not implemented") # TODO(mattjj): re-implement
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
|
||||
@partial(soft_pmap, axis_name='i')
|
||||
def f(x):
|
||||
@ -1380,7 +1381,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapDevicePersistence(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
shape = (2 * 2 * device_count, 2, 3)
|
||||
|
||||
# check that we can maintain device persistence across calls
|
||||
@ -1393,7 +1394,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testSoftPmapAllToAll(self):
|
||||
raise SkipTest("the underlying code here is broken") # TODO(mattjj)
|
||||
n = 4 * xla_bridge.device_count()
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return lax.all_to_all(x, 'i', 0, 0)
|
||||
ans = soft_pmap(f, 'i')(jnp.arange(n ** 2).reshape(n, n))
|
||||
@ -1401,7 +1402,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testShardedDeviceArrayBlockUntilReady(self):
|
||||
x = np.arange(xla_bridge.device_count())
|
||||
x = np.arange(jax.device_count())
|
||||
x = self.pmap(lambda x: x)(x)
|
||||
x.block_until_ready() # doesn't crash
|
||||
|
||||
@ -1409,7 +1410,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testJitPmapComposition(self):
|
||||
f = lambda x: x - lax.psum(x, 'i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.sum(x, 0)
|
||||
|
||||
@ -1435,7 +1436,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
@ignore_jit_of_pmap_warning()
|
||||
def testIssue1065(self):
|
||||
# from https://github.com/google/jax/issues/1065
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
|
||||
def multi_step_pmap(state, count):
|
||||
@partial(self.pmap, axis_name='x')
|
||||
@ -1455,7 +1456,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = lambda x: 2 * x
|
||||
f = self.pmap(f, axis_name='i')
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
y = f(x)
|
||||
@ -1471,7 +1472,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# replica.
|
||||
raise SkipTest("need eager multi-replica support")
|
||||
# test came from https://github.com/google/jax/issues/1369
|
||||
nrep = xla_bridge.device_count()
|
||||
nrep = jax.device_count()
|
||||
|
||||
def pmvm(a, b):
|
||||
a = a.reshape((nrep, -1, a.shape[1]))
|
||||
@ -1497,7 +1498,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return sum(args_list)
|
||||
|
||||
vals = list(range(500))
|
||||
ndevices = xla_bridge.device_count()
|
||||
ndevices = jax.device_count()
|
||||
self.assertAllClose(f(jnp.array([vals] * ndevices)),
|
||||
jnp.array([sum(vals)] * ndevices))
|
||||
|
||||
@ -1577,7 +1578,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def testPsumOnBooleanDtype(self):
|
||||
# https://github.com/google/jax/issues/3123
|
||||
n = xla_bridge.device_count()
|
||||
n = jax.device_count()
|
||||
if n > 1:
|
||||
x = jnp.array([True, False])
|
||||
|
||||
@ -1608,7 +1609,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertTrue(w() is None)
|
||||
|
||||
def testJitOfPmapWarningMessage(self):
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
|
||||
if device_count == 1:
|
||||
raise SkipTest("test requires at least two devices")
|
||||
@ -1651,7 +1652,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def test_issue_1062(self):
|
||||
# code from https://github.com/google/jax/issues/1062 @shoyer
|
||||
# this tests, among other things, whether ShardedDeviceTuple constants work
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
|
||||
@jit
|
||||
def multi_step(state, count):
|
||||
@ -1708,7 +1709,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
for axis in range(len(shape))
|
||||
)
|
||||
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
|
||||
if xla_bridge.device_count() < shape[axis]:
|
||||
if jax.device_count() < shape[axis]:
|
||||
raise SkipTest(f"test requires at least {shape[axis]} devices")
|
||||
if (jtu.device_under_test() == 'cpu' and
|
||||
np.issubdtype(dtype, np.floating) and
|
||||
@ -1733,7 +1734,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
@partial(self.pmap, axis_name='i')
|
||||
def func(_):
|
||||
return jax.lax.psum(dtype(0), axis_name='i')
|
||||
unused_arg = jnp.arange(xla_bridge.device_count())
|
||||
unused_arg = jnp.arange(jax.device_count())
|
||||
out_dtype = func(unused_arg).dtype
|
||||
self.assertEqual(out_dtype, dtype)
|
||||
|
||||
@ -1746,7 +1747,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
y = lax.cond(True, jax.pmap(identity), jax.pmap(identity), x)
|
||||
return y
|
||||
|
||||
cond_of_pmap(jnp.zeros((xla_bridge.device_count(), 2)))
|
||||
cond_of_pmap(jnp.zeros((jax.device_count(), 2)))
|
||||
|
||||
def test_static_argnum_on_method(self):
|
||||
|
||||
@ -1763,7 +1764,7 @@ class CppPmapTest(PythonPmapTest):
|
||||
|
||||
@property
|
||||
def pmap(self):
|
||||
if jax.lib._xla_extension_version >= 38:
|
||||
if jax._src.lib._xla_extension_version >= 38:
|
||||
return src_api._cpp_pmap
|
||||
else:
|
||||
return src_api._python_pmap
|
||||
@ -1787,7 +1788,7 @@ class VmapOfPmapTest(jtu.JaxTestCase):
|
||||
)))
|
||||
def testVmapOfPmap(self, shapes, vmap_in_axes, pmap_in_axes, vmap_out_axes, pmap_out_axes):
|
||||
vmapped_size = 3
|
||||
pmapped_size = xla_bridge.device_count()
|
||||
pmapped_size = jax.device_count()
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
@ -1824,7 +1825,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
return x + collective(x.dot(y), ('i', 'j'))
|
||||
return f
|
||||
|
||||
if xla_bridge.device_count() < 4:
|
||||
if jax.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
x = jnp.ones((2, 2, 64, 64))
|
||||
y = f(jax.pmap, jax.pmap)(x, x)
|
||||
@ -1842,7 +1843,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
return x + jax.lax.ppermute(x.dot(y), 'i', perm)
|
||||
return f
|
||||
|
||||
if xla_bridge.device_count() < 4:
|
||||
if jax.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
x = jnp.ones((2, 2, 64, 64))
|
||||
self.assertAllClose(f(jax.pmap)(x, x), f(jax.vmap)(x, x))
|
||||
@ -1938,7 +1939,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
@ignore_slow_all_to_all_warning()
|
||||
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
|
||||
raise SkipTest("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
|
||||
if xla_bridge.device_count() < 4:
|
||||
if jax.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
|
||||
def f(x):
|
||||
@ -1957,7 +1958,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
return jax.lax.all_gather(x, 'i')
|
||||
return f
|
||||
|
||||
if xla_bridge.device_count() < 4:
|
||||
if jax.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
x = jnp.ones((2, 2, 64, 64))
|
||||
self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x))
|
||||
@ -1967,19 +1968,19 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
|
||||
def testAllDevices(self):
|
||||
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i',
|
||||
devices=xla_bridge.devices())
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
devices=jax.devices())
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.sum(x, 0)
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testOneDevice(self):
|
||||
if xla_bridge.device_count() == 1:
|
||||
if jax.device_count() == 1:
|
||||
raise SkipTest("this test requires multiple devices")
|
||||
|
||||
d0 = xla_bridge.devices()[0]
|
||||
d1 = xla_bridge.devices()[1]
|
||||
d0 = jax.devices()[0]
|
||||
d1 = jax.devices()[1]
|
||||
f = lambda x: jnp.dot(x, x.T)
|
||||
f0 = pmap(f, devices=[d0])
|
||||
f1 = pmap(f, devices=[d1])
|
||||
@ -1992,18 +1993,18 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
|
||||
def testNoDevicesError(self):
|
||||
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "'devices' argument to pmap must be non-empty, or None."):
|
||||
f(x)
|
||||
|
||||
def testBadAxisSizeError(self):
|
||||
if xla_bridge.device_count() == 1:
|
||||
if jax.device_count() == 1:
|
||||
raise SkipTest("this test requires multiple devices")
|
||||
|
||||
f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
|
||||
devices=xla_bridge.devices())
|
||||
devices=jax.devices())
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Leading axis size of input to pmapped function must "
|
||||
r"equal the number of local devices passed to pmap. Got axis_size=1, "
|
||||
@ -2014,7 +2015,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
ValueError, r"Leading axis size of input to pmapped function must "
|
||||
r"equal the number of local devices passed to pmap. Got axis_size=\d, "
|
||||
r"num_local_devices=\d."):
|
||||
f(jnp.ones(xla_bridge.device_count() + 1))
|
||||
f(jnp.ones(jax.device_count() + 1))
|
||||
|
||||
def testBadAxisSizeErrorNested(self):
|
||||
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
|
||||
@ -2028,18 +2029,18 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
f(jnp.ones((1, 4)))
|
||||
|
||||
def testNestedPmaps(self):
|
||||
if xla_bridge.device_count() % 2 != 0:
|
||||
if jax.device_count() % 2 != 0:
|
||||
raise SkipTest
|
||||
|
||||
# Devices specified in outer pmap are OK
|
||||
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
|
||||
@partial(pmap, axis_name='i', devices=jax.devices())
|
||||
def foo(x):
|
||||
@partial(pmap, axis_name='j')
|
||||
def bar(y):
|
||||
return lax.psum(y, 'j')
|
||||
return bar(x)
|
||||
|
||||
x = jnp.ones((xla_bridge.device_count() // 2, 2))
|
||||
x = jnp.ones((jax.device_count() // 2, 2))
|
||||
ans = foo(x)
|
||||
expected = x * 2
|
||||
self.assertAllClose(ans, expected)
|
||||
@ -2048,7 +2049,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
# Devices specified in inner pmap not OK
|
||||
@partial(pmap, axis_name='i')
|
||||
def foo(x):
|
||||
@partial(pmap, axis_name='j', devices=xla_bridge.devices())
|
||||
@partial(pmap, axis_name='j', devices=jax.devices())
|
||||
def bar(y):
|
||||
return lax.psum(y, 'j')
|
||||
return bar(x)
|
||||
@ -2056,17 +2057,17 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Nested pmap with explicit devices argument."):
|
||||
foo(jnp.ones((xla_bridge.device_count(), 1)))
|
||||
foo(jnp.ones((jax.device_count(), 1)))
|
||||
|
||||
def testJitInPmap(self):
|
||||
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
|
||||
@partial(pmap, axis_name='i', devices=jax.devices())
|
||||
def foo(x):
|
||||
@jit
|
||||
def bar(y):
|
||||
return y + 1
|
||||
return lax.psum(bar(x), 'i')
|
||||
|
||||
ndevices = xla_bridge.device_count()
|
||||
ndevices = jax.device_count()
|
||||
ans = foo(jnp.ones((ndevices, 1)))
|
||||
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices * 2
|
||||
self.assertAllClose(ans, expected)
|
||||
@ -2075,22 +2076,22 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def testPmapInJit(self):
|
||||
@jit
|
||||
def foo(x):
|
||||
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
|
||||
@partial(pmap, axis_name='i', devices=jax.devices())
|
||||
def bar(y):
|
||||
return lax.psum(y, 'i')
|
||||
return bar(x)
|
||||
|
||||
ndevices = xla_bridge.device_count()
|
||||
ndevices = jax.device_count()
|
||||
ans = foo(jnp.ones((ndevices, 1)))
|
||||
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testGradBasic(self):
|
||||
@partial(pmap, axis_name='i', devices=xla_bridge.devices())
|
||||
@partial(pmap, axis_name='i', devices=jax.devices())
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
|
||||
@ -2101,7 +2102,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
@partial(pmap, axis_name='i', static_broadcasted_argnums=1)
|
||||
def f(x, y):
|
||||
return jnp.sin(x + y())
|
||||
shape = (xla_bridge.device_count(), 4)
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = lambda: 3.
|
||||
|
||||
@ -2113,9 +2114,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
@partial(pmap, in_axes=(1, 2))
|
||||
def f(x, y):
|
||||
return jnp.sin(x + y)
|
||||
xshape = (2, xla_bridge.device_count(), 4)
|
||||
xshape = (2, jax.device_count(), 4)
|
||||
x = np.arange(prod(xshape)).reshape(xshape)
|
||||
yshape = (2, 4, xla_bridge.device_count())
|
||||
yshape = (2, 4, jax.device_count())
|
||||
y = np.arange(prod(yshape)).reshape(yshape)
|
||||
|
||||
self.assertAllClose(f(x, y),
|
||||
@ -2126,9 +2127,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
return jnp.sin(x + y + z)
|
||||
fp = pmap(f, in_axes=(1, 2, None))
|
||||
fv = vmap(f, in_axes=(1, 2, None))
|
||||
xshape = (5, xla_bridge.device_count(), 7)
|
||||
xshape = (5, jax.device_count(), 7)
|
||||
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
|
||||
yshape = (5, 7, xla_bridge.device_count())
|
||||
yshape = (5, 7, jax.device_count())
|
||||
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
|
||||
zshape = (5, 7)
|
||||
z = np.arange(prod(zshape), dtype=np.float32).reshape(zshape)
|
||||
@ -2145,7 +2146,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
@partial(pmap, in_axes=(1, None), out_axes=(2, None))
|
||||
def f(x, y):
|
||||
return jnp.sin(x + y), y * 2
|
||||
xshape = (2, xla_bridge.device_count(), 4)
|
||||
xshape = (2, jax.device_count(), 4)
|
||||
x = np.arange(prod(xshape)).reshape(xshape)
|
||||
yshape = (2, 4)
|
||||
y = np.arange(prod(yshape)).reshape(yshape)
|
||||
@ -2158,7 +2159,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
@partial(pmap, out_axes={'a': 0})
|
||||
def f(x):
|
||||
return {'a': x}
|
||||
device_count = xla_bridge.device_count()
|
||||
device_count = jax.device_count()
|
||||
x = jnp.arange(device_count)
|
||||
tree_util.tree_multimap(self.assertAllClose, f(x), {'a': x})
|
||||
|
||||
@ -2172,7 +2173,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def f(x, y, z):
|
||||
return jnp.sin(x + y) * z
|
||||
|
||||
pmapped_size = xla_bridge.device_count()
|
||||
pmapped_size = jax.device_count()
|
||||
mapped_shapes = [(3, 4), (3, 1), (1, 4)]
|
||||
arg_shapes = map(partial(add_bdim, pmapped_size), in_axes, mapped_shapes)
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -2192,7 +2193,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
|
||||
xshape = (5, 7)
|
||||
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
|
||||
yshape = (5, xla_bridge.device_count(), 7)
|
||||
yshape = (5, jax.device_count(), 7)
|
||||
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
|
||||
self.assertAllClose(jax.grad(mk_case(pmap))(x, y),
|
||||
jax.grad(mk_case(vmap))(x, y))
|
||||
|
@ -25,8 +25,8 @@ from jax import dtypes
|
||||
from jax.experimental import sparse
|
||||
from jax.experimental.sparse.ops import _bcoo_nse
|
||||
from jax import lax
|
||||
from jax.lib import cusparse
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.lib import cusparse
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax import jit
|
||||
from jax import test_util as jtu
|
||||
from jax import xla
|
||||
|
@ -25,7 +25,7 @@ from jax import tree_util
|
||||
from jax._src.tree_util import _process_pytree
|
||||
from jax import flatten_util
|
||||
import jax.numpy as jnp
|
||||
|
||||
import jax._src.lib
|
||||
|
||||
def _dummy_func(*args, **kwargs):
|
||||
return
|
||||
@ -206,7 +206,7 @@ class TreeTest(jtu.JaxTestCase):
|
||||
def testTreedefTupleFromChildren(self):
|
||||
# https://github.com/google/jax/issues/7377
|
||||
# TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
|
||||
if jax.lib._xla_extension_version < 29:
|
||||
if jax._src.lib._xla_extension_version < 29:
|
||||
self.skipTest("fixed in future jaxlib")
|
||||
tree = ((1, 2, (3, 4)), (5,))
|
||||
leaves, treedef1 = tree_util.tree_flatten(tree)
|
||||
@ -324,7 +324,7 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertRegex(str(treedef), correct_string)
|
||||
|
||||
def testTreeDefWithEmptyDictStringRepresentation(self):
|
||||
if jax.lib._xla_extension_version < 35:
|
||||
if jax._src.lib._xla_extension_version < 35:
|
||||
self.skipTest("fixed in future jaxlib")
|
||||
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
|
||||
|
||||
|
@ -17,8 +17,8 @@ import warnings
|
||||
|
||||
from absl.testing import absltest
|
||||
from jax import test_util as jtu
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
mock = absltest.mock
|
||||
|
||||
@ -79,8 +79,8 @@ class XlaBridgeTest(absltest.TestCase):
|
||||
msg = str(w[-1].message)
|
||||
self.assertIn("Did you run your code on all TPU hosts?", msg)
|
||||
|
||||
with mock.patch(
|
||||
"jax.lib.xla_client.make_tpu_client", side_effect=_mock_tpu_client):
|
||||
with mock.patch("jax._src.lib.xla_client.make_tpu_client",
|
||||
side_effect=_mock_tpu_client):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ from jax.core import NamedShape, JaxprTypeError
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.maps import Mesh, mesh, xmap, serial_loop, SerialLoop
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.lib import xla_bridge
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.util import curry, unzip2, split_list, prod
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lax.parallel import pgather
|
||||
|
Loading…
x
Reference in New Issue
Block a user