mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[sharding_in_types] Add out_type
argument to einsum
and dot_general
to allow specifying for the output type. Right now, it only accept a NamedSharding
but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout
.
PiperOrigin-RevId: 688399552
This commit is contained in:
parent
5d3cac6603
commit
ebb75db8a5
@ -1040,7 +1040,8 @@ DotDimensionNumbers = tuple[tuple[Sequence[int], Sequence[int]],
|
||||
|
||||
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None) -> Array:
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
out_type=None) -> Array:
|
||||
"""General dot product/contraction operator.
|
||||
|
||||
Wraps XLA's `DotGeneral
|
||||
@ -1086,6 +1087,10 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
|
||||
non-contracting/non-batch dimensions.
|
||||
"""
|
||||
if out_type is not None and not isinstance(out_type, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
'`out_type` argument of `dot_general` only supports NamedSharding '
|
||||
'instances. Please file a bug if this is not enough for your use case.')
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
cdims = (api_util._ensure_index_tuple(lhs_contract),
|
||||
api_util._ensure_index_tuple(rhs_contract))
|
||||
@ -1097,7 +1102,8 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
return dot_general_p.bind(lhs, rhs,
|
||||
dimension_numbers=(cdims, bdims),
|
||||
precision=canonicalize_precision(precision),
|
||||
preferred_element_type=preferred_element_type)
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=out_type)
|
||||
|
||||
|
||||
def ragged_dot(
|
||||
@ -3002,7 +3008,11 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
operand = hlo.real(operand)
|
||||
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
||||
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
|
||||
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
|
||||
if config.sharding_in_types.value:
|
||||
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
||||
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
|
||||
|
||||
@ -3164,7 +3174,8 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
|
||||
|
||||
|
||||
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
for d in (lhs_contracting, lhs_batch)):
|
||||
@ -3241,12 +3252,16 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
|
||||
raise TypeError(msg)
|
||||
|
||||
def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
if lhs.sharding.mesh != rhs.sharding.mesh:
|
||||
raise ValueError(
|
||||
'Mesh of both lhs and rhs should match. Got lhs:'
|
||||
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')
|
||||
|
||||
if out_type is not None:
|
||||
return out_type
|
||||
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
|
||||
rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch)
|
||||
@ -3280,7 +3295,8 @@ def tuple_delete(tup, idx):
|
||||
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
del dimension_numbers # unused
|
||||
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
|
||||
# primitive_util.h's HigherPrecisionType, e.g.
|
||||
@ -3327,7 +3343,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
|
||||
|
||||
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
swap_ans=False):
|
||||
out_type, swap_ans=False):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = x.aval.ndim
|
||||
x_kept = remaining(range(x_ndim), x_contract, x_batch)
|
||||
@ -3347,12 +3363,14 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
return x_bar
|
||||
|
||||
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: DTypeLike | None):
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
y_bar = _dot_general_transpose_lhs(
|
||||
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
|
||||
preferred_element_type=preferred_element_type, swap_ans=True)
|
||||
preferred_element_type=preferred_element_type, out_type=out_type,
|
||||
swap_ans=True)
|
||||
if y_bar.dtype != y.aval.dtype:
|
||||
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
|
||||
return y_bar
|
||||
@ -3366,6 +3384,7 @@ def _dot_batch_rule(
|
||||
batch_dims,
|
||||
*,
|
||||
dimension_numbers,
|
||||
out_type,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
**_,
|
||||
@ -3395,12 +3414,16 @@ def _dot_batch_rule(
|
||||
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
|
||||
else:
|
||||
rhs_shape = np.shape(rhs)
|
||||
if out_type is not None:
|
||||
raise NotImplementedError("vmap with out_type is not supported. "
|
||||
"Please open an issue.")
|
||||
batched_out = invoke_prim(
|
||||
lhs,
|
||||
rhs,
|
||||
new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=out_type,
|
||||
)
|
||||
result_batch_dim = batching.shape_as_bdim(
|
||||
result_stack_dim,
|
||||
@ -3570,7 +3593,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
|
||||
|
||||
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision, preferred_element_type: np.dtype | None,
|
||||
platform: str = "default"):
|
||||
out_type, platform: str = "default"):
|
||||
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
||||
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
||||
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
||||
@ -3658,6 +3681,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
**algorithm_kwarg,
|
||||
)
|
||||
if config.sharding_in_types.value:
|
||||
if out_type is not None:
|
||||
assert aval_out.sharding == out_type
|
||||
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
||||
result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp)
|
||||
if accumulation_aval.dtype != aval_out.dtype:
|
||||
@ -3711,12 +3736,15 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
|
||||
return (m, n)
|
||||
|
||||
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
|
||||
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
|
||||
precision, preferred_element_type: DTypeLike | None,
|
||||
**_) -> np.dtype:
|
||||
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
|
||||
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
|
||||
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
|
||||
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision, preferred_element_type=preferred_element_type)
|
||||
return _dot_general_dtype_rule(
|
||||
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision, preferred_element_type=preferred_element_type,
|
||||
out_type=None)
|
||||
|
||||
|
||||
def _ragged_dot_jvp_rule(
|
||||
@ -3855,6 +3883,7 @@ def _ragged_dot_batch_rule(
|
||||
*,
|
||||
precision,
|
||||
preferred_element_type: DTypeLike | None,
|
||||
out_type,
|
||||
**_,
|
||||
):
|
||||
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
|
||||
@ -3868,6 +3897,7 @@ def _ragged_dot_batch_rule(
|
||||
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=out_type,
|
||||
)
|
||||
|
||||
|
||||
|
@ -67,10 +67,10 @@ from jax._src.typing import (
|
||||
DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar,
|
||||
)
|
||||
from jax._src.util import (
|
||||
NumpyComplexWarning,
|
||||
canonicalize_axis as _canonicalize_axis,
|
||||
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
|
||||
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
|
||||
from jax.sharding import Sharding, SingleDeviceSharding
|
||||
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
|
||||
PartitionSpec as P)
|
||||
from jax.tree_util import tree_flatten, tree_leaves, tree_map
|
||||
import numpy as np
|
||||
import opt_einsum
|
||||
@ -8955,6 +8955,7 @@ def einsum(
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_type=None,
|
||||
) -> Array: ...
|
||||
|
||||
@overload
|
||||
@ -8967,6 +8968,7 @@ def einsum(
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_type=None,
|
||||
) -> Array: ...
|
||||
|
||||
def einsum(
|
||||
@ -8977,6 +8979,7 @@ def einsum(
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: DTypeLike | None = None,
|
||||
_dot_general: Callable[..., Array] = lax.dot_general,
|
||||
out_type=None,
|
||||
) -> Array:
|
||||
"""Einstein summation
|
||||
|
||||
@ -9208,11 +9211,11 @@ def einsum(
|
||||
|
||||
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
|
||||
|
||||
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
|
||||
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
|
||||
if spec is not None:
|
||||
einsum = jax.named_call(einsum, name=spec)
|
||||
return einsum(operands, contractions, precision,
|
||||
preferred_element_type, _dot_general)
|
||||
preferred_element_type, _dot_general, out_type)
|
||||
|
||||
|
||||
# Enable other modules to override einsum_contact_path.
|
||||
@ -9311,7 +9314,12 @@ def _einsum(
|
||||
precision,
|
||||
preferred_element_type,
|
||||
_dot_general=lax.dot_general,
|
||||
out_type=None,
|
||||
):
|
||||
if out_type is not None and not isinstance(out_type, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
"`out_type` argument of `einsum` only supports NamedSharding instances."
|
||||
" Please file a bug if this is not enough for your use case.")
|
||||
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
|
||||
operands = list(map(asarray, operands))
|
||||
if preferred_element_type is None:
|
||||
@ -9434,12 +9442,21 @@ def _einsum(
|
||||
if names == result_names:
|
||||
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
|
||||
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=out_type)
|
||||
else:
|
||||
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
|
||||
if (config.sharding_in_types.value and out_type is not None and
|
||||
names != result_names):
|
||||
spec = out_type.spec
|
||||
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
|
||||
dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec))
|
||||
else:
|
||||
dot_general_out_type = out_type # type: ignore
|
||||
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
|
||||
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_type=dot_general_out_type)
|
||||
else:
|
||||
raise NotImplementedError # if this is actually reachable, open an issue!
|
||||
|
||||
@ -9452,7 +9469,8 @@ def _einsum(
|
||||
operand = lax.transpose(operand, perm)
|
||||
operands.append(operand) # used in next iteration
|
||||
|
||||
return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type)
|
||||
return lax_internal._convert_element_type(operands[0], preferred_element_type,
|
||||
output_weak_type)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
|
||||
|
@ -2089,10 +2089,11 @@ def _dot_general_lowering(
|
||||
b,
|
||||
*,
|
||||
dimension_numbers,
|
||||
out_type,
|
||||
precision,
|
||||
preferred_element_type,
|
||||
):
|
||||
del preferred_element_type # Unused.
|
||||
del preferred_element_type, out_type # Unused.
|
||||
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
|
||||
assert batch_dims == ((), ())
|
||||
|
||||
|
@ -2180,7 +2180,7 @@ def _conv_general_dilated(lhs, rhs, *,
|
||||
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
|
||||
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers,
|
||||
def _dot_general(lhs, rhs, *, dimension_numbers, out_type,
|
||||
precision: lax_internal.CanonicalPrecision,
|
||||
preferred_element_type: DType | None,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
|
@ -606,8 +606,11 @@ mlir.register_lowering(bcoo_transpose_p, mlir.lower_fun(
|
||||
|
||||
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
|
||||
|
||||
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
|
||||
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None,
|
||||
preferred_element_type: None = None,
|
||||
out_type=None) -> BCOO | Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -625,7 +628,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers:
|
||||
the result will be dense, of type ndarray.
|
||||
"""
|
||||
# TODO(jakevdp) make use of these?
|
||||
del precision # unused
|
||||
del precision, out_type # unused
|
||||
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
|
||||
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
|
||||
dimension_numbers)
|
||||
@ -1051,7 +1054,8 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
|
||||
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
|
||||
kwds = {'dimension_numbers': dimension_numbers,
|
||||
'precision': None,
|
||||
'preferred_element_type': None}
|
||||
'preferred_element_type': None,
|
||||
'out_type': None}
|
||||
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
|
||||
return A, B, indices
|
||||
|
||||
|
@ -462,7 +462,8 @@ bcsr_dot_general_p = core.Primitive('bcsr_dot_general')
|
||||
def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
dimension_numbers: DotDimensionNumbers,
|
||||
precision: None = None,
|
||||
preferred_element_type: None = None) -> Array:
|
||||
preferred_element_type: None = None,
|
||||
out_type=None) -> Array:
|
||||
"""A general contraction operation.
|
||||
|
||||
Args:
|
||||
@ -479,7 +480,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
|
||||
are sparse, the result will be sparse, of type BCSR. If either input is
|
||||
dense, the result will be dense, of type ndarray.
|
||||
"""
|
||||
del precision # unused
|
||||
del precision, out_type # unused
|
||||
if isinstance(rhs, (np.ndarray, jax.Array)):
|
||||
if isinstance(lhs, (np.ndarray, jax.Array)):
|
||||
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
|
@ -111,4 +111,4 @@ def _dot_general_validated_shape(
|
||||
rhs = core.ShapedArray(rhs_shape, np.float32)
|
||||
return _dot_general_shape_rule(
|
||||
lhs, rhs, dimension_numbers=dimension_numbers,
|
||||
precision=None, preferred_element_type=None)
|
||||
precision=None, preferred_element_type=None, out_type=None)
|
||||
|
@ -4945,6 +4945,60 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
_, out = g(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))
|
||||
|
||||
def test_einsum_with_out_type(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P('y', 'x')))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
lowered_text = f.lower(arr1, arr2).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
out = jnp.einsum('xy,yz->xz', x, y,
|
||||
out_type=NamedSharding(x.sharding.mesh, P('x', None)))
|
||||
self.assertEqual(out.sharding.spec, P('x', None))
|
||||
return out
|
||||
|
||||
arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
|
||||
arr4 = jax.device_put(np_inp.T, NamedSharding(mesh, P('x', 'y')))
|
||||
out2 = g(arr3, arr4)
|
||||
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
||||
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
def test_einsum_inverse(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
np_inp = np.arange(64)
|
||||
|
||||
@jax.jit
|
||||
def h(x, y):
|
||||
s = NamedSharding(x.sharding.mesh, P('x', None, 'y', None))
|
||||
out = jnp.einsum('btd,dhq->bhtq', x, y, out_type=s)
|
||||
self.assertEqual(out.sharding.spec, s.spec)
|
||||
return out
|
||||
|
||||
arr1 = jax.device_put(np_inp.reshape(8, 4, 2),
|
||||
NamedSharding(mesh, P('x', 'y', None)))
|
||||
arr2 = jax.device_put(np_inp.reshape(2, 4, 8),
|
||||
NamedSharding(mesh, P(None, 'x', 'y')))
|
||||
out = h(arr1, arr2)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None)))
|
||||
|
||||
lowered_text = h.lower(arr1, arr2).as_text()
|
||||
self.assertIn('@Sharding', lowered_text)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user