diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml index 6143d1887..d354ffcf3 100644 --- a/.github/workflows/tsan.yaml +++ b/.github/workflows/tsan.yaml @@ -198,4 +198,8 @@ jobs: --test_output=errors \ --local_test_jobs=32 \ --test_timeout=600 \ + --config=resultstore \ + --spawn_strategy=local \ + --remote_cache=remotebuildexecution.googleapis.com \ + --remote_instance_name=projects/tensorflow-testing/instances/default_instance \ //tests:cpu_tests diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 1b63bec02..d3ddac4ed 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -264,10 +264,52 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "UEObolTqw4pp" + }, "source": [ "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", "\n", + "The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host.\n", + "\n", + "To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aKNeOHTJnqmS", + "outputId": "847c53ec-8b2e-4be0-f993-7fde7d77c0f2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pinned_host\n", + "device\n" + ] + } + ], + "source": [ + "s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host')\n", + "s_dev = s_host.with_memory_kind('device')\n", + "arr_host = jax.device_put(arr, s_host)\n", + "arr_dev = jax.device_put(arr, s_dev)\n", + "print(arr_host.sharding.memory_kind)\n", + "print(arr_dev.sharding.memory_kind)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jDHYnVqHwaST" + }, + "source": [ "## 1. Automatic parallelism via `jit`\n", "\n", "Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.\n", @@ -354,10 +396,98 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "Q4N5mrr9i_ki" + }, "source": [ "The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on.\n", "\n", + "### 1.1 Sharding transformation between memory types\n", + "\n", + "The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array.\n", + "\n", + "#### Example 1: Pinned host to device memory\n", + "\n", + "In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PXu3MhafyRHo", + "outputId": "7bc6821f-a4a9-4cf8-8b21-e279d516d27b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", + " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", + " [16. 17. 18. 19. 20. 21. 22. 23.]\n", + " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", + "device\n" + ] + } + ], + "source": [ + "f = jax.jit(lambda x: x, out_shardings=s_dev)\n", + "out_dev = f(arr_host)\n", + "print(out_dev)\n", + "print(out_dev.sharding.memory_kind)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LuYFqpcBySiX" + }, + "source": [ + "#### Example 2: Device to pinned_host memory\n", + "\n", + "In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qLsgNlKfybRw", + "outputId": "a16448b9-7e39-408f-b200-505f65ad4464" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0. 1. 2. 3. 4. 5. 6. 7.]\n", + " [ 8. 9. 10. 11. 12. 13. 14. 15.]\n", + " [16. 17. 18. 19. 20. 21. 22. 23.]\n", + " [24. 25. 26. 27. 28. 29. 30. 31.]]\n", + "pinned_host\n" + ] + } + ], + "source": [ + "g = jax.jit(lambda x: x, out_shardings=s_host)\n", + "out_host = g(arr_dev)\n", + "print(out_host)\n", + "print(out_host.sharding.memory_kind)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7BGD31-owaSU" + }, + "source": [ "## 2. Semi-automated sharding with constraints\n", "\n", "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 14eb968eb..b05eb8d5f 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -90,8 +90,31 @@ print(arr_sharded) jax.debug.visualize_array_sharding(arr_sharded) ``` ++++ {"id": "UEObolTqw4pp"} + The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. +The {class}`~jax.sharding.NamedSharding` includes a parameter called `memory_kind`. This parameter determines the type of memory to be used and defaults to `device`. You can set this parameter to `pinned_host` if you prefer to place it on the host. + +To create a new sharding that only differs from an existing sharding in terms of its memory kind, you can use the `with_memory_kind` method on the existing sharding. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: aKNeOHTJnqmS +outputId: 847c53ec-8b2e-4be0-f993-7fde7d77c0f2 +--- +s_host = jax.NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') +s_dev = s_host.with_memory_kind('device') +arr_host = jax.device_put(arr, s_host) +arr_dev = jax.device_put(arr, s_dev) +print(arr_host.sharding.memory_kind) +print(arr_dev.sharding.memory_kind) +``` + ++++ {"id": "jDHYnVqHwaST"} + ## 1. Automatic parallelism via `jit` Once you have sharded data, the easiest way to do parallel computation is to simply pass the data to a {func}`jax.jit`-compiled function! In JAX, you need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications. @@ -129,8 +152,52 @@ jax.debug.visualize_array_sharding(result) print(result) ``` ++++ {"id": "Q4N5mrr9i_ki"} + The result is partially replicated: that is, the first two elements of the array are replicated on devices `0` and `6`, the second on `1` and `7`, and so on. +### 1.1 Sharding transformation between memory types + +The output sharding of a {func}`jax.jit` function can differ from the input sharding if you specify the output sharding using the `out_shardings` parameter. Specifically, the `memory_kind` of the output can be different from that of the input array. + +#### Example 1: Pinned host to device memory + +In the example below, the {func}`jax.jit` function `f` takes an array sharded in `pinned_host` memory and generates an array in `device` memory. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: PXu3MhafyRHo +outputId: 7bc6821f-a4a9-4cf8-8b21-e279d516d27b +--- +f = jax.jit(lambda x: x, out_shardings=s_dev) +out_dev = f(arr_host) +print(out_dev) +print(out_dev.sharding.memory_kind) +``` + ++++ {"id": "LuYFqpcBySiX"} + +#### Example 2: Device to pinned_host memory + +In the example below, the {func}`jax.jit` function `g` takes an array sharded in `device` memory and generates an array in `pinned_host` memory. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: qLsgNlKfybRw +outputId: a16448b9-7e39-408f-b200-505f65ad4464 +--- +g = jax.jit(lambda x: x, out_shardings=s_host) +out_host = g(arr_dev) +print(out_host) +print(out_host.sharding.memory_kind) +``` + ++++ {"id": "7BGD31-owaSU"} + ## 2. Semi-automated sharding with constraints If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. diff --git a/jax/BUILD b/jax/BUILD index 1225993ad..28d32d834 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -499,6 +499,7 @@ pytype_strict_library( ":traceback_util", ":typing", ":util", + "//jax/_src/lib", ] + py_deps("ml_dtypes") + py_deps("numpy"), ) diff --git a/jax/_src/api.py b/jax/_src/api.py index 581b2b512..bc4c8c9f0 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -99,6 +99,7 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +@api_boundary def _nan_check_posthook(fun, args, kwargs, output): """Hook function called by the C++ jit/pmap to perform NaN checking.""" buffers = [] @@ -108,12 +109,18 @@ def _nan_check_posthook(fun, args, kwargs, output): try: dispatch.check_special(pjit.pjit_p.name, buffers) - except FloatingPointError: - # compiled_fun can only raise in this case + except dispatch.InternalFloatingPointError as e: assert config.debug_nans.value or config.debug_infs.value - print("Invalid nan value encountered in the output of a C++-jit/pmap " - "function. Calling the de-optimized version.") - fun._cache_miss(*args, **kwargs)[0] # probably won't return + if hasattr(fun, '_fun'): + f = fun._fun + if getattr(f, '_apply_primitive', False): + raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None + # compiled_fun can only raise in this case + dispatch.maybe_recursive_nan_check(e, f, args, kwargs) + raise AssertionError("Unreachable") from e + else: + # TODO(emilyaf): Shouldn't need this fallback. + raise def _update_debug_special_global(_): if config._read("jax_debug_nans") or config._read("jax_debug_infs"): @@ -1574,11 +1581,14 @@ def _cpp_pmap( execute: Callable | None = None with core.take_current_trace() as trace: - if isinstance(trace, core.EvalTrace): - execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params) - out = execute(*p.flat_args) - else: - out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) + try: + if isinstance(trace, core.EvalTrace): + execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params) + out = execute(*p.flat_args) + else: + out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params) + except dispatch.InternalFloatingPointError as e: + raise FloatingPointError(f'Invalid value ({e.ty}) encountered in parallel computation.') out_tree, out_flat = p.out_tree, out out_pytree_def = out_tree() @@ -1629,6 +1639,7 @@ def _cpp_pmap( _pmap_cache_clears.add(cpp_mapped_f) pmap_f = wraps(fun)(cpp_mapped_f) + pmap_f._fun = fun @api_boundary def lower(*args, **kwargs): @@ -1674,6 +1685,7 @@ def _cpp_pmap( _pmap_cache_clears = weakref.WeakSet() # type: ignore +@api_boundary def jvp( fun: Callable, primals, tangents, has_aux: bool = False ) -> tuple[Any, ...]: @@ -1878,6 +1890,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args): return apply_flat_fun_nokwargs(fun, io_tree, py_args) +@api_boundary def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): if len(py_args_) != 1: msg = (f"The function returned by `jax.vjp` applied to {name} was called " @@ -1937,6 +1950,7 @@ def vjp(fun: Callable[..., tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]: ... +@api_boundary def vjp( fun: Callable, *primals, has_aux: bool = False, reduce_axes=() ) -> tuple[Any, Callable] | tuple[Any, Callable, Any]: @@ -2225,6 +2239,18 @@ def _infer_src_sharding(src, x) -> Sharding | None: return None +@lru_cache(maxsize=2048) +def _check_string_compatible_sharding(s): + """Checks if target devices are compatible with string arrays.""" + if isinstance(s, xc.Device) and s.device_kind == "cpu": + return + if (isinstance(s, Sharding) + and s._internal_device_list[0].device_kind == "cpu"): + return + raise TypeError( + "String arrays can only be sharded to CPU devices. Received" + f" unsupported device or sharding: {s}") + # TODO(yashkatariya): Generalize check_compatible_aval (maybe renamed) and use # that to check if shardings are compatible with the input. @lru_cache(maxsize=2048) @@ -2235,6 +2261,10 @@ def _check_sharding(aval, s): "`jax.device_put` only accepts `None`, `jax.sharding.Sharding`," " `jax.Device`, `Layout` or a pytree of these values. Received" f" invalid value: {s}") + + if isinstance(aval, core.ShapedArray) and dtypes.is_string_dtype(aval.dtype): + _check_string_compatible_sharding(s) + if isinstance(s, Sharding): if isinstance(aval, core.AbstractToken): aval = core.get_token_aval() diff --git a/jax/_src/core.py b/jax/_src/core.py index 2aefe1544..c1a833c3c 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1472,11 +1472,14 @@ Value = Any def valid_jaxtype(x) -> bool: try: - abstractify(x) + aval = abstractify(x) except TypeError: return False else: - return True + if hasattr(aval, "dtype") and dtypes.is_string_dtype(aval.dtype): + return False + else: + return True def check_valid_jaxtype(x): if not valid_jaxtype(x): diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 7dea452c8..ea9408dbd 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -25,7 +25,7 @@ import itertools import logging import threading import time -from typing import Any, NamedTuple +from typing import Any, Callable, NamedTuple import jax from jax._src import api @@ -100,6 +100,7 @@ def xla_primitive_callable(prim: core.Primitive, **params): return prim.bind(*args, **params) prim_fun.__name__ = prim.name prim_fun.__qualname__ = prim.name + prim_fun._apply_primitive = True return api.jit(prim_fun) @@ -321,15 +322,52 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None: def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: if dtypes.issubdtype(dtype, np.inexact): if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): - raise FloatingPointError(f"invalid value (nan) encountered in {name}") + raise InternalFloatingPointError(name, "nan") if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))): - raise FloatingPointError(f"invalid value (inf) encountered in {name}") + raise InternalFloatingPointError(name, "inf") class CopySemantics(enum.Enum): ALIAS = enum.auto() COPY = enum.auto() DONATE = enum.auto() +class InternalFloatingPointError(Exception): + name: str + ty: str + + def __init__(self, name: str, ty: str): + self.name = name + self.ty = ty + +def maybe_recursive_nan_check(e: Exception, fun: Callable, args, kwargs, +) -> None: # always raises an exception + print("Invalid nan value encountered in the output of a jax.jit " + "function. Calling the de-optimized version.") + try: + _ = fun(*args, **kwargs) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + _raise_no_nan_in_deoptimized(e) + +def _raise_no_nan_in_deoptimized(e) -> None: + msg = (f"{str(e)}. Because " + "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " + "de-optimized function (i.e., the function as if the `jit` " + "decorator were removed) was called in an attempt to get a more " + "precise error message. However, the de-optimized function did not " + "produce invalid values during its execution. This behavior can " + "result from `jit` optimizations causing the invalid value to be " + "produced. It may also arise from having nan/inf literals as " + "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " + "\n\n" + "It may be possible to avoid the invalid value by removing the " + "`jit` decorator, at the cost of losing optimizations. " + "\n\n" + "If you see this error, consider opening a bug report at " + "https://github.com/jax-ml/jax.") + raise FloatingPointError(msg) from None + def _identity_fn(x): return x diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f8f3d8597..953381781 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -33,6 +33,7 @@ import ml_dtypes import numpy as np from jax._src import config +from jax._src.lib import xla_extension_version from jax._src.typing import Array, DType, DTypeLike from jax._src.util import set_module, StrictABC @@ -486,18 +487,37 @@ _complex_types: list[JAXType] = [ np.dtype('complex64'), np.dtype('complex128'), ] -_jax_types = _bool_types + _int_types + _float_types + _complex_types -_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types} +# We add the StringDType only to `_jax_dtype_set` but not to `_jax_types` and +# `_dtype_kinds`. This is because, in spite of a very similar sounding name, +# `_jax_types` is only meant for the promotion related logic, and StringDType +# does not participate in promotions at the moment. Similarly, `_dtype_kinds` is +# only meant for the `jnp.isdtype` and we want to be conservative and not allow +# StringDType to be used in there. +_string_types: list[JAXType] = [] +if hasattr(np.dtypes, 'StringDType') and xla_extension_version >= 311: + _string_types: list[JAXType] = [np.dtypes.StringDType()] # type: ignore + +_jax_dtype_set = { + float0, + *_bool_types, + *_int_types, + *_float_types, + *_complex_types, + *_string_types, +} + +_jax_types = (_bool_types + _int_types + _float_types + _complex_types) + _dtype_kinds: dict[str, set] = { - 'bool': {*_bool_types}, - 'signed integer': {*_signed_types}, - 'unsigned integer': {*_unsigned_types}, - 'integral': {*_signed_types, *_unsigned_types}, - 'real floating': {*_float_types}, - 'complex floating': {*_complex_types}, - 'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types}, + 'bool': {*_bool_types}, + 'signed integer': {*_signed_types}, + 'unsigned integer': {*_unsigned_types}, + 'integral': {*_signed_types, *_unsigned_types}, + 'real floating': {*_float_types}, + 'complex floating': {*_complex_types}, + 'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types}, } @@ -870,8 +890,14 @@ def check_user_dtype_supported(dtype, fun_name=None): uint2, uint4 ] - if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: - msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" + if ( + np_dtype.kind not in 'biufcT' + and not is_custom_dtype + and not dtype == float0 + ): + msg = ( + f'JAX only supports number, bool, and string dtypes, got dtype {dtype}' + ) msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) if dtype is not None and np_dtype != canonicalize_dtype(np_dtype): @@ -949,3 +975,7 @@ def short_dtype_name(dtype) -> str: else: return (dtype.name.replace('float', 'f').replace('uint' , 'u') .replace('int' , 'i').replace('complex', 'c')) + + +def is_string_dtype(dtype: DTypeLike | None) -> bool: + return dtype in _string_types diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 8d0f0e7ca..ccee33eb5 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -22,6 +22,7 @@ from functools import partial from typing import Any from jax._src import config +from jax._src import dispatch from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax.tree_util import (tree_flatten, tree_unflatten, @@ -360,8 +361,15 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack, cts_out = get_primitive_transpose(eqn.primitive)( params, call_jaxpr, invals, cts_in, cts_in_avals) else: - cts_out = get_primitive_transpose(eqn.primitive)( - cts_in, *invals, **eqn.params) + try: + cts_out = get_primitive_transpose(eqn.primitive)( + cts_in, *invals, **eqn.params) + except (FloatingPointError, ZeroDivisionError) as e: + msg = "When differentiating the code at the top of the callstack:" + if msg not in e.args[0]: + e.args = e.args[0] + f'\n{msg}', + e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}', + raise e from None cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) @@ -1003,7 +1011,20 @@ def map_transpose(primitive, params, call_jaxpr, args, ct, _): if update_params: new_params = update_params(new_params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) - out_flat = primitive.bind(fun, *all_args, **new_params) + + try: + out_flat = primitive.bind(fun, *all_args, **new_params) + except dispatch.InternalFloatingPointError as e: + print("Invalid nan value encountered in the backward pass of a jax.jit " + "function. Calling the de-optimized backward pass.") + try: + _ = backward_pass(call_jaxpr, None, {}, args, ct) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + # If control reaches this line, we got a NaN on the output of `compiled` + # but not `fun.call_wrapped` on the same arguments. Let's tell the user. + dispatch._raise_no_nan_in_deoptimized(e) arg_cts = tree_unflatten(out_tree(), out_flat) # The freevars are being fanned out (not mapped). During transpose the diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index e00b5b3e8..4da1f23bf 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -57,6 +57,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax.lax import (PrecisionLike,_array_copy, _sort_le_comparator, _sort_lt_comparator) from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.numpy import util @@ -5474,6 +5475,39 @@ def _supports_buffer_protocol(obj): return True +def _make_string_array( + object: np.ndarray, + dtype: DTypeLike | None = None, + ndmin: int = 0, + device: xc.Device | Sharding | None = None, +) -> Array: + if xla_extension_version < 311: + raise TypeError( + "String arrays are not supported in JAX before XLA extension version" + " 311." + ) + if not isinstance(object, np.ndarray): + raise TypeError( + "Currently, string arrays can only be made from NumPy" + f" arrays. Got: {type(object)}." + ) + if dtype is not None and ( + dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype) + ): + raise TypeError( + f"Cannot make an array with dtype {dtype} from an object with dtype" + f" {object.dtype}." + ) + if ndmin > object.ndim: + raise TypeError( + f"ndmin {ndmin} cannot be greater than object's ndims" + f" {object.ndim} for string arrays." + ) + + # Just do a device_put since XLA does not support string as a data type. + return jax.device_put(x=object, device=device) + + @export def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, @@ -5567,6 +5601,15 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # Keep the output uncommitted. return jax.device_put(object) + # String arrays need separate handling because XLA does not support string + # as a data type. + if dtypes.is_string_dtype(dtype) or ( + hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype) + ): + return _make_string_array( + object=object, dtype=dtype, ndmin=ndmin, device=device + ) + # For Python scalar literals, call coerce_to_array to catch any overflow # errors. We don't use dtypes.is_python_scalar because we don't want this # triggering for traced values. We do this here because it matters whether or diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index faf70200d..c9cefcd72 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -222,6 +222,10 @@ def _python_pjit_helper(fun: Callable, jit_info: PjitInfo, *args, **kwargs): f"Argument '{name}' of shape {aval.str_short()} of type" f' {type(arg)} is not a valid JAX type.') from e raise AssertionError("Unreachable") from e + except dispatch.InternalFloatingPointError as e: + if getattr(fun, '_apply_primitive', False): + raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None + dispatch.maybe_recursive_nan_check(e, fun, args, kwargs) if p.attrs_tracked: num_states_out = sum(end_tree.num_leaves for _, end_tree, _ in p.attrs_tracked) @@ -1700,33 +1704,7 @@ def _pjit_call_impl_python( ("out_layouts", out_layouts), ("abstract args", map(core.abstractify, args)), ("fingerprint", fingerprint)) - try: - return compiled.unsafe_call(*args), compiled, pgle_profiler - except FloatingPointError as e: - assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case - - if len(jaxpr.eqns) > 1: - _ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return - - # If control reaches this line, we got a NaN on the output of `compiled` - # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - msg = (f"{str(e)}. Because " - "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " - "de-optimized function (i.e., the function as if the `jit` " - "decorator were removed) was called in an attempt to get a more " - "precise error message. However, the de-optimized function did not " - "produce invalid values during its execution. This behavior can " - "result from `jit` optimizations causing the invalid value to be " - "produced. It may also arise from having nan/inf constants as " - "outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " - "\n\n" - "It may be possible to avoid the invalid value by removing the " - "`jit` decorator, at the cost of losing optimizations. " - "\n\n" - "If you see this error, consider opening a bug report at " - "https://github.com/jax-ml/jax.") - raise FloatingPointError(msg) - + return compiled.unsafe_call(*args), compiled, pgle_profiler @weakref_lru_cache def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts, @@ -2404,19 +2382,31 @@ def _pjit_transpose(cts_in, *primals_in, transpose_in_layouts = (None,) * len(attrs_tracked) + transpose_in_layouts transpose_out_layouts = (None,) * len(attrs_tracked) + transpose_out_layouts - nz_cts_out = pjit_p.bind( - *primals_and_nz_cts_in, - jaxpr=transpose_jaxpr, - in_shardings=transpose_in_shardings, - out_shardings=transpose_out_shardings, - in_layouts=transpose_in_layouts, - out_layouts=transpose_out_layouts, - resource_env=resource_env, - donated_invars=(False,) * len(primals_and_nz_cts_in), - name=name, - keep_unused=keep_unused, - inline=inline, - compiler_options_kvs=compiler_options_kvs) + try: + nz_cts_out = pjit_p.bind( + *primals_and_nz_cts_in, + jaxpr=transpose_jaxpr, + in_shardings=transpose_in_shardings, + out_shardings=transpose_out_shardings, + in_layouts=transpose_in_layouts, + out_layouts=transpose_out_layouts, + resource_env=resource_env, + donated_invars=(False,) * len(primals_and_nz_cts_in), + name=name, + keep_unused=keep_unused, + inline=inline, + compiler_options_kvs=compiler_options_kvs) + except dispatch.InternalFloatingPointError as e: + print("Invalid nan value encountered in the backward pass of a jax.jit " + "function. Calling the de-optimized backward pass.") + try: + _ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None # great + else: + # If control reaches this line, we got a NaN on the output of `compiled` + # but not `fun.call_wrapped` on the same arguments. Let's tell the user. + dispatch._raise_no_nan_in_deoptimized(e) if attrs_tracked: final_states, nz_cts_out = split_list(nz_cts_out, [len(init_states)]) diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD index d60b4c333..85ad90326 100644 --- a/jax/experimental/jax2tf/BUILD +++ b/jax/experimental/jax2tf/BUILD @@ -30,7 +30,6 @@ package( py_library( name = "jax2tf", srcs = ["__init__.py"], - srcs_version = "PY3", visibility = ["//visibility:public"], deps = [":jax2tf_internal"], ) @@ -42,7 +41,6 @@ py_library( "impl_no_xla.py", "jax2tf.py", ], - srcs_version = "PY3", # TODO: b/255503696: enable pytype tags = ["pytype_unchecked_annotations"], visibility = jax_visibility("jax2tf_internal"), diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD index 3417c1abf..d166f1308 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD @@ -24,7 +24,6 @@ package( py_library( name = "back_compat_testdata", srcs = glob(["*.py"]), - srcs_version = "PY3", deps = [ "//third_party/py/numpy", "//third_party/py/typing_extensions", diff --git a/jax/experimental/jax2tf/tests/flax_models/BUILD b/jax/experimental/jax2tf/tests/flax_models/BUILD index 331d4ab8e..c028f13a4 100644 --- a/jax/experimental/jax2tf/tests/flax_models/BUILD +++ b/jax/experimental/jax2tf/tests/flax_models/BUILD @@ -28,7 +28,6 @@ package( py_library( name = "flax_models", srcs = glob(["*.py"]), - srcs_version = "PY3", deps = [ "//jax", "//third_party/py/flax:core", diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 11538a368..32897fc82 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -874,6 +874,15 @@ def _match(mesh, check_rep, pspec, x): def _rem_singleton(x): return jnp.squeeze(x, axis=0) def _add_singleton(x): return jnp.expand_dims(x, axis=0) +def _maybe_check_special(outs): + if not config.debug_nans.value and not config.debug_infs.value: return + bufs = [s.data for leaf in tree_leaves(outs) + for s in getattr(leaf, 'addressable_shards', [])] + try: + dispatch.check_special('shard_map', bufs) + except dispatch.InternalFloatingPointError as e: + raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None + class ShardMapTrace(core.Trace): __slots__ = ("mesh", "check", "context_mesh") @@ -902,9 +911,10 @@ class ShardMapTrace(core.Trace): out_vals = eager_rule(self.mesh, *in_vals, **params) else: f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) - with (core.eval_context(), jax.disable_jit(False), - set_abstract_mesh(self.context_mesh)): + with (core.eval_context(), jax.disable_jit(False), jax.debug_nans(False), + jax.debug_infs(False), set_abstract_mesh(self.context_mesh)): out_vals = jax.jit(f)(*in_vals) + _maybe_check_special(out_vals) rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() if prim.multiple_results: @@ -1700,10 +1710,21 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, def new_out_names_thunk(): return tuple(names for names, nz in zip(in_names, nz_arg_cts()) if nz) - out_flat = shard_map_p.bind( - fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, - auto=auto) + try: + out_flat = shard_map_p.bind( + fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), + out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, + auto=auto) + except (FloatingPointError, ZeroDivisionError) as e: + print("Invalid nan value encountered in the backward pass of a shard_map " + "function. Calling the de-optimized backward pass.") + try: + _ = fun_trans.call_wrapped(out_cts, args) + except (FloatingPointError, ZeroDivisionError) as e2: + raise e2 from None + else: + dispatch._raise_no_nan_in_deoptimized(e) + return tree_unflatten(out_tree(), out_flat) ad.primitive_transposes[shard_map_p] = _shard_map_transpose diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index 1be1fa1c9..a931b0a30 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -27,6 +27,7 @@ from jax.interpreters import mlir from jax.interpreters import xla from jax._src import core +from jax._src import ffi from jax._src.interpreters import ad from jax._src.lib import gpu_solver @@ -533,10 +534,14 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder): def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder): - data_aval, _, _, _, = ctx.avals_in - return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, - indptr, b, tol, reorder) - + # TODO(danfm): remove after JAX 0.5.1 release. + if hasattr(gpu_solver, "cuda_csrlsvqr"): + data_aval, _, _, _, = ctx.avals_in + return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices, + indptr, b, tol, reorder) + return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")( + ctx, data, indices, indptr, b, tol=np.float64(tol), + reorder=np.int32(reorder)) def _spsolve_cpu_lowering(ctx, data, indices, indptr, b, tol, reorder): del tol, reorder diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 092affce8..29f5cd8be 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -230,6 +230,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", @@ -251,6 +252,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 2c0517fdf..2770f5c1b 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -50,6 +50,8 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA", + CsrlsvqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", OrgqrFfi); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index c74d9a147..357a38eec 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -486,6 +486,8 @@ nb::dict Registrations() { #ifdef JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); + dict[JAX_GPU_PREFIX "solver_csrlsvqr_ffi"] = + EncapsulateFfiHandler(CsrlsvqrFfi); #endif // JAX_GPU_CUDA return dict; diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc index d93d049d4..f43941321 100644 --- a/jaxlib/gpu/solver_interface.cc +++ b/jaxlib/gpu/solver_interface.cc @@ -20,6 +20,10 @@ limitations under the License. #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" +#ifdef JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cusolverSp.h" +#endif + namespace jax { namespace JAX_GPU_NAMESPACE { namespace solver { @@ -315,6 +319,23 @@ JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched); JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); #undef JAX_GPU_DEFINE_GESVDJ_BATCHED +#define JAX_GPU_DEFINE_CSRLSVQR(Type, Scalar, Name) \ + template <> \ + absl::Status Csrlsvqr( \ + cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \ + const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \ + const Type *b, double tol, int reorder, Type *x, int *singularity) { \ + return JAX_AS_STATUS(Name(handle, n, nnz, matdesc, csrValA, csrRowPtrA, \ + csrColIndA, b, static_cast(tol), \ + reorder, x, singularity)); \ + } + +JAX_GPU_DEFINE_CSRLSVQR(float, float, cusolverSpScsrlsvqr); +JAX_GPU_DEFINE_CSRLSVQR(double, double, cusolverSpDcsrlsvqr); +JAX_GPU_DEFINE_CSRLSVQR(gpuComplex, float, cusolverSpCcsrlsvqr); +JAX_GPU_DEFINE_CSRLSVQR(gpuDoubleComplex, double, cusolverSpZcsrlsvqr); +#undef JAX_GPU_DEFINE_CSRLSVQR + #endif // JAX_GPU_CUDA // Symmetric tridiagonal reduction: sytrd diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h index e84a688a6..fa11f3d0e 100644 --- a/jaxlib/gpu/solver_interface.h +++ b/jaxlib/gpu/solver_interface.h @@ -23,6 +23,10 @@ limitations under the License. #include "absl/strings/str_format.h" #include "jaxlib/gpu/vendor.h" +#ifdef JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cusolverSp.h" +#endif + namespace jax { namespace JAX_GPU_NAMESPACE { namespace solver { @@ -206,6 +210,13 @@ JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBatchedBufferSize); JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); #undef JAX_GPU_SOLVER_GesvdjBatched_ARGS +#define JAX_GPU_SOLVER_Csrlsvqr_ARGS(Type, ...) \ + cusolverSpHandle_t handle, int n, int nnz, cusparseMatDescr_t matdesc, \ + const Type *csrValA, const int *csrRowPtrA, const int *csrColIndA, \ + const Type *b, double tol, int reorder, Type *x, int *singularity +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Csrlsvqr); +#undef JAX_GPU_SOLVER_Csrlsvqr_ARGS + #endif // JAX_GPU_CUDA // Symmetric tridiagonal reduction: sytrd diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 7e6f14ed4..eb45163a2 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -41,6 +41,10 @@ limitations under the License. #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" +#ifdef JAX_GPU_CUDA +#include "third_party/gpus/cuda/include/cusolverSp.h" +#endif + #define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) @@ -1013,6 +1017,82 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch, .Ret>() // info ); +// csrlsvqr: Linear system solve via Sparse QR + +template +ffi::Error CsrlsvqrImpl(int64_t n, int64_t nnz, double tol, int reorder, + gpuStream_t stream, ffi::AnyBuffer csrValA, + ffi::Buffer csrColIndA, + ffi::Buffer csrRowPtrA, ffi::AnyBuffer b, + ffi::Result x) { + FFI_ASSIGN_OR_RETURN(auto handle, SpSolverHandlePool::Borrow(stream)); + + FFI_ASSIGN_OR_RETURN(auto int_n, MaybeCastNoOverflow(n)); + FFI_ASSIGN_OR_RETURN(auto int_nnz, MaybeCastNoOverflow(nnz)); + + cusparseMatDescr_t matdesc = nullptr; + JAX_FFI_RETURN_IF_GPU_ERROR(cusparseCreateMatDescr(&matdesc)); + JAX_FFI_RETURN_IF_GPU_ERROR( + cusparseSetMatType(matdesc, CUSPARSE_MATRIX_TYPE_GENERAL)); + JAX_FFI_RETURN_IF_GPU_ERROR( + cusparseSetMatIndexBase(matdesc, CUSPARSE_INDEX_BASE_ZERO)); + + auto* csrValA_data = static_cast(csrValA.untyped_data()); + auto* csrColIndA_data = csrColIndA.typed_data(); + auto* csrRowPtrA_data = csrRowPtrA.typed_data(); + auto* b_data = static_cast(b.untyped_data()); + auto* x_data = static_cast(x->untyped_data()); + + int singularity = -1; + auto result = solver::Csrlsvqr( + handle.get(), int_n, int_nnz, matdesc, csrValA_data, csrRowPtrA_data, + csrColIndA_data, b_data, tol, reorder, x_data, &singularity); + cusparseDestroyMatDescr(matdesc); + FFI_RETURN_IF_ERROR_STATUS(result); + + if (singularity >= 0) { + return ffi::Error(ffi::ErrorCode::kInternal, + "Singular matrix in linear solve."); + } + + return ffi::Error::Success(); +} + +ffi::Error CsrlsvqrDispatch(gpuStream_t stream, int reorder, double tol, + ffi::AnyBuffer csrValA, + ffi::Buffer csrColIndA, + ffi::Buffer csrRowPtrA, ffi::AnyBuffer b, + ffi::Result x) { + auto dataType = csrValA.element_type(); + if (dataType != b.element_type() || dataType != x->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to csrlsvqr must have the same element type"); + } + int64_t n = b.element_count(); + int64_t nnz = csrValA.element_count(); + FFI_RETURN_IF_ERROR( + CheckShape(csrColIndA.dimensions(), nnz, "csrColIndA", "csrlsvqr")); + FFI_RETURN_IF_ERROR( + CheckShape(csrRowPtrA.dimensions(), n + 1, "csrColPtrA", "csrlsvqr")); + FFI_RETURN_IF_ERROR(CheckShape(x->dimensions(), n, "x", "csrlsvqr")); + SOLVER_DISPATCH_IMPL(CsrlsvqrImpl, n, nnz, tol, reorder, stream, csrValA, + csrColIndA, csrRowPtrA, b, x); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in csrlsvqr", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CsrlsvqrFfi, CsrlsvqrDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("reorder") // reorder + .Attr("tol") // tol + .Arg() // csrValA + .Arg>() // csrColIndA + .Arg>() // csrRowPtrA + .Arg() // b + .Ret() // x +); + #endif // JAX_GPU_CUDA // Symmetric tridiagonal reduction: sytrd diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 2f9494d7f..8e90a310e 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -40,6 +40,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(SytrdFfi); #ifdef JAX_GPU_CUDA XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CsrlsvqrFfi); #endif // JAX_GPU_CUDA } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 68bfa7b59..1b4cecdf0 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -12,17 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import importlib -import jaxlib.mlir.ir as ir - -import numpy as np - from jaxlib import xla_client -from .hlo_helpers import custom_call - try: from .cuda import _blas as _cublas # pytype: disable=import-error except ImportError: @@ -129,27 +122,3 @@ def has_magma(): if _hiphybrid: return _hiphybrid.has_magma() return False - -def _csrlsvqr_hlo(platform, gpu_solver, dtype, data, - indices, indptr, b, tol, reorder): - """Sparse solver via QR decomposition. CUDA only.""" - b_type = ir.RankedTensorType(b.type) - data_type = ir.RankedTensorType(data.type) - - n = b_type.shape[0] - nnz = data_type.shape[0] - opaque = gpu_solver.build_csrlsvqr_descriptor( - np.dtype(dtype), n, nnz, reorder, tol - ) - - out = custom_call( - f"{platform}solver_csrlsvqr", # call_target_name - result_types=[b.type], - operands=[data, indptr, indices, b], - backend_config=opaque, # backend_config - operand_layouts=[(0,), (0,), (0,), (0,)], # operand_layouts - result_layouts=[(0,)] # result_layouts - ).results - return out - -cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver) diff --git a/tests/BUILD b/tests/BUILD index 6aca3372c..e1c055ed5 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1607,6 +1607,11 @@ jax_py_test( ], ) +jax_multiplatform_test( + name = "string_array_test", + srcs = ["string_array_test.py"], +) + jax_multiplatform_test( name = "cudnn_fusion_test", srcs = ["cudnn_fusion_test.py"], @@ -1642,6 +1647,7 @@ exports_files( "shard_map_test.py", "transfer_guard_test.py", "layout_test.py", + "string_array_test.py", ], visibility = jax_test_file_visibility, ) diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 4573f542c..c0cef5084 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -24,6 +24,8 @@ from jax._src import api from jax._src import test_util as jtu from jax import numpy as jnp from jax.experimental import pjit +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P jax.config.parse_flags_with_absl() @@ -75,7 +77,6 @@ class DebugNaNsTest(jtu.JaxTestCase): @jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION) def testCallDeoptimized(self, jit): - raise SkipTest("re-enable once we handle contexts properly") # TODO(dougalm) @jit def f(x): return jax.lax.cond( @@ -89,6 +90,25 @@ class DebugNaNsTest(jtu.JaxTestCase): with self.assertRaisesRegex(FloatingPointError, msg): f(1) + def testShardMap(self): + mesh = jax.make_mesh((1,), ('x',)) + f = shard_map(lambda x: 0. / x, mesh=mesh, in_specs=(P('x')), out_specs=P('x')) + # For the Cpp pmap, the first execution always goes through Python. + f(jnp.array([1.])) + + with self.assertRaisesRegex( + FloatingPointError, + r"Invalid value \(nan\) encountered in sharded computation"): + ans = f(jnp.array([0.])) + ans.block_until_ready() + + if jax.device_count() >= 2: + with self.assertRaisesRegex( + FloatingPointError, + r"Invalid value \(nan\) encountered in sharded computation"): + ans = f(jnp.array([1., 0.])) + ans.block_until_ready() + def testPmap(self): pmap_funcs = [api._cpp_pmap] @@ -99,17 +119,47 @@ class DebugNaNsTest(jtu.JaxTestCase): with self.assertRaisesRegex( FloatingPointError, - r"invalid value \(nan\) encountered in parallel computation"): + r"invalid value \(nan\) encountered in div"): ans = f(jnp.array([0.])) ans.block_until_ready() if jax.device_count() >= 2: with self.assertRaisesRegex( FloatingPointError, - r"invalid value \(nan\) encountered in parallel computation"): + r"Invalid value \(nan\) encountered in parallel computation"): ans = f(jnp.array([1., 0.])) ans.block_until_ready() + def testGradPmap(self): + @jax.jit + def f(x): + y = x**2 + return jnp.log(y) + + _, f_vjp = jax.vjp(jax.pmap(f), jnp.zeros([1])) + + with self.assertRaisesRegex( + FloatingPointError, + r"invalid value \(nan\) encountered in mul\nWhen differentiating"): + ans, = f_vjp(jnp.ones([1])) + ans.block_until_ready() + + def testGradShardMap(self): + @jax.jit + def f(x): + y = x**2 + return jnp.log(y) + + mesh = jax.make_mesh((1,), ('x',)) + shmap_f = shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x')) + _, f_vjp = jax.vjp(shmap_f, jnp.zeros([1])) + + with self.assertRaisesRegex( + FloatingPointError, + r"invalid value \(nan\) encountered in mul\nWhen differentiating"): + ans, = f_vjp(jnp.ones([1])) + ans.block_until_ready() + def testPmapNoNaN(self): ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.])) ans.block_until_ready() @@ -163,17 +213,23 @@ class DebugNaNsTest(jtu.JaxTestCase): with self.assertRaisesRegex( FloatingPointError, - r"invalid value \(nan\) encountered in jit\(true_divide\)"): + r"invalid value \(nan\) encountered in div"): f(inp, inp) - # TODO(yashkatariya): Fix this and make true_divide appear in the name again. - # Instead of `f` showing up in the error, the name should be of the - # primitive (true_divide) in this case. with self.assertRaisesRegex( FloatingPointError, - r"invalid value \(nan\) encountered in jit\(f\)"): + r"invalid value \(nan\) encountered in div"): jax.jit(f)(inp, inp) + def testDebugNansInput(self): + + @jax.jit + def f(x): + return x * 3. + + with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*input"): + f(np.nan) + @jtu.with_config(jax_debug_infs=True) class DebugInfsTest(jtu.JaxTestCase): @@ -233,7 +289,7 @@ class DebugInfsTest(jtu.JaxTestCase): y = x + 2 # avoid trivial dispatch path by adding some eqn return jnp.nan, y - with self.assertRaisesRegex(FloatingPointError, "de-optimized"): + with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*literal"): with jax.debug_nans(True): f(3) diff --git a/tests/errors_test.py b/tests/errors_test.py index 7dfc4e51a..25f29cfee 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -335,6 +335,47 @@ class FilteredTracebackTest(jtu.JaxTestCase): ('bwd_err', 'g = err(g)'), ('err', 'assert False')], filter_mode=filter_mode) + def test_jvp(self, filter_mode): + def err(_): + assert False + return () + + def f(): + p = (1.,) + t = (0.,) + return jax.jvp(err, p, t) + + check_filtered_stack_trace(self, AssertionError, f, [ + ('f', 'return jax.jvp(err, p, t)'), + ('err', 'assert False')], filter_mode=filter_mode) + + def test_vjp(self, filter_mode): + def err(_): + assert False + return () + + def f(): + x = 1. + return jax.vjp(err, x)[0] + + check_filtered_stack_trace(self, AssertionError, f, [ + ('f', 'return jax.vjp(err, x)[0]'), + ('err', 'assert False')], filter_mode=filter_mode) + + def test_debug_nans(self, filter_mode): + @jax.jit + def f(x): + return 0. / x + + f(2.) + def g(): + return f(0.) + + with jax.debug_nans(True): + check_filtered_stack_trace(self, ZeroDivisionError, g, [ + ('g', 'return f(0.)'), + ('f', 'return 0. / x')], filter_mode=filter_mode) + def test_cause_chain(self, filter_mode): @jit def inner(x): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 62b0fc994..00d8f8dc0 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3758,8 +3758,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self.assertIsNot(x, y) def testArrayUnsupportedDtypeError(self): - with self.assertRaisesRegex(TypeError, - "JAX only supports number and bool dtypes.*"): + with self.assertRaisesRegex( + TypeError, 'JAX only supports number, bool, and string dtypes.*' + ): jnp.array(3, [('a','.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "46f8cf03902c0af58468e2258f9438788e7f4c97" -XLA_SHA256 = "0c391b0a8433d26bfc93e5bee775f7eb629b811a42222ce2b4c7449044a5bc0d" +XLA_COMMIT = "85eccd2ed9f2afd956ab17afd31480a042f07f92" +XLA_SHA256 = "ed853428d3f92aeb3a0cabd564f2373309b4784cd6f90db74ccc2d2ae735984f" def repo(): tf_http_archive(