Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.

Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
This commit is contained in:
Peter Hawkins 2023-03-27 13:29:59 -07:00 committed by jax authors
parent ae4f1fcb66
commit 6cc1bf54a1
39 changed files with 2705 additions and 2596 deletions

View File

@ -127,6 +127,7 @@ from jax._src.api import xla_computation as xla_computation
from jax.interpreters import ad # TODO(phawkins): update users to avoid this. from jax.interpreters import ad # TODO(phawkins): update users to avoid this.
from jax.interpreters import pxla # TODO(phawkins): update users to avoid this. from jax.interpreters import pxla # TODO(phawkins): update users to avoid this.
from jax.interpreters import partial_eval # TODO(phawkins): update users to avoid this.
from jax.interpreters import xla # TODO(phawkins): update users to avoid this. from jax.interpreters import xla # TODO(phawkins): update users to avoid this.
from jax._src.array import ( from jax._src.array import (

View File

@ -21,12 +21,10 @@ import types
import numpy as np import numpy as np
import jax import jax
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import effects from jax._src import effects
from jax._src import source_info_util from jax._src import source_info_util
@ -35,6 +33,8 @@ from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
@ -737,7 +737,7 @@ def _optimization_barrier(arg):
optimization_barrier_p = core.Primitive('optimization_barrier') optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl( optimization_barrier_p.def_impl(
partial(xla.apply_primitive, optimization_barrier_p)) partial(dispatch.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p, mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule) _optimization_barrier_lowering_rule)

View File

@ -79,8 +79,8 @@ from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (custom_gradient, custom_jvp, from jax.custom_derivatives import (custom_gradient, custom_jvp,
custom_vjp, linear_call) custom_vjp, linear_call)
from jax.custom_transpose import custom_transpose from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.config import ( from jax._src.config import (

View File

@ -26,8 +26,6 @@ import jax.tree_util as jtu
from jax import lax from jax import lax
from jax.api_util import flatten_fun from jax.api_util import flatten_fun
from jax.experimental import pjit from jax.experimental import pjit
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten from jax.tree_util import tree_flatten
from jax.tree_util import tree_map from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten from jax.tree_util import tree_unflatten
@ -42,6 +40,8 @@ from jax._src import traceback_util
from jax._src.config import config from jax._src.config import config
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.typing import Array from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache) unzip3, weakref_lru_cache)

View File

@ -18,9 +18,6 @@ from typing import Callable, Optional
import jax import jax
from jax import tree_util from jax import tree_util
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_map, tree_structure, from jax.tree_util import (tree_flatten, tree_map, tree_structure,
tree_unflatten, treedef_tuple) tree_unflatten, treedef_tuple)
from jax._src import core from jax._src import core
@ -33,6 +30,9 @@ from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters.batching import not_mapped from jax._src.interpreters.batching import not_mapped
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
source_info_util.register_exclusion(__file__) source_info_util.register_exclusion(__file__)

View File

@ -21,8 +21,6 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple, treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves) register_pytree_node_class, tree_leaves)
from jax.errors import UnexpectedTracerError from jax.errors import UnexpectedTracerError
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.config import config from jax.config import config
from jax._src import core from jax._src import core
@ -38,6 +36,8 @@ from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters.batching import not_mapped from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable

View File

@ -15,9 +15,6 @@
import functools import functools
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_map, from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten) tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util from jax._src import ad_util
@ -29,6 +26,9 @@ from jax._src import source_info_util
from jax._src import traceback_util from jax._src import traceback_util
from jax._src import util from jax._src import util
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
source_info_util.register_exclusion(__file__) source_info_util.register_exclusion(__file__)

View File

@ -38,7 +38,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding from jax._src.sharding import Sharding
from jax._src.sharding_impls import GSPMDSharding, NamedSharding from jax._src.sharding_impls import GSPMDSharding, NamedSharding
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
# pytype: disable=import-error # pytype: disable=import-error
try: try:

View File

@ -32,8 +32,6 @@ import numpy as np
import jax import jax
from jax.monitoring import record_event_duration_secs from jax.monitoring import record_event_duration_secs
import jax.interpreters.mlir as mlir
import jax.interpreters.partial_eval as pe
from jax._src import array from jax._src import array
from jax._src import core from jax._src import core
@ -49,6 +47,8 @@ from jax._src import xla_bridge as xb
from jax._src.config import config, flags from jax._src.config import config, flags
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir

View File

@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
import jax import jax
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.config import config from jax.config import config
from jax.tree_util import (tree_flatten, tree_unflatten, from jax.tree_util import (tree_flatten, tree_unflatten,
register_pytree_node, Partial) register_pytree_node, Partial)

View File

@ -35,7 +35,7 @@ from jax._src import linear_util as lu
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function, canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache) curry, memoize, weakref_lru_cache)
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
Array = Any Array = Any
map, unsafe_map = safe_map, map map, unsafe_map = safe_map, map

View File

@ -32,9 +32,6 @@ import numpy as np
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax.config import config from jax.config import config
from jax._src.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
from jax._src import device_array from jax._src import device_array
@ -43,6 +40,9 @@ from jax._src import effects as effects_lib
from jax._src import source_info_util from jax._src import source_info_util
from jax._src import util from jax._src import util
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo

File diff suppressed because it is too large Load Diff

View File

@ -48,7 +48,6 @@ import numpy as np
import jax import jax
from jax.errors import JAXTypeError from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map from jax.tree_util import tree_flatten, tree_map
from jax._src import api_util from jax._src import api_util
@ -71,6 +70,7 @@ from jax._src.config import flags
from jax._src.core import ShapedArray from jax._src.core import ShapedArray
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc

View File

@ -78,6 +78,7 @@ from jax.interpreters import xla
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
@ -362,7 +363,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
approx_top_k_p = core.Primitive('approx_top_k') approx_top_k_p = core.Primitive('approx_top_k')
approx_top_k_p.multiple_results = True approx_top_k_p.multiple_results = True
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p)) approx_top_k_p.def_impl(partial(dispatch.apply_primitive, approx_top_k_p))
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval) approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation) xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation, xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,

View File

@ -25,7 +25,7 @@ from jax._src import ad_util
from jax._src import util from jax._src import util
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3 from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
from jax.api_util import flatten_fun_nokwargs from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.tree_util import tree_map, tree_unflatten from jax.tree_util import tree_map, tree_unflatten
map, unsafe_map = safe_map, map map, unsafe_map = safe_map, map

View File

@ -22,14 +22,10 @@ import operator
from typing import Callable, Sequence, List, Tuple from typing import Callable, Sequence, List, Tuple
from jax.config import config from jax.config import config
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import effects from jax._src import effects
from jax._src import linear_util as lu from jax._src import linear_util as lu
@ -37,6 +33,11 @@ from jax._src import source_info_util
from jax._src import util from jax._src import util
from jax._src import state from jax._src import state
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lax import lax from jax._src.lax import lax
from jax._src.traceback_util import api_boundary from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, split_list, partition_list) from jax._src.util import (safe_map, split_list, partition_list)
@ -803,7 +804,7 @@ def cond_bind(*args, branches, linear):
cond_p = core.AxisPrimitive('cond') cond_p = core.AxisPrimitive('cond')
cond_p.multiple_results = True cond_p.multiple_results = True
cond_p.def_impl(partial(xla.apply_primitive, cond_p)) cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval) cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
cond_p.def_custom_bind(cond_bind) cond_p.def_custom_bind(cond_bind)
ad.primitive_jvps[cond_p] = _cond_jvp ad.primitive_jvps[cond_p] = _cond_jvp

View File

@ -20,16 +20,16 @@ from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple,
import jax.numpy as jnp import jax.numpy as jnp
from jax import lax from jax import lax
from jax.api_util import flatten_fun_nokwargs from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import ad from jax._src.interpreters import ad
from jax.interpreters import batching from jax._src.interpreters import batching
from jax.interpreters import mlir from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten, from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
treedef_tuple, tree_map, tree_leaves, PyTreeDef) treedef_tuple, tree_map, tree_leaves, PyTreeDef)
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src import source_info_util from jax._src import source_info_util
@ -299,7 +299,7 @@ def _for_impl_unrolled(body, nsteps, unroll, *args):
return state return state
mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True)) mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
for_p.def_impl(functools.partial(xla.apply_primitive, for_p)) for_p.def_impl(functools.partial(dispatch.apply_primitive, for_p))
def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
jaxpr, nsteps, reverse, which_linear, unroll): jaxpr, nsteps, reverse, which_linear, unroll):

View File

@ -27,14 +27,15 @@ from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad from jax.interpreters import ad
from jax.interpreters import batching from jax.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.interpreters import xla from jax._src.interpreters import xla
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf, from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map, tree_flatten_with_path, keystr) tree_map, tree_flatten_with_path, keystr)
from jax._src.tree_util import equality_errors from jax._src.tree_util import equality_errors
from jax._src import ad_checkpoint from jax._src import ad_checkpoint
from jax._src import ad_util from jax._src import ad_util
from jax._src import api from jax._src import api
from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src import effects from jax._src import effects
from jax._src import source_info_util from jax._src import source_info_util
@ -1032,7 +1033,7 @@ def scan_bind(*args, **params):
scan_p = core.AxisPrimitive("scan") scan_p = core.AxisPrimitive("scan")
scan_p.multiple_results = True scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind) scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(xla.apply_primitive, scan_p)) scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
scan_p.def_effectful_abstract_eval(_scan_abstract_eval) scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose ad.reducing_transposes[scan_p] = _scan_transpose
@ -1612,7 +1613,7 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
while_p = core.AxisPrimitive('while') while_p = core.AxisPrimitive('while')
while_p.multiple_results = True while_p.multiple_results = True
while_p.def_impl(partial(xla.apply_primitive, while_p)) while_p.def_impl(partial(dispatch.apply_primitive, while_p))
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval pe.custom_partial_eval_rules[while_p] = _while_partial_eval

View File

@ -27,8 +27,6 @@ import numpy as np
import jax import jax
from jax import tree_util from jax import tree_util
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_map from jax.tree_util import tree_map
from jax._src import ad_util from jax._src import ad_util
@ -51,7 +49,9 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.interpreters.batching import ConcatAxis from jax._src.interpreters.batching import ConcatAxis
from jax._src.lax import slicing from jax._src.lax import slicing
from jax._src.lax.utils import ( from jax._src.lax.utils import (
@ -2264,7 +2264,7 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
return core._pp_eqn(eqn.replace(params=params), context, settings) return core._pp_eqn(eqn.replace(params=params), context, settings)
convert_element_type_p = Primitive('convert_element_type') convert_element_type_p = Primitive('convert_element_type')
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p)) convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
convert_element_type_p.def_abstract_eval( convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p, partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule, _convert_element_type_shape_rule, _convert_element_type_dtype_rule,
@ -3475,7 +3475,7 @@ def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions):
reduce_p = core.Primitive('reduce') reduce_p = core.Primitive('reduce')
reduce_p.multiple_results = True reduce_p.multiple_results = True
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p)) reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
reduce_p.def_abstract_eval( reduce_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule, partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
_reduce_dtype_rule, _reduce_weak_type_rule, _reduce_dtype_rule, _reduce_weak_type_rule,
@ -3869,7 +3869,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys
sort_p = Primitive('sort') sort_p = Primitive('sort')
sort_p.multiple_results = True sort_p.multiple_results = True
sort_p.def_impl(partial(xla.apply_primitive, sort_p)) sort_p.def_impl(partial(dispatch.apply_primitive, sort_p))
sort_p.def_abstract_eval(_sort_abstract_eval) sort_p.def_abstract_eval(_sort_abstract_eval)
ad.primitive_jvps[sort_p] = _sort_jvp ad.primitive_jvps[sort_p] = _sort_jvp
batching.primitive_batchers[sort_p] = _sort_batch_rule batching.primitive_batchers[sort_p] = _sort_batch_rule
@ -3960,7 +3960,7 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
top_k_p = Primitive('top_k') top_k_p = Primitive('top_k')
top_k_p.multiple_results = True top_k_p.multiple_results = True
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p)) top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval) top_k_p.def_abstract_eval(_top_k_abstract_eval)
def _top_k_lower(ctx, operand, k): def _top_k_lower(ctx, operand, k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results return chlo.TopKOp(operand, mlir.i64_attr(k)).results
@ -3993,7 +3993,7 @@ def create_token(_=None):
return create_token_p.bind() return create_token_p.bind()
create_token_p = Primitive("create_token") create_token_p = Primitive("create_token")
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p)) create_token_p.def_impl(partial(dispatch.apply_primitive, create_token_p))
create_token_p.def_abstract_eval(lambda *_: abstract_token) create_token_p.def_abstract_eval(lambda *_: abstract_token)
def _create_token_lowering(ctx, *operands): def _create_token_lowering(ctx, *operands):
@ -4015,7 +4015,7 @@ def _after_all_abstract_eval(*operands):
after_all_p = Primitive("after_all") after_all_p = Primitive("after_all")
after_all_p.def_impl(partial(xla.apply_primitive, after_all_p)) after_all_p.def_impl(partial(dispatch.apply_primitive, after_all_p))
after_all_p.def_abstract_eval(_after_all_abstract_eval) after_all_p.def_abstract_eval(_after_all_abstract_eval)
def _after_all_lowering(ctx, *operands): def _after_all_lowering(ctx, *operands):
@ -4060,7 +4060,7 @@ def _infeed_abstract_eval(token, *, shapes, partitions):
infeed_p = Primitive("infeed") infeed_p = Primitive("infeed")
infeed_p.multiple_results = True infeed_p.multiple_results = True
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p)) infeed_p.def_impl(partial(dispatch.apply_primitive, infeed_p))
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval) infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
mlir.lowerable_effects.add_type(InOutFeedEffect) mlir.lowerable_effects.add_type(InOutFeedEffect)
@ -4111,7 +4111,7 @@ def _outfeed_abstract_eval(token, *xs, partitions):
return abstract_token, {outfeed_effect} return abstract_token, {outfeed_effect}
outfeed_p = Primitive("outfeed") outfeed_p = Primitive("outfeed")
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p)) outfeed_p.def_impl(partial(dispatch.apply_primitive, outfeed_p))
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval) outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
mlir.lowerable_effects.add_type(InOutFeedEffect) mlir.lowerable_effects.add_type(InOutFeedEffect)
@ -4153,7 +4153,7 @@ def _rng_uniform_abstract_eval(a, b, *, shape):
weak_type=(a.weak_type and b.weak_type)) weak_type=(a.weak_type and b.weak_type))
rng_uniform_p = Primitive("rng_uniform") rng_uniform_p = Primitive("rng_uniform")
rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p)) rng_uniform_p.def_impl(partial(dispatch.apply_primitive, rng_uniform_p))
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval) rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
def _rng_uniform_lowering(ctx, a, b, *, shape): def _rng_uniform_lowering(ctx, a, b, *, shape):
@ -4247,7 +4247,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
rng_bit_generator_p = Primitive("rng_bit_generator") rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl( rng_bit_generator_p.def_impl(
partial(xla.apply_primitive, rng_bit_generator_p)) partial(dispatch.apply_primitive, rng_bit_generator_p))
rng_bit_generator_p.def_abstract_eval( rng_bit_generator_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, rng_bit_generator_p, partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule, _rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
@ -4298,7 +4298,7 @@ def _copy_impl(prim, *args, **kwargs):
if isinstance(a, jax.Array) and isinstance(a.sharding, PmapSharding): if isinstance(a, jax.Array) and isinstance(a.sharding, PmapSharding):
sharded_dim = _which_dim_sharded(a.sharding) sharded_dim = _which_dim_sharded(a.sharding)
return _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs) return _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs)
return xla.apply_primitive(prim, *args, **kwargs) return dispatch.apply_primitive(prim, *args, **kwargs)
# The copy_p primitive exists for expressing making copies of runtime arrays. # The copy_p primitive exists for expressing making copies of runtime arrays.
# For that reason we don't simplify it out of jaxprs (e.g. for jit invariance). # For that reason we don't simplify it out of jaxprs (e.g. for jit invariance).
@ -4354,7 +4354,7 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
return core.DShapedArray(shape, dtype, False) return core.DShapedArray(shape, dtype, False)
iota_p = Primitive('iota') iota_p = Primitive('iota')
iota_p.def_impl(partial(xla.apply_primitive, iota_p)) iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval) iota_p.def_abstract_eval(_iota_abstract_eval)
def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension): def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension):

View File

@ -23,10 +23,10 @@ import numpy as np
import jax import jax
from jax import lax from jax import lax
from jax.interpreters import xla
from jax._src import ad_util from jax._src import ad_util
from jax._src import api from jax._src import api
from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src.core import ( from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_shape) Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
@ -448,10 +448,12 @@ mlir.register_lowering(
# Asymmetric eigendecomposition # Asymmetric eigendecomposition
def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
return ( return dispatch.apply_primitive(
xla.apply_primitive(eig_p, operand, eig_p,
compute_left_eigenvectors=compute_left_eigenvectors, operand,
compute_right_eigenvectors=compute_right_eigenvectors)) compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors,
)
def eig_lower(*args, **kw): def eig_lower(*args, **kw):
raise NotImplementedError( raise NotImplementedError(
@ -577,8 +579,8 @@ def eigh_jacobi(x: ArrayLike, *, lower: bool = True,
return w, v return w, v
def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues): def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
w, v = xla.apply_primitive(eigh_jacobi_p, operand, lower=lower, w, v = dispatch.apply_primitive(eigh_jacobi_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues) sort_eigenvalues=sort_eigenvalues)
return w, v return w, v
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues): def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
@ -634,8 +636,8 @@ mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule)
def _eigh_impl(operand, *, lower, sort_eigenvalues): def _eigh_impl(operand, *, lower, sort_eigenvalues):
v, w = xla.apply_primitive(eigh_p, operand, lower=lower, v, w = dispatch.apply_primitive(eigh_p, operand, lower=lower,
sort_eigenvalues=sort_eigenvalues) sort_eigenvalues=sort_eigenvalues)
return v, w return v, w
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues): def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
@ -1016,7 +1018,7 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
lu_pivots_to_permutation_p.multiple_results = False lu_pivots_to_permutation_p.multiple_results = False
lu_pivots_to_permutation_p.def_impl( lu_pivots_to_permutation_p.def_impl(
partial(xla.apply_primitive, lu_pivots_to_permutation_p)) partial(dispatch.apply_primitive, lu_pivots_to_permutation_p))
lu_pivots_to_permutation_p.def_abstract_eval( lu_pivots_to_permutation_p.def_abstract_eval(
_lu_pivots_to_permutation_abstract_eval) _lu_pivots_to_permutation_abstract_eval)
batching.primitive_batchers[lu_pivots_to_permutation_p] = ( batching.primitive_batchers[lu_pivots_to_permutation_p] = (
@ -1111,7 +1113,7 @@ def _lu_python(x):
return fn(x) return fn(x)
def _lu_impl(operand): def _lu_impl(operand):
lu, pivot, perm = xla.apply_primitive(lu_p, operand) lu, pivot, perm = dispatch.apply_primitive(lu_p, operand)
return lu, pivot, perm return lu, pivot, perm
def _lu_abstract_eval(operand): def _lu_abstract_eval(operand):
@ -1385,7 +1387,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
geqrf_p = Primitive('geqrf') geqrf_p = Primitive('geqrf')
geqrf_p.multiple_results = True geqrf_p.multiple_results = True
geqrf_p.def_impl(partial(xla.apply_primitive, geqrf_p)) geqrf_p.def_impl(partial(dispatch.apply_primitive, geqrf_p))
geqrf_p.def_abstract_eval(_geqrf_abstract_eval) geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
mlir.register_lowering(geqrf_p, _geqrf_lowering_rule) mlir.register_lowering(geqrf_p, _geqrf_lowering_rule)
@ -1474,7 +1476,7 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
householder_product_p = Primitive('householder_product') householder_product_p = Primitive('householder_product')
householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p)) householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p))
householder_product_p.def_abstract_eval(_householder_product_abstract_eval) householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
mlir.register_lowering(householder_product_p, _householder_product_lowering_rule) mlir.register_lowering(householder_product_p, _householder_product_lowering_rule)
@ -1494,7 +1496,7 @@ mlir.register_lowering(
def _qr_impl(operand, *, full_matrices): def _qr_impl(operand, *, full_matrices):
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices) q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r return q, r
def _qr_abstract_eval(operand, *, full_matrices): def _qr_abstract_eval(operand, *, full_matrices):
@ -1572,7 +1574,7 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering));
# Singular value decomposition # Singular value decomposition
def _svd_impl(operand, *, full_matrices, compute_uv): def _svd_impl(operand, *, full_matrices, compute_uv):
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices, return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices,
compute_uv=compute_uv) compute_uv=compute_uv)
def _svd_abstract_eval(operand, *, full_matrices, compute_uv): def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
@ -1762,7 +1764,7 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t
tridiagonal_solve_p = Primitive('tridiagonal_solve') tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False tridiagonal_solve_p.multiple_results = False
tridiagonal_solve_p.def_impl( tridiagonal_solve_p.def_impl(
functools.partial(xla.apply_primitive, tridiagonal_solve_p)) functools.partial(dispatch.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b) tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve? # TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
@ -1873,7 +1875,7 @@ def schur(x: ArrayLike, *,
def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals, def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
select_callable): select_callable):
return xla.apply_primitive( return dispatch.apply_primitive(
schur_p, schur_p,
operand, operand,
compute_schur_vectors=compute_schur_vectors, compute_schur_vectors=compute_schur_vectors,
@ -2000,7 +2002,7 @@ def _hessenberg_abstract_eval(a):
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)] return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
hessenberg_p = Primitive("hessenberg") hessenberg_p = Primitive("hessenberg")
hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p)) hessenberg_p.def_impl(partial(dispatch.apply_primitive, hessenberg_p))
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval) hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
hessenberg_p.multiple_results = True hessenberg_p.multiple_results = True
@ -2098,7 +2100,7 @@ def _tridiagonal_abstract_eval(a, *, lower):
] ]
tridiagonal_p = Primitive("tridiagonal") tridiagonal_p = Primitive("tridiagonal")
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p)) tridiagonal_p.def_impl(partial(dispatch.apply_primitive, tridiagonal_p))
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval) tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
tridiagonal_p.multiple_results = True tridiagonal_p.multiple_results = True

View File

@ -20,7 +20,6 @@ import weakref
import numpy as np import numpy as np
import jax import jax
from jax.interpreters import partial_eval as pe
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
@ -30,6 +29,7 @@ from jax._src import util
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax from jax._src.lax import lax
from jax._src.lax.utils import ( from jax._src.lax.utils import (
_argnum_weak_type, _argnum_weak_type,

View File

@ -23,6 +23,7 @@ from typing import Callable
from jax.interpreters import xla from jax.interpreters import xla
from jax._src import core from jax._src import core
from jax._src import dispatch
from jax._src import dtypes from jax._src import dtypes
from jax._src.util import safe_zip from jax._src.util import safe_zip
from jax._src.lib import xla_client from jax._src.lib import xla_client
@ -44,7 +45,7 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
weak_type_rule = weak_type_rule or _standard_weak_type_rule weak_type_rule = weak_type_rule or _standard_weak_type_rule
named_shape_rule = named_shape_rule or standard_named_shape_rule named_shape_rule = named_shape_rule or standard_named_shape_rule
prim = core.Primitive(name) prim = core.Primitive(name)
prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_impl(partial(dispatch.apply_primitive, prim))
prim.def_abstract_eval( prim.def_abstract_eval(
partial(standard_abstract_eval, prim, shape_rule, dtype_rule, partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
weak_type_rule, named_shape_rule)) weak_type_rule, named_shape_rule))

View File

@ -39,7 +39,7 @@ from jax.errors import JAXTypeError
from jax._src.array import ArrayImpl from jax._src.array import ArrayImpl
from jax._src.sharding_impls import NamedSharding from jax._src.sharding_impls import NamedSharding
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax.interpreters import batching from jax.interpreters import batching
@ -965,7 +965,7 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
# This is DynamicJaxprTrace.process_map with some very minor modifications # This is DynamicJaxprTrace.process_map with some very minor modifications
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
from jax.interpreters.partial_eval import ( from jax._src.interpreters.partial_eval import (
trace_to_subjaxpr_dynamic, DynamicJaxprTracer, trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
convert_constvars_jaxpr, new_jaxpr_eqn) convert_constvars_jaxpr, new_jaxpr_eqn)
assert primitive is xmap_p assert primitive is xmap_p

View File

@ -28,9 +28,9 @@ import jax
from jax._src import core from jax._src import core
from jax import stages from jax import stages
from jax.errors import JAXTypeError from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters.pxla import PartitionSpec from jax._src.interpreters.pxla import PartitionSpec
from jax._src.interpreters import xla
from jax._src.tree_util import ( from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
treedef_tuple, broadcast_prefix, all_leaves) treedef_tuple, broadcast_prefix, all_leaves)

View File

@ -21,10 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
import numpy as np import numpy as np
from jax import lax from jax import lax
from jax.interpreters import partial_eval as pe
from jax._src import core from jax._src import core
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src.state.types import AbstractRef from jax._src.state.types import AbstractRef
from jax._src.state.primitives import get_p, swap_p, addupdate_p from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.util import safe_map, safe_zip, split_list from jax._src.util import safe_map, safe_zip, split_list

View File

@ -19,13 +19,13 @@ from typing import Any, List, Tuple, Union
import numpy as np import numpy as np
from jax import lax from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax._src import ad_util from jax._src import ad_util
from jax._src import core from jax._src import core
from jax._src import pretty_printer as pp from jax._src import pretty_printer as pp
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.typing import Array from jax._src.typing import Array
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
AccumEffect) AccumEffect)

View File

@ -19,11 +19,11 @@ from jax import tree_util
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax.experimental import pjit from jax.experimental import pjit
from jax.errors import UnexpectedTracerError from jax.errors import UnexpectedTracerError
from jax._src import mesh as mesh_lib
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
import jax.interpreters.pxla as pxla
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src import custom_api_util from jax._src import custom_api_util
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.api_util import flatten_fun_nokwargs from jax._src.api_util import flatten_fun_nokwargs
@ -370,7 +370,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
propagate_user_sharding, partition, propagate_user_sharding, partition,
infer_sharding_from_operands, infer_sharding_from_operands,
static_args): static_args):
mesh = pxla.thread_resources.env.physical_mesh mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context axis_context = ctx.module_context.axis_context
if isinstance(axis_context, mlir.ShardingContext): if isinstance(axis_context, mlir.ShardingContext):

View File

@ -511,9 +511,9 @@ from jax import custom_derivatives
from jax._src import dtypes from jax._src import dtypes
from jax import lax from jax import lax
from jax.experimental import pjit from jax.experimental import pjit
from jax.interpreters import ad, batching, pxla from jax._src.interpreters import ad, batching, pxla
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src import ad_checkpoint from jax._src import ad_checkpoint
from jax._src import dispatch from jax._src import dispatch

View File

@ -61,7 +61,6 @@ import numpy as np
from jax import lax from jax import lax
import jax.numpy as jnp import jax.numpy as jnp
from jax.experimental import pjit from jax.experimental import pjit
from jax.interpreters import partial_eval as pe
from jax.tree_util import (register_pytree_node, tree_structure, from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten,) treedef_is_leaf, tree_flatten, tree_unflatten,)
@ -70,6 +69,7 @@ from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax._src.api_util import shaped_abstractify from jax._src.api_util import shaped_abstractify
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.util import unzip2, weakref_lru_cache from jax._src.util import unzip2, weakref_lru_cache

View File

@ -43,10 +43,10 @@ from jax._src.util import (HashableFunction, HashablePartial, unzip2,
as_hashable_function, memoize, partition_list, as_hashable_function, memoize, partition_list,
merge_lists) merge_lists)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.interpreters import xla from jax._src.interpreters import xla
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax.interpreters import ad from jax.interpreters import ad
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,

View File

@ -36,7 +36,6 @@ from jax.experimental.sparse.util import (
SparseEfficiencyError, SparseEfficiencyWarning, Shape, SparseEfficiencyError, SparseEfficiencyWarning, Shape,
SparseInfo) SparseInfo)
from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
import jax.numpy as jnp import jax.numpy as jnp
from jax.util import safe_zip, unzip2, split_list from jax.util import safe_zip, unzip2, split_list
@ -45,6 +44,7 @@ from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.lax.lax import ( from jax._src.lax.lax import (
_const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers) _const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers)
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode

View File

@ -62,9 +62,9 @@ from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_spar
import jax.numpy as jnp import jax.numpy as jnp
from jax._src.api_util import flatten_fun_nokwargs from jax._src.api_util import flatten_fun_nokwargs
from jax._src.lib import pytree from jax._src.lib import pytree
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.interpreters import xla from jax._src.interpreters import xla
from jax.interpreters import pxla from jax._src.interpreters import pxla
from jax.tree_util import tree_flatten, tree_map, tree_unflatten from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import safe_map, safe_zip, split_list from jax.util import safe_map, safe_zip, split_list
from jax._src.config import config from jax._src.config import config

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,7 @@ per-file-ignores =
jax/interpreters/ad.py:F401 jax/interpreters/ad.py:F401
jax/interpreters/batching.py:F401 jax/interpreters/batching.py:F401
jax/interpreters/mlir.py:F401 jax/interpreters/mlir.py:F401
jax/interpreters/partial_eval.py:F401
jax/interpreters/pxla.py:F401 jax/interpreters/pxla.py:F401
jax/interpreters/xla.py:F401 jax/interpreters/xla.py:F401
jax/linear_util.py:F401 jax/linear_util.py:F401

View File

@ -55,7 +55,7 @@ from jax.interpreters import ad
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax.interpreters import xla from jax.interpreters import xla
from jax.interpreters import batching from jax.interpreters import batching
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jax._src import array from jax._src import array
from jax.experimental import pjit from jax.experimental import pjit

View File

@ -29,7 +29,6 @@ from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs from jax.api_util import flatten_fun_nokwargs
from jax.config import config from jax.config import config
from jax.interpreters import partial_eval as pe
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce, from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
tree_leaves) tree_leaves)
@ -38,6 +37,7 @@ from jax._src import linear_util as lu
from jax._src import util from jax._src import util
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.core import UnshapedArray, ShapedArray, DBIdx from jax._src.core import UnshapedArray, ShapedArray, DBIdx
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow from jax._src.lax import control_flow as lax_control_flow

View File

@ -26,8 +26,8 @@ from jax._src import linear_util as lu
from jax.config import config from jax.config import config
from jax.experimental import maps from jax.experimental import maps
from jax.experimental import pjit from jax.experimental import pjit
from jax.interpreters import ad from jax._src.interpreters import ad
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src import ad_checkpoint from jax._src import ad_checkpoint
from jax._src import dispatch from jax._src import dispatch

View File

@ -24,7 +24,7 @@ from jax._src import core
from jax import lax from jax import lax
from jax._src import linear_util as lu from jax._src import linear_util as lu
from jax.config import config from jax.config import config
from jax.interpreters import partial_eval as pe from jax._src.interpreters import partial_eval as pe
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.util import tuple_insert from jax._src.util import tuple_insert
import jax.numpy as jnp import jax.numpy as jnp