Merge pull request #26279 from MichaelHudgins:tsan-resultstore

PiperOrigin-RevId: 723918760
This commit is contained in:
Michael Hudgins 2025-02-06 14:55:57 +00:00
commit 2e808f2836
30 changed files with 917 additions and 129 deletions

View File

@ -198,4 +198,8 @@ jobs:
--test_output=errors \ --test_output=errors \
--local_test_jobs=32 \ --local_test_jobs=32 \
--test_timeout=600 \ --test_timeout=600 \
--config=resultstore \
--spawn_strategy=local \
--remote_cache=remotebuildexecution.googleapis.com \
--remote_instance_name=projects/tensorflow-testing/instances/default_instance \
//tests:cpu_tests //tests:cpu_tests

View File

@ -264,10 +264,52 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "UEObolTqw4pp"
},
"source": [ "source": [
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
"\n", "\n",
"The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n",
"\n",
"To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aKNeOHTJnqmS",
"outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pinned_host\n",
"device\n"
]
}
],
"source": [
"s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n",
"s_dev = s_host.with_memory_kind('device')\n",
"arr_host = jax.device_put(arr, s_host)\n",
"arr_dev = jax.device_put(arr, s_dev)\n",
"print(arr_host.sharding.memory_kind)\n",
"print(arr_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDHYnVqHwaST"
},
"source": [
"## 1. Automatic parallelism via `jit`\n", "## 1. Automatic parallelism via `jit`\n",
"\n", "\n",
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n", "Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n",
@ -354,10 +396,98 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"id": "Q4N5mrr9i_ki"
},
"source": [ "source": [
"The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
"\n", "\n",
"### 1.1 Sharding transformation between memory types\n",
"\n",
"The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n",
"\n",
"#### Example 1: Pinned host to device memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PXu3MhafyRHo",
"outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"device\n"
]
}
],
"source": [
"f = jax.jit(lambda x: x, out_shardings=s_dev)\n",
"out_dev = f(arr_host)\n",
"print(out_dev)\n",
"print(out_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LuYFqpcBySiX"
},
"source": [
"#### Example 2: Device to pinned_host memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qLsgNlKfybRw",
"outputId": "a16448b9-7e39-408f-b200-505f65ad4464"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"pinned_host\n"
]
}
],
"source": [
"g = jax.jit(lambda x: x, out_shardings=s_host)\n",
"out_host = g(arr_dev)\n",
"print(out_host)\n",
"print(out_host.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7BGD31-owaSU"
},
"source": [
"## 2. Semi-automated sharding with constraints\n", "## 2. Semi-automated sharding with constraints\n",
"\n", "\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",

View File

@ -90,8 +90,31 @@ print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded) jax.debug.visualize_array_sharding(arr_sharded)
``` ```
+++ {"id": "UEObolTqw4pp"}
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.
To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding.
```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: aKNeOHTJnqmS
outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2
---
s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
s_dev = s_host.with_memory_kind('device')
arr_host = jax.device_put(arr, s_host)
arr_dev = jax.device_put(arr, s_dev)
print(arr_host.sharding.memory_kind)
print(arr_dev.sharding.memory_kind)
```
+++ {"id": "jDHYnVqHwaST"}
## 1. Automatic parallelism via `jit` ## 1. Automatic parallelism via `jit`
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
@ -129,8 +152,52 @@ jax.debug.visualize_array_sharding(result)
print(result) print(result)
``` ```
+++ {"id": "Q4N5mrr9i_ki"}
The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.
### 1.1 Sharding transformation between memory types
The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.
#### Example 1: Pinned host to device memory
In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory.
```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: PXu3MhafyRHo
outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b
---
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_dev = f(arr_host)
print(out_dev)
print(out_dev.sharding.memory_kind)
```
+++ {"id": "LuYFqpcBySiX"}
#### Example 2: Device to pinned_host memory
In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory.
```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: qLsgNlKfybRw
outputId: a16448b9-7e39-408f-b200-505f65ad4464
---
g = jax.jit(lambda x: x, out_shardings=s_host)
out_host = g(arr_dev)
print(out_host)
print(out_host.sharding.memory_kind)
```
+++ {"id": "7BGD31-owaSU"}
## 2. Semi-automated sharding with constraints ## 2. Semi-automated sharding with constraints
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.

View File

@ -499,6 +499,7 @@ pytype_strict_library(
":traceback_util", ":traceback_util",
":typing", ":typing",
":util", ":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"), ] + py_deps("ml_dtypes") + py_deps("numpy"),
) )

View File

@ -99,6 +99,7 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip zip, unsafe_zip = safe_zip, zip
@api_boundary
def _nan_check_posthook(fun, args, kwargs, output): def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking.""" """Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = [] buffers = []
@ -108,12 +109,18 @@ def _nan_check_posthook(fun, args, kwargs, output):
try: try:
dispatch.check_special(pjit.pjit_p.name, buffers) dispatch.check_special(pjit.pjit_p.name, buffers)
except FloatingPointError: except dispatch.InternalFloatingPointError as e:
# compiled_fun can only raise in this case
assert config.debug_nans.value or config.debug_infs.value assert config.debug_nans.value or config.debug_infs.value
print("Invalid nan value encountered in the output of a C++-jit/pmap " if hasattr(fun, '_fun'):
"function. Calling the de-optimized version.") f = fun._fun
fun._cache_miss(*args, **kwargs)[0] # probably won't return if getattr(f, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
# compiled_fun can only raise in this case
dispatch.maybe_recursive_nan_check(e, f, args, kwargs)
raise AssertionError("Unreachable") from e
else:
# TODO(emilyaf): Shouldn't need this fallback.
raise
def _update_debug_special_global(_): def _update_debug_special_global(_):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"): if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
@ -1574,11 +1581,14 @@ def _cpp_pmap(
execute: Callable | None = None execute: Callable | None = None
with core.take_current_trace() as trace: with core.take_current_trace() as trace:
if isinstance(trace, core.EvalTrace): try:
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params) if isinstance(trace, core.EvalTrace):
out = execute(*p.flat_args) execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
else: out = execute(*p.flat_args)
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
except dispatch.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.')
out_tree, out_flat = p.out_tree, out out_tree, out_flat = p.out_tree, out
out_pytree_def = out_tree() out_pytree_def = out_tree()
@ -1629,6 +1639,7 @@ def _cpp_pmap(
_pmap_cache_clears.add(cpp_mapped_f) _pmap_cache_clears.add(cpp_mapped_f)
pmap_f = wraps(fun)(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f)
pmap_f._fun = fun
@api_boundary @api_boundary
def lower(*args, **kwargs): def lower(*args, **kwargs):
@ -1674,6 +1685,7 @@ def _cpp_pmap(
_pmap_cache_clears = weakref.WeakSet() # type: ignore _pmap_cache_clears = weakref.WeakSet() # type: ignore
@api_boundary
def jvp( def jvp(
fun: Callable, primals, tangents, has_aux: bool = False fun: Callable, primals, tangents, has_aux: bool = False
) -> tuple[Any, ...]: ) -> tuple[Any, ...]:
@ -1878,6 +1890,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
return apply_flat_fun_nokwargs(fun, io_tree, py_args) return apply_flat_fun_nokwargs(fun, io_tree, py_args)
@api_boundary
def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
if len(py_args_) != 1: if len(py_args_) != 1:
msg = (f"The function returned by `jax.vjp` applied to {name} was called " msg = (f"The function returned by `jax.vjp` applied to {name} was called "
@ -1937,6 +1950,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
has_aux: Literal[True], has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]: reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
... ...
@api_boundary
def vjp( def vjp(
fun: Callable, *primals, has_aux: bool = False, reduce_axes=() fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]: ) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
@ -2225,6 +2239,18 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return None return None
@lru_cache(maxsize=2048)
def _check_string_compatible_sharding(s):
"""Checks if target devices are compatible with string arrays."""
if isinstance(s, xc.Device) and s.device_kind == "cpu":
return
if (isinstance(s, Sharding)
and s._internal_device_list[0].device_kind == "cpu"):
return
raise TypeError(
"String arrays can only be sharded to CPU devices. Received"
f" unsupported device or sharding: {s}")
# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use # TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use
# that to check if shardings are compatible with the input. # that to check if shardings are compatible with the input.
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
@ -2235,6 +2261,10 @@ def _check_sharding(aval, s):
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
" `jax.Device`, `Layout` or a pytree of these values. Received" " `jax.Device`, `Layout` or a pytree of these values. Received"
f" invalid value: {s}") f" invalid value: {s}")
if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype):
_check_string_compatible_sharding(s)
if isinstance(s, Sharding): if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken): if isinstance(aval, core.AbstractToken):
aval = core.get_token_aval() aval = core.get_token_aval()

View File

@ -1472,11 +1472,14 @@ Value = Any
def valid_jaxtype(x) -> bool: def valid_jaxtype(x) -> bool:
try: try:
abstractify(x) aval = abstractify(x)
except TypeError: except TypeError:
return False return False
else: else:
return True if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype):
return False
else:
return True
def check_valid_jaxtype(x): def check_valid_jaxtype(x):
if not valid_jaxtype(x): if not valid_jaxtype(x):

View File

@ -25,7 +25,7 @@ import itertools
import logging import logging
import threading import threading
import time import time
from typing import Any, NamedTuple from typing import Any, Callable, NamedTuple
import jax import jax
from jax._src import api from jax._src import api
@ -100,6 +100,7 @@ def xla_primitive_callable(prim: core.Primitive, **params):
return prim.bind(*args, **params) return prim.bind(*args, **params)
prim_fun.__name__ = prim.name prim_fun.__name__ = prim.name
prim_fun.__qualname__ = prim.name prim_fun.__qualname__ = prim.name
prim_fun._apply_primitive = True
return api.jit(prim_fun) return api.jit(prim_fun)
@ -321,15 +322,52 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if dtypes.issubdtype(dtype, np.inexact): if dtypes.issubdtype(dtype, np.inexact):
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}") raise InternalFloatingPointError(name, "nan")
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}") raise InternalFloatingPointError(name, "inf")
class CopySemantics(enum.Enum): class CopySemantics(enum.Enum):
ALIAS = enum.auto() ALIAS = enum.auto()
COPY = enum.auto() COPY = enum.auto()
DONATE = enum.auto() DONATE = enum.auto()
class InternalFloatingPointError(Exception):
name: str
ty: str
def __init__(self, name: str, ty: str):
self.name = name
self.ty = ty
def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs,
) -> None: # always raises an exception
print("Invalid nan value encountered in the output of a jax.jit "
"function. Calling the de-optimized version.")
try:
_ = fun(*args, **kwargs)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
_raise_no_nan_in_deoptimized(e)
def _raise_no_nan_in_deoptimized(e) -> None:
msg = (f"{str(e)}. Because "
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
"produce invalid values during its execution. This behavior can "
"result from `jit` optimizations causing the invalid value to be "
"produced. It may also arise from having nan/inf literals as "
"inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
"\n\n"
"It may be possible to avoid the invalid value by removing the "
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/jax-ml/jax.")
raise FloatingPointError(msg) from None
def _identity_fn(x): def _identity_fn(x):
return x return x

View File

@ -33,6 +33,7 @@ import ml_dtypes
import numpy as np import numpy as np
from jax._src import config from jax._src import config
from jax._src.lib import xla_extension_version
from jax._src.typing import Array, DType, DTypeLike from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC from jax._src.util import set_module, StrictABC
@ -486,18 +487,37 @@ _complex_types: list[JAXType] = [
np.dtype('complex64'), np.dtype('complex64'),
np.dtype('complex128'), np.dtype('complex128'),
] ]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}
# We add the StringDType only to `_jax_dtype_set` but not to `_jax_types` and
# `_dtype_kinds`. This is because, in spite of a very similar sounding name,
# `_jax_types` is only meant for the promotion related logic, and StringDType
# does not participate in promotions at the moment. Similarly, `_dtype_kinds` is
# only meant for the `jnp.isdtype` and we want to be conservative and not allow
# StringDType to be used in there.
_string_types: list[JAXType] = []
if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311:
_string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore
_jax_dtype_set = {
float0,
*_bool_types,
*_int_types,
*_float_types,
*_complex_types,
*_string_types,
}
_jax_types = (_bool_types + _int_types + _float_types + _complex_types)
_dtype_kinds: dict[str, set] = { _dtype_kinds: dict[str, set] = {
'bool': {*_bool_types}, 'bool': {*_bool_types},
'signed integer': {*_signed_types}, 'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types}, 'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types}, 'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types}, 'real floating': {*_float_types},
'complex floating': {*_complex_types}, 'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types}, 'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
} }
@ -870,8 +890,14 @@ def check_user_dtype_supported(dtype, fun_name=None):
uint2, uint2,
uint4 uint4
] ]
if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: if (
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" np_dtype.kind not in 'biufcT'
and not is_custom_dtype
and not dtype == float0
):
msg = (
f'JAX only supports number, bool, and string dtypes, got dtype {dtype}'
)
msg += f" in {fun_name}" if fun_name else "" msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg) raise TypeError(msg)
if dtype is not None and np_dtype != canonicalize_dtype(np_dtype): if dtype is not None and np_dtype != canonicalize_dtype(np_dtype):
@ -949,3 +975,7 @@ def short_dtype_name(dtype) -> str:
else: else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u') return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c')) .replace('int' , 'i').replace('complex', 'c'))
def is_string_dtype(dtype: DTypeLike | None) -> bool:
return dtype in _string_types

View File

@ -22,6 +22,7 @@ from functools import partial
from typing import Any from typing import Any
from jax._src import config from jax._src import config
from jax._src import dispatch
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.tree_util import (tree_flatten, tree_unflatten, from jax.tree_util import (tree_flatten, tree_unflatten,
@ -360,8 +361,15 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
cts_out = get_primitive_transpose(eqn.primitive)( cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in, cts_in_avals) params, call_jaxpr, invals, cts_in, cts_in_avals)
else: else:
cts_out = get_primitive_transpose(eqn.primitive)( try:
cts_in, *invals, **eqn.params) cts_out = get_primitive_transpose(eqn.primitive)(
cts_in, *invals, **eqn.params)
except (FloatingPointError, ZeroDivisionError) as e:
msg = "When differentiating the code at the top of the callstack:"
if msg not in e.args[0]:
e.args = e.args[0] + f'\n{msg}',
e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}',
raise e from None
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
# FIXME: Some invars correspond to primals! # FIXME: Some invars correspond to primals!
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
@ -1003,7 +1011,20 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
if update_params: if update_params:
new_params = update_params(new_params, map(is_undefined_primal, args), new_params = update_params(new_params, map(is_undefined_primal, args),
[type(x) is not Zero for x in ct]) [type(x) is not Zero for x in ct])
out_flat = primitive.bind(fun, *all_args, **new_params)
try:
out_flat = primitive.bind(fun, *all_args, **new_params)
except dispatch.InternalFloatingPointError as e:
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
_ = backward_pass(call_jaxpr, None, {}, args, ct)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
dispatch._raise_no_nan_in_deoptimized(e)
arg_cts = tree_unflatten(out_tree(), out_flat) arg_cts = tree_unflatten(out_tree(), out_flat)
# The freevars are being fanned out (not mapped). During transpose the # The freevars are being fanned out (not mapped). During transpose the

View File

@ -57,6 +57,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import (PrecisionLike,_array_copy, from jax._src.lax.lax import (PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator) _sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.numpy import reductions from jax._src.numpy import reductions
from jax._src.numpy import ufuncs from jax._src.numpy import ufuncs
from jax._src.numpy import util from jax._src.numpy import util
@ -5474,6 +5475,39 @@ def _supports_buffer_protocol(obj):
return True return True
def _make_string_array(
object: np.ndarray,
dtype: DTypeLike | None = None,
ndmin: int = 0,
device: xc.Device | Sharding | None = None,
) -> Array:
if xla_extension_version < 311:
raise TypeError(
"String arrays are not supported in JAX before XLA extension version"
" 311."
)
if not isinstance(object, np.ndarray):
raise TypeError(
"Currently, string arrays can only be made from NumPy"
f" arrays. Got: {type(object)}."
)
if dtype is not None and (
dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype)
):
raise TypeError(
f"Cannot make an array with dtype {dtype} from an object with dtype"
f" {object.dtype}."
)
if ndmin > object.ndim:
raise TypeError(
f"ndmin {ndmin} cannot be greater than object's ndims"
f" {object.ndim} for string arrays."
)
# Just do a device_put since XLA does not support string as a data type.
return jax.device_put(x=object, device=device)
@export @export
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0, order: str | None = "K", ndmin: int = 0,
@ -5567,6 +5601,15 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted. # Keep the output uncommitted.
return jax.device_put(object) return jax.device_put(object)
# String arrays need separate handling because XLA does not support string
# as a data type.
if dtypes.is_string_dtype(dtype) or (
hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype)
):
return _make_string_array(
object=object, dtype=dtype, ndmin=ndmin, device=device
)
# For Python scalar literals, call coerce_to_array to catch any overflow # For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this # errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or # triggering for traced values. We do this here because it matters whether or

View File

@ -222,6 +222,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
f"Argument '{name}' of shape {aval.str_short()} of type" f"Argument '{name}' of shape {aval.str_short()} of type"
f' {type(arg)} is not a valid JAX type.') from e f' {type(arg)} is not a valid JAX type.') from e
raise AssertionError("Unreachable") from e raise AssertionError("Unreachable") from e
except dispatch.InternalFloatingPointError as e:
if getattr(fun, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
dispatch.maybe_recursive_nan_check(e, fun, args, kwargs)
if p.attrs_tracked: if p.attrs_tracked:
num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked) num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked)
@ -1700,33 +1704,7 @@ def _pjit_call_impl_python(
("out_layouts", out_layouts), ("out_layouts", out_layouts),
("abstract args", map(core.abstractify, args)), ("abstract args", map(core.abstractify, args)),
("fingerprint", fingerprint)) ("fingerprint", fingerprint))
try: return compiled.unsafe_call(*args), compiled, pgle_profiler
return compiled.unsafe_call(*args), compiled, pgle_profiler
except FloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
if len(jaxpr.eqns) > 1:
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
msg = (f"{str(e)}. Because "
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
"produce invalid values during its execution. This behavior can "
"result from `jit` optimizations causing the invalid value to be "
"produced. It may also arise from having nan/inf constants as "
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
"\n\n"
"It may be possible to avoid the invalid value by removing the "
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/jax-ml/jax.")
raise FloatingPointError(msg)
@weakref_lru_cache @weakref_lru_cache
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
@ -2404,19 +2382,31 @@ def _pjit_transpose(cts_in, *primals_in,
transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts
transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts
nz_cts_out = pjit_p.bind( try:
*primals_and_nz_cts_in, nz_cts_out = pjit_p.bind(
jaxpr=transpose_jaxpr, *primals_and_nz_cts_in,
in_shardings=transpose_in_shardings, jaxpr=transpose_jaxpr,
out_shardings=transpose_out_shardings, in_shardings=transpose_in_shardings,
in_layouts=transpose_in_layouts, out_shardings=transpose_out_shardings,
out_layouts=transpose_out_layouts, in_layouts=transpose_in_layouts,
resource_env=resource_env, out_layouts=transpose_out_layouts,
donated_invars=(False,) * len(primals_and_nz_cts_in), resource_env=resource_env,
name=name, donated_invars=(False,) * len(primals_and_nz_cts_in),
keep_unused=keep_unused, name=name,
inline=inline, keep_unused=keep_unused,
compiler_options_kvs=compiler_options_kvs) inline=inline,
compiler_options_kvs=compiler_options_kvs)
except dispatch.InternalFloatingPointError as e:
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
_ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None # great
else:
# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
dispatch._raise_no_nan_in_deoptimized(e)
if attrs_tracked: if attrs_tracked:
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])

View File

@ -30,7 +30,6 @@ package(
py_library( py_library(
name = "jax2tf", name = "jax2tf",
srcs = ["__init__.py"], srcs = ["__init__.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":jax2tf_internal"], deps = [":jax2tf_internal"],
) )
@ -42,7 +41,6 @@ py_library(
"impl_no_xla.py", "impl_no_xla.py",
"jax2tf.py", "jax2tf.py",
], ],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype # TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"], tags = ["pytype_unchecked_annotations"],
visibility = jax_visibility("jax2tf_internal"), visibility = jax_visibility("jax2tf_internal"),

View File

@ -24,7 +24,6 @@ package(
py_library( py_library(
name = "back_compat_testdata", name = "back_compat_testdata",
srcs = glob(["*.py"]), srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [ deps = [
"//third_party/py/numpy", "//third_party/py/numpy",
"//third_party/py/typing_extensions", "//third_party/py/typing_extensions",

View File

@ -28,7 +28,6 @@ package(
py_library( py_library(
name = "flax_models", name = "flax_models",
srcs = glob(["*.py"]), srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [ deps = [
"//jax", "//jax",
"//third_party/py/flax:core", "//third_party/py/flax:core",

View File

@ -874,6 +874,15 @@ def _match(mesh, check_rep, pspec, x):
def _rem_singleton(x): return jnp.squeeze(x, axis=0) def _rem_singleton(x): return jnp.squeeze(x, axis=0)
def _add_singleton(x): return jnp.expand_dims(x, axis=0) def _add_singleton(x): return jnp.expand_dims(x, axis=0)
def _maybe_check_special(outs):
if not config.debug_nans.value and not config.debug_infs.value: return
bufs = [s.data for leaf in tree_leaves(outs)
for s in getattr(leaf, 'addressable_shards', [])]
try:
dispatch.check_special('shard_map', bufs)
except dispatch.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
class ShardMapTrace(core.Trace): class ShardMapTrace(core.Trace):
__slots__ = ("mesh", "check", "context_mesh") __slots__ = ("mesh", "check", "context_mesh")
@ -902,9 +911,10 @@ class ShardMapTrace(core.Trace):
out_vals = eager_rule(self.mesh, *in_vals, **params) out_vals = eager_rule(self.mesh, *in_vals, **params)
else: else:
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
with (core.eval_context(), jax.disable_jit(False), with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False),
set_abstract_mesh(self.context_mesh)): jax.debug_infs(False), set_abstract_mesh(self.context_mesh)):
out_vals = jax.jit(f)(*in_vals) out_vals = jax.jit(f)(*in_vals)
_maybe_check_special(out_vals)
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
if prim.multiple_results: if prim.multiple_results:
@ -1700,10 +1710,21 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
def new_out_names_thunk(): def new_out_names_thunk():
return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz)
out_flat = shard_map_p.bind( try:
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), out_flat = shard_map_p.bind(
out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
auto=auto) out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite,
auto=auto)
except (FloatingPointError, ZeroDivisionError) as e:
print("Invalid nan value encountered in the backward pass of a shard_map "
"function. Calling the de-optimized backward pass.")
try:
_ = fun_trans.call_wrapped(out_cts, args)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
dispatch._raise_no_nan_in_deoptimized(e)
return tree_unflatten(out_tree(), out_flat) return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose ad.primitive_transposes[shard_map_p] = _shard_map_transpose

View File

@ -27,6 +27,7 @@ from jax.interpreters import mlir
from jax.interpreters import xla from jax.interpreters import xla
from jax._src import core from jax._src import core
from jax._src import ffi
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.lib import gpu_solver from jax._src.lib import gpu_solver
@ -533,10 +534,14 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder):
def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder): def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder):
data_aval, _, _, _, = ctx.avals_in # TODO(danfm): remove after JAX 0.5.1 release.
return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, if hasattr(gpu_solver, "cuda_csrlsvqr"):
indptr, b, tol, reorder) data_aval, _, _, _, = ctx.avals_in
return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices,
indptr, b, tol, reorder)
return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")(
ctx, data, indices, indptr, b, tol=np.float64(tol),
reorder=np.int32(reorder))
def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder): def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
del tol, reorder del tol, reorder

View File

@ -230,6 +230,7 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusolver",
@ -251,6 +252,7 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:ffi", "@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudart",

View File

@ -50,6 +50,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA",
GeqrfFfi); GeqrfFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA",
CsrlsvqrFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA",
OrgqrFfi); OrgqrFfi);

View File

@ -486,6 +486,8 @@ nb::dict Registrations() {
#ifdef JAX_GPU_CUDA #ifdef JAX_GPU_CUDA
dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi);
dict[JAX_GPU_PREFIX "solver_csrlsvqr_ffi"] =
EncapsulateFfiHandler(CsrlsvqrFfi);
#endif // JAX_GPU_CUDA #endif // JAX_GPU_CUDA
return dict; return dict;

View File

@ -20,6 +20,10 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/vendor.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"
#endif
namespace jax { namespace jax {
namespace JAX_GPU_NAMESPACE { namespace JAX_GPU_NAMESPACE {
namespace solver { namespace solver {
@ -315,6 +319,23 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched);
#undef JAX_GPU_DEFINE_GESVDJ_BATCHED #undef JAX_GPU_DEFINE_GESVDJ_BATCHED
#define JAX_GPU_DEFINE_CSRLSVQR(Type, Scalar, Name) \
template <> \
absl::Status Csrlsvqr<Type>( \
cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \
const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \
const Type *b, double tol, int reorder, Type *x, int *singularity) { \
return JAX_AS_STATUS(Name(handle, n, nnz, matdesc, csrValA, csrRowPtrA, \
csrColIndA, b, static_cast<Scalar>(tol), \
reorder, x, singularity)); \
}
JAX_GPU_DEFINE_CSRLSVQR(float, float, cusolverSpScsrlsvqr);
JAX_GPU_DEFINE_CSRLSVQR(double, double, cusolverSpDcsrlsvqr);
JAX_GPU_DEFINE_CSRLSVQR(gpuComplex, float, cusolverSpCcsrlsvqr);
JAX_GPU_DEFINE_CSRLSVQR(gpuDoubleComplex, double, cusolverSpZcsrlsvqr);
#undef JAX_GPU_DEFINE_CSRLSVQR
#endif // JAX_GPU_CUDA #endif // JAX_GPU_CUDA
// Symmetric tridiagonal reduction: sytrd // Symmetric tridiagonal reduction: sytrd

View File

@ -23,6 +23,10 @@ limitations under the License.
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/vendor.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"
#endif
namespace jax { namespace jax {
namespace JAX_GPU_NAMESPACE { namespace JAX_GPU_NAMESPACE {
namespace solver { namespace solver {
@ -206,6 +210,13 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBatchedBufferSize);
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched);
#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS #undef JAX_GPU_SOLVER_GesvdjBatched_ARGS
#define JAX_GPU_SOLVER_Csrlsvqr_ARGS(Type, ...) \
cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \
const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \
const Type *b, double tol, int reorder, Type *x, int *singularity
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Csrlsvqr);
#undef JAX_GPU_SOLVER_Csrlsvqr_ARGS
#endif // JAX_GPU_CUDA #endif // JAX_GPU_CUDA
// Symmetric tridiagonal reduction: sytrd // Symmetric tridiagonal reduction: sytrd

View File

@ -41,6 +41,10 @@ limitations under the License.
#include "jaxlib/gpu/vendor.h" #include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h" #include "xla/ffi/api/ffi.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"
#endif
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
@ -1013,6 +1017,82 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch,
.Ret<ffi::Buffer<ffi::S32>>() // info .Ret<ffi::Buffer<ffi::S32>>() // info
); );
// csrlsvqr: Linear system solve via Sparse QR
template <typename T>
ffi::Error CsrlsvqrImpl(int64_t n, int64_t nnz, double tol, int reorder,
gpuStream_t stream, ffi::AnyBuffer csrValA,
ffi::Buffer<ffi::S32> csrColIndA,
ffi::Buffer<ffi::S32> csrRowPtrA, ffi::AnyBuffer b,
ffi::Result<ffi::AnyBuffer> x) {
FFI_ASSIGN_OR_RETURN(auto handle, SpSolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(auto int_n, MaybeCastNoOverflow<int>(n));
FFI_ASSIGN_OR_RETURN(auto int_nnz, MaybeCastNoOverflow<int>(nnz));
cusparseMatDescr_t matdesc = nullptr;
JAX_FFI_RETURN_IF_GPU_ERROR(cusparseCreateMatDescr(&matdesc));
JAX_FFI_RETURN_IF_GPU_ERROR(
cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL));
JAX_FFI_RETURN_IF_GPU_ERROR(
cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO));
auto* csrValA_data = static_cast<T*>(csrValA.untyped_data());
auto* csrColIndA_data = csrColIndA.typed_data();
auto* csrRowPtrA_data = csrRowPtrA.typed_data();
auto* b_data = static_cast<T*>(b.untyped_data());
auto* x_data = static_cast<T*>(x->untyped_data());
int singularity = -1;
auto result = solver::Csrlsvqr<T>(
handle.get(), int_n, int_nnz, matdesc, csrValA_data, csrRowPtrA_data,
csrColIndA_data, b_data, tol, reorder, x_data, &singularity);
cusparseDestroyMatDescr(matdesc);
FFI_RETURN_IF_ERROR_STATUS(result);
if (singularity >= 0) {
return ffi::Error(ffi::ErrorCode::kInternal,
"Singular matrix in linear solve.");
}
return ffi::Error::Success();
}
ffi::Error CsrlsvqrDispatch(gpuStream_t stream, int reorder, double tol,
ffi::AnyBuffer csrValA,
ffi::Buffer<ffi::S32> csrColIndA,
ffi::Buffer<ffi::S32> csrRowPtrA, ffi::AnyBuffer b,
ffi::Result<ffi::AnyBuffer> x) {
auto dataType = csrValA.element_type();
if (dataType != b.element_type() || dataType != x->element_type()) {
return ffi::Error::InvalidArgument(
"The inputs and outputs to csrlsvqr must have the same element type");
}
int64_t n = b.element_count();
int64_t nnz = csrValA.element_count();
FFI_RETURN_IF_ERROR(
CheckShape(csrColIndA.dimensions(), nnz, "csrColIndA", "csrlsvqr"));
FFI_RETURN_IF_ERROR(
CheckShape(csrRowPtrA.dimensions(), n + 1, "csrColPtrA", "csrlsvqr"));
FFI_RETURN_IF_ERROR(CheckShape(x->dimensions(), n, "x", "csrlsvqr"));
SOLVER_DISPATCH_IMPL(CsrlsvqrImpl, n, nnz, tol, reorder, stream, csrValA,
csrColIndA, csrRowPtrA, b, x);
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in csrlsvqr", absl::FormatStreamed(dataType)));
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CsrlsvqrFfi, CsrlsvqrDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Attr<int>("reorder") // reorder
.Attr<double>("tol") // tol
.Arg<ffi::AnyBuffer>() // csrValA
.Arg<ffi::Buffer<ffi::S32>>() // csrColIndA
.Arg<ffi::Buffer<ffi::S32>>() // csrRowPtrA
.Arg<ffi::AnyBuffer>() // b
.Ret<ffi::AnyBuffer>() // x
);
#endif // JAX_GPU_CUDA #endif // JAX_GPU_CUDA
// Symmetric tridiagonal reduction: sytrd // Symmetric tridiagonal reduction: sytrd

View File

@ -40,6 +40,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi);
#ifdef JAX_GPU_CUDA #ifdef JAX_GPU_CUDA
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrlsvqrFfi);
#endif // JAX_GPU_CUDA #endif // JAX_GPU_CUDA
} // namespace JAX_GPU_NAMESPACE } // namespace JAX_GPU_NAMESPACE

View File

@ -12,17 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import importlib import importlib
import jaxlib.mlir.ir as ir
import numpy as np
from jaxlib import xla_client from jaxlib import xla_client
from .hlo_helpers import custom_call
try: try:
from .cuda import _blas as _cublas # pytype: disable=import-error from .cuda import _blas as _cublas # pytype: disable=import-error
except ImportError: except ImportError:
@ -129,27 +122,3 @@ def has_magma():
if _hiphybrid: if _hiphybrid:
return _hiphybrid.has_magma() return _hiphybrid.has_magma()
return False return False
def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
indices, indptr, b, tol, reorder):
"""Sparse solver via QR decomposition. CUDA only."""
b_type = ir.RankedTensorType(b.type)
data_type = ir.RankedTensorType(data.type)
n = b_type.shape[0]
nnz = data_type.shape[0]
opaque = gpu_solver.build_csrlsvqr_descriptor(
np.dtype(dtype), n, nnz, reorder, tol
)
out = custom_call(
f"{platform}solver_csrlsvqr", # call_target_name
result_types=[b.type],
operands=[data, indptr, indices, b],
backend_config=opaque, # backend_config
operand_layouts=[(0,), (0,), (0,), (0,)], # operand_layouts
result_layouts=[(0,)] # result_layouts
).results
return out
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)

View File

@ -1607,6 +1607,11 @@ jax_py_test(
], ],
) )
jax_multiplatform_test(
name = "string_array_test",
srcs = ["string_array_test.py"],
)
jax_multiplatform_test( jax_multiplatform_test(
name = "cudnn_fusion_test", name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"], srcs = ["cudnn_fusion_test.py"],
@ -1642,6 +1647,7 @@ exports_files(
"shard_map_test.py", "shard_map_test.py",
"transfer_guard_test.py", "transfer_guard_test.py",
"layout_test.py", "layout_test.py",
"string_array_test.py",
], ],
visibility = jax_test_file_visibility, visibility = jax_test_file_visibility,
) )

View File

@ -24,6 +24,8 @@ from jax._src import api
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax import numpy as jnp from jax import numpy as jnp
from jax.experimental import pjit from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
jax.config.parse_flags_with_absl() jax.config.parse_flags_with_absl()
@ -75,7 +77,6 @@ class DebugNaNsTest(jtu.JaxTestCase):
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION) @jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
def testCallDeoptimized(self, jit): def testCallDeoptimized(self, jit):
raise SkipTest("re-enable once we handle contexts properly") # TODO(dougalm)
@jit @jit
def f(x): def f(x):
return jax.lax.cond( return jax.lax.cond(
@ -89,6 +90,25 @@ class DebugNaNsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(FloatingPointError, msg): with self.assertRaisesRegex(FloatingPointError, msg):
f(1) f(1)
def testShardMap(self):
mesh = jax.make_mesh((1,), ('x',))
f = shard_map(lambda x: 0. / x, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testPmap(self): def testPmap(self):
pmap_funcs = [api._cpp_pmap] pmap_funcs = [api._cpp_pmap]
@ -99,17 +119,47 @@ class DebugNaNsTest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
FloatingPointError, FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"): r"invalid value \(nan\) encountered in div"):
ans = f(jnp.array([0.])) ans = f(jnp.array([0.]))
ans.block_until_ready() ans.block_until_ready()
if jax.device_count() >= 2: if jax.device_count() >= 2:
with self.assertRaisesRegex( with self.assertRaisesRegex(
FloatingPointError, FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"): r"Invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([1., 0.])) ans = f(jnp.array([1., 0.]))
ans.block_until_ready() ans.block_until_ready()
def testGradPmap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
_, f_vjp = jax.vjp(jax.pmap(f), jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testGradShardMap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
mesh = jax.make_mesh((1,), ('x',))
shmap_f = shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
_, f_vjp = jax.vjp(shmap_f, jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testPmapNoNaN(self): def testPmapNoNaN(self):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.])) ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready() ans.block_until_ready()
@ -163,17 +213,23 @@ class DebugNaNsTest(jtu.JaxTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
FloatingPointError, FloatingPointError,
r"invalid value \(nan\) encountered in jit\(true_divide\)"): r"invalid value \(nan\) encountered in div"):
f(inp, inp) f(inp, inp)
# TODO(yashkatariya): Fix this and make true_divide appear in the name again.
# Instead of `f` showing up in the error, the name should be of the
# primitive (true_divide) in this case.
with self.assertRaisesRegex( with self.assertRaisesRegex(
FloatingPointError, FloatingPointError,
r"invalid value \(nan\) encountered in jit\(f\)"): r"invalid value \(nan\) encountered in div"):
jax.jit(f)(inp, inp) jax.jit(f)(inp, inp)
def testDebugNansInput(self):
@jax.jit
def f(x):
return x * 3.
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*input"):
f(np.nan)
@jtu.with_config(jax_debug_infs=True) @jtu.with_config(jax_debug_infs=True)
class DebugInfsTest(jtu.JaxTestCase): class DebugInfsTest(jtu.JaxTestCase):
@ -233,7 +289,7 @@ class DebugInfsTest(jtu.JaxTestCase):
y = x + 2 # avoid trivial dispatch path by adding some eqn y = x + 2 # avoid trivial dispatch path by adding some eqn
return jnp.nan, y return jnp.nan, y
with self.assertRaisesRegex(FloatingPointError, "de-optimized"): with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*literal"):
with jax.debug_nans(True): with jax.debug_nans(True):
f(3) f(3)

View File

@ -335,6 +335,47 @@ class FilteredTracebackTest(jtu.JaxTestCase):
('bwd_err', 'g = err(g)'), ('bwd_err', 'g = err(g)'),
('err', 'assert False')], filter_mode=filter_mode) ('err', 'assert False')], filter_mode=filter_mode)
def test_jvp(self, filter_mode):
def err(_):
assert False
return ()
def f():
p = (1.,)
t = (0.,)
return jax.jvp(err, p, t)
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return jax.jvp(err, p, t)'),
('err', 'assert False')], filter_mode=filter_mode)
def test_vjp(self, filter_mode):
def err(_):
assert False
return ()
def f():
x = 1.
return jax.vjp(err, x)[0]
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return jax.vjp(err, x)[0]'),
('err', 'assert False')], filter_mode=filter_mode)
def test_debug_nans(self, filter_mode):
@jax.jit
def f(x):
return 0. / x
f(2.)
def g():
return f(0.)
with jax.debug_nans(True):
check_filtered_stack_trace(self, ZeroDivisionError, g, [
('g', 'return f(0.)'),
('f', 'return 0. / x')], filter_mode=filter_mode)
def test_cause_chain(self, filter_mode): def test_cause_chain(self, filter_mode):
@jit @jit
def inner(x): def inner(x):

View File

@ -3758,8 +3758,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertIsNot(x, y) self.assertIsNot(x, y)
def testArrayUnsupportedDtypeError(self): def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(
"JAX only supports number and bool dtypes.*"): TypeError, 'JAX only supports number, bool, and string dtypes.*'
):
jnp.array(3, [('a','<i4'),('b','<i4')]) jnp.array(3, [('a','<i4'),('b','<i4')])
def testArrayFromInteger(self): def testArrayFromInteger(self):

217
tests/string_array_test.py Normal file
View File

@ -0,0 +1,217 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(2)
class StringArrayTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_extension_version < 311:
self.skipTest(
"Skipping this test because the current XLA extension version:"
f" {xla_extension_version} is older than 309, the oldest version with"
" string array support."
)
if not hasattr(np.dtypes, "StringDType"):
self.skipTest(
"Skipping this test because the numpy.dtype.StringDType is not"
" available."
)
def make_test_string_array(self, device=None):
"""Makes and returns a simple 2x1 string array on the first CPU device."""
if device is None:
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 1:
self.skipTest(
"Skipping this test because no CPU devices are available."
)
device = cpu_devices[0]
numpy_string_array = np.array(
["abcd", "efgh"], dtype=np.dtypes.StringDType() # type: ignore
)
jax_string_array = jax.device_put(numpy_string_array, device=device)
jax_string_array.block_until_ready()
return jax_string_array
@parameterized.named_parameters(
("asarray", True),
("device_put", False),
)
@jtu.run_on_devices("cpu")
def test_single_device_array(self, asarray):
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 1:
self.skipTest("Skipping this test because no CPU devices are available.")
numpy_string_array = np.array(
["abcdefghijklmnopqrstuvwxyz", "cba"], dtype=np.dtypes.StringDType() # type: ignore
)
if asarray:
jax_string_array = jnp.asarray(numpy_string_array, device=cpu_devices[0])
else:
jax_string_array = jax.device_put(
numpy_string_array, device=cpu_devices[0]
)
jax_string_array.block_until_ready()
array_read_back = jax.device_get(jax_string_array)
self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore
np.testing.assert_array_equal(array_read_back, numpy_string_array)
@parameterized.named_parameters(
("asarray", True),
("device_put", False),
)
@jtu.run_on_devices("cpu")
def test_multi_device_array(self, asarray):
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 2:
self.skipTest(
f"Skipping this test because only {len(cpu_devices)} host"
" devices are available. Need at least 2."
)
numpy_string_array = np.array(
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
)
mesh = jax.sharding.Mesh(np.array(cpu_devices).reshape((2, 1)), ("x", "y"))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("x", "y")
)
if asarray:
jax_string_array = jnp.asarray(numpy_string_array, device=sharding)
else:
jax_string_array = jax.device_put(numpy_string_array, device=sharding)
jax_string_array.block_until_ready()
array_read_back = jax.device_get(jax_string_array)
self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore
np.testing.assert_array_equal(array_read_back, numpy_string_array)
@jtu.run_on_devices("cpu")
def test_dtype_conversions(self):
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 1:
self.skipTest("Skipping this test because no CPU devices are available.")
# Explicitly specifying the dtype should work with StringDType numpy arrays.
numpy_string_array = np.array(
["abcd", "efgh"], dtype=np.dtypes.StringDType() # type: ignore
)
jax_string_array = jnp.asarray(
numpy_string_array,
device=cpu_devices[0],
dtype=np.dtypes.StringDType(),
) # type: ignore
jax_string_array.block_until_ready()
# Cannot make a non-StringDType array from a StringDType numpy array.
with self.assertRaisesRegex(
TypeError,
r"Cannot make an array with dtype bfloat16 from an object with dtype"
r" StringDType.*",
):
jnp.asarray(
numpy_string_array,
device=cpu_devices[0],
dtype=jnp.bfloat16,
)
# Cannot make a StringDType array from a numeric numpy array.
numpy_int_array = np.arange(2, dtype=np.int32)
with self.assertRaisesRegex(
TypeError,
r"Cannot make an array with dtype StringDType.*from an object with"
r" dtype int32.",
):
jnp.asarray(
numpy_int_array,
device=cpu_devices[0],
dtype=np.dtypes.StringDType(), # type: ignore
)
@parameterized.named_parameters(
("asarray", True),
("device_put", False),
)
@jtu.skip_on_devices("cpu")
def test_string_array_cannot_be_non_cpu_devices(self, asarray):
devices = jax.devices()
if len(devices) < 1:
self.skipTest("Skipping this test because no devices are available.")
numpy_string_array = np.array(
["abcdefghijklmnopqrstuvwxyz", "cba"], dtype=np.dtypes.StringDType() # type: ignore
)
with self.assertRaisesRegex(
TypeError, "String arrays can only be sharded to CPU devices"
):
if asarray:
jax_string_array = jnp.asarray(numpy_string_array, device=devices[0])
else:
jax_string_array = jax.device_put(numpy_string_array, device=devices[0])
jax_string_array.block_until_ready()
def test_jit_fails_with_string_arrays(self):
f = jax.jit(lambda x: x)
input_array = self.make_test_string_array()
self.assertRaisesRegex(
TypeError,
r"Argument.*is not a valid JAX type.",
lambda: f(input_array),
)
def test_grad_fails_with_string_arrays(self):
f = jax.grad(lambda x: x)
input_array = self.make_test_string_array()
self.assertRaisesRegex(
TypeError,
r"Argument.*is not a valid JAX type.",
lambda: f(input_array),
)
def test_vmap_without_jit_works_with_string_arrays(self):
f = jax.vmap(lambda x: x)
input_array = self.make_test_string_array()
output_array = f(input_array)
self.assertEqual(output_array.dtype, input_array.dtype)
np.testing.assert_array_equal(output_array, input_array)
def test_vmap_with_jit_fails_with_string_arrays(self):
f = jax.vmap(lambda x: x + jnp.arange(2))
input_array = self.make_test_string_array()
self.assertRaisesRegex(
ValueError,
r".*StringDType.*is not a valid dtype",
lambda: f(input_array),
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum # curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result. # and update XLA_SHA256 with the result.
XLA_COMMIT = "46f8cf03902c0af58468e2258f9438788e7f4c97" XLA_COMMIT = "85eccd2ed9f2afd956ab17afd31480a042f07f92"
XLA_SHA256 = "0c391b0a8433d26bfc93e5bee775f7eb629b811a42222ce2b4c7449044a5bc0d" XLA_SHA256 = "ed853428d3f92aeb3a0cabd564f2373309b4784cd6f90db74ccc2d2ae735984f"
def repo(): def repo():
tf_http_archive( tf_http_archive(