Update minimum jaxlib version to 0.3.14.

This commit is contained in:
Peter Hawkins 2022-07-08 00:21:16 +00:00
parent 44bd311ae7
commit 0b4b0ba072
32 changed files with 231 additions and 699 deletions

View File

@ -59,7 +59,6 @@ from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import xla_extension_version
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
@ -585,9 +584,8 @@ def _cpp_jit(
return _BackendAndDeviceInfo(default_device, committed_to_device)
jitted_f_kwargs = {}
if xla_extension_version >= 71:
jitted_f_kwargs["has_explicit_device"] = (
device is not None or backend is not None)
jitted_f_kwargs["has_explicit_device"] = (
device is not None or backend is not None)
cpp_jitted_f = jax_jit.jit(
fun,
cache_miss,

View File

@ -109,9 +109,8 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
return result
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
platform="cpu")
if jaxlib.version >= (0, 3, 11):
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="gpu")
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="gpu")
if jaxlib.version >= (0, 3, 15):
mlir.register_lowering(
debug_callback_p, debug_callback_lowering, platform="tpu")

View File

@ -66,22 +66,14 @@ class State:
if self.service is not None:
raise RuntimeError('distributed.initialize should only be called once.')
logging.info('Starting JAX distributed service on %s', coordinator_address)
if xla_client._version >= 72:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
else:
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes)
self.service = xla_extension.get_distributed_runtime_service(
coordinator_address, num_processes, config.jax_coordination_service)
if self.client is not None:
raise RuntimeError('distributed.initialize should only be called once.')
if xla_client._version >= 72:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service)
else:
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id)
self.client = xla_extension.get_distributed_runtime_client(
coordinator_address, process_id, config.jax_coordination_service)
logging.info('Connecting to JAX distributed service on %s', coordinator_address)
self.client.connect()

View File

@ -132,8 +132,6 @@ def approx_max_k(operand: Array,
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> dot_products, neighbors = mips(qy, db, k=10)
"""
if xc._version < 45:
aggregate_to_topk = True
return approx_top_k_p.bind(
operand,
k=k,
@ -197,8 +195,6 @@ def approx_min_k(operand: Array,
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
arithmetics and produces the same set of neighbors.
"""
if xc._version < 45:
aggregate_to_topk = True
return approx_top_k_p.bind(
operand,
k=k,
@ -225,13 +221,10 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
dims[reduction_dimension], k))
if not dtypes.issubdtype(operand.dtype, np.floating):
raise ValueError('operand must be a floating type')
if xc._version >= 45:
reduction_input_size = dims[reduction_dimension]
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
reduction_input_size_override)[0]
else:
dims[reduction_dimension] = k
reduction_input_size = dims[reduction_dimension]
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
reduction_input_size_override)[0]
return (operand.update(
shape=dims, dtype=operand.dtype, weak_type=operand.weak_type),
operand.update(shape=dims, dtype=np.dtype(np.int32)))

View File

@ -3070,17 +3070,10 @@ masking.masking_rules[pad_p] = _pad_masking_rule
def _pad_lower(ctx, x, padding_value, *, padding_config):
low, high, interior = util.unzip3(padding_config)
if jax._src.lib.mlir_api_version < 15:
aval_out, = ctx.avals_out
return mhlo.PadOp(mlir.aval_to_ir_type(aval_out), x, padding_value,
mlir.dense_int_elements(low),
mlir.dense_int_elements(high),
mlir.dense_int_elements(interior)).results
else:
return mhlo.PadOp(x, padding_value,
mlir.dense_int_elements(low),
mlir.dense_int_elements(high),
mlir.dense_int_elements(interior)).results
return mhlo.PadOp(x, padding_value,
mlir.dense_int_elements(low),
mlir.dense_int_elements(high),
mlir.dense_int_elements(interior)).results
mlir.register_lowering(pad_p, _pad_lower)
@ -3817,13 +3810,8 @@ masking.defvectorized(reduce_precision_p)
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
aval_out, = ctx.avals_out
if jax._src.lib.mlir_api_version >= 21:
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
else:
return mhlo.ReducePrecisionOp(mlir.aval_to_ir_type(aval_out), operand,
mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
@ -4059,12 +4047,9 @@ top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
if jax._src.lib.mlir_api_version >= 16:
def _top_k_lower(ctx, operand, k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
mlir.register_lowering(top_k_p, _top_k_lower)
else:
xla.register_translation(top_k_p, _top_k_translation_rule)
def _top_k_lower(ctx, operand, k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
mlir.register_lowering(top_k_p, _top_k_lower)
ad.primitive_jvps[top_k_p] = _top_k_jvp
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
@ -4315,9 +4300,7 @@ def _rng_bit_generator_lowering(
key = mhlo.BitcastConvertOp(
ir.RankedTensorType.get([2], u64_type),
mhlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
algorithm_attr = (
_rng_algorithm(algorithm) if jax._src.lib.mlir_api_version >= 14
else mlir.i32_attr(algorithm))
algorithm_attr = _rng_algorithm(algorithm)
out_key, out_vals = mhlo.RngBitGeneratorOp(
key.type,
ir.RankedTensorType.get(shape, rbg_etype),

View File

@ -46,11 +46,6 @@ from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import cuda_linalg
from jax._src.lib import hip_linalg
from jax._src.lib import sparse_apis
from jax._src.lib import solver_apis
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
@ -414,12 +409,7 @@ ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
def _cholesky_lowering(ctx, x):
if jax._src.lib.mlir_api_version < 18:
aval, = ctx.avals_out
return mhlo.CholeskyOp(mlir.aval_to_ir_type(aval), x,
lower=ir.BoolAttr.get(True)).results
else:
return mhlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
return mhlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
mlir.register_lowering(cholesky_p, _cholesky_lowering)
@ -442,22 +432,14 @@ mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, gpu_solver.cuda_potrf),
platform='cuda')
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, gpu_solver.rocm_potrf),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, solver_apis.potrf_mhlo),
platform='gpu')
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, gpu_solver.cuda_potrf),
platform='cuda')
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, gpu_solver.rocm_potrf),
platform='rocm')
# Asymmetric eigendecomposition
@ -756,12 +738,6 @@ if gpu_solver is not None:
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
platform='gpu')
mlir.register_lowering(
eigh_p, mlir.lower_fun(_eigh_tpu_impl, multiple_results=True),
platform='tpu')
@ -948,21 +924,14 @@ def _triangular_solve_gpu_lower(
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
if gpu_solver is not None:
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, gpu_solver.cuda_trsm),
platform='cuda')
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, gpu_solver.rocm_trsm),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, solver_apis.trsm_mhlo),
platform='gpu')
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, gpu_solver.cuda_trsm),
platform='cuda')
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, gpu_solver.rocm_trsm),
platform='rocm')
# Support operation for LU decomposition: Transformation of the pivots returned
@ -1052,31 +1021,16 @@ batching.primitive_batchers[lu_pivots_to_permutation_p] = (
mlir.register_lowering(
lu_pivots_to_permutation_p,
mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False))
if gpu_linalg:
mlir.register_lowering(
lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
gpu_linalg.cuda_lu_pivots_to_permutation),
platform='cuda')
mlir.register_lowering(
lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
gpu_linalg.hip_lu_pivots_to_permutation),
platform='rocm')
if cuda_linalg:
mlir.register_lowering(lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
cuda_linalg.lu_pivots_to_permutation_mhlo),
platform='cuda')
if hip_linalg:
mlir.register_lowering(lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
hip_linalg.lu_pivots_to_permutation_mhlo),
platform='rocm')
mlir.register_lowering(
lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
gpu_linalg.cuda_lu_pivots_to_permutation),
platform='cuda')
mlir.register_lowering(
lu_pivots_to_permutation_p,
partial(_lu_pivots_to_permutation_gpu_lowering,
gpu_linalg.hip_lu_pivots_to_permutation),
platform='rocm')
# LU decomposition
@ -1273,18 +1227,12 @@ mlir.register_lowering(lu_p,
partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf),
platform='cuda')
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, solver_apis.getrf_mhlo),
platform='gpu')
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf),
platform='cuda')
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf),
platform='rocm')
xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu')
@ -1418,25 +1366,16 @@ xla.register_translation(geqrf_p, _geqrf_translation_rule)
mlir.register_lowering(
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo, None),
platform='cpu')
if gpu_solver is not None:
# TODO(phawkins): make cuda_geqrf_batched and rocm_geqrf_unbatched
# unconditional when jaxlib 0.3.11 is the minimum.
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
getattr(gpu_solver, 'cuda_geqrf_batched', None)),
platform='cuda')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
getattr(gpu_solver, 'rocm_geqrf_batched', None)),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, solver_apis.geqrf_mhlo, None),
platform='gpu')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
gpu_solver.cuda_geqrf_batched),
platform='cuda')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
gpu_solver.rocm_geqrf_batched),
platform='rocm')
# orgqr: product of elementary Householder reflectors
@ -1505,21 +1444,14 @@ xla.register_translation(orgqr_p, _orgqr_translation_rule)
mlir.register_lowering(
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
platform='cuda')
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, solver_apis.orgqr_mhlo),
platform='gpu')
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
platform='cuda')
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
platform='rocm')
def _qr_impl(operand, *, full_matrices):
@ -1601,11 +1533,6 @@ qr_p.multiple_results = True
qr_p.def_impl(_qr_impl)
qr_p.def_abstract_eval(_qr_abstract_eval)
# Older jaxlibs didn't expose geqrf and orgqr as separate XLA operations.
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.10.
if jax._src.lib.xla_extension_version < 69:
xla.register_translation(qr_p, _qr_translation_rule, platform="tpu")
ad.primitive_jvps[qr_p] = qr_jvp_rule
batching.primitive_batchers[qr_p] = _qr_batching_rule
@ -1795,19 +1722,12 @@ batching.primitive_batchers[svd_p] = _svd_batching_rule
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
platform='cuda')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
platform='gpu')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
platform='cuda')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.rocm_gesvd),
platform='rocm')
mlir.register_lowering(svd_p, _svd_tpu_lowering_rule)
@ -1821,20 +1741,15 @@ tridiagonal_solve_p.def_impl(
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
if sparse_apis and hasattr(sparse_apis, "gtsv2"):
mlir.register_lowering(tridiagonal_solve_p,
_tridiagonal_solve_gpu_lowering,
platform='gpu')
if gpu_sparse:
mlir.register_lowering(
tridiagonal_solve_p,
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2),
platform='cuda')
mlir.register_lowering(
tridiagonal_solve_p,
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2),
platform='rocm')
mlir.register_lowering(
tridiagonal_solve_p,
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.cuda_gtsv2),
platform='cuda')
mlir.register_lowering(
tridiagonal_solve_p,
partial(_tridiagonal_solve_gpu_lowering, gpu_sparse.rocm_gtsv2),
platform='rocm')
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
@ -1966,17 +1881,10 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]
# TODO(jakevdp): remove this try/except when minimum jaxlib >= 0.3.8
try:
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
except TypeError: # API for jaxlib <= 0.3.7
gees_result = lapack.gees_mhlo(operand, # pytype: disable=missing-parameter
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
# Number of return values depends on value of sort_eig_vals.
T, vs, *_, info = gees_result

View File

@ -721,12 +721,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
_replica_groups(ctx.module_context.axis_env, named_axes,
axis_index_groups))
def all_reduce(x_dtype, x):
if jax._src.lib.mlir_api_version >= 17:
op = mhlo.AllReduceOp(
x.type, x, replica_groups=replica_groups, channel_handle=None)
else:
op = mhlo.AllReduceOp(
x, replica_groups=replica_groups, channel_handle=None)
op = mhlo.AllReduceOp(
x.type, x, replica_groups=replica_groups, channel_handle=None)
scalar_aval = core.ShapedArray((), x_dtype)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
@ -741,22 +737,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
mhlo.ReturnOp(util.flatten(out_nodes))
return op.result
outs = []
for aval, x in zip(ctx.avals_in, args):
# TODO(phawkins): remove this special case when jaxlib > 0.3.10 is the
# minimum.
if (jax._src.lib.xla_extension_version < 75 and prim is lax.add_p
and dtypes.issubdtype(aval.dtype, np.complexfloating)):
real_dtype = np.finfo(aval.dtype).dtype
outs.append(
mhlo.ComplexOp(
all_reduce(real_dtype,
mhlo.RealOp(x).result),
all_reduce(real_dtype,
mhlo.ImagOp(x).result)).result)
else:
outs.append(all_reduce(aval.dtype, x))
return outs
return [all_reduce(aval.dtype, x) for aval, x in zip(ctx.avals_in, args)]
def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
@ -947,17 +928,10 @@ def _all_to_all_lowering(ctx, x, *,
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
if jax._src.lib.mlir_api_version < 19:
return mhlo.AllToAllOp(mlir.aval_to_ir_type(ctx.avals_out[0]),
x, mlir.i64_attr(split_axis),
mlir.i64_attr(concat_axis),
mlir.i64_attr(split_count),
_replica_groups_mhlo(replica_groups)).results
else:
return mhlo.AllToAllOp(x, mlir.i64_attr(split_axis),
mlir.i64_attr(concat_axis),
mlir.i64_attr(split_count),
_replica_groups_mhlo(replica_groups)).results
return mhlo.AllToAllOp(x, mlir.i64_attr(split_axis),
mlir.i64_attr(concat_axis),
mlir.i64_attr(split_count),
_replica_groups_mhlo(replica_groups)).results
else:
warnings.warn(
"all_to_all (and pswapaxes) are only implemented properly for TPUs and GPUs (if "

View File

@ -1941,12 +1941,9 @@ def _scatter_lower(ctx, operand, indices, updates, *,
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
if jax._src.lib.mlir_api_version >= 20:
result = mlir.aval_to_ir_types(aval_out)
operand = [operand]
updates = [updates]
else:
result = mlir.aval_to_ir_type(aval_out)
result = mlir.aval_to_ir_types(aval_out)
operand = [operand]
updates = [updates]
op = mhlo.ScatterOp(
result,
operand,
@ -2001,19 +1998,12 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1)
real_dtype = _real_dtype(aval_out.dtype)
if jax._src.lib.mlir_api_version >= 20:
operand_type_part = mlir.aval_to_ir_types(
core.ShapedArray(aval_out.shape, real_dtype))
else:
operand_type_part = mlir.aval_to_ir_type(
core.ShapedArray(aval_out.shape, real_dtype))
operand_type_part = mlir.aval_to_ir_types(
core.ShapedArray(aval_out.shape, real_dtype))
def _scatter(operand_part, updates_part):
# If the MLIR api supports variadic scatter, we make a variadic scatter op
# with arity 1
if jax._src.lib.mlir_api_version >= 20:
operand_part = [operand_part]
updates_part = [updates_part]
operand_part = [operand_part]
updates_part = [updates_part]
scatter = mhlo.ScatterOp(
operand_type_part,

View File

@ -706,16 +706,10 @@ def _select_and_gather_add_lowering(
def pack(a, b):
a_dims = ir.RankedTensorType(a.type).shape
b_dims = ir.RankedTensorType(b.type).shape
if jax._src.lib.mlir_api_version >= 21:
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
else:
a = mhlo.ReducePrecisionOp(a.type, a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b.type, b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
b = mhlo.ShiftRightLogicalOp(

View File

@ -113,71 +113,10 @@ pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
# TODO(phawkins): make gpu_... unconditional after jaxlib >= 0.3.11
# becomes the minimum; remove cuda_... and hip_....
try:
import jaxlib.cusolver as cusolver # pytype: disable=import-error
except ImportError:
cusolver = None
try:
import jaxlib.hipsolver as hipsolver # pytype: disable=import-error
except ImportError:
hipsolver = None
try:
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
except ImportError:
gpu_solver = None
try:
import jaxlib.cusparse as cusparse # pytype: disable=import-error
except ImportError:
cusparse = None
try:
import jaxlib.hipsparse as hipsparse # pytype: disable=import-error
except ImportError:
hipsparse = None
try:
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
except ImportError:
gpu_sparse = None
sparse_apis = cusparse or hipsparse or None
solver_apis = cusolver or hipsolver or None
try:
import jaxlib.cuda_prng as cuda_prng # pytype: disable=import-error
except ImportError:
cuda_prng = None
try:
import jaxlib.hip_prng as hip_prng # pytype: disable=import-error
except ImportError:
hip_prng = None
try:
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
except ImportError:
gpu_prng = None
try:
import jaxlib.cuda_linalg as cuda_linalg # pytype: disable=import-error
except ImportError:
cuda_linalg = None
try:
import jaxlib.hip_linalg as hip_linalg # pytype: disable=import-error
except ImportError:
hip_linalg = None
try:
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
except ImportError:
gpu_linalg = None
import jaxlib.gpu_solver as gpu_solver # pytype: disable=import-error
import jaxlib.gpu_sparse as gpu_sparse # pytype: disable=import-error
import jaxlib.gpu_prng as gpu_prng # pytype: disable=import-error
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
# Jaxlib code is split between the Jax and the Tensorflow repositories.
# Only for the internal usage of the JAX developers, we expose a version

View File

@ -219,15 +219,12 @@ def make_gpu_client(platform_name=None):
platform_name=platform_name)
if hasattr(xla_client, "make_gpu_client"):
if xla_client._version >= 65:
register_backend_factory(
'cuda', partial(make_gpu_client, platform_name='cuda'),
priority=200)
register_backend_factory(
'rocm', partial(make_gpu_client, platform_name='rocm'),
priority=200)
else:
register_backend_factory('gpu', make_gpu_client, priority=200)
register_backend_factory(
'cuda', partial(make_gpu_client, platform_name='cuda'),
priority=200)
register_backend_factory(
'rocm', partial(make_gpu_client, platform_name='rocm'),
priority=200)
if hasattr(xla_client, "make_tpu_client"):
register_backend_factory(
@ -310,16 +307,7 @@ def backends():
for platform, priority in platforms_and_priorites:
try:
backend = _init_backend(platform)
if platform == "gpu" and xla_client._version <= 64:
# TODO(phawkins): remove this special handling when jaxlib v0.3.11
# is the minimum.
if "rocm" in backend.platform_version:
_backends["rocm"] = backend
else:
_backends["cuda"] = backend
else:
_backends[platform] = backend
_backends[platform] = backend
if priority > default_priority:
_default_backend = backend

View File

@ -38,11 +38,7 @@ from jax._src.numpy.lax_numpy import (
import jax._src.pretty_printer as pp
from jax._src.util import canonicalize_axis, prod
# TODO(phawkins): make gpu_prng unconditional after jaxlib >= 0.3.11
# becomes the minimum; remove cuda_prng and hip_prng.
from jax._src.lib import gpu_prng
from jax._src.lib import cuda_prng
from jax._src.lib import hip_prng
UINT_DTYPES = {
@ -416,27 +412,14 @@ mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),
multiple_results=True), platform='cpu')
if gpu_prng:
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32),
platform='cuda')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32),
platform='rocm')
if cuda_prng:
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering, cuda_prng.threefry2x32_lowering),
platform='cuda')
if hip_prng:
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering, hip_prng.threefry2x32_lowering),
platform='rocm')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering, gpu_prng.cuda_threefry2x32),
platform='cuda')
mlir.register_lowering(
threefry2x32_p,
partial(_threefry2x32_gpu_lowering, gpu_prng.rocm_threefry2x32),
platform='rocm')
@partial(jit, inline=True)

View File

@ -830,10 +830,8 @@ def _initialize_jax_jit_thread_local_state():
dynamic_trace_state=copy)
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.11
if lib.xla_extension_version >= 70:
jax_jit.set_thread_local_state_initialization_callback(
_initialize_jax_jit_thread_local_state)
jax_jit.set_thread_local_state_initialization_callback(
_initialize_jax_jit_thread_local_state)
def trace_state_clean() -> bool:
trace_state = thread_local_state.trace_state

View File

@ -111,10 +111,7 @@ def _hash_computation(hash_obj, xla_computation):
hash_obj.update(scrubbed_hlo)
def _hash_compile_options(hash_obj, compile_options_obj):
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
expected_num_compile_options = 32
else:
expected_num_compile_options = 31
expected_num_compile_options = 32
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
f"Unexpected number of CompileOption fields: "
f"{len(dir(compile_options_obj))}. This likely: means that an extra "
@ -128,8 +125,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
_hash_bool(hash_obj, compile_options_obj.tuple_arguments)
_hash_int(hash_obj, compile_options_obj.num_replicas)
_hash_int(hash_obj, compile_options_obj.num_partitions)
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
_hash_int(hash_obj, compile_options_obj.profile_version)
_hash_int(hash_obj, compile_options_obj.profile_version)
if compile_options_obj.device_assignment is not None:
hash_obj.update(compile_options_obj.device_assignment.serialize())

View File

@ -1932,12 +1932,6 @@ def _ensure_spmd_and(f):
return f(v)
return update
def _ensure_supports_manual_and(f):
def update(v):
if v and not hasattr(xc.OpSharding.Type, "MANUAL"):
raise RuntimeError("This flag requires a version of jaxlib that supports MANUAL sharding type")
return f(v)
return update
try:
config.define_bool_state(
@ -1955,7 +1949,7 @@ try:
"the MANUAL partitioning feature of the XLA SPMD partitioner instead of "
"sharding constraints on vectorized code. "
"Requires experimental_xmap_spmd_lowering!"),
update_global_hook=_ensure_supports_manual_and(_ensure_spmd_and(_clear_compilation_cache)),
update_global_hook=_ensure_spmd_and(_clear_compilation_cache),
update_thread_local_hook=_thread_local_flag_unsupported)
config.define_bool_state(
name="experimental_xmap_ensure_fixed_sharding",

View File

@ -46,7 +46,6 @@ from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.setops import _unique
from jax._src.lib import gpu_sparse
from jax._src.lib import sparse_apis
Dtype = Any
Shape = Tuple[int, ...]
@ -924,26 +923,18 @@ batching.primitive_batchers[bcoo_dot_general_p] = _bcoo_dot_general_batch_rule
mlir.register_lowering(
bcoo_dot_general_p, _bcoo_dot_general_default_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(bcoo_dot_general_p,
partial(_bcoo_dot_general_gpu_lowering,
gpu_sparse.cuda_coo_matvec,
gpu_sparse.cuda_coo_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(bcoo_dot_general_p,
partial(_bcoo_dot_general_gpu_lowering,
gpu_sparse.rocm_coo_matvec,
gpu_sparse.rocm_coo_matmat),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(bcoo_dot_general_p,
partial(_bcoo_dot_general_gpu_lowering,
sparse_apis.coo_matvec_mhlo,
sparse_apis.coo_matmat_mhlo),
platform='gpu')
partial(_bcoo_dot_general_gpu_lowering,
gpu_sparse.cuda_coo_matvec,
gpu_sparse.cuda_coo_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(bcoo_dot_general_p,
partial(_bcoo_dot_general_gpu_lowering,
gpu_sparse.rocm_coo_matvec,
gpu_sparse.rocm_coo_matmat),
platform='rocm')
#----------------------------------------------------------------------
# bcoo_dot_general_sampled

View File

@ -31,7 +31,6 @@ from jax import tree_util
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import gpu_sparse
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp
@ -230,23 +229,16 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo):
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_todense_p,
partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_todense_p,
partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_todense_p,
partial(_coo_todense_gpu_lowering, sparse_apis.coo_todense_mhlo),
platform='gpu')
partial(_coo_todense_gpu_lowering, gpu_sparse.cuda_coo_todense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_todense_p,
partial(_coo_todense_gpu_lowering, gpu_sparse.rocm_coo_todense),
platform='rocm')
#--------------------------------------------------------------------
# coo_fromdense
@ -356,23 +348,16 @@ ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose
mlir.register_lowering(coo_fromdense_p, _coo_fromdense_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_fromdense_p,
partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_fromdense_p,
partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_fromdense_p,
partial(_coo_fromdense_gpu_lowering, sparse_apis.coo_fromdense_mhlo),
platform='gpu')
partial(_coo_fromdense_gpu_lowering, gpu_sparse.cuda_coo_fromdense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_fromdense_p,
partial(_coo_fromdense_gpu_lowering, gpu_sparse.rocm_coo_fromdense),
platform='rocm')
#--------------------------------------------------------------------
# coo_matvec
@ -486,23 +471,17 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose):
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
mlir.register_lowering(coo_matvec_p, _coo_matvec_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_matvec_p,
partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_matvec_p,
partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_matvec_p,
partial(_coo_matvec_gpu_lowering, sparse_apis.coo_matvec_mhlo),
platform='gpu')
partial(_coo_matvec_gpu_lowering, gpu_sparse.cuda_coo_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_matvec_p,
partial(_coo_matvec_gpu_lowering, gpu_sparse.rocm_coo_matvec),
platform='rocm')
#--------------------------------------------------------------------
# coo_matmat
@ -611,20 +590,13 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose):
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
mlir.register_lowering(coo_matmat_p, _coo_matmat_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_matmat_p,
partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_matmat_p,
partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
coo_matmat_p,
partial(_coo_matmat_gpu_lowering, sparse_apis.coo_matmat_mhlo),
platform='gpu')
partial(_coo_matmat_gpu_lowering, gpu_sparse.cuda_coo_matmat),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
coo_matmat_p,
partial(_coo_matmat_gpu_lowering, gpu_sparse.rocm_coo_matmat),
platform='rocm')

View File

@ -31,7 +31,6 @@ from jax import lax
from jax import tree_util
from jax._src.lax.lax import _const
from jax._src.lib import gpu_sparse
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp
@ -241,23 +240,17 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
mlir.register_lowering(csr_todense_p, _csr_todense_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
csr_todense_p,
partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
csr_todense_p,
partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
csr_todense_p,
partial(_csr_todense_gpu_lowering, sparse_apis.csr_todense_mhlo),
platform='gpu')
partial(_csr_todense_gpu_lowering, gpu_sparse.cuda_csr_todense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
csr_todense_p,
partial(_csr_todense_gpu_lowering, gpu_sparse.rocm_csr_todense),
platform='rocm')
#--------------------------------------------------------------------
# csr_fromdense
@ -349,23 +342,16 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
mlir.register_lowering(csr_fromdense_p, _csr_fromdense_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
csr_fromdense_p,
partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
csr_fromdense_p,
partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
csr_fromdense_p,
partial(_csr_fromdense_gpu_lowering, sparse_apis.csr_fromdense_mhlo),
platform='gpu')
partial(_csr_fromdense_gpu_lowering, gpu_sparse.cuda_csr_fromdense),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
csr_fromdense_p,
partial(_csr_fromdense_gpu_lowering, gpu_sparse.rocm_csr_fromdense),
platform='rocm')
#--------------------------------------------------------------------
# csr_matvec
@ -446,23 +432,16 @@ ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
mlir.register_lowering(csr_matvec_p, _csr_matvec_lowering)
if gpu_sparse:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
csr_matvec_p,
partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
csr_matvec_p,
partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
if gpu_sparse.cuda_is_supported:
mlir.register_lowering(
csr_matvec_p,
partial(_csr_matvec_gpu_lowering, sparse_apis.csr_matvec_mhlo),
platform='gpu')
partial(_csr_matvec_gpu_lowering, gpu_sparse.cuda_csr_matvec),
platform='cuda')
if gpu_sparse.rocm_is_supported:
mlir.register_lowering(
csr_matvec_p,
partial(_csr_matvec_gpu_lowering, gpu_sparse.rocm_csr_matvec),
platform='rocm')
#--------------------------------------------------------------------
@ -555,9 +534,3 @@ if gpu_sparse:
csr_matmat_p,
partial(_csr_matmat_gpu_lowering, gpu_sparse.rocm_csr_matmat),
platform='rocm')
if sparse_apis and sparse_apis.is_supported:
mlir.register_lowering(
csr_matmat_p,
partial(_csr_matmat_gpu_lowering, sparse_apis.csr_matmat_mhlo),
platform='gpu')

View File

@ -226,10 +226,7 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
x = x.view(np.uint16)
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
if jax._src.lib.mlir_api_version < 21:
return (mhlo.ConstOp(attr).result,)
else:
return (mhlo.ConstantOp(attr).result,)
return (mhlo.ConstantOp(attr).result,)
@ -1371,9 +1368,6 @@ def emit_python_callback(
has_side_effect: bool) -> Tuple[List[ir.Value], Any, Any]:
"""Creates an MHLO `CustomCallOp` that calls back to the provided function."""
platform = ctx.module_context.platform
if platform in {"cuda", "rocm"} and jax._src.lib.version < (0, 3, 11):
raise ValueError(
"`EmitPythonCallback` on CUDA only supported on jaxlib >= 0.3.11")
if platform in {"tpu"} and jax._src.lib.version < (0, 3, 15):
raise ValueError(
"`EmitPythonCallback` on TPU only supported on jaxlib >= 0.3.15")

View File

@ -1732,13 +1732,8 @@ def _mhlo_shard(aval, axis_env, xs, in_axis):
idxs.insert(in_axis, _unravel_index_mhlo(axis_env))
dims_unsqueezed = dims.copy()
dims_unsqueezed.insert(in_axis, 1)
if jax._src.lib.mlir_api_version < 13:
dynamic_slice_result = mhlo.DynamicSliceOp(
mlir.aval_to_ir_type(aval.update(shape=dims_unsqueezed)),
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
else:
dynamic_slice_result = mhlo.DynamicSliceOp(
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
dynamic_slice_result = mhlo.DynamicSliceOp(
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
return [
mhlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
]
@ -1764,13 +1759,8 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
padded = mlir.full_like_aval(0, padded_aval)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
if jax._src.lib.mlir_api_version < 9:
broadcast_result = mhlo.BroadcastOp(
mlir.aval_to_ir_type(aval.update(shape=[1] + dims)), x,
mlir.dense_int_elements([1])).result
else:
broadcast_result = mhlo.BroadcastOp(
x, mlir.dense_int_elements([1])).result
broadcast_result = mhlo.BroadcastOp(
x, mlir.dense_int_elements([1])).result
padded = mhlo.DynamicUpdateSliceOp(
padded.type, padded, broadcast_result, idxs).result
replica_groups = mlir.dense_int_elements(

View File

@ -16,7 +16,7 @@
# eval()-ed by setup.py, so it should not have any dependencies.
__version__ = "0.3.15"
_minimum_jaxlib_version = "0.3.10"
_minimum_jaxlib_version = "0.3.14"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

View File

@ -15,7 +15,6 @@
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
# via CustomCallWithLayout.
import jax
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
@ -29,54 +28,17 @@ for _name, _value in _lapack.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")
if xla_client._version >= 64:
def _mhlo_u8(x):
if jax._src.lib.mlir_api_version < 21:
return mhlo.ConstOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.uint8),
type=ir.IntegerType.get_unsigned(8))).result
else:
return mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.uint8),
type=ir.IntegerType.get_unsigned(8))).result
def _mhlo_s32(x):
if jax._src.lib.mlir_api_version < 21:
return mhlo.ConstOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.int32),
type=ir.IntegerType.get_signless(32))).result
else:
return mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.int32),
type=ir.IntegerType.get_signless(32))).result
else:
def _mhlo_u8(x):
typ = ir.RankedTensorType.get([], ir.IntegerType.get_unsigned(8))
if jax._src.lib.mlir_api_version < 21:
return mhlo.ConstOp(
typ,
ir.DenseElementsAttr.get(
np.array(x, dtype=np.uint8), type=typ.element_type)).result
else:
return mhlo.ConstantOp(
typ,
ir.DenseElementsAttr.get(
np.array(x, dtype=np.uint8), type=typ.element_type)).result
def _mhlo_s32(x):
typ = ir.RankedTensorType.get([], ir.IntegerType.get_signless(32))
if jax._src.lib.mlir_api_version < 21:
return mhlo.ConstOp(typ,
ir.DenseElementsAttr.get(np.array(
x, dtype=np.int32))).result
else:
return mhlo.ConstantOp(
typ, ir.DenseElementsAttr.get(np.array(x, dtype=np.int32))).result
def _mhlo_u8(x):
return mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.uint8),
type=ir.IntegerType.get_unsigned(8))).result
def _mhlo_s32(x):
return mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.int32),
type=ir.IntegerType.get_signless(32))).result
# TODO(phawkins): it would be nice to avoid duplicating code for each type.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from typing import List
import jaxlib.mlir.ir as ir
@ -129,51 +128,17 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
raise ValueError(f"Unknown output type {out_dtype}")
if 0 in a_type.shape or 0 in out_shape:
if xla_client._version >= 64:
if jax._src.lib.mlir_api_version < 21:
zero = mhlo.ConstOp(
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
else:
zero = mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
else:
if jax._src.lib.mlir_api_version < 21:
zero = mhlo.ConstOp(
ir.RankedTensorType.get([], out_type),
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
else:
zero = mhlo.ConstantOp(
ir.RankedTensorType.get([], out_type),
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
zero = mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
return mhlo.BroadcastOp(
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
u8_type = ir.IntegerType.get_unsigned(8)
if xla_client._version >= 64:
if jax._src.lib.mlir_api_version < 21:
descriptor = mhlo.ConstOp(
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
else:
descriptor = mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
else:
if jax._src.lib.mlir_api_version < 21:
descriptor = mhlo.ConstOp(
ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
else:
descriptor = mhlo.ConstantOp(
ir.RankedTensorType.get([len(descriptor_bytes)], u8_type),
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
descriptor = mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
layout = tuple(range(n - 1, -1, -1))
return custom_call(
"pocketfft",

View File

@ -654,9 +654,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
".*while trying to hash an object of type "
"<class 'numpy\\.ndarray'>, 1. The error was:\nTypeError: "
"unhashable type: 'numpy\\.ndarray'")
# Typo was fixed in newer jaxlib
if jax._src.lib.xla_extension_version < 66:
msg = msg.replace('occurred', 'occured')
with self.assertRaisesRegex(ValueError, msg):
jitted_f(1, np.asarray(1))
@ -9167,13 +9164,11 @@ class DynamicShapeTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
@unittest.skipIf(jtu.device_under_test() != 'iree', "iree test")
@unittest.skipIf(jaxlib_version < (0, 3, 11), "test requires jaxlib>=0.3.11")
def test_jit_of_broadcast(self):
x = jax.jit(jnp.ones)(3)
self.assertAllClose(x, jnp.ones(3))
@unittest.skipIf(jtu.device_under_test() != 'iree', "iree test")
@unittest.skipIf(jaxlib_version < (0, 3, 11), "test requires jaxlib>=0.3.11")
def test_jit_of_broadcast2(self):
x = jax.jit(lambda n: jnp.ones(2 * n))(3)
self.assertAllClose(x, jnp.ones(2 * 3))

View File

@ -49,13 +49,9 @@ def setUpModule():
def tearDownModule():
prev_xla_flags()
# TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum
# version is >= 0.3.11
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
# version is >= 0.3.15
disabled_backends = []
if jaxlib.version < (0, 3, 11):
disabled_backends.append("gpu")
if jaxlib.version < (0, 3, 15):
disabled_backends.append("tpu")

View File

@ -58,13 +58,9 @@ def setUpModule():
def tearDownModule():
prev_xla_flags()
# TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum
# version is >= 0.3.11
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
# version is >= 0.3.15
disabled_backends = []
if jaxlib.version < (0, 3, 11):
disabled_backends.append("gpu")
if jaxlib.version < (0, 3, 15):
disabled_backends.append("tpu")

View File

@ -35,8 +35,6 @@ except ImportError:
config.parse_flags_with_absl()
@unittest.skipIf(jax._src.lib.xla_extension_version < 73,
"Test requires jaxlib 0.3.12 or newer.")
@unittest.skipIf(not portpicker, "Test requires portpicker")
class DistributedTest(jtu.JaxTestCase):

View File

@ -59,13 +59,9 @@ lcf.allowed_effects.add('while')
lcf.allowed_effects.add('while1')
lcf.allowed_effects.add('while2')
# TODO(sharadmv): remove jaxlib guards for GPU tests when jaxlib minimum
# version is >= 0.3.11
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
# version is >= 0.3.15
disabled_backends = []
if jaxlib.version < (0, 3, 11):
disabled_backends.append('gpu')
if jaxlib.version < (0, 3, 15):
disabled_backends.append('tpu')

View File

@ -183,9 +183,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
else ["lu", "qr"])
))
def testSlogdet(self, shape, dtype, method):
if (method == 'qr' and jax._src.lib.xla_extension_version < 69 and
jtu.device_under_test() == "tpu"):
raise unittest.SkipTest('qr decomposition is not supported.')
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
slogdet = partial(jnp.linalg.slogdet, method=method)

View File

@ -1436,9 +1436,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
s2_shape, s3_shape, s4_shape):
# Disable on SE runtime type because XLA sharding propagation is not
# supported.
if xla_client._version < 74 or xla_bridge.get_backend().runtime_type == 'se':
raise unittest.SkipTest('Needs xla_extension_version >= 74 or '
'TFRT runtime.')
if xla_bridge.get_backend().runtime_type == 'se':
raise unittest.SkipTest('Needs TFRT runtime.')
global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
global_input_shape = (8, 2)

View File

@ -2014,10 +2014,6 @@ class CppPmapTest(PythonPmapTest):
self.assertEqual(f._cache_size(), size+1)
def test_cache_hits_across_threads(self):
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.11
if xla_bridge.xla_client._version < 70:
raise unittest.SkipTest("This test requires jaxlib version >= 0.3.11")
f = lambda x: x+1
inputs = np.zeros([jax.device_count()], dtype=np.float32)
pmaped_f = self.pmap(f)
@ -2035,10 +2031,6 @@ class CppPmapTest(PythonPmapTest):
self.assertEqual(pmaped_f._cache_size, 1)
def test_cache_uses_jax_key(self):
# TODO(parkers): remove after minimum jaxlib version is > 0.3.11
if xla_bridge.xla_client._version < 74:
raise unittest.SkipTest("This test requires jaxlib version >= 0.3.11")
f = lambda x: x+1
inputs = np.zeros([jax.device_count()], dtype=np.float32)
pmaped_f = self.pmap(f)

View File

@ -34,7 +34,6 @@ from jax.experimental.sparse import bcoo as sparse_bcoo
from jax.experimental.sparse.bcoo import BCOOInfo
from jax import lax
from jax._src.lib import gpu_sparse
from jax._src.lib import sparse_apis
from jax._src.lib import xla_bridge
from jax import jit
from jax import tree_util
@ -57,11 +56,8 @@ MATMUL_TOL = {
np.complex128: 1E-10,
}
if gpu_sparse:
GPU_LOWERING_ENABLED = gpu_sparse and (gpu_sparse.cuda_is_supported or
gpu_sparse.rocm_is_supported)
else:
GPU_LOWERING_ENABLED = (sparse_apis and sparse_apis.is_supported)
GPU_LOWERING_ENABLED = gpu_sparse and (gpu_sparse.cuda_is_supported or
gpu_sparse.rocm_is_supported)
class BcooDotGeneralProperties(NamedTuple):
lhs_shape: Tuple[int]
@ -516,24 +512,15 @@ class cuSparseTest(jtu.JaxTestCase):
cuda_version = None if version == "<unknown>" else int(
version.split()[-1])
if cuda_version is None or cuda_version < 11000:
if gpu_sparse:
self.assertFalse(gpu_sparse and gpu_sparse.cuda_is_supported)
else:
self.assertFalse(sparse_apis and sparse_apis.is_supported)
self.assertFalse(gpu_sparse and gpu_sparse.cuda_is_supported)
self.assertNotIn(sparse.csr_todense_p,
mlir._platform_specific_lowerings["cuda"])
else:
if gpu_sparse:
self.assertTrue(gpu_sparse and gpu_sparse.cuda_is_supported)
else:
self.assertTrue(sparse_apis and sparse_apis.is_supported)
self.assertTrue(gpu_sparse and gpu_sparse.cuda_is_supported)
self.assertIn(sparse.csr_todense_p,
mlir._platform_specific_lowerings["cuda"])
else:
if gpu_sparse:
self.assertTrue(gpu_sparse and gpu_sparse.rocm_is_supported)
else:
self.assertTrue(sparse_apis and sparse_apis.is_supported)
self.assertTrue(gpu_sparse and gpu_sparse.rocm_is_supported)
self.assertIn(sparse.csr_todense_p,
mlir._platform_specific_lowerings["rocm"])