Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.

Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
This commit is contained in:
Peter Hawkins 2023-03-27 13:29:59 -07:00 committed by jax authors
parent ae4f1fcb66
commit 6cc1bf54a1
39 changed files with 2705 additions and 2596 deletions

View File

@ -127,6 +127,7 @@ from jax._src.api import xla_computation as xla_computation
from jax.interpreters import ad # TODO(phawkins): update users to avoid this.
from jax.interpreters import pxla # TODO(phawkins): update users to avoid this.
from jax.interpreters import partial_eval # TODO(phawkins): update users to avoid this.
from jax.interpreters import xla # TODO(phawkins): update users to avoid this.
from jax._src.array import (

View File

@ -21,12 +21,10 @@ import types
import numpy as np
import jax
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import effects
from jax._src import source_info_util
@ -35,6 +33,8 @@ from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo
@ -737,7 +737,7 @@ def _optimization_barrier(arg):
optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(xla.apply_primitive, optimization_barrier_p))
partial(dispatch.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)

View File

@ -79,8 +79,8 @@ from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (custom_gradient, custom_jvp,
custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.config import (

View File

@ -26,8 +26,6 @@ import jax.tree_util as jtu
from jax import lax
from jax.api_util import flatten_fun
from jax.experimental import pjit
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
@ -42,6 +40,8 @@ from jax._src import traceback_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache)

View File

@ -18,9 +18,6 @@ from typing import Callable, Optional
import jax
from jax import tree_util
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_map, tree_structure,
tree_unflatten, treedef_tuple)
from jax._src import core
@ -33,6 +30,9 @@ from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters.batching import not_mapped
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
source_info_util.register_exclusion(__file__)

View File

@ -21,8 +21,6 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves)
from jax.errors import UnexpectedTracerError
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.config import config
from jax._src import core
@ -38,6 +36,8 @@ from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable

View File

@ -15,9 +15,6 @@
import functools
from typing import Any, Callable, Optional, Tuple
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util
@ -29,6 +26,9 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
source_info_util.register_exclusion(__file__)

View File

@ -38,7 +38,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding
from jax._src.sharding_impls import GSPMDSharding, NamedSharding
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
# pytype: disable=import-error
try:

View File

@ -32,8 +32,6 @@ import numpy as np
import jax
from jax.monitoring import record_event_duration_secs
import jax.interpreters.mlir as mlir
import jax.interpreters.partial_eval as pe
from jax._src import array
from jax._src import core
@ -49,6 +47,8 @@ from jax._src import xla_bridge as xb
from jax._src.config import config, flags
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir

View File

@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
import jax
from jax._src import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
from jax.config import config
from jax.tree_util import (tree_flatten, tree_unflatten,
register_pytree_node, Partial)

View File

@ -35,7 +35,7 @@ from jax._src import linear_util as lu
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
Array = Any
map, unsafe_map = safe_map, map

View File

@ -32,9 +32,6 @@ import numpy as np
from jax._src import linear_util as lu
from jax.config import config
from jax._src.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src import ad_util
from jax._src import core
from jax._src import device_array
@ -43,6 +40,9 @@ from jax._src import effects as effects_lib
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo

File diff suppressed because it is too large Load Diff

View File

@ -48,7 +48,6 @@ import numpy as np
import jax
from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map
from jax._src import api_util
@ -71,6 +70,7 @@ from jax._src.config import flags
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc

View File

@ -78,6 +78,7 @@ from jax.interpreters import xla
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -362,7 +363,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
approx_top_k_p = core.Primitive('approx_top_k')
approx_top_k_p.multiple_results = True
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p))
approx_top_k_p.def_impl(partial(dispatch.apply_primitive, approx_top_k_p))
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,

View File

@ -25,7 +25,7 @@ from jax._src import ad_util
from jax._src import util
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_map, tree_unflatten
map, unsafe_map = safe_map, map

View File

@ -22,14 +22,10 @@ import operator
from typing import Callable, Sequence, List, Tuple
from jax.config import config
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
@ -37,6 +33,11 @@ from jax._src import source_info_util
from jax._src import util
from jax._src import state
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, split_list, partition_list)
@ -803,7 +804,7 @@ def cond_bind(*args, branches, linear):
cond_p = core.AxisPrimitive('cond')
cond_p.multiple_results = True
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp

View File

@ -20,16 +20,16 @@ from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple,
import jax.numpy as jnp
from jax import lax
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import source_info_util
@ -299,7 +299,7 @@ def _for_impl_unrolled(body, nsteps, unroll, *args):
return state
mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
for_p.def_impl(functools.partial(xla.apply_primitive, for_p))
for_p.def_impl(functools.partial(dispatch.apply_primitive, for_p))
def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
jaxpr, nsteps, reverse, which_linear, unroll):

View File

@ -27,14 +27,15 @@ from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map, tree_flatten_with_path, keystr)
from jax._src.tree_util import equality_errors
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import source_info_util
@ -1032,7 +1033,7 @@ def scan_bind(*args, **params):
scan_p = core.AxisPrimitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
@ -1612,7 +1613,7 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
while_p = core.AxisPrimitive('while')
while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p))
while_p.def_impl(partial(dispatch.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval

View File

@ -27,8 +27,6 @@ import numpy as np
import jax
from jax import tree_util
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_map
from jax._src import ad_util
@ -51,7 +49,9 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.interpreters.batching import ConcatAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
@ -2264,7 +2264,7 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
return core._pp_eqn(eqn.replace(params=params), context, settings)
convert_element_type_p = Primitive('convert_element_type')
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
@ -3475,7 +3475,7 @@ def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions):
reduce_p = core.Primitive('reduce')
reduce_p.multiple_results = True
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
reduce_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
_reduce_dtype_rule, _reduce_weak_type_rule,
@ -3869,7 +3869,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys
sort_p = Primitive('sort')
sort_p.multiple_results = True
sort_p.def_impl(partial(xla.apply_primitive, sort_p))
sort_p.def_impl(partial(dispatch.apply_primitive, sort_p))
sort_p.def_abstract_eval(_sort_abstract_eval)
ad.primitive_jvps[sort_p] = _sort_jvp
batching.primitive_batchers[sort_p] = _sort_batch_rule
@ -3960,7 +3960,7 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
top_k_p = Primitive('top_k')
top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
def _top_k_lower(ctx, operand, k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
@ -3993,7 +3993,7 @@ def create_token(_=None):
return create_token_p.bind()
create_token_p = Primitive("create_token")
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
create_token_p.def_impl(partial(dispatch.apply_primitive, create_token_p))
create_token_p.def_abstract_eval(lambda *_: abstract_token)
def _create_token_lowering(ctx, *operands):
@ -4015,7 +4015,7 @@ def _after_all_abstract_eval(*operands):
after_all_p = Primitive("after_all")
after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
after_all_p.def_impl(partial(dispatch.apply_primitive, after_all_p))
after_all_p.def_abstract_eval(_after_all_abstract_eval)
def _after_all_lowering(ctx, *operands):
@ -4060,7 +4060,7 @@ def _infeed_abstract_eval(token, *, shapes, partitions):
infeed_p = Primitive("infeed")
infeed_p.multiple_results = True
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
infeed_p.def_impl(partial(dispatch.apply_primitive, infeed_p))
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
mlir.lowerable_effects.add_type(InOutFeedEffect)
@ -4111,7 +4111,7 @@ def _outfeed_abstract_eval(token, *xs, partitions):
return abstract_token, {outfeed_effect}
outfeed_p = Primitive("outfeed")
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
outfeed_p.def_impl(partial(dispatch.apply_primitive, outfeed_p))
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
mlir.lowerable_effects.add_type(InOutFeedEffect)
@ -4153,7 +4153,7 @@ def _rng_uniform_abstract_eval(a, b, *, shape):
weak_type=(a.weak_type and b.weak_type))
rng_uniform_p = Primitive("rng_uniform")
rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p))
rng_uniform_p.def_impl(partial(dispatch.apply_primitive, rng_uniform_p))
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
def _rng_uniform_lowering(ctx, a, b, *, shape):
@ -4247,7 +4247,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
partial(xla.apply_primitive, rng_bit_generator_p))
partial(dispatch.apply_primitive, rng_bit_generator_p))
rng_bit_generator_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
@ -4298,7 +4298,7 @@ def _copy_impl(prim, *args, **kwargs):
if isinstance(a, jax.Array) and isinstance(a.sharding, PmapSharding):
sharded_dim = _which_dim_sharded(a.sharding)
return _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs)
return xla.apply_primitive(prim, *args, **kwargs)
return dispatch.apply_primitive(prim, *args, **kwargs)
# The copy_p primitive exists for expressing making copies of runtime arrays.
# For that reason we don't simplify it out of jaxprs (e.g. for jit invariance).
@ -4354,7 +4354,7 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
return core.DShapedArray(shape, dtype, False)
iota_p = Primitive('iota')
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval)
def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension):

View File

@ -23,10 +23,10 @@ import numpy as np
import jax
from jax import lax
from jax.interpreters import xla
from jax._src import ad_util
from jax._src import api
from jax._src import dispatch
from jax._src import dtypes
from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
@ -448,10 +448,12 @@ mlir.register_lowering(
# Asymmetric eigendecomposition
def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
return (
xla.apply_primitive(eig_p, operand,
compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors))
return dispatch.apply_primitive(
eig_p,
operand,
compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
)
def eig_lower(*args, **kw):
raise NotImplementedError(
@ -577,8 +579,8 @@ def eigh_jacobi(x: ArrayLike, *, lower: bool = True,
return w, v
def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
w, v = xla.apply_primitive(eigh_jacobi_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
w, v = dispatch.apply_primitive(eigh_jacobi_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return w, v
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
@ -634,8 +636,8 @@ mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule)
def _eigh_impl(operand, *, lower, sort_eigenvalues):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
v, w = dispatch.apply_primitive(eigh_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues)
return v, w
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
@ -1016,7 +1018,7 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
lu_pivots_to_permutation_p.multiple_results = False
lu_pivots_to_permutation_p.def_impl(
partial(xla.apply_primitive, lu_pivots_to_permutation_p))
partial(dispatch.apply_primitive, lu_pivots_to_permutation_p))
lu_pivots_to_permutation_p.def_abstract_eval(
_lu_pivots_to_permutation_abstract_eval)
batching.primitive_batchers[lu_pivots_to_permutation_p] = (
@ -1111,7 +1113,7 @@ def _lu_python(x):
return fn(x)
def _lu_impl(operand):
lu, pivot, perm = xla.apply_primitive(lu_p, operand)
lu, pivot, perm = dispatch.apply_primitive(lu_p, operand)
return lu, pivot, perm
def _lu_abstract_eval(operand):
@ -1385,7 +1387,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
geqrf_p = Primitive('geqrf')
geqrf_p.multiple_results = True
geqrf_p.def_impl(partial(xla.apply_primitive, geqrf_p))
geqrf_p.def_impl(partial(dispatch.apply_primitive, geqrf_p))
geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
mlir.register_lowering(geqrf_p, _geqrf_lowering_rule)
@ -1474,7 +1476,7 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
householder_product_p = Primitive('householder_product')
householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p))
householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p))
householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
mlir.register_lowering(householder_product_p, _householder_product_lowering_rule)
@ -1494,7 +1496,7 @@ mlir.register_lowering(
def _qr_impl(operand, *, full_matrices):
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r
def _qr_abstract_eval(operand, *, full_matrices):
@ -1572,7 +1574,7 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering));
# Singular value decomposition
def _svd_impl(operand, *, full_matrices, compute_uv):
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices,
compute_uv=compute_uv)
def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
@ -1762,7 +1764,7 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t
tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False
tridiagonal_solve_p.def_impl(
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
functools.partial(dispatch.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
@ -1873,7 +1875,7 @@ def schur(x: ArrayLike, *,
def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
select_callable):
return xla.apply_primitive(
return dispatch.apply_primitive(
schur_p,
operand,
compute_schur_vectors=compute_schur_vectors,
@ -2000,7 +2002,7 @@ def _hessenberg_abstract_eval(a):
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
hessenberg_p = Primitive("hessenberg")
hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p))
hessenberg_p.def_impl(partial(dispatch.apply_primitive, hessenberg_p))
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
hessenberg_p.multiple_results = True
@ -2098,7 +2100,7 @@ def _tridiagonal_abstract_eval(a, *, lower):
]
tridiagonal_p = Primitive("tridiagonal")
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
tridiagonal_p.def_impl(partial(dispatch.apply_primitive, tridiagonal_p))
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
tridiagonal_p.multiple_results = True

View File

@ -20,7 +20,6 @@ import weakref
import numpy as np
import jax
from jax.interpreters import partial_eval as pe
from jax._src import ad_util
from jax._src import core
@ -30,6 +29,7 @@ from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src.lax.utils import (
_argnum_weak_type,

View File

@ -23,6 +23,7 @@ from typing import Callable
from jax.interpreters import xla
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.util import safe_zip
from jax._src.lib import xla_client
@ -44,7 +45,7 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
weak_type_rule = weak_type_rule or _standard_weak_type_rule
named_shape_rule = named_shape_rule or standard_named_shape_rule
prim = core.Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_impl(partial(dispatch.apply_primitive, prim))
prim.def_abstract_eval(
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
weak_type_rule, named_shape_rule))

View File

@ -39,7 +39,7 @@ from jax.errors import JAXTypeError
from jax._src.array import ArrayImpl
from jax._src.sharding_impls import NamedSharding
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax.interpreters import batching
@ -965,7 +965,7 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
# This is DynamicJaxprTrace.process_map with some very minor modifications
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
from jax.interpreters.partial_eval import (
from jax._src.interpreters.partial_eval import (
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
convert_constvars_jaxpr, new_jaxpr_eqn)
assert primitive is xmap_p

View File

@ -28,9 +28,9 @@ import jax
from jax._src import core
from jax import stages
from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters.pxla import PartitionSpec
from jax._src.interpreters import xla
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
treedef_tuple, broadcast_prefix, all_leaves)

View File

@ -21,10 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
import numpy as np
from jax import lax
from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src.state.types import AbstractRef
from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.util import safe_map, safe_zip, split_list

View File

@ -19,13 +19,13 @@ from typing import Any, List, Tuple, Union
import numpy as np
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax._src import ad_util
from jax._src import core
from jax._src import pretty_printer as pp
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.typing import Array
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
AccumEffect)

View File

@ -19,11 +19,11 @@ from jax import tree_util
from jax._src import linear_util as lu
from jax.experimental import pjit
from jax.errors import UnexpectedTracerError
from jax._src import mesh as mesh_lib
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import ir
import jax.interpreters.pxla as pxla
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
from jax._src import custom_api_util
from jax._src.lib import xla_client as xc
from jax._src.api_util import flatten_fun_nokwargs
@ -370,7 +370,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
propagate_user_sharding, partition,
infer_sharding_from_operands,
static_args):
mesh = pxla.thread_resources.env.physical_mesh
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, mlir.ShardingContext):

View File

@ -511,9 +511,9 @@ from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax.interpreters import ad, batching, pxla
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import ad, batching, pxla
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import dispatch

View File

@ -61,7 +61,6 @@ import numpy as np
from jax import lax
import jax.numpy as jnp
from jax.experimental import pjit
from jax.interpreters import partial_eval as pe
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten,)
@ -70,6 +69,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src.api_util import shaped_abstractify
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache

View File

@ -43,10 +43,10 @@ from jax._src.util import (HashableFunction, HashablePartial, unzip2,
as_hashable_function, memoize, partition_list,
merge_lists)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax.interpreters import batching
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax.interpreters import ad
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,

View File

@ -36,7 +36,6 @@ from jax.experimental.sparse.util import (
SparseEfficiencyError, SparseEfficiencyWarning, Shape,
SparseInfo)
from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
import jax.numpy as jnp
from jax.util import safe_zip, unzip2, split_list
@ -45,6 +44,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.lax.lax import (
_const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers)
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode

View File

@ -62,9 +62,9 @@ from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_spar
import jax.numpy as jnp
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.lib import pytree
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import safe_map, safe_zip, split_list
from jax._src.config import config

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,7 @@ per-file-ignores =
jax/interpreters/ad.py:F401
jax/interpreters/batching.py:F401
jax/interpreters/mlir.py:F401
jax/interpreters/partial_eval.py:F401
jax/interpreters/pxla.py:F401
jax/interpreters/xla.py:F401
jax/linear_util.py:F401

View File

@ -55,7 +55,7 @@ from jax.interpreters import ad
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
from jax.sharding import PartitionSpec as P
from jax._src import array
from jax.experimental import pjit

View File

@ -29,7 +29,6 @@ from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
tree_leaves)
@ -38,6 +37,7 @@ from jax._src import linear_util as lu
from jax._src import util
from jax._src import test_util as jtu
from jax._src.core import UnshapedArray, ShapedArray, DBIdx
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow

View File

@ -26,8 +26,8 @@ from jax._src import linear_util as lu
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src import ad_checkpoint
from jax._src import dispatch

View File

@ -24,7 +24,7 @@ from jax._src import core
from jax import lax
from jax._src import linear_util as lu
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import partial_eval as pe
from jax._src import test_util as jtu
from jax._src.util import tuple_insert
import jax.numpy as jnp