Increase minimum jaxlib version to 0.3.7.

Drop backwards compatibility with older jaxlib versions.
This commit is contained in:
Peter Hawkins 2022-04-16 09:59:48 -04:00
parent 057412bdff
commit 0150d15cb2
15 changed files with 100 additions and 195 deletions

View File

@ -27,7 +27,6 @@ from jax._src import dtypes
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
import jax._src.lib
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_client
from jax._src.lib import pocketfft
@ -186,5 +185,4 @@ ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
if pocketfft:
xla.register_translation(fft_p, _fft_translation_rule_cpu, platform='cpu')
if jax._src.lib.version >= (0, 3, 6):
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')

View File

@ -1741,8 +1741,7 @@ mlir.register_lowering(atan2_p, partial(_nary_lower_mhlo, mhlo.Atan2Op))
sinh_p = standard_unop(_float | _complex, 'sinh')
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
if jax._src.lib.mlir_api_version >= 7:
mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
cosh_p = standard_unop(_float | _complex, 'cosh')
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
@ -2653,11 +2652,8 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
full_precision = (precision, precision)
else:
full_precision = precision
if jax._src.lib.mlir_api_version >= 3:
return ir.ArrayAttr.get(
[mhlo.PrecisionAttr.get(str(p)) for p in full_precision])
else:
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in full_precision])
return ir.ArrayAttr.get(
[mhlo.PrecisionAttr.get(str(p)) for p in full_precision])

View File

@ -34,7 +34,6 @@ from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
from jax._src.lax import lax as lax_internal
import jax._src.lib
from jax._src.lib import lapack
from jax._src.lib import cuda_linalg
@ -392,17 +391,16 @@ def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
result, _nan_like_mhlo(out_aval))]
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo),
platform='cpu')
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo),
platform='cpu')
if solver_apis is not None:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, solver_apis.potrf_mhlo),
platform='gpu')
if solver_apis is not None:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, solver_apis.potrf_mhlo),
platform='gpu')
# Asymmetric eigendecomposition
@ -543,8 +541,7 @@ eig_p.def_abstract_eval(eig_abstract_eval)
xla.register_translation(eig_p, eig_lower)
xla.register_translation(eig_p, _eig_cpu_translation_rule, platform='cpu')
mlir.register_lowering(eig_p, eig_lower)
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule
@ -656,20 +653,18 @@ xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, lapack.syevd),
platform='cpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),
platform='cpu')
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),
platform='cpu')
if solver_apis is not None:
xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, solver_apis.syevd),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
platform='gpu')
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
platform='gpu')
triangular_solve_dtype_rule = partial(
@ -865,9 +860,8 @@ def _triangular_solve_cpu_lower(
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
def _triangular_solve_gpu_translation_rule(
@ -922,11 +916,10 @@ if solver_apis is not None:
triangular_solve_p,
partial(_triangular_solve_gpu_translation_rule, solver_apis.trsm),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, solver_apis.trsm_mhlo),
platform='gpu')
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, solver_apis.trsm_mhlo),
platform='gpu')
# Support operation for LU decomposition: Transformation of the pivots returned
@ -1038,19 +1031,17 @@ if cuda_linalg:
xla.register_translation(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu,
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu_lowering,
platform='gpu')
mlir.register_lowering(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu_lowering,
platform='gpu')
if hip_linalg:
xla.register_translation(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu,
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu_lowering,
platform='gpu')
mlir.register_lowering(lu_pivots_to_permutation_p,
_lu_pivots_to_permutation_gpu_lowering,
platform='gpu')
# LU decomposition
# Computes a pivoted LU decomposition such that
@ -1264,19 +1255,17 @@ batching.primitive_batchers[lu_p] = _lu_batching_rule
xla.register_translation(lu_p,
partial(_lu_cpu_gpu_translation_rule, lapack.getrf),
platform='cpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(lu_p,
partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo),
platform='cpu')
mlir.register_lowering(lu_p,
partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo),
platform='cpu')
if solver_apis is not None:
xla.register_translation(
lu_p, partial(_lu_cpu_gpu_translation_rule, solver_apis.getrf),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, solver_apis.getrf_mhlo),
platform='gpu')
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, solver_apis.getrf_mhlo),
platform='gpu')
xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu')
@ -1480,21 +1469,19 @@ xla.register_translation(
qr_p, partial(_qr_cpu_gpu_translation_rule, lapack.geqrf, lapack.orgqr),
platform='cpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
platform='cpu')
mlir.register_lowering(
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
platform='cpu')
if solver_apis is not None:
xla.register_translation(
qr_p,
partial(_qr_cpu_gpu_translation_rule, solver_apis.geqrf, solver_apis.orgqr),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, solver_apis.geqrf_mhlo, solver_apis.orgqr_mhlo),
platform='gpu')
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, solver_apis.geqrf_mhlo, solver_apis.orgqr_mhlo),
platform='gpu')
# Singular value decomposition
@ -1712,19 +1699,17 @@ xla.register_translation(svd_p, _svd_translation_rule)
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, lapack.gesdd),
platform='cpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
platform='cpu')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
platform='cpu')
if solver_apis is not None:
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, solver_apis.gesvd),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
platform='gpu')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
platform='gpu')
def _tridiagonal_solve_gpu_translation_rule(ctx, avals_in, avals_out, dl, d, du,
b, *, m, n, ldb, t):
@ -1743,10 +1728,9 @@ if sparse_apis and hasattr(sparse_apis, "gtsv2"):
xla.register_translation(tridiagonal_solve_p,
_tridiagonal_solve_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(tridiagonal_solve_p,
_tridiagonal_solve_gpu_lowering,
platform='gpu')
mlir.register_lowering(tridiagonal_solve_p,
_tridiagonal_solve_gpu_lowering,
platform='gpu')
def _tridiagonal_solve_jax(dl, d, du, b, **kw):
"""Pure JAX implementation of `tridiagonal_solve`."""
@ -1983,8 +1967,7 @@ schur_p.def_abstract_eval(_schur_abstract_eval)
xla.register_translation(schur_p, _schur_translation_rule)
xla.register_translation(schur_p, _schur_cpu_translation_rule, platform='cpu')
mlir.register_lowering(schur_p, _schur_translation_rule)
if jax._src.lib.version >= (0, 3, 6):
mlir.register_lowering(schur_p, _schur_cpu_lowering, platform='cpu')
mlir.register_lowering(schur_p, _schur_cpu_lowering, platform='cpu')
batching.primitive_batchers[schur_p] = _schur_batching_rule
ad.primitive_jvps[schur_p] = _schur_jvp_rule

View File

@ -16,9 +16,4 @@
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.mhlo as mhlo
try:
import jaxlib.mlir.dialects.func as func # pytype: disable=import-error
except ImportError:
# TODO(phawkins): remove std dialect after MLIR change lands, most likely in
# jaxlib > 0.3.0.
import jaxlib.mlir.dialects.std as func # pytype: disable=import-error
import jaxlib.mlir.dialects.func as func

View File

@ -113,9 +113,8 @@ def get_compile_options(
compile_options.num_partitions = num_partitions
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
if jax._src.lib.xla_extension_version >= 61:
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if use_auto_spmd_partitioning and jax._src.lib.xla_extension_version >= 62:
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if use_auto_spmd_partitioning:
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
if device_assignment is not None:

View File

@ -30,7 +30,6 @@ from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.api import jit, vmap
from jax._src.lax import lax as lax_internal
import jax._src.lib
from jax._src.lib import xla_client
from jax._src.lib import cuda_prng
from jax._src.lib.mlir.dialects import mhlo
@ -452,9 +451,8 @@ mlir.register_lowering(threefry2x32_p, mlir.lower_fun(
if cuda_prng or hip_prng:
xla.register_translation(threefry2x32_p, _threefry2x32_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(threefry2x32_p, _threefry2x32_gpu_lowering,
platform='gpu')
mlir.register_lowering(threefry2x32_p, _threefry2x32_gpu_lowering,
platform='gpu')
@partial(jit, inline=True)
def threefry_2x32(keypair, count):

View File

@ -129,12 +129,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
hash_obj.update(compile_options_obj.device_assignment.serialize())
def _hash_executable_build_options(hash_obj, executable_obj):
if jax._src.lib.xla_extension_version >= 62:
expected_options = 34
elif jax._src.lib.xla_extension_version >= 61:
expected_options = 32
else:
expected_options = 31
expected_options = 34
assert len(dir(executable_obj)) == expected_options, (
f"Unexpected number of executable_build_options fields: "
f"{len(dir(executable_obj))}. This likely means that an extra "
@ -147,9 +142,8 @@ def _hash_executable_build_options(hash_obj, executable_obj):
if executable_obj.device_assignment is not None:
hash_obj.update(executable_obj.device_assignment.serialize())
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
if jax._src.lib.xla_extension_version >= 61:
_hash_bool(hash_obj, executable_obj.use_auto_spmd_partitioning)
if jax._src.lib.xla_extension_version >= 62 and executable_obj.use_auto_spmd_partitioning:
_hash_bool(hash_obj, executable_obj.use_auto_spmd_partitioning)
if executable_obj.use_auto_spmd_partitioning:
if executable_obj.auto_spmd_partitioning_mesh_shape is not None:
hash_obj.update(
executable_obj.auto_spmd_partitioning_mesh_shape.serialize())

View File

@ -39,7 +39,6 @@ from jax._src.api_util import flatten_axes
from jax._src.lax.lax import (
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
import jax._src.lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_client as xc
@ -1069,10 +1068,9 @@ if sparse_apis and sparse_apis.is_supported:
xla.register_translation(bcoo_dot_general_p,
_bcoo_dot_general_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(bcoo_dot_general_p,
_bcoo_dot_general_gpu_lowering,
platform='gpu')
mlir.register_lowering(bcoo_dot_general_p,
_bcoo_dot_general_gpu_lowering,
platform='gpu')
#----------------------------------------------------------------------
# bcoo_dot_general_sampled

View File

@ -28,7 +28,6 @@ from jax.interpreters import xla
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
import jax._src.lib
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
@ -241,9 +240,8 @@ mlir.register_lowering(coo_todense_p, _coo_todense_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_todense_p, _coo_todense_gpu_lowering,
platform='gpu')
mlir.register_lowering(coo_todense_p, _coo_todense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# coo_fromdense
@ -371,10 +369,9 @@ if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_fromdense_p,
_coo_fromdense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_fromdense_p,
_coo_fromdense_gpu_lowering,
platform='gpu')
mlir.register_lowering(coo_fromdense_p,
_coo_fromdense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# coo_matvec
@ -519,9 +516,8 @@ mlir.register_lowering(coo_matvec_p, _coo_matvec_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_matvec_p, _coo_matvec_gpu_lowering,
platform='gpu')
mlir.register_lowering(coo_matvec_p, _coo_matvec_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# coo_matmat
@ -661,6 +657,5 @@ mlir.register_lowering(coo_matmat_p, _coo_matmat_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(coo_matmat_p, _coo_matmat_gpu_lowering,
platform='gpu')
mlir.register_lowering(coo_matmat_p, _coo_matmat_gpu_lowering,
platform='gpu')

View File

@ -28,7 +28,6 @@ from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
import jax._src.lib
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp
@ -222,9 +221,8 @@ mlir.register_lowering(csr_todense_p, _csr_todense_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_todense_p, _csr_todense_gpu_lowering,
platform='gpu')
mlir.register_lowering(csr_todense_p, _csr_todense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# csr_fromdense
@ -336,10 +334,9 @@ if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_fromdense_p,
_csr_fromdense_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_fromdense_p,
_csr_fromdense_gpu_lowering,
platform='gpu')
mlir.register_lowering(csr_fromdense_p,
_csr_fromdense_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
# csr_matvec
@ -437,9 +434,8 @@ mlir.register_lowering(csr_matvec_p, _csr_matvec_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_matvec_p, _csr_matvec_gpu_lowering,
platform='gpu')
mlir.register_lowering(csr_matvec_p, _csr_matvec_gpu_lowering,
platform='gpu')
#--------------------------------------------------------------------
@ -537,6 +533,5 @@ mlir.register_lowering(csr_matmat_p, _csr_matmat_lowering)
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
platform='gpu')
if jax._src.lib.version > (0, 3, 5):
mlir.register_lowering(csr_matmat_p, _csr_matmat_gpu_lowering,
platform='gpu')
mlir.register_lowering(csr_matmat_p, _csr_matmat_gpu_lowering,
platform='gpu')

View File

@ -28,7 +28,6 @@ from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
from typing_extensions import Protocol
import warnings
import jax
from jax import core
from jax import linear_util as lu
from jax._src import ad_util
@ -47,16 +46,6 @@ import jax.interpreters.partial_eval as pe
import jax.interpreters.xla as xla
import numpy as np
# TODO(jakevdp): remove this when minimum_jaxlib_version >= 0.3.3
if jax._src.lib.mlir_api_version >= 4:
FuncOp = func_dialect.FuncOp # pytype: disable=module-attr
else:
from jax._src.lib.mlir.dialects import builtin # pylint: disable=import-not-at-top
FuncOp = builtin.FuncOp # pytype: disable=module-attr
# mypy gets confused by conditional imports, so alias to Any for now.
FuncOpType = Any
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
@ -373,7 +362,7 @@ class ModuleContext:
name_stack: NameStack
# Cached primitive lowerings.
cached_primitive_lowerings: Dict[Any, FuncOpType]
cached_primitive_lowerings: Dict[Any, func_dialect.FuncOp]
@property
def axis_env(self) -> xla.AxisEnv:
@ -388,7 +377,7 @@ class ModuleContext:
module: Optional[ir.Module] = None,
ip: Optional[ir.InsertionPoint] = None,
symbol_table: Optional[ir.SymbolTable] = None,
cached_primitive_lowerings: Optional[Dict[Any, FuncOpType]] = None):
cached_primitive_lowerings: Optional[Dict[Any, func_dialect.FuncOp]] = None):
assert platform is not None
self.context = context or make_ir_context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
@ -571,7 +560,7 @@ def lower_jaxpr_to_fun(
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
use_sharding_annotations: bool = True,
input_output_aliases: Optional[Sequence[Optional[int]]] = None
) -> FuncOpType:
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
Assumes that an MLIR context, location, and insertion point are set.
@ -610,7 +599,7 @@ def lower_jaxpr_to_fun(
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
func_op = FuncOp(name, ftype, ip=ctx.ip)
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private")
ctx.symbol_table.insert(func_op)
@ -706,7 +695,7 @@ def lower_jaxpr_to_fun(
return func_op
def _emit_lowering_rule_as_fun(lowering_rule,
ctx: LoweringRuleContext) -> FuncOpType:
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
"""Emits the contents of a lowering rule as a private function."""
input_types = map(aval_to_ir_types, ctx.avals_in)
output_types = map(aval_to_ir_types, ctx.avals_out)
@ -714,7 +703,8 @@ def _emit_lowering_rule_as_fun(lowering_rule,
flat_output_types = util.flatten(output_types)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
assert ctx.primitive is not None
func_op = FuncOp(ctx.primitive.name, ftype, ip=ctx.module_context.ip)
func_op = func_dialect.FuncOp(ctx.primitive.name, ftype,
ip=ctx.module_context.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get("private")
ctx.module_context.symbol_table.insert(func_op)
entry_block = func_op.add_entry_block()
@ -880,17 +870,8 @@ register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
def compare_mhlo(x, y, direction, type):
"""Creates mhlo.CompareOp."""
if jax._src.lib.mlir_api_version >= 5:
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
dims = ir.RankedTensorType(x.type).shape
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(bool_shape, x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(type))
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
def _minmax_mhlo(op, cmp, x, y):
"""Min/max that compares complex values lexicographically as pairs."""
@ -1010,7 +991,7 @@ def xla_fallback_lowering(prim: core.Primitive):
submodule = ir.Module.parse(submodule_str)
callee_name = None
for op in submodule.body.operations:
op = typing.cast(FuncOpType, op)
op = typing.cast(func_dialect.FuncOp, op)
module_ctx.module.body.append(op)
if op.name.value == "main":
op.attributes["sym_name"] = ir.StringAttr.get(f"xla_fallback_{prim.name}")
@ -1029,12 +1010,8 @@ def xla_fallback_lowering(prim: core.Primitive):
flatten_lowering_ir_args(args)).result
if not prim.multiple_results:
return [call]
if jax._src.lib.mlir_api_version < 6:
flat_results = [mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result
for i, typ in enumerate(flat_output_types)]
else:
flat_results = [mhlo.GetTupleElementOp(call, i32_attr(i)).result
for i in range(len(flat_output_types))]
flat_results = [mhlo.GetTupleElementOp(call, i32_attr(i)).result
for i in range(len(flat_output_types))]
return util.unflatten(flat_results, map(len, output_types))
return fallback
@ -1069,6 +1046,3 @@ register_lowering(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
# # Not present in cHLO or mHLO (b/203798239), although we could just emit the
# # lowered pattern ourselves.
# lax.top_k_p,
# # TODO(phawkins): implement these lax ops:
# lax.rng_bit_generator_p,

View File

@ -44,7 +44,6 @@ import sys
from absl import logging
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.core import ConcreteArray, ShapedArray
@ -1730,21 +1729,6 @@ def _mhlo_shard(aval, axis_env, xs, in_axis):
raise TypeError(aval)
def _compare_mhlo(x, y, direction, type):
"""Creates mhlo.CompareOp."""
if jax._src.lib.mlir_api_version >= 5:
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
tensor_type = ir.RankedTensorType(x.type)
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(bool_shape, x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(type))
# TODO(b/110096942): more efficient gather
def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
if aval is core.abstract_unit:
@ -1786,7 +1770,9 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
float_zero = mlir.full_like_aval(0, padded_aval)
out = _compare_mhlo(out, float_zero, "NE", "FLOAT").result
out = mhlo.CompareOp(out, float_zero,
mhlo.ComparisonDirectionAttr.get("NE"),
mhlo.ComparisonTypeAttr.get("FLOAT")).result
return out
else:
raise TypeError(aval)

View File

@ -18,5 +18,5 @@ def _version_as_tuple(version_str):
__version__ = "0.3.8"
__version_info__ = _version_as_tuple(__version__)
_minimum_jaxlib_version = "0.3.2"
_minimum_jaxlib_version = "0.3.7"
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)

View File

@ -14,7 +14,6 @@
import itertools
import unittest
import numpy as np
@ -109,8 +108,6 @@ class FftTest(jtu.JaxTestCase):
lax.fft(x, "FFT", fft_lengths=(10,)))
@parameterized.parameters((np.float32,), (np.float64,))
@unittest.skipIf(jax._src.lib.xla_extension_version < 63,
"Test fails for jaxlib <= 0.3.2")
def testLaxIrfftDoesNotMutateInputs(self, dtype):
if dtype == np.float64 and not config.x64_enabled:
raise self.skipTest("float64 requires jax_enable_x64=true")

View File

@ -33,9 +33,6 @@ class RemoteTransferTest(jtu.JaxTestCase):
if jax.device_count() < 2:
raise unittest.SkipTest("Remote transfer requires at lest 2 devices")
dev_a, dev_b = jax.local_devices()[:2]
if not hasattr(dev_a.client, "make_cross_host_receive_buffers"):
# TODO(jheek) remove this once a new version of JAX lib is released
raise unittest.SkipTest("jax-lib doesn't include cross host APIs")
if "libtpu" in jax.local_devices()[0].client.platform_version:
raise unittest.SkipTest("Test does not yet work on cloud TPU")
send_buf = jax.device_put(np.ones((32,)), dev_a)