Merge pull request #26279 from MichaelHudgins:tsan-resultstore

PiperOrigin-RevId: 723918760
This commit is contained in:
Michael Hudgins 2025-02-06 14:55:57 +00:00
commit 2e808f2836
30 changed files with 917 additions and 129 deletions

View File

@ -198,4 +198,8 @@ jobs:
--test_output=errors \
--local_test_jobs=32 \
--test_timeout=600 \
--config=resultstore \
--spawn_strategy=local \
--remote_cache=remotebuildexecution.googleapis.com \
--remote_instance_name=projects/tensorflow-testing/instances/default_instance \
//tests:cpu_tests

View File

@ -264,10 +264,52 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "UEObolTqw4pp"
},
"source": [
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
"\n",
"The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n",
"\n",
"To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aKNeOHTJnqmS",
"outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pinned_host\n",
"device\n"
]
}
],
"source": [
"s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n",
"s_dev = s_host.with_memory_kind('device')\n",
"arr_host = jax.device_put(arr, s_host)\n",
"arr_dev = jax.device_put(arr, s_dev)\n",
"print(arr_host.sharding.memory_kind)\n",
"print(arr_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDHYnVqHwaST"
},
"source": [
"## 1. Automatic parallelism via `jit`\n",
"\n",
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n",
@ -354,10 +396,98 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "Q4N5mrr9i_ki"
},
"source": [
"The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
"\n",
"### 1.1 Sharding transformation between memory types\n",
"\n",
"The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n",
"\n",
"#### Example 1: Pinned host to device memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PXu3MhafyRHo",
"outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"device\n"
]
}
],
"source": [
"f = jax.jit(lambda x: x, out_shardings=s_dev)\n",
"out_dev = f(arr_host)\n",
"print(out_dev)\n",
"print(out_dev.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LuYFqpcBySiX"
},
"source": [
"#### Example 2: Device to pinned_host memory\n",
"\n",
"In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qLsgNlKfybRw",
"outputId": "a16448b9-7e39-408f-b200-505f65ad4464"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
"pinned_host\n"
]
}
],
"source": [
"g = jax.jit(lambda x: x, out_shardings=s_host)\n",
"out_host = g(arr_dev)\n",
"print(out_host)\n",
"print(out_host.sharding.memory_kind)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7BGD31-owaSU"
},
"source": [
"## 2. Semi-automated sharding with constraints\n",
"\n",
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",

View File

@ -90,8 +90,31 @@ print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)
```
+++ {"id": "UEObolTqw4pp"}
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.
To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding.
```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: aKNeOHTJnqmS
outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2
---
s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
s_dev = s_host.with_memory_kind('device')
arr_host = jax.device_put(arr, s_host)
arr_dev = jax.device_put(arr, s_dev)
print(arr_host.sharding.memory_kind)
print(arr_dev.sharding.memory_kind)
```
+++ {"id": "jDHYnVqHwaST"}
## 1. Automatic parallelism via `jit`
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
@ -129,8 +152,52 @@ jax.debug.visualize_array_sharding(result)
print(result)
```
+++ {"id": "Q4N5mrr9i_ki"}
The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.
### 1.1 Sharding transformation between memory types
The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.
#### Example 1: Pinned host to device memory
In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory.
```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: PXu3MhafyRHo
outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b
---
f = jax.jit(lambda x: x, out_shardings=s_dev)
out_dev = f(arr_host)
print(out_dev)
print(out_dev.sharding.memory_kind)
```
+++ {"id": "LuYFqpcBySiX"}
#### Example 2: Device to pinned_host memory
In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory.
```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: qLsgNlKfybRw
outputId: a16448b9-7e39-408f-b200-505f65ad4464
---
g = jax.jit(lambda x: x, out_shardings=s_host)
out_host = g(arr_dev)
print(out_host)
print(out_host.sharding.memory_kind)
```
+++ {"id": "7BGD31-owaSU"}
## 2. Semi-automated sharding with constraints
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.

View File

@ -499,6 +499,7 @@ pytype_strict_library(
":traceback_util",
":typing",
":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"),
)

View File

@ -99,6 +99,7 @@ map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@api_boundary
def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
@ -108,12 +109,18 @@ def _nan_check_posthook(fun, args, kwargs, output):
try:
dispatch.check_special(pjit.pjit_p.name, buffers)
except FloatingPointError:
# compiled_fun can only raise in this case
except dispatch.InternalFloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value
print("Invalid nan value encountered in the output of a C++-jit/pmap "
"function. Calling the de-optimized version.")
fun._cache_miss(*args, **kwargs)[0] # probably won't return
if hasattr(fun, '_fun'):
f = fun._fun
if getattr(f, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
# compiled_fun can only raise in this case
dispatch.maybe_recursive_nan_check(e, f, args, kwargs)
raise AssertionError("Unreachable") from e
else:
# TODO(emilyaf): Shouldn't need this fallback.
raise
def _update_debug_special_global(_):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
@ -1574,11 +1581,14 @@ def _cpp_pmap(
execute: Callable | None = None
with core.take_current_trace() as trace:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
try:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
except dispatch.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.')
out_tree, out_flat = p.out_tree, out
out_pytree_def = out_tree()
@ -1629,6 +1639,7 @@ def _cpp_pmap(
_pmap_cache_clears.add(cpp_mapped_f)
pmap_f = wraps(fun)(cpp_mapped_f)
pmap_f._fun = fun
@api_boundary
def lower(*args, **kwargs):
@ -1674,6 +1685,7 @@ def _cpp_pmap(
_pmap_cache_clears = weakref.WeakSet() # type: ignore
@api_boundary
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> tuple[Any, ...]:
@ -1878,6 +1890,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
return apply_flat_fun_nokwargs(fun, io_tree, py_args)
@api_boundary
def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
if len(py_args_) != 1:
msg = (f"The function returned by `jax.vjp` applied to {name} was called "
@ -1937,6 +1950,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
...
@api_boundary
def vjp(
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
@ -2225,6 +2239,18 @@ def _infer_src_sharding(src, x) -> Sharding | None:
return None
@lru_cache(maxsize=2048)
def _check_string_compatible_sharding(s):
"""Checks if target devices are compatible with string arrays."""
if isinstance(s, xc.Device) and s.device_kind == "cpu":
return
if (isinstance(s, Sharding)
and s._internal_device_list[0].device_kind == "cpu"):
return
raise TypeError(
"String arrays can only be sharded to CPU devices. Received"
f" unsupported device or sharding: {s}")
# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use
# that to check if shardings are compatible with the input.
@lru_cache(maxsize=2048)
@ -2235,6 +2261,10 @@ def _check_sharding(aval, s):
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
" `jax.Device`, `Layout` or a pytree of these values. Received"
f" invalid value: {s}")
if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype):
_check_string_compatible_sharding(s)
if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.get_token_aval()

View File

@ -1472,11 +1472,14 @@ Value = Any
def valid_jaxtype(x) -> bool:
try:
abstractify(x)
aval = abstractify(x)
except TypeError:
return False
else:
return True
if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype):
return False
else:
return True
def check_valid_jaxtype(x):
if not valid_jaxtype(x):

View File

@ -25,7 +25,7 @@ import itertools
import logging
import threading
import time
from typing import Any, NamedTuple
from typing import Any, Callable, NamedTuple
import jax
from jax._src import api
@ -100,6 +100,7 @@ def xla_primitive_callable(prim: core.Primitive, **params):
return prim.bind(*args, **params)
prim_fun.__name__ = prim.name
prim_fun.__qualname__ = prim.name
prim_fun._apply_primitive = True
return api.jit(prim_fun)
@ -321,15 +322,52 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if dtypes.issubdtype(dtype, np.inexact):
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
raise InternalFloatingPointError(name, "nan")
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
raise InternalFloatingPointError(name, "inf")
class CopySemantics(enum.Enum):
ALIAS = enum.auto()
COPY = enum.auto()
DONATE = enum.auto()
class InternalFloatingPointError(Exception):
name: str
ty: str
def __init__(self, name: str, ty: str):
self.name = name
self.ty = ty
def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs,
) -> None: # always raises an exception
print("Invalid nan value encountered in the output of a jax.jit "
"function. Calling the de-optimized version.")
try:
_ = fun(*args, **kwargs)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
_raise_no_nan_in_deoptimized(e)
def _raise_no_nan_in_deoptimized(e) -> None:
msg = (f"{str(e)}. Because "
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
"produce invalid values during its execution. This behavior can "
"result from `jit` optimizations causing the invalid value to be "
"produced. It may also arise from having nan/inf literals as "
"inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
"\n\n"
"It may be possible to avoid the invalid value by removing the "
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/jax-ml/jax.")
raise FloatingPointError(msg) from None
def _identity_fn(x):
return x

View File

@ -33,6 +33,7 @@ import ml_dtypes
import numpy as np
from jax._src import config
from jax._src.lib import xla_extension_version
from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC
@ -486,18 +487,37 @@ _complex_types: list[JAXType] = [
np.dtype('complex64'),
np.dtype('complex128'),
]
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}
# We add the StringDType only to `_jax_dtype_set` but not to `_jax_types` and
# `_dtype_kinds`. This is because, in spite of a very similar sounding name,
# `_jax_types` is only meant for the promotion related logic, and StringDType
# does not participate in promotions at the moment. Similarly, `_dtype_kinds` is
# only meant for the `jnp.isdtype` and we want to be conservative and not allow
# StringDType to be used in there.
_string_types: list[JAXType] = []
if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311:
_string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore
_jax_dtype_set = {
float0,
*_bool_types,
*_int_types,
*_float_types,
*_complex_types,
*_string_types,
}
_jax_types = (_bool_types + _int_types + _float_types + _complex_types)
_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
}
@ -870,8 +890,14 @@ def check_user_dtype_supported(dtype, fun_name=None):
uint2,
uint4
]
if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0:
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
if (
np_dtype.kind not in 'biufcT'
and not is_custom_dtype
and not dtype == float0
):
msg = (
f'JAX only supports number, bool, and string dtypes, got dtype {dtype}'
)
msg += f" in {fun_name}" if fun_name else ""
raise TypeError(msg)
if dtype is not None and np_dtype != canonicalize_dtype(np_dtype):
@ -949,3 +975,7 @@ def short_dtype_name(dtype) -> str:
else:
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
.replace('int' , 'i').replace('complex', 'c'))
def is_string_dtype(dtype: DTypeLike | None) -> bool:
return dtype in _string_types

View File

@ -22,6 +22,7 @@ from functools import partial
from typing import Any
from jax._src import config
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import (tree_flatten, tree_unflatten,
@ -360,8 +361,15 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
cts_out = get_primitive_transpose(eqn.primitive)(
params, call_jaxpr, invals, cts_in, cts_in_avals)
else:
cts_out = get_primitive_transpose(eqn.primitive)(
cts_in, *invals, **eqn.params)
try:
cts_out = get_primitive_transpose(eqn.primitive)(
cts_in, *invals, **eqn.params)
except (FloatingPointError, ZeroDivisionError) as e:
msg = "When differentiating the code at the top of the callstack:"
if msg not in e.args[0]:
e.args = e.args[0] + f'\n{msg}',
e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}',
raise e from None
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
# FIXME: Some invars correspond to primals!
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
@ -1003,7 +1011,20 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
if update_params:
new_params = update_params(new_params, map(is_undefined_primal, args),
[type(x) is not Zero for x in ct])
out_flat = primitive.bind(fun, *all_args, **new_params)
try:
out_flat = primitive.bind(fun, *all_args, **new_params)
except dispatch.InternalFloatingPointError as e:
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
_ = backward_pass(call_jaxpr, None, {}, args, ct)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
dispatch._raise_no_nan_in_deoptimized(e)
arg_cts = tree_unflatten(out_tree(), out_flat)
# The freevars are being fanned out (not mapped). During transpose the

View File

@ -57,6 +57,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax.lax import (PrecisionLike,_array_copy,
_sort_le_comparator, _sort_lt_comparator)
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.numpy import util
@ -5474,6 +5475,39 @@ def _supports_buffer_protocol(obj):
return True
def _make_string_array(
object: np.ndarray,
dtype: DTypeLike | None = None,
ndmin: int = 0,
device: xc.Device | Sharding | None = None,
) -> Array:
if xla_extension_version < 311:
raise TypeError(
"String arrays are not supported in JAX before XLA extension version"
" 311."
)
if not isinstance(object, np.ndarray):
raise TypeError(
"Currently, string arrays can only be made from NumPy"
f" arrays. Got: {type(object)}."
)
if dtype is not None and (
dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype)
):
raise TypeError(
f"Cannot make an array with dtype {dtype} from an object with dtype"
f" {object.dtype}."
)
if ndmin > object.ndim:
raise TypeError(
f"ndmin {ndmin} cannot be greater than object's ndims"
f" {object.ndim} for string arrays."
)
# Just do a device_put since XLA does not support string as a data type.
return jax.device_put(x=object, device=device)
@export
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
order: str | None = "K", ndmin: int = 0,
@ -5567,6 +5601,15 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)
# String arrays need separate handling because XLA does not support string
# as a data type.
if dtypes.is_string_dtype(dtype) or (
hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype)
):
return _make_string_array(
object=object, dtype=dtype, ndmin=ndmin, device=device
)
# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or

View File

@ -222,6 +222,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
f"Argument '{name}' of shape {aval.str_short()} of type"
f' {type(arg)} is not a valid JAX type.') from e
raise AssertionError("Unreachable") from e
except dispatch.InternalFloatingPointError as e:
if getattr(fun, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
dispatch.maybe_recursive_nan_check(e, fun, args, kwargs)
if p.attrs_tracked:
num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked)
@ -1700,33 +1704,7 @@ def _pjit_call_impl_python(
("out_layouts", out_layouts),
("abstract args", map(core.abstractify, args)),
("fingerprint", fingerprint))
try:
return compiled.unsafe_call(*args), compiled, pgle_profiler
except FloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
if len(jaxpr.eqns) > 1:
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
msg = (f"{str(e)}. Because "
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
"precise error message. However, the de-optimized function did not "
"produce invalid values during its execution. This behavior can "
"result from `jit` optimizations causing the invalid value to be "
"produced. It may also arise from having nan/inf constants as "
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
"\n\n"
"It may be possible to avoid the invalid value by removing the "
"`jit` decorator, at the cost of losing optimizations. "
"\n\n"
"If you see this error, consider opening a bug report at "
"https://github.com/jax-ml/jax.")
raise FloatingPointError(msg)
return compiled.unsafe_call(*args), compiled, pgle_profiler
@weakref_lru_cache
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
@ -2404,19 +2382,31 @@ def _pjit_transpose(cts_in, *primals_in,
transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts
transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts
nz_cts_out = pjit_p.bind(
*primals_and_nz_cts_in,
jaxpr=transpose_jaxpr,
in_shardings=transpose_in_shardings,
out_shardings=transpose_out_shardings,
in_layouts=transpose_in_layouts,
out_layouts=transpose_out_layouts,
resource_env=resource_env,
donated_invars=(False,) * len(primals_and_nz_cts_in),
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
try:
nz_cts_out = pjit_p.bind(
*primals_and_nz_cts_in,
jaxpr=transpose_jaxpr,
in_shardings=transpose_in_shardings,
out_shardings=transpose_out_shardings,
in_layouts=transpose_in_layouts,
out_layouts=transpose_out_layouts,
resource_env=resource_env,
donated_invars=(False,) * len(primals_and_nz_cts_in),
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
except dispatch.InternalFloatingPointError as e:
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
_ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None # great
else:
# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
dispatch._raise_no_nan_in_deoptimized(e)
if attrs_tracked:
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])

View File

@ -30,7 +30,6 @@ package(
py_library(
name = "jax2tf",
srcs = ["__init__.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [":jax2tf_internal"],
)
@ -42,7 +41,6 @@ py_library(
"impl_no_xla.py",
"jax2tf.py",
],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"],
visibility = jax_visibility("jax2tf_internal"),

View File

@ -24,7 +24,6 @@ package(
py_library(
name = "back_compat_testdata",
srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [
"//third_party/py/numpy",
"//third_party/py/typing_extensions",

View File

@ -28,7 +28,6 @@ package(
py_library(
name = "flax_models",
srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [
"//jax",
"//third_party/py/flax:core",

View File

@ -874,6 +874,15 @@ def _match(mesh, check_rep, pspec, x):
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
def _add_singleton(x): return jnp.expand_dims(x, axis=0)
def _maybe_check_special(outs):
if not config.debug_nans.value and not config.debug_infs.value: return
bufs = [s.data for leaf in tree_leaves(outs)
for s in getattr(leaf, 'addressable_shards', [])]
try:
dispatch.check_special('shard_map', bufs)
except dispatch.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
class ShardMapTrace(core.Trace):
__slots__ = ("mesh", "check", "context_mesh")
@ -902,9 +911,10 @@ class ShardMapTrace(core.Trace):
out_vals = eager_rule(self.mesh, *in_vals, **params)
else:
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
with (core.eval_context(), jax.disable_jit(False),
set_abstract_mesh(self.context_mesh)):
with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False),
jax.debug_infs(False), set_abstract_mesh(self.context_mesh)):
out_vals = jax.jit(f)(*in_vals)
_maybe_check_special(out_vals)
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
if prim.multiple_results:
@ -1700,10 +1710,21 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
def new_out_names_thunk():
return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz)
out_flat = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite,
auto=auto)
try:
out_flat = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite,
auto=auto)
except (FloatingPointError, ZeroDivisionError) as e:
print("Invalid nan value encountered in the backward pass of a shard_map "
"function. Calling the de-optimized backward pass.")
try:
_ = fun_trans.call_wrapped(out_cts, args)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
dispatch._raise_no_nan_in_deoptimized(e)
return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose

View File

@ -27,6 +27,7 @@ from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src import core
from jax._src import ffi
from jax._src.interpreters import ad
from jax._src.lib import gpu_solver
@ -533,10 +534,14 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder):
def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder):
data_aval, _, _, _, = ctx.avals_in
return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices,
indptr, b, tol, reorder)
# TODO(danfm): remove after JAX 0.5.1 release.
if hasattr(gpu_solver, "cuda_csrlsvqr"):
data_aval, _, _, _, = ctx.avals_in
return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices,
indptr, b, tol, reorder)
return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")(
ctx, data, indices, indptr, b, tol=np.float64(tol),
reorder=np.int32(reorder))
def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
del tol, reorder

View File

@ -230,6 +230,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",
"@xla//xla/tsl/cuda:cusolver",
@ -251,6 +252,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cublas",
"@xla//xla/tsl/cuda:cudart",

View File

@ -50,6 +50,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA",
GeqrfFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA",
CsrlsvqrFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA",
OrgqrFfi);

View File

@ -486,6 +486,8 @@ nb::dict Registrations() {
#ifdef JAX_GPU_CUDA
dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi);
dict[JAX_GPU_PREFIX "solver_csrlsvqr_ffi"] =
EncapsulateFfiHandler(CsrlsvqrFfi);
#endif // JAX_GPU_CUDA
return dict;

View File

@ -20,6 +20,10 @@ limitations under the License.
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/gpu/vendor.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"
#endif
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace solver {
@ -315,6 +319,23 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched);
JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched);
#undef JAX_GPU_DEFINE_GESVDJ_BATCHED
#define JAX_GPU_DEFINE_CSRLSVQR(Type, Scalar, Name) \
template <> \
absl::Status Csrlsvqr<Type>( \
cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \
const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \
const Type *b, double tol, int reorder, Type *x, int *singularity) { \
return JAX_AS_STATUS(Name(handle, n, nnz, matdesc, csrValA, csrRowPtrA, \
csrColIndA, b, static_cast<Scalar>(tol), \
reorder, x, singularity)); \
}
JAX_GPU_DEFINE_CSRLSVQR(float, float, cusolverSpScsrlsvqr);
JAX_GPU_DEFINE_CSRLSVQR(double, double, cusolverSpDcsrlsvqr);
JAX_GPU_DEFINE_CSRLSVQR(gpuComplex, float, cusolverSpCcsrlsvqr);
JAX_GPU_DEFINE_CSRLSVQR(gpuDoubleComplex, double, cusolverSpZcsrlsvqr);
#undef JAX_GPU_DEFINE_CSRLSVQR
#endif // JAX_GPU_CUDA
// Symmetric tridiagonal reduction: sytrd

View File

@ -23,6 +23,10 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "jaxlib/gpu/vendor.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"
#endif
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace solver {
@ -206,6 +210,13 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBatchedBufferSize);
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched);
#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS
#define JAX_GPU_SOLVER_Csrlsvqr_ARGS(Type, ...) \
cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \
const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \
const Type *b, double tol, int reorder, Type *x, int *singularity
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Csrlsvqr);
#undef JAX_GPU_SOLVER_Csrlsvqr_ARGS
#endif // JAX_GPU_CUDA
// Symmetric tridiagonal reduction: sytrd

View File

@ -41,6 +41,10 @@ limitations under the License.
#include "jaxlib/gpu/vendor.h"
#include "xla/ffi/api/ffi.h"
#ifdef JAX_GPU_CUDA
#include "third_party/gpus/cuda/include/cusolverSp.h"
#endif
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
@ -1013,6 +1017,82 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch,
.Ret<ffi::Buffer<ffi::S32>>() // info
);
// csrlsvqr: Linear system solve via Sparse QR
template <typename T>
ffi::Error CsrlsvqrImpl(int64_t n, int64_t nnz, double tol, int reorder,
gpuStream_t stream, ffi::AnyBuffer csrValA,
ffi::Buffer<ffi::S32> csrColIndA,
ffi::Buffer<ffi::S32> csrRowPtrA, ffi::AnyBuffer b,
ffi::Result<ffi::AnyBuffer> x) {
FFI_ASSIGN_OR_RETURN(auto handle, SpSolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(auto int_n, MaybeCastNoOverflow<int>(n));
FFI_ASSIGN_OR_RETURN(auto int_nnz, MaybeCastNoOverflow<int>(nnz));
cusparseMatDescr_t matdesc = nullptr;
JAX_FFI_RETURN_IF_GPU_ERROR(cusparseCreateMatDescr(&matdesc));
JAX_FFI_RETURN_IF_GPU_ERROR(
cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL));
JAX_FFI_RETURN_IF_GPU_ERROR(
cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO));
auto* csrValA_data = static_cast<T*>(csrValA.untyped_data());
auto* csrColIndA_data = csrColIndA.typed_data();
auto* csrRowPtrA_data = csrRowPtrA.typed_data();
auto* b_data = static_cast<T*>(b.untyped_data());
auto* x_data = static_cast<T*>(x->untyped_data());
int singularity = -1;
auto result = solver::Csrlsvqr<T>(
handle.get(), int_n, int_nnz, matdesc, csrValA_data, csrRowPtrA_data,
csrColIndA_data, b_data, tol, reorder, x_data, &singularity);
cusparseDestroyMatDescr(matdesc);
FFI_RETURN_IF_ERROR_STATUS(result);
if (singularity >= 0) {
return ffi::Error(ffi::ErrorCode::kInternal,
"Singular matrix in linear solve.");
}
return ffi::Error::Success();
}
ffi::Error CsrlsvqrDispatch(gpuStream_t stream, int reorder, double tol,
ffi::AnyBuffer csrValA,
ffi::Buffer<ffi::S32> csrColIndA,
ffi::Buffer<ffi::S32> csrRowPtrA, ffi::AnyBuffer b,
ffi::Result<ffi::AnyBuffer> x) {
auto dataType = csrValA.element_type();
if (dataType != b.element_type() || dataType != x->element_type()) {
return ffi::Error::InvalidArgument(
"The inputs and outputs to csrlsvqr must have the same element type");
}
int64_t n = b.element_count();
int64_t nnz = csrValA.element_count();
FFI_RETURN_IF_ERROR(
CheckShape(csrColIndA.dimensions(), nnz, "csrColIndA", "csrlsvqr"));
FFI_RETURN_IF_ERROR(
CheckShape(csrRowPtrA.dimensions(), n + 1, "csrColPtrA", "csrlsvqr"));
FFI_RETURN_IF_ERROR(CheckShape(x->dimensions(), n, "x", "csrlsvqr"));
SOLVER_DISPATCH_IMPL(CsrlsvqrImpl, n, nnz, tol, reorder, stream, csrValA,
csrColIndA, csrRowPtrA, b, x);
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in csrlsvqr", absl::FormatStreamed(dataType)));
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CsrlsvqrFfi, CsrlsvqrDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Attr<int>("reorder") // reorder
.Attr<double>("tol") // tol
.Arg<ffi::AnyBuffer>() // csrValA
.Arg<ffi::Buffer<ffi::S32>>() // csrColIndA
.Arg<ffi::Buffer<ffi::S32>>() // csrRowPtrA
.Arg<ffi::AnyBuffer>() // b
.Ret<ffi::AnyBuffer>() // x
);
#endif // JAX_GPU_CUDA
// Symmetric tridiagonal reduction: sytrd

View File

@ -40,6 +40,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi);
#ifdef JAX_GPU_CUDA
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrlsvqrFfi);
#endif // JAX_GPU_CUDA
} // namespace JAX_GPU_NAMESPACE

View File

@ -12,17 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import importlib
import jaxlib.mlir.ir as ir
import numpy as np
from jaxlib import xla_client
from .hlo_helpers import custom_call
try:
from .cuda import _blas as _cublas # pytype: disable=import-error
except ImportError:
@ -129,27 +122,3 @@ def has_magma():
if _hiphybrid:
return _hiphybrid.has_magma()
return False
def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
indices, indptr, b, tol, reorder):
"""Sparse solver via QR decomposition. CUDA only."""
b_type = ir.RankedTensorType(b.type)
data_type = ir.RankedTensorType(data.type)
n = b_type.shape[0]
nnz = data_type.shape[0]
opaque = gpu_solver.build_csrlsvqr_descriptor(
np.dtype(dtype), n, nnz, reorder, tol
)
out = custom_call(
f"{platform}solver_csrlsvqr", # call_target_name
result_types=[b.type],
operands=[data, indptr, indices, b],
backend_config=opaque, # backend_config
operand_layouts=[(0,), (0,), (0,), (0,)], # operand_layouts
result_layouts=[(0,)] # result_layouts
).results
return out
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)

View File

@ -1607,6 +1607,11 @@ jax_py_test(
],
)
jax_multiplatform_test(
name = "string_array_test",
srcs = ["string_array_test.py"],
)
jax_multiplatform_test(
name = "cudnn_fusion_test",
srcs = ["cudnn_fusion_test.py"],
@ -1642,6 +1647,7 @@ exports_files(
"shard_map_test.py",
"transfer_guard_test.py",
"layout_test.py",
"string_array_test.py",
],
visibility = jax_test_file_visibility,
)

View File

@ -24,6 +24,8 @@ from jax._src import api
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
jax.config.parse_flags_with_absl()
@ -75,7 +77,6 @@ class DebugNaNsTest(jtu.JaxTestCase):
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
def testCallDeoptimized(self, jit):
raise SkipTest("re-enable once we handle contexts properly") # TODO(dougalm)
@jit
def f(x):
return jax.lax.cond(
@ -89,6 +90,25 @@ class DebugNaNsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testShardMap(self):
mesh = jax.make_mesh((1,), ('x',))
f = shard_map(lambda x: 0. / x, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testPmap(self):
pmap_funcs = [api._cpp_pmap]
@ -99,17 +119,47 @@ class DebugNaNsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
r"invalid value \(nan\) encountered in div"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
r"Invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testGradPmap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
_, f_vjp = jax.vjp(jax.pmap(f), jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testGradShardMap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
mesh = jax.make_mesh((1,), ('x',))
shmap_f = shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
_, f_vjp = jax.vjp(shmap_f, jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testPmapNoNaN(self):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready()
@ -163,17 +213,23 @@ class DebugNaNsTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in jit\(true_divide\)"):
r"invalid value \(nan\) encountered in div"):
f(inp, inp)
# TODO(yashkatariya): Fix this and make true_divide appear in the name again.
# Instead of `f` showing up in the error, the name should be of the
# primitive (true_divide) in this case.
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in jit\(f\)"):
r"invalid value \(nan\) encountered in div"):
jax.jit(f)(inp, inp)
def testDebugNansInput(self):
@jax.jit
def f(x):
return x * 3.
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*input"):
f(np.nan)
@jtu.with_config(jax_debug_infs=True)
class DebugInfsTest(jtu.JaxTestCase):
@ -233,7 +289,7 @@ class DebugInfsTest(jtu.JaxTestCase):
y = x + 2 # avoid trivial dispatch path by adding some eqn
return jnp.nan, y
with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*literal"):
with jax.debug_nans(True):
f(3)

View File

@ -335,6 +335,47 @@ class FilteredTracebackTest(jtu.JaxTestCase):
('bwd_err', 'g = err(g)'),
('err', 'assert False')], filter_mode=filter_mode)
def test_jvp(self, filter_mode):
def err(_):
assert False
return ()
def f():
p = (1.,)
t = (0.,)
return jax.jvp(err, p, t)
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return jax.jvp(err, p, t)'),
('err', 'assert False')], filter_mode=filter_mode)
def test_vjp(self, filter_mode):
def err(_):
assert False
return ()
def f():
x = 1.
return jax.vjp(err, x)[0]
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return jax.vjp(err, x)[0]'),
('err', 'assert False')], filter_mode=filter_mode)
def test_debug_nans(self, filter_mode):
@jax.jit
def f(x):
return 0. / x
f(2.)
def g():
return f(0.)
with jax.debug_nans(True):
check_filtered_stack_trace(self, ZeroDivisionError, g, [
('g', 'return f(0.)'),
('f', 'return 0. / x')], filter_mode=filter_mode)
def test_cause_chain(self, filter_mode):
@jit
def inner(x):

View File

@ -3758,8 +3758,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertIsNot(x, y)
def testArrayUnsupportedDtypeError(self):
with self.assertRaisesRegex(TypeError,
"JAX only supports number and bool dtypes.*"):
with self.assertRaisesRegex(
TypeError, 'JAX only supports number, bool, and string dtypes.*'
):
jnp.array(3, [('a','<i4'),('b','<i4')])
def testArrayFromInteger(self):

217
tests/string_array_test.py Normal file
View File

@ -0,0 +1,217 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import numpy as np
config.parse_flags_with_absl()
jtu.request_cpu_devices(2)
class StringArrayTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if xla_extension_version < 311:
self.skipTest(
"Skipping this test because the current XLA extension version:"
f" {xla_extension_version} is older than 309, the oldest version with"
" string array support."
)
if not hasattr(np.dtypes, "StringDType"):
self.skipTest(
"Skipping this test because the numpy.dtype.StringDType is not"
" available."
)
def make_test_string_array(self, device=None):
"""Makes and returns a simple 2x1 string array on the first CPU device."""
if device is None:
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 1:
self.skipTest(
"Skipping this test because no CPU devices are available."
)
device = cpu_devices[0]
numpy_string_array = np.array(
["abcd", "efgh"], dtype=np.dtypes.StringDType() # type: ignore
)
jax_string_array = jax.device_put(numpy_string_array, device=device)
jax_string_array.block_until_ready()
return jax_string_array
@parameterized.named_parameters(
("asarray", True),
("device_put", False),
)
@jtu.run_on_devices("cpu")
def test_single_device_array(self, asarray):
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 1:
self.skipTest("Skipping this test because no CPU devices are available.")
numpy_string_array = np.array(
["abcdefghijklmnopqrstuvwxyz", "cba"], dtype=np.dtypes.StringDType() # type: ignore
)
if asarray:
jax_string_array = jnp.asarray(numpy_string_array, device=cpu_devices[0])
else:
jax_string_array = jax.device_put(
numpy_string_array, device=cpu_devices[0]
)
jax_string_array.block_until_ready()
array_read_back = jax.device_get(jax_string_array)
self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore
np.testing.assert_array_equal(array_read_back, numpy_string_array)
@parameterized.named_parameters(
("asarray", True),
("device_put", False),
)
@jtu.run_on_devices("cpu")
def test_multi_device_array(self, asarray):
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 2:
self.skipTest(
f"Skipping this test because only {len(cpu_devices)} host"
" devices are available. Need at least 2."
)
numpy_string_array = np.array(
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
)
mesh = jax.sharding.Mesh(np.array(cpu_devices).reshape((2, 1)), ("x", "y"))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("x", "y")
)
if asarray:
jax_string_array = jnp.asarray(numpy_string_array, device=sharding)
else:
jax_string_array = jax.device_put(numpy_string_array, device=sharding)
jax_string_array.block_until_ready()
array_read_back = jax.device_get(jax_string_array)
self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore
np.testing.assert_array_equal(array_read_back, numpy_string_array)
@jtu.run_on_devices("cpu")
def test_dtype_conversions(self):
cpu_devices = jax.devices("cpu")
if len(cpu_devices) < 1:
self.skipTest("Skipping this test because no CPU devices are available.")
# Explicitly specifying the dtype should work with StringDType numpy arrays.
numpy_string_array = np.array(
["abcd", "efgh"], dtype=np.dtypes.StringDType() # type: ignore
)
jax_string_array = jnp.asarray(
numpy_string_array,
device=cpu_devices[0],
dtype=np.dtypes.StringDType(),
) # type: ignore
jax_string_array.block_until_ready()
# Cannot make a non-StringDType array from a StringDType numpy array.
with self.assertRaisesRegex(
TypeError,
r"Cannot make an array with dtype bfloat16 from an object with dtype"
r" StringDType.*",
):
jnp.asarray(
numpy_string_array,
device=cpu_devices[0],
dtype=jnp.bfloat16,
)
# Cannot make a StringDType array from a numeric numpy array.
numpy_int_array = np.arange(2, dtype=np.int32)
with self.assertRaisesRegex(
TypeError,
r"Cannot make an array with dtype StringDType.*from an object with"
r" dtype int32.",
):
jnp.asarray(
numpy_int_array,
device=cpu_devices[0],
dtype=np.dtypes.StringDType(), # type: ignore
)
@parameterized.named_parameters(
("asarray", True),
("device_put", False),
)
@jtu.skip_on_devices("cpu")
def test_string_array_cannot_be_non_cpu_devices(self, asarray):
devices = jax.devices()
if len(devices) < 1:
self.skipTest("Skipping this test because no devices are available.")
numpy_string_array = np.array(
["abcdefghijklmnopqrstuvwxyz", "cba"], dtype=np.dtypes.StringDType() # type: ignore
)
with self.assertRaisesRegex(
TypeError, "String arrays can only be sharded to CPU devices"
):
if asarray:
jax_string_array = jnp.asarray(numpy_string_array, device=devices[0])
else:
jax_string_array = jax.device_put(numpy_string_array, device=devices[0])
jax_string_array.block_until_ready()
def test_jit_fails_with_string_arrays(self):
f = jax.jit(lambda x: x)
input_array = self.make_test_string_array()
self.assertRaisesRegex(
TypeError,
r"Argument.*is not a valid JAX type.",
lambda: f(input_array),
)
def test_grad_fails_with_string_arrays(self):
f = jax.grad(lambda x: x)
input_array = self.make_test_string_array()
self.assertRaisesRegex(
TypeError,
r"Argument.*is not a valid JAX type.",
lambda: f(input_array),
)
def test_vmap_without_jit_works_with_string_arrays(self):
f = jax.vmap(lambda x: x)
input_array = self.make_test_string_array()
output_array = f(input_array)
self.assertEqual(output_array.dtype, input_array.dtype)
np.testing.assert_array_equal(output_array, input_array)
def test_vmap_with_jit_fails_with_string_arrays(self):
f = jax.vmap(lambda x: x + jnp.arange(2))
input_array = self.make_test_string_array()
self.assertRaisesRegex(
ValueError,
r".*StringDType.*is not a valid dtype",
lambda: f(input_array),
)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result.
XLA_COMMIT = "46f8cf03902c0af58468e2258f9438788e7f4c97"
XLA_SHA256 = "0c391b0a8433d26bfc93e5bee775f7eb629b811a42222ce2b4c7449044a5bc0d"
XLA_COMMIT = "85eccd2ed9f2afd956ab17afd31480a042f07f92"
XLA_SHA256 = "ed853428d3f92aeb3a0cabd564f2373309b4784cd6f90db74ccc2d2ae735984f"
def repo():
tf_http_archive(