mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
This commit is contained in:
commit
2e808f2836
4
.github/workflows/tsan.yaml
vendored
4
.github/workflows/tsan.yaml
vendored
@ -198,4 +198,8 @@ jobs:
|
||||
--test_output=errors \
|
||||
--local_test_jobs=32 \
|
||||
--test_timeout=600 \
|
||||
--config=resultstore \
|
||||
--spawn_strategy=local \
|
||||
--remote_cache=remotebuildexecution.googleapis.com \
|
||||
--remote_instance_name=projects/tensorflow-testing/instances/default_instance \
|
||||
//tests:cpu_tests
|
||||
|
@ -264,10 +264,52 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "UEObolTqw4pp"
|
||||
},
|
||||
"source": [
|
||||
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
|
||||
"\n",
|
||||
"The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n",
|
||||
"\n",
|
||||
"To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "aKNeOHTJnqmS",
|
||||
"outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"pinned_host\n",
|
||||
"device\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n",
|
||||
"s_dev = s_host.with_memory_kind('device')\n",
|
||||
"arr_host = jax.device_put(arr, s_host)\n",
|
||||
"arr_dev = jax.device_put(arr, s_dev)\n",
|
||||
"print(arr_host.sharding.memory_kind)\n",
|
||||
"print(arr_dev.sharding.memory_kind)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "jDHYnVqHwaST"
|
||||
},
|
||||
"source": [
|
||||
"## 1. Automatic parallelism via `jit`\n",
|
||||
"\n",
|
||||
"Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n",
|
||||
@ -354,10 +396,98 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"id": "Q4N5mrr9i_ki"
|
||||
},
|
||||
"source": [
|
||||
"The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n",
|
||||
"\n",
|
||||
"### 1.1 Sharding transformation between memory types\n",
|
||||
"\n",
|
||||
"The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n",
|
||||
"\n",
|
||||
"#### Example 1: Pinned host to device memory\n",
|
||||
"\n",
|
||||
"In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "PXu3MhafyRHo",
|
||||
"outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
|
||||
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
|
||||
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
|
||||
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
|
||||
"device\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"f = jax.jit(lambda x: x, out_shardings=s_dev)\n",
|
||||
"out_dev = f(arr_host)\n",
|
||||
"print(out_dev)\n",
|
||||
"print(out_dev.sharding.memory_kind)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "LuYFqpcBySiX"
|
||||
},
|
||||
"source": [
|
||||
"#### Example 2: Device to pinned_host memory\n",
|
||||
"\n",
|
||||
"In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "qLsgNlKfybRw",
|
||||
"outputId": "a16448b9-7e39-408f-b200-505f65ad4464"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n",
|
||||
" [ 8. 9. 10. 11. 12. 13. 14. 15.]\n",
|
||||
" [16. 17. 18. 19. 20. 21. 22. 23.]\n",
|
||||
" [24. 25. 26. 27. 28. 29. 30. 31.]]\n",
|
||||
"pinned_host\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"g = jax.jit(lambda x: x, out_shardings=s_host)\n",
|
||||
"out_host = g(arr_dev)\n",
|
||||
"print(out_host)\n",
|
||||
"print(out_host.sharding.memory_kind)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "7BGD31-owaSU"
|
||||
},
|
||||
"source": [
|
||||
"## 2. Semi-automated sharding with constraints\n",
|
||||
"\n",
|
||||
"If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
|
||||
|
@ -90,8 +90,31 @@ print(arr_sharded)
|
||||
jax.debug.visualize_array_sharding(arr_sharded)
|
||||
```
|
||||
|
||||
+++ {"id": "UEObolTqw4pp"}
|
||||
|
||||
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
|
||||
|
||||
The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.
|
||||
|
||||
To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding.
|
||||
|
||||
```{code-cell}
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: aKNeOHTJnqmS
|
||||
outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2
|
||||
---
|
||||
s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')
|
||||
s_dev = s_host.with_memory_kind('device')
|
||||
arr_host = jax.device_put(arr, s_host)
|
||||
arr_dev = jax.device_put(arr, s_dev)
|
||||
print(arr_host.sharding.memory_kind)
|
||||
print(arr_dev.sharding.memory_kind)
|
||||
```
|
||||
|
||||
+++ {"id": "jDHYnVqHwaST"}
|
||||
|
||||
## 1. Automatic parallelism via `jit`
|
||||
|
||||
Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
|
||||
@ -129,8 +152,52 @@ jax.debug.visualize_array_sharding(result)
|
||||
print(result)
|
||||
```
|
||||
|
||||
+++ {"id": "Q4N5mrr9i_ki"}
|
||||
|
||||
The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.
|
||||
|
||||
### 1.1 Sharding transformation between memory types
|
||||
|
||||
The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.
|
||||
|
||||
#### Example 1: Pinned host to device memory
|
||||
|
||||
In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory.
|
||||
|
||||
```{code-cell}
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: PXu3MhafyRHo
|
||||
outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b
|
||||
---
|
||||
f = jax.jit(lambda x: x, out_shardings=s_dev)
|
||||
out_dev = f(arr_host)
|
||||
print(out_dev)
|
||||
print(out_dev.sharding.memory_kind)
|
||||
```
|
||||
|
||||
+++ {"id": "LuYFqpcBySiX"}
|
||||
|
||||
#### Example 2: Device to pinned_host memory
|
||||
|
||||
In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory.
|
||||
|
||||
```{code-cell}
|
||||
---
|
||||
colab:
|
||||
base_uri: https://localhost:8080/
|
||||
id: qLsgNlKfybRw
|
||||
outputId: a16448b9-7e39-408f-b200-505f65ad4464
|
||||
---
|
||||
g = jax.jit(lambda x: x, out_shardings=s_host)
|
||||
out_host = g(arr_dev)
|
||||
print(out_host)
|
||||
print(out_host.sharding.memory_kind)
|
||||
```
|
||||
|
||||
+++ {"id": "7BGD31-owaSU"}
|
||||
|
||||
## 2. Semi-automated sharding with constraints
|
||||
|
||||
If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
|
||||
|
@ -499,6 +499,7 @@ pytype_strict_library(
|
||||
":traceback_util",
|
||||
":typing",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("ml_dtypes") + py_deps("numpy"),
|
||||
)
|
||||
|
||||
|
@ -99,6 +99,7 @@ map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
@api_boundary
|
||||
def _nan_check_posthook(fun, args, kwargs, output):
|
||||
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
|
||||
buffers = []
|
||||
@ -108,12 +109,18 @@ def _nan_check_posthook(fun, args, kwargs, output):
|
||||
|
||||
try:
|
||||
dispatch.check_special(pjit.pjit_p.name, buffers)
|
||||
except FloatingPointError:
|
||||
# compiled_fun can only raise in this case
|
||||
except dispatch.InternalFloatingPointError as e:
|
||||
assert config.debug_nans.value or config.debug_infs.value
|
||||
print("Invalid nan value encountered in the output of a C++-jit/pmap "
|
||||
"function. Calling the de-optimized version.")
|
||||
fun._cache_miss(*args, **kwargs)[0] # probably won't return
|
||||
if hasattr(fun, '_fun'):
|
||||
f = fun._fun
|
||||
if getattr(f, '_apply_primitive', False):
|
||||
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
|
||||
# compiled_fun can only raise in this case
|
||||
dispatch.maybe_recursive_nan_check(e, f, args, kwargs)
|
||||
raise AssertionError("Unreachable") from e
|
||||
else:
|
||||
# TODO(emilyaf): Shouldn't need this fallback.
|
||||
raise
|
||||
|
||||
def _update_debug_special_global(_):
|
||||
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
|
||||
@ -1574,11 +1581,14 @@ def _cpp_pmap(
|
||||
|
||||
execute: Callable | None = None
|
||||
with core.take_current_trace() as trace:
|
||||
if isinstance(trace, core.EvalTrace):
|
||||
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
|
||||
out = execute(*p.flat_args)
|
||||
else:
|
||||
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
|
||||
try:
|
||||
if isinstance(trace, core.EvalTrace):
|
||||
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
|
||||
out = execute(*p.flat_args)
|
||||
else:
|
||||
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
|
||||
except dispatch.InternalFloatingPointError as e:
|
||||
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.')
|
||||
|
||||
out_tree, out_flat = p.out_tree, out
|
||||
out_pytree_def = out_tree()
|
||||
@ -1629,6 +1639,7 @@ def _cpp_pmap(
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
pmap_f._fun = fun
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
@ -1674,6 +1685,7 @@ def _cpp_pmap(
|
||||
_pmap_cache_clears = weakref.WeakSet() # type: ignore
|
||||
|
||||
|
||||
@api_boundary
|
||||
def jvp(
|
||||
fun: Callable, primals, tangents, has_aux: bool = False
|
||||
) -> tuple[Any, ...]:
|
||||
@ -1878,6 +1890,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
|
||||
|
||||
return apply_flat_fun_nokwargs(fun, io_tree, py_args)
|
||||
|
||||
@api_boundary
|
||||
def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
|
||||
if len(py_args_) != 1:
|
||||
msg = (f"The function returned by `jax.vjp` applied to {name} was called "
|
||||
@ -1937,6 +1950,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
|
||||
has_aux: Literal[True],
|
||||
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
|
||||
...
|
||||
@api_boundary
|
||||
def vjp(
|
||||
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
|
||||
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
|
||||
@ -2225,6 +2239,18 @@ def _infer_src_sharding(src, x) -> Sharding | None:
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _check_string_compatible_sharding(s):
|
||||
"""Checks if target devices are compatible with string arrays."""
|
||||
if isinstance(s, xc.Device) and s.device_kind == "cpu":
|
||||
return
|
||||
if (isinstance(s, Sharding)
|
||||
and s._internal_device_list[0].device_kind == "cpu"):
|
||||
return
|
||||
raise TypeError(
|
||||
"String arrays can only be sharded to CPU devices. Received"
|
||||
f" unsupported device or sharding: {s}")
|
||||
|
||||
# TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use
|
||||
# that to check if shardings are compatible with the input.
|
||||
@lru_cache(maxsize=2048)
|
||||
@ -2235,6 +2261,10 @@ def _check_sharding(aval, s):
|
||||
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
|
||||
" `jax.Device`, `Layout` or a pytree of these values. Received"
|
||||
f" invalid value: {s}")
|
||||
|
||||
if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype):
|
||||
_check_string_compatible_sharding(s)
|
||||
|
||||
if isinstance(s, Sharding):
|
||||
if isinstance(aval, core.AbstractToken):
|
||||
aval = core.get_token_aval()
|
||||
|
@ -1472,11 +1472,14 @@ Value = Any
|
||||
|
||||
def valid_jaxtype(x) -> bool:
|
||||
try:
|
||||
abstractify(x)
|
||||
aval = abstractify(x)
|
||||
except TypeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def check_valid_jaxtype(x):
|
||||
if not valid_jaxtype(x):
|
||||
|
@ -25,7 +25,7 @@ import itertools
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
@ -100,6 +100,7 @@ def xla_primitive_callable(prim: core.Primitive, **params):
|
||||
return prim.bind(*args, **params)
|
||||
prim_fun.__name__ = prim.name
|
||||
prim_fun.__qualname__ = prim.name
|
||||
prim_fun._apply_primitive = True
|
||||
return api.jit(prim_fun)
|
||||
|
||||
|
||||
@ -321,15 +322,52 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
|
||||
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
||||
if dtypes.issubdtype(dtype, np.inexact):
|
||||
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
|
||||
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
||||
raise InternalFloatingPointError(name, "nan")
|
||||
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
raise InternalFloatingPointError(name, "inf")
|
||||
|
||||
class CopySemantics(enum.Enum):
|
||||
ALIAS = enum.auto()
|
||||
COPY = enum.auto()
|
||||
DONATE = enum.auto()
|
||||
|
||||
class InternalFloatingPointError(Exception):
|
||||
name: str
|
||||
ty: str
|
||||
|
||||
def __init__(self, name: str, ty: str):
|
||||
self.name = name
|
||||
self.ty = ty
|
||||
|
||||
def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs,
|
||||
) -> None: # always raises an exception
|
||||
print("Invalid nan value encountered in the output of a jax.jit "
|
||||
"function. Calling the de-optimized version.")
|
||||
try:
|
||||
_ = fun(*args, **kwargs)
|
||||
except (FloatingPointError, ZeroDivisionError) as e2:
|
||||
raise e2 from None
|
||||
else:
|
||||
_raise_no_nan_in_deoptimized(e)
|
||||
|
||||
def _raise_no_nan_in_deoptimized(e) -> None:
|
||||
msg = (f"{str(e)}. Because "
|
||||
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
|
||||
"de-optimized function (i.e., the function as if the `jit` "
|
||||
"decorator were removed) was called in an attempt to get a more "
|
||||
"precise error message. However, the de-optimized function did not "
|
||||
"produce invalid values during its execution. This behavior can "
|
||||
"result from `jit` optimizations causing the invalid value to be "
|
||||
"produced. It may also arise from having nan/inf literals as "
|
||||
"inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
|
||||
"\n\n"
|
||||
"It may be possible to avoid the invalid value by removing the "
|
||||
"`jit` decorator, at the cost of losing optimizations. "
|
||||
"\n\n"
|
||||
"If you see this error, consider opening a bug report at "
|
||||
"https://github.com/jax-ml/jax.")
|
||||
raise FloatingPointError(msg) from None
|
||||
|
||||
def _identity_fn(x):
|
||||
return x
|
||||
|
||||
|
@ -33,6 +33,7 @@ import ml_dtypes
|
||||
import numpy as np
|
||||
|
||||
from jax._src import config
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.typing import Array, DType, DTypeLike
|
||||
from jax._src.util import set_module, StrictABC
|
||||
|
||||
@ -486,18 +487,37 @@ _complex_types: list[JAXType] = [
|
||||
np.dtype('complex64'),
|
||||
np.dtype('complex128'),
|
||||
]
|
||||
_jax_types = _bool_types + _int_types + _float_types + _complex_types
|
||||
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}
|
||||
|
||||
|
||||
# We add the StringDType only to `_jax_dtype_set` but not to `_jax_types` and
|
||||
# `_dtype_kinds`. This is because, in spite of a very similar sounding name,
|
||||
# `_jax_types` is only meant for the promotion related logic, and StringDType
|
||||
# does not participate in promotions at the moment. Similarly, `_dtype_kinds` is
|
||||
# only meant for the `jnp.isdtype` and we want to be conservative and not allow
|
||||
# StringDType to be used in there.
|
||||
_string_types: list[JAXType] = []
|
||||
if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311:
|
||||
_string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore
|
||||
|
||||
_jax_dtype_set = {
|
||||
float0,
|
||||
*_bool_types,
|
||||
*_int_types,
|
||||
*_float_types,
|
||||
*_complex_types,
|
||||
*_string_types,
|
||||
}
|
||||
|
||||
_jax_types = (_bool_types + _int_types + _float_types + _complex_types)
|
||||
|
||||
_dtype_kinds: dict[str, set] = {
|
||||
'bool': {*_bool_types},
|
||||
'signed integer': {*_signed_types},
|
||||
'unsigned integer': {*_unsigned_types},
|
||||
'integral': {*_signed_types, *_unsigned_types},
|
||||
'real floating': {*_float_types},
|
||||
'complex floating': {*_complex_types},
|
||||
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
|
||||
'bool': {*_bool_types},
|
||||
'signed integer': {*_signed_types},
|
||||
'unsigned integer': {*_unsigned_types},
|
||||
'integral': {*_signed_types, *_unsigned_types},
|
||||
'real floating': {*_float_types},
|
||||
'complex floating': {*_complex_types},
|
||||
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
|
||||
}
|
||||
|
||||
|
||||
@ -870,8 +890,14 @@ def check_user_dtype_supported(dtype, fun_name=None):
|
||||
uint2,
|
||||
uint4
|
||||
]
|
||||
if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0:
|
||||
msg = f"JAX only supports number and bool dtypes, got dtype {dtype}"
|
||||
if (
|
||||
np_dtype.kind not in 'biufcT'
|
||||
and not is_custom_dtype
|
||||
and not dtype == float0
|
||||
):
|
||||
msg = (
|
||||
f'JAX only supports number, bool, and string dtypes, got dtype {dtype}'
|
||||
)
|
||||
msg += f" in {fun_name}" if fun_name else ""
|
||||
raise TypeError(msg)
|
||||
if dtype is not None and np_dtype != canonicalize_dtype(np_dtype):
|
||||
@ -949,3 +975,7 @@ def short_dtype_name(dtype) -> str:
|
||||
else:
|
||||
return (dtype.name.replace('float', 'f').replace('uint' , 'u')
|
||||
.replace('int' , 'i').replace('complex', 'c'))
|
||||
|
||||
|
||||
def is_string_dtype(dtype: DTypeLike | None) -> bool:
|
||||
return dtype in _string_types
|
||||
|
@ -22,6 +22,7 @@ from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
@ -360,8 +361,15 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
params, call_jaxpr, invals, cts_in, cts_in_avals)
|
||||
else:
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
cts_in, *invals, **eqn.params)
|
||||
try:
|
||||
cts_out = get_primitive_transpose(eqn.primitive)(
|
||||
cts_in, *invals, **eqn.params)
|
||||
except (FloatingPointError, ZeroDivisionError) as e:
|
||||
msg = "When differentiating the code at the top of the callstack:"
|
||||
if msg not in e.args[0]:
|
||||
e.args = e.args[0] + f'\n{msg}',
|
||||
e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}',
|
||||
raise e from None
|
||||
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
|
||||
# FIXME: Some invars correspond to primals!
|
||||
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
|
||||
@ -1003,7 +1011,20 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _):
|
||||
if update_params:
|
||||
new_params = update_params(new_params, map(is_undefined_primal, args),
|
||||
[type(x) is not Zero for x in ct])
|
||||
out_flat = primitive.bind(fun, *all_args, **new_params)
|
||||
|
||||
try:
|
||||
out_flat = primitive.bind(fun, *all_args, **new_params)
|
||||
except dispatch.InternalFloatingPointError as e:
|
||||
print("Invalid nan value encountered in the backward pass of a jax.jit "
|
||||
"function. Calling the de-optimized backward pass.")
|
||||
try:
|
||||
_ = backward_pass(call_jaxpr, None, {}, args, ct)
|
||||
except (FloatingPointError, ZeroDivisionError) as e2:
|
||||
raise e2 from None
|
||||
else:
|
||||
# If control reaches this line, we got a NaN on the output of `compiled`
|
||||
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
|
||||
dispatch._raise_no_nan_in_deoptimized(e)
|
||||
arg_cts = tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
# The freevars are being fanned out (not mapped). During transpose the
|
||||
|
@ -57,6 +57,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.lax import (PrecisionLike,_array_copy,
|
||||
_sort_le_comparator, _sort_lt_comparator)
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
@ -5474,6 +5475,39 @@ def _supports_buffer_protocol(obj):
|
||||
return True
|
||||
|
||||
|
||||
def _make_string_array(
|
||||
object: np.ndarray,
|
||||
dtype: DTypeLike | None = None,
|
||||
ndmin: int = 0,
|
||||
device: xc.Device | Sharding | None = None,
|
||||
) -> Array:
|
||||
if xla_extension_version < 311:
|
||||
raise TypeError(
|
||||
"String arrays are not supported in JAX before XLA extension version"
|
||||
" 311."
|
||||
)
|
||||
if not isinstance(object, np.ndarray):
|
||||
raise TypeError(
|
||||
"Currently, string arrays can only be made from NumPy"
|
||||
f" arrays. Got: {type(object)}."
|
||||
)
|
||||
if dtype is not None and (
|
||||
dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Cannot make an array with dtype {dtype} from an object with dtype"
|
||||
f" {object.dtype}."
|
||||
)
|
||||
if ndmin > object.ndim:
|
||||
raise TypeError(
|
||||
f"ndmin {ndmin} cannot be greater than object's ndims"
|
||||
f" {object.ndim} for string arrays."
|
||||
)
|
||||
|
||||
# Just do a device_put since XLA does not support string as a data type.
|
||||
return jax.device_put(x=object, device=device)
|
||||
|
||||
|
||||
@export
|
||||
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
order: str | None = "K", ndmin: int = 0,
|
||||
@ -5567,6 +5601,15 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
# Keep the output uncommitted.
|
||||
return jax.device_put(object)
|
||||
|
||||
# String arrays need separate handling because XLA does not support string
|
||||
# as a data type.
|
||||
if dtypes.is_string_dtype(dtype) or (
|
||||
hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype)
|
||||
):
|
||||
return _make_string_array(
|
||||
object=object, dtype=dtype, ndmin=ndmin, device=device
|
||||
)
|
||||
|
||||
# For Python scalar literals, call coerce_to_array to catch any overflow
|
||||
# errors. We don't use dtypes.is_python_scalar because we don't want this
|
||||
# triggering for traced values. We do this here because it matters whether or
|
||||
|
@ -222,6 +222,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs):
|
||||
f"Argument '{name}' of shape {aval.str_short()} of type"
|
||||
f' {type(arg)} is not a valid JAX type.') from e
|
||||
raise AssertionError("Unreachable") from e
|
||||
except dispatch.InternalFloatingPointError as e:
|
||||
if getattr(fun, '_apply_primitive', False):
|
||||
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
|
||||
dispatch.maybe_recursive_nan_check(e, fun, args, kwargs)
|
||||
|
||||
if p.attrs_tracked:
|
||||
num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked)
|
||||
@ -1700,33 +1704,7 @@ def _pjit_call_impl_python(
|
||||
("out_layouts", out_layouts),
|
||||
("abstract args", map(core.abstractify, args)),
|
||||
("fingerprint", fingerprint))
|
||||
try:
|
||||
return compiled.unsafe_call(*args), compiled, pgle_profiler
|
||||
except FloatingPointError as e:
|
||||
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case
|
||||
|
||||
if len(jaxpr.eqns) > 1:
|
||||
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
|
||||
|
||||
# If control reaches this line, we got a NaN on the output of `compiled`
|
||||
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
|
||||
msg = (f"{str(e)}. Because "
|
||||
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
|
||||
"de-optimized function (i.e., the function as if the `jit` "
|
||||
"decorator were removed) was called in an attempt to get a more "
|
||||
"precise error message. However, the de-optimized function did not "
|
||||
"produce invalid values during its execution. This behavior can "
|
||||
"result from `jit` optimizations causing the invalid value to be "
|
||||
"produced. It may also arise from having nan/inf constants as "
|
||||
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
|
||||
"\n\n"
|
||||
"It may be possible to avoid the invalid value by removing the "
|
||||
"`jit` decorator, at the cost of losing optimizations. "
|
||||
"\n\n"
|
||||
"If you see this error, consider opening a bug report at "
|
||||
"https://github.com/jax-ml/jax.")
|
||||
raise FloatingPointError(msg)
|
||||
|
||||
return compiled.unsafe_call(*args), compiled, pgle_profiler
|
||||
|
||||
@weakref_lru_cache
|
||||
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
|
||||
@ -2404,19 +2382,31 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts
|
||||
transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts
|
||||
|
||||
nz_cts_out = pjit_p.bind(
|
||||
*primals_and_nz_cts_in,
|
||||
jaxpr=transpose_jaxpr,
|
||||
in_shardings=transpose_in_shardings,
|
||||
out_shardings=transpose_out_shardings,
|
||||
in_layouts=transpose_in_layouts,
|
||||
out_layouts=transpose_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
try:
|
||||
nz_cts_out = pjit_p.bind(
|
||||
*primals_and_nz_cts_in,
|
||||
jaxpr=transpose_jaxpr,
|
||||
in_shardings=transpose_in_shardings,
|
||||
out_shardings=transpose_out_shardings,
|
||||
in_layouts=transpose_in_layouts,
|
||||
out_layouts=transpose_out_layouts,
|
||||
resource_env=resource_env,
|
||||
donated_invars=(False,) * len(primals_and_nz_cts_in),
|
||||
name=name,
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
compiler_options_kvs=compiler_options_kvs)
|
||||
except dispatch.InternalFloatingPointError as e:
|
||||
print("Invalid nan value encountered in the backward pass of a jax.jit "
|
||||
"function. Calling the de-optimized backward pass.")
|
||||
try:
|
||||
_ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in)
|
||||
except (FloatingPointError, ZeroDivisionError) as e2:
|
||||
raise e2 from None # great
|
||||
else:
|
||||
# If control reaches this line, we got a NaN on the output of `compiled`
|
||||
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
|
||||
dispatch._raise_no_nan_in_deoptimized(e)
|
||||
|
||||
if attrs_tracked:
|
||||
final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)])
|
||||
|
@ -30,7 +30,6 @@ package(
|
||||
py_library(
|
||||
name = "jax2tf",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jax2tf_internal"],
|
||||
)
|
||||
@ -42,7 +41,6 @@ py_library(
|
||||
"impl_no_xla.py",
|
||||
"jax2tf.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
# TODO: b/255503696: enable pytype
|
||||
tags = ["pytype_unchecked_annotations"],
|
||||
visibility = jax_visibility("jax2tf_internal"),
|
||||
|
@ -24,7 +24,6 @@ package(
|
||||
py_library(
|
||||
name = "back_compat_testdata",
|
||||
srcs = glob(["*.py"]),
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//third_party/py/typing_extensions",
|
||||
|
@ -28,7 +28,6 @@ package(
|
||||
py_library(
|
||||
name = "flax_models",
|
||||
srcs = glob(["*.py"]),
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//jax",
|
||||
"//third_party/py/flax:core",
|
||||
|
@ -874,6 +874,15 @@ def _match(mesh, check_rep, pspec, x):
|
||||
def _rem_singleton(x): return jnp.squeeze(x, axis=0)
|
||||
def _add_singleton(x): return jnp.expand_dims(x, axis=0)
|
||||
|
||||
def _maybe_check_special(outs):
|
||||
if not config.debug_nans.value and not config.debug_infs.value: return
|
||||
bufs = [s.data for leaf in tree_leaves(outs)
|
||||
for s in getattr(leaf, 'addressable_shards', [])]
|
||||
try:
|
||||
dispatch.check_special('shard_map', bufs)
|
||||
except dispatch.InternalFloatingPointError as e:
|
||||
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
|
||||
|
||||
class ShardMapTrace(core.Trace):
|
||||
__slots__ = ("mesh", "check", "context_mesh")
|
||||
|
||||
@ -902,9 +911,10 @@ class ShardMapTrace(core.Trace):
|
||||
out_vals = eager_rule(self.mesh, *in_vals, **params)
|
||||
else:
|
||||
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
|
||||
with (core.eval_context(), jax.disable_jit(False),
|
||||
set_abstract_mesh(self.context_mesh)):
|
||||
with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False),
|
||||
jax.debug_infs(False), set_abstract_mesh(self.context_mesh)):
|
||||
out_vals = jax.jit(f)(*in_vals)
|
||||
_maybe_check_special(out_vals)
|
||||
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
|
||||
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
|
||||
if prim.multiple_results:
|
||||
@ -1700,10 +1710,21 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
def new_out_names_thunk():
|
||||
return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz)
|
||||
|
||||
out_flat = shard_map_p.bind(
|
||||
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite,
|
||||
auto=auto)
|
||||
try:
|
||||
out_flat = shard_map_p.bind(
|
||||
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite,
|
||||
auto=auto)
|
||||
except (FloatingPointError, ZeroDivisionError) as e:
|
||||
print("Invalid nan value encountered in the backward pass of a shard_map "
|
||||
"function. Calling the de-optimized backward pass.")
|
||||
try:
|
||||
_ = fun_trans.call_wrapped(out_cts, args)
|
||||
except (FloatingPointError, ZeroDivisionError) as e2:
|
||||
raise e2 from None
|
||||
else:
|
||||
dispatch._raise_no_nan_in_deoptimized(e)
|
||||
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
|
||||
|
||||
|
@ -27,6 +27,7 @@ from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import ffi
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lib import gpu_solver
|
||||
|
||||
@ -533,10 +534,14 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder):
|
||||
|
||||
|
||||
def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder):
|
||||
data_aval, _, _, _, = ctx.avals_in
|
||||
return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices,
|
||||
indptr, b, tol, reorder)
|
||||
|
||||
# TODO(danfm): remove after JAX 0.5.1 release.
|
||||
if hasattr(gpu_solver, "cuda_csrlsvqr"):
|
||||
data_aval, _, _, _, = ctx.avals_in
|
||||
return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices,
|
||||
indptr, b, tol, reorder)
|
||||
return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")(
|
||||
ctx, data, indices, indptr, b, tol=np.float64(tol),
|
||||
reorder=np.int32(reorder))
|
||||
|
||||
def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder):
|
||||
del tol, reorder
|
||||
|
@ -230,6 +230,7 @@ cc_library(
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cusolver",
|
||||
@ -251,6 +252,7 @@ cc_library(
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/tsl/cuda:cublas",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
|
@ -50,6 +50,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA");
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA",
|
||||
GeqrfFfi);
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA",
|
||||
CsrlsvqrFfi);
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
|
||||
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA",
|
||||
OrgqrFfi);
|
||||
|
@ -486,6 +486,8 @@ nb::dict Registrations() {
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi);
|
||||
dict[JAX_GPU_PREFIX "solver_csrlsvqr_ffi"] =
|
||||
EncapsulateFfiHandler(CsrlsvqrFfi);
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
return dict;
|
||||
|
@ -20,6 +20,10 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#endif
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace solver {
|
||||
@ -315,6 +319,23 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched);
|
||||
JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched);
|
||||
#undef JAX_GPU_DEFINE_GESVDJ_BATCHED
|
||||
|
||||
#define JAX_GPU_DEFINE_CSRLSVQR(Type, Scalar, Name) \
|
||||
template <> \
|
||||
absl::Status Csrlsvqr<Type>( \
|
||||
cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \
|
||||
const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \
|
||||
const Type *b, double tol, int reorder, Type *x, int *singularity) { \
|
||||
return JAX_AS_STATUS(Name(handle, n, nnz, matdesc, csrValA, csrRowPtrA, \
|
||||
csrColIndA, b, static_cast<Scalar>(tol), \
|
||||
reorder, x, singularity)); \
|
||||
}
|
||||
|
||||
JAX_GPU_DEFINE_CSRLSVQR(float, float, cusolverSpScsrlsvqr);
|
||||
JAX_GPU_DEFINE_CSRLSVQR(double, double, cusolverSpDcsrlsvqr);
|
||||
JAX_GPU_DEFINE_CSRLSVQR(gpuComplex, float, cusolverSpCcsrlsvqr);
|
||||
JAX_GPU_DEFINE_CSRLSVQR(gpuDoubleComplex, double, cusolverSpZcsrlsvqr);
|
||||
#undef JAX_GPU_DEFINE_CSRLSVQR
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
// Symmetric tridiagonal reduction: sytrd
|
||||
|
@ -23,6 +23,10 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#endif
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace solver {
|
||||
@ -206,6 +210,13 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr<int>, GesvdjBatchedBufferSize);
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched);
|
||||
#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS
|
||||
|
||||
#define JAX_GPU_SOLVER_Csrlsvqr_ARGS(Type, ...) \
|
||||
cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \
|
||||
const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \
|
||||
const Type *b, double tol, int reorder, Type *x, int *singularity
|
||||
JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Csrlsvqr);
|
||||
#undef JAX_GPU_SOLVER_Csrlsvqr_ARGS
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
// Symmetric tridiagonal reduction: sytrd
|
||||
|
@ -41,6 +41,10 @@ limitations under the License.
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
#include "third_party/gpus/cuda/include/cusolverSp.h"
|
||||
#endif
|
||||
|
||||
#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__))
|
||||
|
||||
@ -1013,6 +1017,82 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch,
|
||||
.Ret<ffi::Buffer<ffi::S32>>() // info
|
||||
);
|
||||
|
||||
// csrlsvqr: Linear system solve via Sparse QR
|
||||
|
||||
template <typename T>
|
||||
ffi::Error CsrlsvqrImpl(int64_t n, int64_t nnz, double tol, int reorder,
|
||||
gpuStream_t stream, ffi::AnyBuffer csrValA,
|
||||
ffi::Buffer<ffi::S32> csrColIndA,
|
||||
ffi::Buffer<ffi::S32> csrRowPtrA, ffi::AnyBuffer b,
|
||||
ffi::Result<ffi::AnyBuffer> x) {
|
||||
FFI_ASSIGN_OR_RETURN(auto handle, SpSolverHandlePool::Borrow(stream));
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(auto int_n, MaybeCastNoOverflow<int>(n));
|
||||
FFI_ASSIGN_OR_RETURN(auto int_nnz, MaybeCastNoOverflow<int>(nnz));
|
||||
|
||||
cusparseMatDescr_t matdesc = nullptr;
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(cusparseCreateMatDescr(&matdesc));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(
|
||||
cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||
JAX_FFI_RETURN_IF_GPU_ERROR(
|
||||
cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO));
|
||||
|
||||
auto* csrValA_data = static_cast<T*>(csrValA.untyped_data());
|
||||
auto* csrColIndA_data = csrColIndA.typed_data();
|
||||
auto* csrRowPtrA_data = csrRowPtrA.typed_data();
|
||||
auto* b_data = static_cast<T*>(b.untyped_data());
|
||||
auto* x_data = static_cast<T*>(x->untyped_data());
|
||||
|
||||
int singularity = -1;
|
||||
auto result = solver::Csrlsvqr<T>(
|
||||
handle.get(), int_n, int_nnz, matdesc, csrValA_data, csrRowPtrA_data,
|
||||
csrColIndA_data, b_data, tol, reorder, x_data, &singularity);
|
||||
cusparseDestroyMatDescr(matdesc);
|
||||
FFI_RETURN_IF_ERROR_STATUS(result);
|
||||
|
||||
if (singularity >= 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInternal,
|
||||
"Singular matrix in linear solve.");
|
||||
}
|
||||
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
ffi::Error CsrlsvqrDispatch(gpuStream_t stream, int reorder, double tol,
|
||||
ffi::AnyBuffer csrValA,
|
||||
ffi::Buffer<ffi::S32> csrColIndA,
|
||||
ffi::Buffer<ffi::S32> csrRowPtrA, ffi::AnyBuffer b,
|
||||
ffi::Result<ffi::AnyBuffer> x) {
|
||||
auto dataType = csrValA.element_type();
|
||||
if (dataType != b.element_type() || dataType != x->element_type()) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"The inputs and outputs to csrlsvqr must have the same element type");
|
||||
}
|
||||
int64_t n = b.element_count();
|
||||
int64_t nnz = csrValA.element_count();
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(csrColIndA.dimensions(), nnz, "csrColIndA", "csrlsvqr"));
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(csrRowPtrA.dimensions(), n + 1, "csrColPtrA", "csrlsvqr"));
|
||||
FFI_RETURN_IF_ERROR(CheckShape(x->dimensions(), n, "x", "csrlsvqr"));
|
||||
SOLVER_DISPATCH_IMPL(CsrlsvqrImpl, n, nnz, tol, reorder, stream, csrValA,
|
||||
csrColIndA, csrRowPtrA, b, x);
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in csrlsvqr", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(CsrlsvqrFfi, CsrlsvqrDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Attr<int>("reorder") // reorder
|
||||
.Attr<double>("tol") // tol
|
||||
.Arg<ffi::AnyBuffer>() // csrValA
|
||||
.Arg<ffi::Buffer<ffi::S32>>() // csrColIndA
|
||||
.Arg<ffi::Buffer<ffi::S32>>() // csrRowPtrA
|
||||
.Arg<ffi::AnyBuffer>() // b
|
||||
.Ret<ffi::AnyBuffer>() // x
|
||||
);
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
// Symmetric tridiagonal reduction: sytrd
|
||||
|
@ -40,6 +40,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrlsvqrFfi);
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
|
@ -12,17 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import importlib
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas # pytype: disable=import-error
|
||||
except ImportError:
|
||||
@ -129,27 +122,3 @@ def has_magma():
|
||||
if _hiphybrid:
|
||||
return _hiphybrid.has_magma()
|
||||
return False
|
||||
|
||||
def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
|
||||
indices, indptr, b, tol, reorder):
|
||||
"""Sparse solver via QR decomposition. CUDA only."""
|
||||
b_type = ir.RankedTensorType(b.type)
|
||||
data_type = ir.RankedTensorType(data.type)
|
||||
|
||||
n = b_type.shape[0]
|
||||
nnz = data_type.shape[0]
|
||||
opaque = gpu_solver.build_csrlsvqr_descriptor(
|
||||
np.dtype(dtype), n, nnz, reorder, tol
|
||||
)
|
||||
|
||||
out = custom_call(
|
||||
f"{platform}solver_csrlsvqr", # call_target_name
|
||||
result_types=[b.type],
|
||||
operands=[data, indptr, indices, b],
|
||||
backend_config=opaque, # backend_config
|
||||
operand_layouts=[(0,), (0,), (0,), (0,)], # operand_layouts
|
||||
result_layouts=[(0,)] # result_layouts
|
||||
).results
|
||||
return out
|
||||
|
||||
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
|
||||
|
@ -1607,6 +1607,11 @@ jax_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "string_array_test",
|
||||
srcs = ["string_array_test.py"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "cudnn_fusion_test",
|
||||
srcs = ["cudnn_fusion_test.py"],
|
||||
@ -1642,6 +1647,7 @@ exports_files(
|
||||
"shard_map_test.py",
|
||||
"transfer_guard_test.py",
|
||||
"layout_test.py",
|
||||
"string_array_test.py",
|
||||
],
|
||||
visibility = jax_test_file_visibility,
|
||||
)
|
||||
|
@ -24,6 +24,8 @@ from jax._src import api
|
||||
from jax._src import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -75,7 +77,6 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
|
||||
def testCallDeoptimized(self, jit):
|
||||
raise SkipTest("re-enable once we handle contexts properly") # TODO(dougalm)
|
||||
@jit
|
||||
def f(x):
|
||||
return jax.lax.cond(
|
||||
@ -89,6 +90,25 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(FloatingPointError, msg):
|
||||
f(1)
|
||||
|
||||
def testShardMap(self):
|
||||
mesh = jax.make_mesh((1,), ('x',))
|
||||
f = shard_map(lambda x: 0. / x, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
|
||||
# For the Cpp pmap, the first execution always goes through Python.
|
||||
f(jnp.array([1.]))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"Invalid value \(nan\) encountered in sharded computation"):
|
||||
ans = f(jnp.array([0.]))
|
||||
ans.block_until_ready()
|
||||
|
||||
if jax.device_count() >= 2:
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"Invalid value \(nan\) encountered in sharded computation"):
|
||||
ans = f(jnp.array([1., 0.]))
|
||||
ans.block_until_ready()
|
||||
|
||||
def testPmap(self):
|
||||
pmap_funcs = [api._cpp_pmap]
|
||||
|
||||
@ -99,17 +119,47 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in parallel computation"):
|
||||
r"invalid value \(nan\) encountered in div"):
|
||||
ans = f(jnp.array([0.]))
|
||||
ans.block_until_ready()
|
||||
|
||||
if jax.device_count() >= 2:
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in parallel computation"):
|
||||
r"Invalid value \(nan\) encountered in parallel computation"):
|
||||
ans = f(jnp.array([1., 0.]))
|
||||
ans.block_until_ready()
|
||||
|
||||
def testGradPmap(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x**2
|
||||
return jnp.log(y)
|
||||
|
||||
_, f_vjp = jax.vjp(jax.pmap(f), jnp.zeros([1]))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
|
||||
ans, = f_vjp(jnp.ones([1]))
|
||||
ans.block_until_ready()
|
||||
|
||||
def testGradShardMap(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
y = x**2
|
||||
return jnp.log(y)
|
||||
|
||||
mesh = jax.make_mesh((1,), ('x',))
|
||||
shmap_f = shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
|
||||
_, f_vjp = jax.vjp(shmap_f, jnp.zeros([1]))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
|
||||
ans, = f_vjp(jnp.ones([1]))
|
||||
ans.block_until_ready()
|
||||
|
||||
def testPmapNoNaN(self):
|
||||
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
|
||||
ans.block_until_ready()
|
||||
@ -163,17 +213,23 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in jit\(true_divide\)"):
|
||||
r"invalid value \(nan\) encountered in div"):
|
||||
f(inp, inp)
|
||||
|
||||
# TODO(yashkatariya): Fix this and make true_divide appear in the name again.
|
||||
# Instead of `f` showing up in the error, the name should be of the
|
||||
# primitive (true_divide) in this case.
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in jit\(f\)"):
|
||||
r"invalid value \(nan\) encountered in div"):
|
||||
jax.jit(f)(inp, inp)
|
||||
|
||||
def testDebugNansInput(self):
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 3.
|
||||
|
||||
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*input"):
|
||||
f(np.nan)
|
||||
|
||||
|
||||
@jtu.with_config(jax_debug_infs=True)
|
||||
class DebugInfsTest(jtu.JaxTestCase):
|
||||
@ -233,7 +289,7 @@ class DebugInfsTest(jtu.JaxTestCase):
|
||||
y = x + 2 # avoid trivial dispatch path by adding some eqn
|
||||
return jnp.nan, y
|
||||
|
||||
with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
|
||||
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*literal"):
|
||||
with jax.debug_nans(True):
|
||||
f(3)
|
||||
|
||||
|
@ -335,6 +335,47 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
('bwd_err', 'g = err(g)'),
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_jvp(self, filter_mode):
|
||||
def err(_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
p = (1.,)
|
||||
t = (0.,)
|
||||
return jax.jvp(err, p, t)
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return jax.jvp(err, p, t)'),
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_vjp(self, filter_mode):
|
||||
def err(_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
x = 1.
|
||||
return jax.vjp(err, x)[0]
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return jax.vjp(err, x)[0]'),
|
||||
('err', 'assert False')], filter_mode=filter_mode)
|
||||
|
||||
def test_debug_nans(self, filter_mode):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return 0. / x
|
||||
|
||||
f(2.)
|
||||
def g():
|
||||
return f(0.)
|
||||
|
||||
with jax.debug_nans(True):
|
||||
check_filtered_stack_trace(self, ZeroDivisionError, g, [
|
||||
('g', 'return f(0.)'),
|
||||
('f', 'return 0. / x')], filter_mode=filter_mode)
|
||||
|
||||
def test_cause_chain(self, filter_mode):
|
||||
@jit
|
||||
def inner(x):
|
||||
|
@ -3758,8 +3758,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertIsNot(x, y)
|
||||
|
||||
def testArrayUnsupportedDtypeError(self):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"JAX only supports number and bool dtypes.*"):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'JAX only supports number, bool, and string dtypes.*'
|
||||
):
|
||||
jnp.array(3, [('a','<i4'),('b','<i4')])
|
||||
|
||||
def testArrayFromInteger(self):
|
||||
|
217
tests/string_array_test.py
Normal file
217
tests/string_array_test.py
Normal file
@ -0,0 +1,217 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(2)
|
||||
|
||||
|
||||
class StringArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_extension_version < 311:
|
||||
self.skipTest(
|
||||
"Skipping this test because the current XLA extension version:"
|
||||
f" {xla_extension_version} is older than 309, the oldest version with"
|
||||
" string array support."
|
||||
)
|
||||
if not hasattr(np.dtypes, "StringDType"):
|
||||
self.skipTest(
|
||||
"Skipping this test because the numpy.dtype.StringDType is not"
|
||||
" available."
|
||||
)
|
||||
|
||||
def make_test_string_array(self, device=None):
|
||||
"""Makes and returns a simple 2x1 string array on the first CPU device."""
|
||||
if device is None:
|
||||
cpu_devices = jax.devices("cpu")
|
||||
if len(cpu_devices) < 1:
|
||||
self.skipTest(
|
||||
"Skipping this test because no CPU devices are available."
|
||||
)
|
||||
device = cpu_devices[0]
|
||||
|
||||
numpy_string_array = np.array(
|
||||
["abcd", "efgh"], dtype=np.dtypes.StringDType() # type: ignore
|
||||
)
|
||||
jax_string_array = jax.device_put(numpy_string_array, device=device)
|
||||
jax_string_array.block_until_ready()
|
||||
return jax_string_array
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("asarray", True),
|
||||
("device_put", False),
|
||||
)
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_single_device_array(self, asarray):
|
||||
cpu_devices = jax.devices("cpu")
|
||||
if len(cpu_devices) < 1:
|
||||
self.skipTest("Skipping this test because no CPU devices are available.")
|
||||
|
||||
numpy_string_array = np.array(
|
||||
["abcdefghijklmnopqrstuvwxyz", "cba"], dtype=np.dtypes.StringDType() # type: ignore
|
||||
)
|
||||
if asarray:
|
||||
jax_string_array = jnp.asarray(numpy_string_array, device=cpu_devices[0])
|
||||
else:
|
||||
jax_string_array = jax.device_put(
|
||||
numpy_string_array, device=cpu_devices[0]
|
||||
)
|
||||
jax_string_array.block_until_ready()
|
||||
|
||||
array_read_back = jax.device_get(jax_string_array)
|
||||
self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore
|
||||
np.testing.assert_array_equal(array_read_back, numpy_string_array)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("asarray", True),
|
||||
("device_put", False),
|
||||
)
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_multi_device_array(self, asarray):
|
||||
cpu_devices = jax.devices("cpu")
|
||||
if len(cpu_devices) < 2:
|
||||
self.skipTest(
|
||||
f"Skipping this test because only {len(cpu_devices)} host"
|
||||
" devices are available. Need at least 2."
|
||||
)
|
||||
|
||||
numpy_string_array = np.array(
|
||||
[["abcd", "efgh"], ["ijkl", "mnop"]], dtype=np.dtypes.StringDType() # type: ignore
|
||||
)
|
||||
mesh = jax.sharding.Mesh(np.array(cpu_devices).reshape((2, 1)), ("x", "y"))
|
||||
sharding = jax.sharding.NamedSharding(
|
||||
mesh, jax.sharding.PartitionSpec("x", "y")
|
||||
)
|
||||
|
||||
if asarray:
|
||||
jax_string_array = jnp.asarray(numpy_string_array, device=sharding)
|
||||
else:
|
||||
jax_string_array = jax.device_put(numpy_string_array, device=sharding)
|
||||
jax_string_array.block_until_ready()
|
||||
|
||||
array_read_back = jax.device_get(jax_string_array)
|
||||
self.assertEqual(array_read_back.dtype, np.dtypes.StringDType()) # type: ignore
|
||||
np.testing.assert_array_equal(array_read_back, numpy_string_array)
|
||||
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_dtype_conversions(self):
|
||||
cpu_devices = jax.devices("cpu")
|
||||
if len(cpu_devices) < 1:
|
||||
self.skipTest("Skipping this test because no CPU devices are available.")
|
||||
|
||||
# Explicitly specifying the dtype should work with StringDType numpy arrays.
|
||||
numpy_string_array = np.array(
|
||||
["abcd", "efgh"], dtype=np.dtypes.StringDType() # type: ignore
|
||||
)
|
||||
jax_string_array = jnp.asarray(
|
||||
numpy_string_array,
|
||||
device=cpu_devices[0],
|
||||
dtype=np.dtypes.StringDType(),
|
||||
) # type: ignore
|
||||
jax_string_array.block_until_ready()
|
||||
|
||||
# Cannot make a non-StringDType array from a StringDType numpy array.
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Cannot make an array with dtype bfloat16 from an object with dtype"
|
||||
r" StringDType.*",
|
||||
):
|
||||
jnp.asarray(
|
||||
numpy_string_array,
|
||||
device=cpu_devices[0],
|
||||
dtype=jnp.bfloat16,
|
||||
)
|
||||
|
||||
# Cannot make a StringDType array from a numeric numpy array.
|
||||
numpy_int_array = np.arange(2, dtype=np.int32)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Cannot make an array with dtype StringDType.*from an object with"
|
||||
r" dtype int32.",
|
||||
):
|
||||
jnp.asarray(
|
||||
numpy_int_array,
|
||||
device=cpu_devices[0],
|
||||
dtype=np.dtypes.StringDType(), # type: ignore
|
||||
)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("asarray", True),
|
||||
("device_put", False),
|
||||
)
|
||||
@jtu.skip_on_devices("cpu")
|
||||
def test_string_array_cannot_be_non_cpu_devices(self, asarray):
|
||||
devices = jax.devices()
|
||||
if len(devices) < 1:
|
||||
self.skipTest("Skipping this test because no devices are available.")
|
||||
|
||||
numpy_string_array = np.array(
|
||||
["abcdefghijklmnopqrstuvwxyz", "cba"], dtype=np.dtypes.StringDType() # type: ignore
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "String arrays can only be sharded to CPU devices"
|
||||
):
|
||||
if asarray:
|
||||
jax_string_array = jnp.asarray(numpy_string_array, device=devices[0])
|
||||
else:
|
||||
jax_string_array = jax.device_put(numpy_string_array, device=devices[0])
|
||||
jax_string_array.block_until_ready()
|
||||
|
||||
def test_jit_fails_with_string_arrays(self):
|
||||
f = jax.jit(lambda x: x)
|
||||
input_array = self.make_test_string_array()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Argument.*is not a valid JAX type.",
|
||||
lambda: f(input_array),
|
||||
)
|
||||
|
||||
def test_grad_fails_with_string_arrays(self):
|
||||
f = jax.grad(lambda x: x)
|
||||
input_array = self.make_test_string_array()
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"Argument.*is not a valid JAX type.",
|
||||
lambda: f(input_array),
|
||||
)
|
||||
|
||||
def test_vmap_without_jit_works_with_string_arrays(self):
|
||||
f = jax.vmap(lambda x: x)
|
||||
input_array = self.make_test_string_array()
|
||||
output_array = f(input_array)
|
||||
self.assertEqual(output_array.dtype, input_array.dtype)
|
||||
np.testing.assert_array_equal(output_array, input_array)
|
||||
|
||||
def test_vmap_with_jit_fails_with_string_arrays(self):
|
||||
f = jax.vmap(lambda x: x + jnp.arange(2))
|
||||
input_array = self.make_test_string_array()
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r".*StringDType.*is not a valid dtype",
|
||||
lambda: f(input_array),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
4
third_party/xla/workspace.bzl
vendored
4
third_party/xla/workspace.bzl
vendored
@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update XLA_SHA256 with the result.
|
||||
|
||||
XLA_COMMIT = "46f8cf03902c0af58468e2258f9438788e7f4c97"
|
||||
XLA_SHA256 = "0c391b0a8433d26bfc93e5bee775f7eb629b811a42222ce2b4c7449044a5bc0d"
|
||||
XLA_COMMIT = "85eccd2ed9f2afd956ab17afd31480a042f07f92"
|
||||
XLA_SHA256 = "ed853428d3f92aeb3a0cabd564f2373309b4784cd6f90db74ccc2d2ae735984f"
|
||||
|
||||
def repo():
|
||||
tf_http_archive(
|
||||
|
Loading…
x
Reference in New Issue
Block a user