mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update minimum jaxlib version to 0.3.14.
This commit is contained in:
parent
44bd311ae7
commit
0b4b0ba072
@ -59,7 +59,6 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
@ -585,9 +584,8 @@ def _cpp_jit(
|
||||
return _BackendAndDeviceInfo(default_device, committed_to_device)
|
||||
|
||||
jitted_f_kwargs = {}
|
||||
if xla_extension_version >= 71:
|
||||
jitted_f_kwargs["has_explicit_device"] = (
|
||||
device is not None or backend is not None)
|
||||
jitted_f_kwargs["has_explicit_device"] = (
|
||||
device is not None or backend is not None)
|
||||
cpp_jitted_f = jax_jit.jit(
|
||||
fun,
|
||||
cache_miss,
|
||||
|
@ -109,9 +109,8 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
return result
|
||||
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
|
||||
platform="cpu")
|
||||
if jaxlib.version >= (0, 3, 11):
|
||||
mlir.register_lowering(
|
||||
debug_callback_p, debug_callback_lowering, platform="gpu")
|
||||
mlir.register_lowering(
|
||||
debug_callback_p, debug_callback_lowering, platform="gpu")
|
||||
if jaxlib.version >= (0, 3, 15):
|
||||
mlir.register_lowering(
|
||||
debug_callback_p, debug_callback_lowering, platform="tpu")
|
||||
|
@ -66,22 +66,14 @@ class State:
|
||||
if self.service is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
logging.info('Starting JAX distributed service on %s', coordinator_address)
|
||||
if xla_client._version >= 72:
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes, config.jax_coordination_service)
|
||||
else:
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes)
|
||||
self.service = xla_extension.get_distributed_runtime_service(
|
||||
coordinator_address, num_processes, config.jax_coordination_service)
|
||||
|
||||
if self.client is not None:
|
||||
raise RuntimeError('distributed.initialize should only be called once.')
|
||||
|
||||
if xla_client._version >= 72:
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id, config.jax_coordination_service)
|
||||
else:
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id)
|
||||
self.client = xla_extension.get_distributed_runtime_client(
|
||||
coordinator_address, process_id, config.jax_coordination_service)
|
||||
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
|
||||
self.client.connect()
|
||||
|
||||
|
@ -132,8 +132,6 @@ def approx_max_k(operand: Array,
|
||||
>>> db = jax.numpy.array(np.random.rand(1024, 64))
|
||||
>>> dot_products, neighbors = mips(qy, db, k=10)
|
||||
"""
|
||||
if xc._version < 45:
|
||||
aggregate_to_topk = True
|
||||
return approx_top_k_p.bind(
|
||||
operand,
|
||||
k=k,
|
||||
@ -197,8 +195,6 @@ def approx_min_k(operand: Array,
|
||||
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
|
||||
arithmetics and produces the same set of neighbors.
|
||||
"""
|
||||
if xc._version < 45:
|
||||
aggregate_to_topk = True
|
||||
return approx_top_k_p.bind(
|
||||
operand,
|
||||
k=k,
|
||||
@ -225,13 +221,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
|
||||
dims[reduction_dimension], k))
|
||||
if not dtypes.issubdtype(operand.dtype, np.floating):
|
||||
raise ValueError('operand must be a floating type')
|
||||
if xc._version >= 45:
|
||||
reduction_input_size = dims[reduction_dimension]
|
||||
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
|
||||
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
|
||||
reduction_input_size_override)[0]
|
||||
else:
|
||||
dims[reduction_dimension] = k
|
||||
reduction_input_size = dims[reduction_dimension]
|
||||
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
|
||||
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
|
||||
reduction_input_size_override)[0]
|
||||
return (operand.update(
|
||||
shape=dims, dtype=operand.dtype, weak_type=operand.weak_type),
|
||||
operand.update(shape=dims, dtype=np.dtype(np.int32)))
|
||||
|
@ -3070,17 +3070,10 @@ masking.masking_rules[pad_p] = _pad_masking_rule
|
||||
|
||||
def _pad_lower(ctx, x, padding_value, *, padding_config):
|
||||
low, high, interior = util.unzip3(padding_config)
|
||||
if jax._src.lib.mlir_api_version < 15:
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.PadOp(mlir.aval_to_ir_type(aval_out), x, padding_value,
|
||||
mlir.dense_int_elements(low),
|
||||
mlir.dense_int_elements(high),
|
||||
mlir.dense_int_elements(interior)).results
|
||||
else:
|
||||
return mhlo.PadOp(x, padding_value,
|
||||
mlir.dense_int_elements(low),
|
||||
mlir.dense_int_elements(high),
|
||||
mlir.dense_int_elements(interior)).results
|
||||
return mhlo.PadOp(x, padding_value,
|
||||
mlir.dense_int_elements(low),
|
||||
mlir.dense_int_elements(high),
|
||||
mlir.dense_int_elements(interior)).results
|
||||
mlir.register_lowering(pad_p, _pad_lower)
|
||||
|
||||
|
||||
@ -3817,13 +3810,8 @@ masking.defvectorized(reduce_precision_p)
|
||||
|
||||
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
||||
aval_out, = ctx.avals_out
|
||||
if jax._src.lib.mlir_api_version >= 21:
|
||||
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
else:
|
||||
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
|
||||
mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
|
||||
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
||||
|
||||
@ -4059,12 +4047,9 @@ top_k_p = Primitive('top_k')
|
||||
top_k_p.multiple_results = True
|
||||
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
|
||||
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
||||
if jax._src.lib.mlir_api_version >= 16:
|
||||
def _top_k_lower(ctx, operand, k):
|
||||
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
|
||||
mlir.register_lowering(top_k_p, _top_k_lower)
|
||||
else:
|
||||
xla.register_translation(top_k_p, _top_k_translation_rule)
|
||||
def _top_k_lower(ctx, operand, k):
|
||||
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
|
||||
mlir.register_lowering(top_k_p, _top_k_lower)
|
||||
ad.primitive_jvps[top_k_p] = _top_k_jvp
|
||||
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
|
||||
|
||||
@ -4315,9 +4300,7 @@ def _rng_bit_generator_lowering(
|
||||
key = mhlo.BitcastConvertOp(
|
||||
ir.RankedTensorType.get([2], u64_type),
|
||||
mhlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
|
||||
algorithm_attr = (
|
||||
_rng_algorithm(algorithm) if jax._src.lib.mlir_api_version >= 14
|
||||
else mlir.i32_attr(algorithm))
|
||||
algorithm_attr = _rng_algorithm(algorithm)
|
||||
out_key, out_vals = mhlo.RngBitGeneratorOp(
|
||||
key.type,
|
||||
ir.RankedTensorType.get(shape, rbg_etype),
|
||||
|
@ -46,11 +46,6 @@ from jax._src.lib import gpu_linalg
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
||||
from jax._src.lib import cuda_linalg
|
||||
from jax._src.lib import hip_linalg
|
||||
from jax._src.lib import sparse_apis
|
||||
from jax._src.lib import solver_apis
|
||||
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -414,12 +409,7 @@ ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
|
||||
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
|
||||
|
||||
def _cholesky_lowering(ctx, x):
|
||||
if jax._src.lib.mlir_api_version < 18:
|
||||
aval, = ctx.avals_out
|
||||
return mhlo.CholeskyOp(mlir.aval_to_ir_type(aval), x,
|
||||
lower=ir.BoolAttr.get(True)).results
|
||||
else:
|
||||
return mhlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
|
||||
return mhlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
|
||||
|
||||
mlir.register_lowering(cholesky_p, _cholesky_lowering)
|
||||
|
||||
@ -442,22 +432,14 @@ mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo),
|
||||
platform='cpu')
|
||||
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, gpu_solver.cuda_potrf),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, gpu_solver.rocm_potrf),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, solver_apis.potrf_mhlo),
|
||||
platform='gpu')
|
||||
mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, gpu_solver.cuda_potrf),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, gpu_solver.rocm_potrf),
|
||||
platform='rocm')
|
||||
|
||||
# Asymmetric eigendecomposition
|
||||
|
||||
@ -756,12 +738,6 @@ if gpu_solver is not None:
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
|
||||
platform='gpu')
|
||||
|
||||
mlir.register_lowering(
|
||||
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
|
||||
platform='tpu')
|
||||
@ -948,21 +924,14 @@ def _triangular_solve_gpu_lower(
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
mhlo.TransposeAttr.get(transpose)).results
|
||||
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_lower, gpu_solver.cuda_trsm),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_lower, gpu_solver.rocm_trsm),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_lower, solver_apis.trsm_mhlo),
|
||||
platform='gpu')
|
||||
mlir.register_lowering(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_lower, gpu_solver.cuda_trsm),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
triangular_solve_p,
|
||||
partial(_triangular_solve_gpu_lower, gpu_solver.rocm_trsm),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# Support operation for LU decomposition: Transformation of the pivots returned
|
||||
@ -1052,31 +1021,16 @@ batching.primitive_batchers[lu_pivots_to_permutation_p] = (
|
||||
mlir.register_lowering(
|
||||
lu_pivots_to_permutation_p,
|
||||
mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False))
|
||||
|
||||
if gpu_linalg:
|
||||
mlir.register_lowering(
|
||||
lu_pivots_to_permutation_p,
|
||||
partial(_lu_pivots_to_permutation_gpu_lowering,
|
||||
gpu_linalg.cuda_lu_pivots_to_permutation),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
lu_pivots_to_permutation_p,
|
||||
partial(_lu_pivots_to_permutation_gpu_lowering,
|
||||
gpu_linalg.hip_lu_pivots_to_permutation),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
if cuda_linalg:
|
||||
mlir.register_lowering(lu_pivots_to_permutation_p,
|
||||
partial(_lu_pivots_to_permutation_gpu_lowering,
|
||||
cuda_linalg.lu_pivots_to_permutation_mhlo),
|
||||
platform='cuda')
|
||||
|
||||
if hip_linalg:
|
||||
mlir.register_lowering(lu_pivots_to_permutation_p,
|
||||
partial(_lu_pivots_to_permutation_gpu_lowering,
|
||||
hip_linalg.lu_pivots_to_permutation_mhlo),
|
||||
platform='rocm')
|
||||
mlir.register_lowering(
|
||||
lu_pivots_to_permutation_p,
|
||||
partial(_lu_pivots_to_permutation_gpu_lowering,
|
||||
gpu_linalg.cuda_lu_pivots_to_permutation),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
lu_pivots_to_permutation_p,
|
||||
partial(_lu_pivots_to_permutation_gpu_lowering,
|
||||
gpu_linalg.hip_lu_pivots_to_permutation),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# LU decomposition
|
||||
@ -1273,18 +1227,12 @@ mlir.register_lowering(lu_p,
|
||||
partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo),
|
||||
platform='cpu')
|
||||
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
lu_p, partial(_lu_cpu_gpu_lowering, solver_apis.getrf_mhlo),
|
||||
platform='gpu')
|
||||
mlir.register_lowering(
|
||||
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf),
|
||||
platform='rocm')
|
||||
|
||||
xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu')
|
||||
|
||||
@ -1418,25 +1366,16 @@ xla.register_translation(geqrf_p, _geqrf_translation_rule)
|
||||
mlir.register_lowering(
|
||||
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo, None),
|
||||
platform='cpu')
|
||||
if gpu_solver is not None:
|
||||
# TODO(phawkins): make cuda_geqrf_batched and rocm_geqrf_unbatched
|
||||
# unconditional when jaxlib 0.3.11 is the minimum.
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
|
||||
getattr(gpu_solver, 'cuda_geqrf_batched', None)),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
|
||||
getattr(gpu_solver, 'rocm_geqrf_batched', None)),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, solver_apis.geqrf_mhlo, None),
|
||||
platform='gpu')
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
|
||||
gpu_solver.cuda_geqrf_batched),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
|
||||
gpu_solver.rocm_geqrf_batched),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# orgqr: product of elementary Householder reflectors
|
||||
@ -1505,21 +1444,14 @@ xla.register_translation(orgqr_p, _orgqr_translation_rule)
|
||||
mlir.register_lowering(
|
||||
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
|
||||
platform='cpu')
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, solver_apis.orgqr_mhlo),
|
||||
platform='gpu')
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
def _qr_impl(operand, *, full_matrices):
|
||||
@ -1601,11 +1533,6 @@ qr_p.multiple_results = True
|
||||
qr_p.def_impl(_qr_impl)
|
||||
qr_p.def_abstract_eval(_qr_abstract_eval)
|
||||
|
||||
# Older jaxlibs didn't expose geqrf and orgqr as separate XLA operations.
|
||||
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.10.
|
||||
if jax._src.lib.xla_extension_version < 69:
|
||||
xla.register_translation(qr_p, _qr_translation_rule, platform="tpu")
|
||||
|
||||
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
||||
batching.primitive_batchers[qr_p] = _qr_batching_rule
|
||||
|
||||
@ -1795,19 +1722,12 @@ batching.primitive_batchers[svd_p] = _svd_batching_rule
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
|
||||
platform='cpu')
|
||||
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
|
||||
platform='gpu')
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd),
|
||||
platform='rocm')
|
||||
|
||||
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
|
||||
|
||||
@ -1821,20 +1741,15 @@ tridiagonal_solve_p.def_impl(
|
||||
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
|
||||
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
|
||||
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
||||
if sparse_apis and hasattr(sparse_apis, "gtsv2"):
|
||||
mlir.register_lowering(tridiagonal_solve_p,
|
||||
_tridiagonal_solve_gpu_lowering,
|
||||
platform='gpu')
|
||||
|
||||
if gpu_sparse:
|
||||
mlir.register_lowering(
|
||||
tridiagonal_solve_p,
|
||||
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
tridiagonal_solve_p,
|
||||
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2),
|
||||
platform='rocm')
|
||||
mlir.register_lowering(
|
||||
tridiagonal_solve_p,
|
||||
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
tridiagonal_solve_p,
|
||||
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
|
||||
@ -1966,17 +1881,10 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
operand_aval, = ctx.avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
# TODO(jakevdp): remove this try/except when minimum jaxlib >= 0.3.8
|
||||
try:
|
||||
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
except TypeError: # API for jaxlib <= 0.3.7
|
||||
gees_result = lapack.gees_mhlo(operand, # pytype: disable=missing-parameter
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
# Number of return values depends on value of sort_eig_vals.
|
||||
T, vs, *_, info = gees_result
|
||||
|
||||
|
@ -721,12 +721,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
_replica_groups(ctx.module_context.axis_env, named_axes,
|
||||
axis_index_groups))
|
||||
def all_reduce(x_dtype, x):
|
||||
if jax._src.lib.mlir_api_version >= 17:
|
||||
op = mhlo.AllReduceOp(
|
||||
x.type, x, replica_groups=replica_groups, channel_handle=None)
|
||||
else:
|
||||
op = mhlo.AllReduceOp(
|
||||
x, replica_groups=replica_groups, channel_handle=None)
|
||||
op = mhlo.AllReduceOp(
|
||||
x.type, x, replica_groups=replica_groups, channel_handle=None)
|
||||
scalar_aval = core.ShapedArray((), x_dtype)
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
@ -741,22 +737,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.result
|
||||
|
||||
outs = []
|
||||
for aval, x in zip(ctx.avals_in, args):
|
||||
# TODO(phawkins): remove this special case when jaxlib > 0.3.10 is the
|
||||
# minimum.
|
||||
if (jax._src.lib.xla_extension_version < 75 and prim is lax.add_p
|
||||
and dtypes.issubdtype(aval.dtype, np.complexfloating)):
|
||||
real_dtype = np.finfo(aval.dtype).dtype
|
||||
outs.append(
|
||||
mhlo.ComplexOp(
|
||||
all_reduce(real_dtype,
|
||||
mhlo.RealOp(x).result),
|
||||
all_reduce(real_dtype,
|
||||
mhlo.ImagOp(x).result)).result)
|
||||
else:
|
||||
outs.append(all_reduce(aval.dtype, x))
|
||||
return outs
|
||||
return [all_reduce(aval.dtype, x) for aval, x in zip(ctx.avals_in, args)]
|
||||
|
||||
|
||||
def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
@ -947,17 +928,10 @@ def _all_to_all_lowering(ctx, x, *,
|
||||
split_count = len(replica_groups[0])
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
if jax._src.lib.mlir_api_version < 19:
|
||||
return mhlo.AllToAllOp(mlir.aval_to_ir_type(ctx.avals_out[0]),
|
||||
x, mlir.i64_attr(split_axis),
|
||||
mlir.i64_attr(concat_axis),
|
||||
mlir.i64_attr(split_count),
|
||||
_replica_groups_mhlo(replica_groups)).results
|
||||
else:
|
||||
return mhlo.AllToAllOp(x, mlir.i64_attr(split_axis),
|
||||
mlir.i64_attr(concat_axis),
|
||||
mlir.i64_attr(split_count),
|
||||
_replica_groups_mhlo(replica_groups)).results
|
||||
return mhlo.AllToAllOp(x, mlir.i64_attr(split_axis),
|
||||
mlir.i64_attr(concat_axis),
|
||||
mlir.i64_attr(split_count),
|
||||
_replica_groups_mhlo(replica_groups)).results
|
||||
else:
|
||||
warnings.warn(
|
||||
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "
|
||||
|
@ -1941,12 +1941,9 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
|
||||
if jax._src.lib.mlir_api_version >= 20:
|
||||
result = mlir.aval_to_ir_types(aval_out)
|
||||
operand = [operand]
|
||||
updates = [updates]
|
||||
else:
|
||||
result = mlir.aval_to_ir_type(aval_out)
|
||||
result = mlir.aval_to_ir_types(aval_out)
|
||||
operand = [operand]
|
||||
updates = [updates]
|
||||
op = mhlo.ScatterOp(
|
||||
result,
|
||||
operand,
|
||||
@ -2001,19 +1998,12 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
|
||||
real_dtype = _real_dtype(aval_out.dtype)
|
||||
if jax._src.lib.mlir_api_version >= 20:
|
||||
operand_type_part = mlir.aval_to_ir_types(
|
||||
core.ShapedArray(aval_out.shape, real_dtype))
|
||||
else:
|
||||
operand_type_part = mlir.aval_to_ir_type(
|
||||
core.ShapedArray(aval_out.shape, real_dtype))
|
||||
operand_type_part = mlir.aval_to_ir_types(
|
||||
core.ShapedArray(aval_out.shape, real_dtype))
|
||||
|
||||
def _scatter(operand_part, updates_part):
|
||||
# If the MLIR api supports variadic scatter, we make a variadic scatter op
|
||||
# with arity 1
|
||||
if jax._src.lib.mlir_api_version >= 20:
|
||||
operand_part = [operand_part]
|
||||
updates_part = [updates_part]
|
||||
operand_part = [operand_part]
|
||||
updates_part = [updates_part]
|
||||
|
||||
scatter = mhlo.ScatterOp(
|
||||
operand_type_part,
|
||||
|
@ -706,16 +706,10 @@ def _select_and_gather_add_lowering(
|
||||
def pack(a, b):
|
||||
a_dims = ir.RankedTensorType(a.type).shape
|
||||
b_dims = ir.RankedTensorType(b.type).shape
|
||||
if jax._src.lib.mlir_api_version >= 21:
|
||||
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
else:
|
||||
a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
|
||||
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
|
||||
b = mhlo.ShiftRightLogicalOp(
|
||||
|
@ -113,71 +113,10 @@ pytree = xla_client._xla.pytree
|
||||
jax_jit = xla_client._xla.jax_jit
|
||||
pmap_lib = xla_client._xla.pmap_lib
|
||||
|
||||
# TODO(phawkins): make gpu_... unconditional after jaxlib >= 0.3.11
|
||||
# becomes the minimum; remove cuda_... and hip_....
|
||||
|
||||
try:
|
||||
import jaxlib.cusolver as cusolver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cusolver = None
|
||||
|
||||
try:
|
||||
import jaxlib.hipsolver as hipsolver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
hipsolver = None
|
||||
|
||||
try:
|
||||
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
|
||||
except ImportError:
|
||||
gpu_solver = None
|
||||
|
||||
try:
|
||||
import jaxlib.cusparse as cusparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cusparse = None
|
||||
|
||||
try:
|
||||
import jaxlib.hipsparse as hipsparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
hipsparse = None
|
||||
|
||||
try:
|
||||
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
|
||||
except ImportError:
|
||||
gpu_sparse = None
|
||||
|
||||
sparse_apis = cusparse or hipsparse or None
|
||||
solver_apis = cusolver or hipsolver or None
|
||||
|
||||
try:
|
||||
import jaxlib.cuda_prng as cuda_prng # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cuda_prng = None
|
||||
|
||||
try:
|
||||
import jaxlib.hip_prng as hip_prng # pytype: disable=import-error
|
||||
except ImportError:
|
||||
hip_prng = None
|
||||
|
||||
try:
|
||||
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
|
||||
except ImportError:
|
||||
gpu_prng = None
|
||||
|
||||
try:
|
||||
import jaxlib.cuda_linalg as cuda_linalg # pytype: disable=import-error
|
||||
except ImportError:
|
||||
cuda_linalg = None
|
||||
|
||||
try:
|
||||
import jaxlib.hip_linalg as hip_linalg # pytype: disable=import-error
|
||||
except ImportError:
|
||||
hip_linalg = None
|
||||
|
||||
try:
|
||||
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
except ImportError:
|
||||
gpu_linalg = None
|
||||
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
|
||||
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
|
||||
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
|
||||
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
|
||||
# Jaxlib code is split between the Jax and the Tensorflow repositories.
|
||||
# Only for the internal usage of the JAX developers, we expose a version
|
||||
|
@ -219,15 +219,12 @@ def make_gpu_client(platform_name=None):
|
||||
platform_name=platform_name)
|
||||
|
||||
if hasattr(xla_client, "make_gpu_client"):
|
||||
if xla_client._version >= 65:
|
||||
register_backend_factory(
|
||||
'cuda', partial(make_gpu_client, platform_name='cuda'),
|
||||
priority=200)
|
||||
register_backend_factory(
|
||||
'rocm', partial(make_gpu_client, platform_name='rocm'),
|
||||
priority=200)
|
||||
else:
|
||||
register_backend_factory('gpu', make_gpu_client, priority=200)
|
||||
register_backend_factory(
|
||||
'cuda', partial(make_gpu_client, platform_name='cuda'),
|
||||
priority=200)
|
||||
register_backend_factory(
|
||||
'rocm', partial(make_gpu_client, platform_name='rocm'),
|
||||
priority=200)
|
||||
|
||||
if hasattr(xla_client, "make_tpu_client"):
|
||||
register_backend_factory(
|
||||
@ -310,16 +307,7 @@ def backends():
|
||||
for platform, priority in platforms_and_priorites:
|
||||
try:
|
||||
backend = _init_backend(platform)
|
||||
|
||||
if platform == "gpu" and xla_client._version <= 64:
|
||||
# TODO(phawkins): remove this special handling when jaxlib v0.3.11
|
||||
# is the minimum.
|
||||
if "rocm" in backend.platform_version:
|
||||
_backends["rocm"] = backend
|
||||
else:
|
||||
_backends["cuda"] = backend
|
||||
else:
|
||||
_backends[platform] = backend
|
||||
_backends[platform] = backend
|
||||
|
||||
if priority > default_priority:
|
||||
_default_backend = backend
|
||||
|
@ -38,11 +38,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import canonicalize_axis, prod
|
||||
|
||||
# TODO(phawkins): make gpu_prng unconditional after jaxlib >= 0.3.11
|
||||
# becomes the minimum; remove cuda_prng and hip_prng.
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib import cuda_prng
|
||||
from jax._src.lib import hip_prng
|
||||
|
||||
|
||||
UINT_DTYPES = {
|
||||
@ -416,27 +412,14 @@ mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
||||
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||
multiple_results=True), platform='cpu')
|
||||
|
||||
if gpu_prng:
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32),
|
||||
platform='rocm')
|
||||
|
||||
if cuda_prng:
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering, cuda_prng.threefry2x32_lowering),
|
||||
platform='cuda')
|
||||
if hip_prng:
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering, hip_prng.threefry2x32_lowering),
|
||||
platform='rocm')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
threefry2x32_p,
|
||||
partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
@partial(jit, inline=True)
|
||||
|
@ -830,10 +830,8 @@ def _initialize_jax_jit_thread_local_state():
|
||||
dynamic_trace_state=copy)
|
||||
|
||||
|
||||
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.11
|
||||
if lib.xla_extension_version >= 70:
|
||||
jax_jit.set_thread_local_state_initialization_callback(
|
||||
_initialize_jax_jit_thread_local_state)
|
||||
jax_jit.set_thread_local_state_initialization_callback(
|
||||
_initialize_jax_jit_thread_local_state)
|
||||
|
||||
def trace_state_clean() -> bool:
|
||||
trace_state = thread_local_state.trace_state
|
||||
|
@ -111,10 +111,7 @@ def _hash_computation(hash_obj, xla_computation):
|
||||
hash_obj.update(scrubbed_hlo)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
|
||||
expected_num_compile_options = 32
|
||||
else:
|
||||
expected_num_compile_options = 31
|
||||
expected_num_compile_options = 32
|
||||
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
|
||||
f"Unexpected number of CompileOption fields: "
|
||||
f"{len(dir(compile_options_obj))}. This likely: means that an extra "
|
||||
@ -128,8 +125,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
_hash_bool(hash_obj, compile_options_obj.tuple_arguments)
|
||||
_hash_int(hash_obj, compile_options_obj.num_replicas)
|
||||
_hash_int(hash_obj, compile_options_obj.num_partitions)
|
||||
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
|
||||
_hash_int(hash_obj, compile_options_obj.profile_version)
|
||||
_hash_int(hash_obj, compile_options_obj.profile_version)
|
||||
if compile_options_obj.device_assignment is not None:
|
||||
hash_obj.update(compile_options_obj.device_assignment.serialize())
|
||||
|
||||
|
@ -1932,12 +1932,6 @@ def _ensure_spmd_and(f):
|
||||
return f(v)
|
||||
return update
|
||||
|
||||
def _ensure_supports_manual_and(f):
|
||||
def update(v):
|
||||
if v and not hasattr(xc.OpSharding.Type, "MANUAL"):
|
||||
raise RuntimeError("This flag requires a version of jaxlib that supports MANUAL sharding type")
|
||||
return f(v)
|
||||
return update
|
||||
|
||||
try:
|
||||
config.define_bool_state(
|
||||
@ -1955,7 +1949,7 @@ try:
|
||||
"the MANUAL partitioning feature of the XLA SPMD partitioner instead of "
|
||||
"sharding constraints on vectorized code. "
|
||||
"Requires experimental_xmap_spmd_lowering!"),
|
||||
update_global_hook=_ensure_supports_manual_and(_ensure_spmd_and(_clear_compilation_cache)),
|
||||
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
|
||||
update_thread_local_hook=_thread_local_flag_unsupported)
|
||||
config.define_bool_state(
|
||||
name="experimental_xmap_ensure_fixed_sharding",
|
||||
|
@ -46,7 +46,6 @@ from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.numpy.setops import _unique
|
||||
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import sparse_apis
|
||||
|
||||
Dtype = Any
|
||||
Shape = Tuple[int, ...]
|
||||
@ -924,26 +923,18 @@ batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
|
||||
mlir.register_lowering(
|
||||
bcoo_dot_general_p, _bcoo_dot_general_default_lowering)
|
||||
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(bcoo_dot_general_p,
|
||||
partial(_bcoo_dot_general_gpu_lowering,
|
||||
gpu_sparse.cuda_coo_matvec,
|
||||
gpu_sparse.cuda_coo_matmat),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(bcoo_dot_general_p,
|
||||
partial(_bcoo_dot_general_gpu_lowering,
|
||||
gpu_sparse.rocm_coo_matvec,
|
||||
gpu_sparse.rocm_coo_matmat),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(bcoo_dot_general_p,
|
||||
partial(_bcoo_dot_general_gpu_lowering,
|
||||
sparse_apis.coo_matvec_mhlo,
|
||||
sparse_apis.coo_matmat_mhlo),
|
||||
platform='gpu')
|
||||
partial(_bcoo_dot_general_gpu_lowering,
|
||||
gpu_sparse.cuda_coo_matvec,
|
||||
gpu_sparse.cuda_coo_matmat),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(bcoo_dot_general_p,
|
||||
partial(_bcoo_dot_general_gpu_lowering,
|
||||
gpu_sparse.rocm_coo_matvec,
|
||||
gpu_sparse.rocm_coo_matmat),
|
||||
platform='rocm')
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# bcoo_dot_general_sampled
|
||||
|
@ -31,7 +31,6 @@ from jax import tree_util
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import sparse_apis
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -230,23 +229,16 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo):
|
||||
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
|
||||
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
||||
mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_todense_p,
|
||||
partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_todense_p,
|
||||
partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_todense_p,
|
||||
partial(_coo_todense_gpu_lowering, sparse_apis.coo_todense_mhlo),
|
||||
platform='gpu')
|
||||
partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_todense_p,
|
||||
partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense),
|
||||
platform='rocm')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_fromdense
|
||||
@ -356,23 +348,16 @@ ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
|
||||
|
||||
mlir.register_lowering(coo_fromdense_p, _coo_fromdense_lowering)
|
||||
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_fromdense_p,
|
||||
partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_fromdense_p,
|
||||
partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_fromdense_p,
|
||||
partial(_coo_fromdense_gpu_lowering, sparse_apis.coo_fromdense_mhlo),
|
||||
platform='gpu')
|
||||
partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_fromdense_p,
|
||||
partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense),
|
||||
platform='rocm')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matvec
|
||||
@ -486,23 +471,17 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose):
|
||||
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
|
||||
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
||||
mlir.register_lowering(coo_matvec_p, _coo_matvec_lowering)
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matvec_p,
|
||||
partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matvec_p,
|
||||
partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matvec_p,
|
||||
partial(_coo_matvec_gpu_lowering, sparse_apis.coo_matvec_mhlo),
|
||||
platform='gpu')
|
||||
partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matvec_p,
|
||||
partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# coo_matmat
|
||||
@ -611,20 +590,13 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose):
|
||||
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
|
||||
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
|
||||
mlir.register_lowering(coo_matmat_p, _coo_matmat_lowering)
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matmat_p,
|
||||
partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matmat_p,
|
||||
partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matmat_p,
|
||||
partial(_coo_matmat_gpu_lowering, sparse_apis.coo_matmat_mhlo),
|
||||
platform='gpu')
|
||||
partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
coo_matmat_p,
|
||||
partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat),
|
||||
platform='rocm')
|
||||
|
@ -31,7 +31,6 @@ from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import sparse_apis
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -241,23 +240,17 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
|
||||
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
|
||||
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
|
||||
mlir.register_lowering(csr_todense_p, _csr_todense_lowering)
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_todense_p,
|
||||
partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_todense_p,
|
||||
partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_todense_p,
|
||||
partial(_csr_todense_gpu_lowering, sparse_apis.csr_todense_mhlo),
|
||||
platform='gpu')
|
||||
partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_todense_p,
|
||||
partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_fromdense
|
||||
@ -349,23 +342,16 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
|
||||
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
|
||||
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
|
||||
mlir.register_lowering(csr_fromdense_p, _csr_fromdense_lowering)
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_fromdense_p,
|
||||
partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_fromdense_p,
|
||||
partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_fromdense_p,
|
||||
partial(_csr_fromdense_gpu_lowering, sparse_apis.csr_fromdense_mhlo),
|
||||
platform='gpu')
|
||||
partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_fromdense_p,
|
||||
partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense),
|
||||
platform='rocm')
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
# csr_matvec
|
||||
@ -446,23 +432,16 @@ ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
|
||||
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
|
||||
mlir.register_lowering(csr_matvec_p, _csr_matvec_lowering)
|
||||
|
||||
if gpu_sparse:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_matvec_p,
|
||||
partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_matvec_p,
|
||||
partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
if gpu_sparse.cuda_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_matvec_p,
|
||||
partial(_csr_matvec_gpu_lowering, sparse_apis.csr_matvec_mhlo),
|
||||
platform='gpu')
|
||||
partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec),
|
||||
platform='cuda')
|
||||
if gpu_sparse.rocm_is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_matvec_p,
|
||||
partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
#--------------------------------------------------------------------
|
||||
@ -555,9 +534,3 @@ if gpu_sparse:
|
||||
csr_matmat_p,
|
||||
partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat),
|
||||
platform='rocm')
|
||||
|
||||
if sparse_apis and sparse_apis.is_supported:
|
||||
mlir.register_lowering(
|
||||
csr_matmat_p,
|
||||
partial(_csr_matmat_gpu_lowering, sparse_apis.csr_matmat_mhlo),
|
||||
platform='gpu')
|
||||
|
@ -226,10 +226,7 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
|
||||
x = x.view(np.uint16)
|
||||
x = np.ascontiguousarray(x)
|
||||
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
return (mhlo.ConstOp(attr).result,)
|
||||
else:
|
||||
return (mhlo.ConstantOp(attr).result,)
|
||||
return (mhlo.ConstantOp(attr).result,)
|
||||
|
||||
|
||||
|
||||
@ -1371,9 +1368,6 @@ def emit_python_callback(
|
||||
has_side_effect: bool) -> Tuple[List[ir.Value], Any, Any]:
|
||||
"""Creates an MHLO `CustomCallOp` that calls back to the provided function."""
|
||||
platform = ctx.module_context.platform
|
||||
if platform in {"cuda", "rocm"} and jax._src.lib.version < (0, 3, 11):
|
||||
raise ValueError(
|
||||
"`EmitPythonCallback` on CUDA only supported on jaxlib >= 0.3.11")
|
||||
if platform in {"tpu"} and jax._src.lib.version < (0, 3, 15):
|
||||
raise ValueError(
|
||||
"`EmitPythonCallback` on TPU only supported on jaxlib >= 0.3.15")
|
||||
|
@ -1732,13 +1732,8 @@ def _mhlo_shard(aval, axis_env, xs, in_axis):
|
||||
idxs.insert(in_axis, _unravel_index_mhlo(axis_env))
|
||||
dims_unsqueezed = dims.copy()
|
||||
dims_unsqueezed.insert(in_axis, 1)
|
||||
if jax._src.lib.mlir_api_version < 13:
|
||||
dynamic_slice_result = mhlo.DynamicSliceOp(
|
||||
mlir.aval_to_ir_type(aval.update(shape=dims_unsqueezed)),
|
||||
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
|
||||
else:
|
||||
dynamic_slice_result = mhlo.DynamicSliceOp(
|
||||
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
|
||||
dynamic_slice_result = mhlo.DynamicSliceOp(
|
||||
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
|
||||
return [
|
||||
mhlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
|
||||
]
|
||||
@ -1764,13 +1759,8 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
|
||||
padded = mlir.full_like_aval(0, padded_aval)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
|
||||
if jax._src.lib.mlir_api_version < 9:
|
||||
broadcast_result = mhlo.BroadcastOp(
|
||||
mlir.aval_to_ir_type(aval.update(shape=[1] + dims)), x,
|
||||
mlir.dense_int_elements([1])).result
|
||||
else:
|
||||
broadcast_result = mhlo.BroadcastOp(
|
||||
x, mlir.dense_int_elements([1])).result
|
||||
broadcast_result = mhlo.BroadcastOp(
|
||||
x, mlir.dense_int_elements([1])).result
|
||||
padded = mhlo.DynamicUpdateSliceOp(
|
||||
padded.type, padded, broadcast_result, idxs).result
|
||||
replica_groups = mlir.dense_int_elements(
|
||||
|
@ -16,7 +16,7 @@
|
||||
# eval()-ed by setup.py, so it should not have any dependencies.
|
||||
|
||||
__version__ = "0.3.15"
|
||||
_minimum_jaxlib_version = "0.3.10"
|
||||
_minimum_jaxlib_version = "0.3.14"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
@ -15,7 +15,6 @@
|
||||
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
|
||||
# via CustomCallWithLayout.
|
||||
|
||||
import jax
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
|
||||
@ -29,54 +28,17 @@ for _name, _value in _lapack.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="cpu")
|
||||
|
||||
|
||||
if xla_client._version >= 64:
|
||||
def _mhlo_u8(x):
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
return mhlo.ConstOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.uint8),
|
||||
type=ir.IntegerType.get_unsigned(8))).result
|
||||
else:
|
||||
return mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.uint8),
|
||||
type=ir.IntegerType.get_unsigned(8))).result
|
||||
|
||||
def _mhlo_s32(x):
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
return mhlo.ConstOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.int32),
|
||||
type=ir.IntegerType.get_signless(32))).result
|
||||
else:
|
||||
return mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.int32),
|
||||
type=ir.IntegerType.get_signless(32))).result
|
||||
else:
|
||||
def _mhlo_u8(x):
|
||||
typ = ir.RankedTensorType.get([], ir.IntegerType.get_unsigned(8))
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
return mhlo.ConstOp(
|
||||
typ,
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.uint8), type=typ.element_type)).result
|
||||
else:
|
||||
return mhlo.ConstantOp(
|
||||
typ,
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.uint8), type=typ.element_type)).result
|
||||
|
||||
def _mhlo_s32(x):
|
||||
typ = ir.RankedTensorType.get([], ir.IntegerType.get_signless(32))
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
return mhlo.ConstOp(typ,
|
||||
ir.DenseElementsAttr.get(np.array(
|
||||
x, dtype=np.int32))).result
|
||||
else:
|
||||
return mhlo.ConstantOp(
|
||||
typ, ir.DenseElementsAttr.get(np.array(x, dtype=np.int32))).result
|
||||
def _mhlo_u8(x):
|
||||
return mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.uint8),
|
||||
type=ir.IntegerType.get_unsigned(8))).result
|
||||
|
||||
def _mhlo_s32(x):
|
||||
return mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.int32),
|
||||
type=ir.IntegerType.get_signless(32))).result
|
||||
|
||||
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
from typing import List
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
@ -129,51 +128,17 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
raise ValueError(f"Unknown output type {out_dtype}")
|
||||
|
||||
if 0 in a_type.shape or 0 in out_shape:
|
||||
if xla_client._version >= 64:
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
zero = mhlo.ConstOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(0, dtype=out_dtype), type=out_type))
|
||||
else:
|
||||
zero = mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(0, dtype=out_dtype), type=out_type))
|
||||
else:
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
zero = mhlo.ConstOp(
|
||||
ir.RankedTensorType.get([], out_type),
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(0, dtype=out_dtype), type=out_type))
|
||||
else:
|
||||
zero = mhlo.ConstantOp(
|
||||
ir.RankedTensorType.get([], out_type),
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(0, dtype=out_dtype), type=out_type))
|
||||
zero = mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(0, dtype=out_dtype), type=out_type))
|
||||
return mhlo.BroadcastOp(
|
||||
zero,
|
||||
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
|
||||
|
||||
u8_type = ir.IntegerType.get_unsigned(8)
|
||||
if xla_client._version >= 64:
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
descriptor = mhlo.ConstOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
else:
|
||||
descriptor = mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
else:
|
||||
if jax._src.lib.mlir_api_version < 21:
|
||||
descriptor = mhlo.ConstOp(
|
||||
ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
else:
|
||||
descriptor = mhlo.ConstantOp(
|
||||
ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
descriptor = mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
layout = tuple(range(n - 1, -1, -1))
|
||||
return custom_call(
|
||||
"pocketfft",
|
||||
|
@ -654,9 +654,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
".*while trying to hash an object of type "
|
||||
"<class 'numpy\\.ndarray'>, 1. The error was:\nTypeError: "
|
||||
"unhashable type: 'numpy\\.ndarray'")
|
||||
# Typo was fixed in newer jaxlib
|
||||
if jax._src.lib.xla_extension_version < 66:
|
||||
msg = msg.replace('occurred', 'occured')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jitted_f(1, np.asarray(1))
|
||||
@ -9167,13 +9164,11 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() != 'iree', "iree test")
|
||||
@unittest.skipIf(jaxlib_version < (0, 3, 11), "test requires jaxlib>=0.3.11")
|
||||
def test_jit_of_broadcast(self):
|
||||
x = jax.jit(jnp.ones)(3)
|
||||
self.assertAllClose(x, jnp.ones(3))
|
||||
|
||||
@unittest.skipIf(jtu.device_under_test() != 'iree', "iree test")
|
||||
@unittest.skipIf(jaxlib_version < (0, 3, 11), "test requires jaxlib>=0.3.11")
|
||||
def test_jit_of_broadcast2(self):
|
||||
x = jax.jit(lambda n: jnp.ones(2 * n))(3)
|
||||
self.assertAllClose(x, jnp.ones(2 * 3))
|
||||
|
@ -49,13 +49,9 @@ def setUpModule():
|
||||
def tearDownModule():
|
||||
prev_xla_flags()
|
||||
|
||||
# TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum
|
||||
# version is >= 0.3.11
|
||||
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
|
||||
# version is >= 0.3.15
|
||||
disabled_backends = []
|
||||
if jaxlib.version < (0, 3, 11):
|
||||
disabled_backends.append("gpu")
|
||||
if jaxlib.version < (0, 3, 15):
|
||||
disabled_backends.append("tpu")
|
||||
|
||||
|
@ -58,13 +58,9 @@ def setUpModule():
|
||||
def tearDownModule():
|
||||
prev_xla_flags()
|
||||
|
||||
# TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum
|
||||
# version is >= 0.3.11
|
||||
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
|
||||
# version is >= 0.3.15
|
||||
disabled_backends = []
|
||||
if jaxlib.version < (0, 3, 11):
|
||||
disabled_backends.append("gpu")
|
||||
if jaxlib.version < (0, 3, 15):
|
||||
disabled_backends.append("tpu")
|
||||
|
||||
|
@ -35,8 +35,6 @@ except ImportError:
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@unittest.skipIf(jax._src.lib.xla_extension_version < 73,
|
||||
"Test requires jaxlib 0.3.12 or newer.")
|
||||
@unittest.skipIf(not portpicker, "Test requires portpicker")
|
||||
class DistributedTest(jtu.JaxTestCase):
|
||||
|
||||
|
@ -59,13 +59,9 @@ lcf.allowed_effects.add('while')
|
||||
lcf.allowed_effects.add('while1')
|
||||
lcf.allowed_effects.add('while2')
|
||||
|
||||
# TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum
|
||||
# version is >= 0.3.11
|
||||
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
|
||||
# version is >= 0.3.15
|
||||
disabled_backends = []
|
||||
if jaxlib.version < (0, 3, 11):
|
||||
disabled_backends.append('gpu')
|
||||
if jaxlib.version < (0, 3, 15):
|
||||
disabled_backends.append('tpu')
|
||||
|
||||
|
@ -183,9 +183,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
else ["lu", "qr"])
|
||||
))
|
||||
def testSlogdet(self, shape, dtype, method):
|
||||
if (method == 'qr' and jax._src.lib.xla_extension_version < 69 and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
raise unittest.SkipTest('qr decomposition is not supported.')
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
slogdet = partial(jnp.linalg.slogdet, method=method)
|
||||
|
@ -1436,9 +1436,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
s2_shape, s3_shape, s4_shape):
|
||||
# Disable on SE runtime type because XLA sharding propagation is not
|
||||
# supported.
|
||||
if xla_client._version < 74 or xla_bridge.get_backend().runtime_type == 'se':
|
||||
raise unittest.SkipTest('Needs xla_extension_version >= 74 or '
|
||||
'TFRT runtime.')
|
||||
if xla_bridge.get_backend().runtime_type == 'se':
|
||||
raise unittest.SkipTest('Needs TFRT runtime.')
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
|
||||
|
@ -2014,10 +2014,6 @@ class CppPmapTest(PythonPmapTest):
|
||||
self.assertEqual(f._cache_size(), size+1)
|
||||
|
||||
def test_cache_hits_across_threads(self):
|
||||
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.11
|
||||
if xla_bridge.xla_client._version < 70:
|
||||
raise unittest.SkipTest("This test requires jaxlib version >= 0.3.11")
|
||||
|
||||
f = lambda x: x+1
|
||||
inputs = np.zeros([jax.device_count()], dtype=np.float32)
|
||||
pmaped_f = self.pmap(f)
|
||||
@ -2035,10 +2031,6 @@ class CppPmapTest(PythonPmapTest):
|
||||
self.assertEqual(pmaped_f._cache_size, 1)
|
||||
|
||||
def test_cache_uses_jax_key(self):
|
||||
# TODO(parkers): remove after minimum jaxlib version is > 0.3.11
|
||||
if xla_bridge.xla_client._version < 74:
|
||||
raise unittest.SkipTest("This test requires jaxlib version >= 0.3.11")
|
||||
|
||||
f = lambda x: x+1
|
||||
inputs = np.zeros([jax.device_count()], dtype=np.float32)
|
||||
pmaped_f = self.pmap(f)
|
||||
|
@ -34,7 +34,6 @@ from jax.experimental.sparse import bcoo as sparse_bcoo
|
||||
from jax.experimental.sparse.bcoo import BCOOInfo
|
||||
from jax import lax
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import sparse_apis
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax import jit
|
||||
from jax import tree_util
|
||||
@ -57,11 +56,8 @@ MATMUL_TOL = {
|
||||
np.complex128: 1E-10,
|
||||
}
|
||||
|
||||
if gpu_sparse:
|
||||
GPU_LOWERING_ENABLED = gpu_sparse and (gpu_sparse.cuda_is_supported or
|
||||
gpu_sparse.rocm_is_supported)
|
||||
else:
|
||||
GPU_LOWERING_ENABLED = (sparse_apis and sparse_apis.is_supported)
|
||||
GPU_LOWERING_ENABLED = gpu_sparse and (gpu_sparse.cuda_is_supported or
|
||||
gpu_sparse.rocm_is_supported)
|
||||
|
||||
class BcooDotGeneralProperties(NamedTuple):
|
||||
lhs_shape: Tuple[int]
|
||||
@ -516,24 +512,15 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
cuda_version = None if version == "<unknown>" else int(
|
||||
version.split()[-1])
|
||||
if cuda_version is None or cuda_version < 11000:
|
||||
if gpu_sparse:
|
||||
self.assertFalse(gpu_sparse and gpu_sparse.cuda_is_supported)
|
||||
else:
|
||||
self.assertFalse(sparse_apis and sparse_apis.is_supported)
|
||||
self.assertFalse(gpu_sparse and gpu_sparse.cuda_is_supported)
|
||||
self.assertNotIn(sparse.csr_todense_p,
|
||||
mlir._platform_specific_lowerings["cuda"])
|
||||
else:
|
||||
if gpu_sparse:
|
||||
self.assertTrue(gpu_sparse and gpu_sparse.cuda_is_supported)
|
||||
else:
|
||||
self.assertTrue(sparse_apis and sparse_apis.is_supported)
|
||||
self.assertTrue(gpu_sparse and gpu_sparse.cuda_is_supported)
|
||||
self.assertIn(sparse.csr_todense_p,
|
||||
mlir._platform_specific_lowerings["cuda"])
|
||||
else:
|
||||
if gpu_sparse:
|
||||
self.assertTrue(gpu_sparse and gpu_sparse.rocm_is_supported)
|
||||
else:
|
||||
self.assertTrue(sparse_apis and sparse_apis.is_supported)
|
||||
self.assertTrue(gpu_sparse and gpu_sparse.rocm_is_supported)
|
||||
self.assertIn(sparse.csr_todense_p,
|
||||
mlir._platform_specific_lowerings["rocm"])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user