mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters. PiperOrigin-RevId: 519813664
This commit is contained in:
parent
ae4f1fcb66
commit
6cc1bf54a1
@ -127,6 +127,7 @@ from jax._src.api import xla_computation as xla_computation
|
||||
|
||||
from jax.interpreters import ad # TODO(phawkins): update users to avoid this.
|
||||
from jax.interpreters import pxla # TODO(phawkins): update users to avoid this.
|
||||
from jax.interpreters import partial_eval # TODO(phawkins): update users to avoid this.
|
||||
from jax.interpreters import xla # TODO(phawkins): update users to avoid this.
|
||||
|
||||
from jax._src.array import (
|
||||
|
@ -21,12 +21,10 @@ import types
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import effects
|
||||
from jax._src import source_info_util
|
||||
@ -35,6 +33,8 @@ from jax._src import util
|
||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -737,7 +737,7 @@ def _optimization_barrier(arg):
|
||||
optimization_barrier_p = core.Primitive('optimization_barrier')
|
||||
optimization_barrier_p.multiple_results = True
|
||||
optimization_barrier_p.def_impl(
|
||||
partial(xla.apply_primitive, optimization_barrier_p))
|
||||
partial(dispatch.apply_primitive, optimization_barrier_p))
|
||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||
mlir.register_lowering(optimization_barrier_p,
|
||||
_optimization_barrier_lowering_rule)
|
||||
|
@ -79,8 +79,8 @@ from jax.custom_batching import custom_vmap
|
||||
from jax.custom_derivatives import (custom_gradient, custom_jvp,
|
||||
custom_vjp, linear_call)
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
from jax._src.config import (
|
||||
|
@ -26,8 +26,6 @@ import jax.tree_util as jtu
|
||||
from jax import lax
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten
|
||||
from jax.tree_util import tree_map
|
||||
from jax.tree_util import tree_unflatten
|
||||
@ -42,6 +40,8 @@ from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
unzip3, weakref_lru_cache)
|
||||
|
@ -18,9 +18,6 @@ from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_map, tree_structure,
|
||||
tree_unflatten, treedef_tuple)
|
||||
from jax._src import core
|
||||
@ -33,6 +30,9 @@ from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters.batching import not_mapped
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
@ -21,8 +21,6 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves)
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.config import config
|
||||
|
||||
from jax._src import core
|
||||
@ -38,6 +36,8 @@ from jax._src.core import raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters.batching import not_mapped
|
||||
from jax._src.lax import lax
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
|
@ -15,9 +15,6 @@
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
|
||||
tree_structure, treedef_tuple, tree_unflatten)
|
||||
from jax._src import ad_util
|
||||
@ -29,6 +26,9 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
@ -38,7 +38,7 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import GSPMDSharding, NamedSharding
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
|
@ -32,8 +32,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.monitoring import record_event_duration_secs
|
||||
import jax.interpreters.mlir as mlir
|
||||
import jax.interpreters.partial_eval as pe
|
||||
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
@ -49,6 +47,8 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
|
@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
register_pytree_node, Partial)
|
||||
|
@ -35,7 +35,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize, weakref_lru_cache)
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
Array = Any
|
||||
map, unsafe_map = safe_map, map
|
||||
|
@ -32,9 +32,6 @@ import numpy as np
|
||||
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
@ -43,6 +40,9 @@ from jax._src import effects as effects_lib
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
2506
jax/_src/interpreters/partial_eval.py
Normal file
2506
jax/_src/interpreters/partial_eval.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -48,7 +48,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
|
||||
from jax._src import api_util
|
||||
@ -71,6 +70,7 @@ from jax._src.config import flags
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
@ -78,6 +78,7 @@ from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -362,7 +363,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
|
||||
|
||||
approx_top_k_p = core.Primitive('approx_top_k')
|
||||
approx_top_k_p.multiple_results = True
|
||||
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p))
|
||||
approx_top_k_p.def_impl(partial(dispatch.apply_primitive, approx_top_k_p))
|
||||
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
|
||||
|
@ -25,7 +25,7 @@ from jax._src import ad_util
|
||||
from jax._src import util
|
||||
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_map, tree_unflatten
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
|
@ -22,14 +22,10 @@ import operator
|
||||
from typing import Callable, Sequence, List, Tuple
|
||||
|
||||
from jax.config import config
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
@ -37,6 +33,11 @@ from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import state
|
||||
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (safe_map, split_list, partition_list)
|
||||
@ -803,7 +804,7 @@ def cond_bind(*args, branches, linear):
|
||||
|
||||
cond_p = core.AxisPrimitive('cond')
|
||||
cond_p.multiple_results = True
|
||||
cond_p.def_impl(partial(xla.apply_primitive, cond_p))
|
||||
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
|
||||
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
|
||||
cond_p.def_custom_bind(cond_bind)
|
||||
ad.primitive_jvps[cond_p] = _cond_jvp
|
||||
|
@ -20,16 +20,16 @@ from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple,
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.tree_util import (tree_flatten, tree_structure, tree_unflatten,
|
||||
treedef_tuple, tree_map, tree_leaves, PyTreeDef)
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
@ -299,7 +299,7 @@ def _for_impl_unrolled(body, nsteps, unroll, *args):
|
||||
return state
|
||||
|
||||
mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
|
||||
for_p.def_impl(functools.partial(xla.apply_primitive, for_p))
|
||||
for_p.def_impl(functools.partial(dispatch.apply_primitive, for_p))
|
||||
|
||||
def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
|
||||
jaxpr, nsteps, reverse, which_linear, unroll):
|
||||
|
@ -27,14 +27,15 @@ from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
tree_map, tree_flatten_with_path, keystr)
|
||||
from jax._src.tree_util import equality_errors
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import source_info_util
|
||||
@ -1032,7 +1033,7 @@ def scan_bind(*args, **params):
|
||||
scan_p = core.AxisPrimitive("scan")
|
||||
scan_p.multiple_results = True
|
||||
scan_p.def_custom_bind(scan_bind)
|
||||
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
|
||||
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
|
||||
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
|
||||
ad.primitive_jvps[scan_p] = _scan_jvp
|
||||
ad.reducing_transposes[scan_p] = _scan_transpose
|
||||
@ -1612,7 +1613,7 @@ def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
|
||||
while_p = core.AxisPrimitive('while')
|
||||
while_p.multiple_results = True
|
||||
while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_impl(partial(dispatch.apply_primitive, while_p))
|
||||
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
|
@ -27,8 +27,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
from jax._src import ad_util
|
||||
@ -51,7 +49,9 @@ from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters.batching import ConcatAxis
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
@ -2264,7 +2264,7 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
|
||||
return core._pp_eqn(eqn.replace(params=params), context, settings)
|
||||
|
||||
convert_element_type_p = Primitive('convert_element_type')
|
||||
convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p))
|
||||
convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p))
|
||||
convert_element_type_p.def_abstract_eval(
|
||||
partial(standard_abstract_eval, convert_element_type_p,
|
||||
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
||||
@ -3475,7 +3475,7 @@ def _reduce_named_shape_rule(*avals, computation, jaxpr, consts, dimensions):
|
||||
|
||||
reduce_p = core.Primitive('reduce')
|
||||
reduce_p.multiple_results = True
|
||||
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
|
||||
reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
|
||||
reduce_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
|
||||
_reduce_dtype_rule, _reduce_weak_type_rule,
|
||||
@ -3869,7 +3869,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys
|
||||
|
||||
sort_p = Primitive('sort')
|
||||
sort_p.multiple_results = True
|
||||
sort_p.def_impl(partial(xla.apply_primitive, sort_p))
|
||||
sort_p.def_impl(partial(dispatch.apply_primitive, sort_p))
|
||||
sort_p.def_abstract_eval(_sort_abstract_eval)
|
||||
ad.primitive_jvps[sort_p] = _sort_jvp
|
||||
batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
@ -3960,7 +3960,7 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
|
||||
|
||||
top_k_p = Primitive('top_k')
|
||||
top_k_p.multiple_results = True
|
||||
top_k_p.def_impl(partial(xla.apply_primitive, top_k_p))
|
||||
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
|
||||
top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
||||
def _top_k_lower(ctx, operand, k):
|
||||
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
|
||||
@ -3993,7 +3993,7 @@ def create_token(_=None):
|
||||
return create_token_p.bind()
|
||||
|
||||
create_token_p = Primitive("create_token")
|
||||
create_token_p.def_impl(partial(xla.apply_primitive, create_token_p))
|
||||
create_token_p.def_impl(partial(dispatch.apply_primitive, create_token_p))
|
||||
create_token_p.def_abstract_eval(lambda *_: abstract_token)
|
||||
|
||||
def _create_token_lowering(ctx, *operands):
|
||||
@ -4015,7 +4015,7 @@ def _after_all_abstract_eval(*operands):
|
||||
|
||||
|
||||
after_all_p = Primitive("after_all")
|
||||
after_all_p.def_impl(partial(xla.apply_primitive, after_all_p))
|
||||
after_all_p.def_impl(partial(dispatch.apply_primitive, after_all_p))
|
||||
after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
||||
|
||||
def _after_all_lowering(ctx, *operands):
|
||||
@ -4060,7 +4060,7 @@ def _infeed_abstract_eval(token, *, shapes, partitions):
|
||||
|
||||
infeed_p = Primitive("infeed")
|
||||
infeed_p.multiple_results = True
|
||||
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
|
||||
infeed_p.def_impl(partial(dispatch.apply_primitive, infeed_p))
|
||||
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
|
||||
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
||||
|
||||
@ -4111,7 +4111,7 @@ def _outfeed_abstract_eval(token, *xs, partitions):
|
||||
return abstract_token, {outfeed_effect}
|
||||
|
||||
outfeed_p = Primitive("outfeed")
|
||||
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
|
||||
outfeed_p.def_impl(partial(dispatch.apply_primitive, outfeed_p))
|
||||
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
|
||||
mlir.lowerable_effects.add_type(InOutFeedEffect)
|
||||
|
||||
@ -4153,7 +4153,7 @@ def _rng_uniform_abstract_eval(a, b, *, shape):
|
||||
weak_type=(a.weak_type and b.weak_type))
|
||||
|
||||
rng_uniform_p = Primitive("rng_uniform")
|
||||
rng_uniform_p.def_impl(partial(xla.apply_primitive, rng_uniform_p))
|
||||
rng_uniform_p.def_impl(partial(dispatch.apply_primitive, rng_uniform_p))
|
||||
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
||||
|
||||
def _rng_uniform_lowering(ctx, a, b, *, shape):
|
||||
@ -4247,7 +4247,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
||||
rng_bit_generator_p = Primitive("rng_bit_generator")
|
||||
rng_bit_generator_p.multiple_results = True
|
||||
rng_bit_generator_p.def_impl(
|
||||
partial(xla.apply_primitive, rng_bit_generator_p))
|
||||
partial(dispatch.apply_primitive, rng_bit_generator_p))
|
||||
rng_bit_generator_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
|
||||
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
||||
@ -4298,7 +4298,7 @@ def _copy_impl(prim, *args, **kwargs):
|
||||
if isinstance(a, jax.Array) and isinstance(a.sharding, PmapSharding):
|
||||
sharded_dim = _which_dim_sharded(a.sharding)
|
||||
return _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs)
|
||||
return xla.apply_primitive(prim, *args, **kwargs)
|
||||
return dispatch.apply_primitive(prim, *args, **kwargs)
|
||||
|
||||
# The copy_p primitive exists for expressing making copies of runtime arrays.
|
||||
# For that reason we don't simplify it out of jaxprs (e.g. for jit invariance).
|
||||
@ -4354,7 +4354,7 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
|
||||
return core.DShapedArray(shape, dtype, False)
|
||||
|
||||
iota_p = Primitive('iota')
|
||||
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
|
||||
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
|
||||
iota_p.def_abstract_eval(_iota_abstract_eval)
|
||||
|
||||
def _iota_staging_rule(trace, *dyn_shape, dtype, shape, dimension):
|
||||
|
@ -23,10 +23,10 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.core import (
|
||||
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
|
||||
@ -448,10 +448,12 @@ mlir.register_lowering(
|
||||
# Asymmetric eigendecomposition
|
||||
|
||||
def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors):
|
||||
return (
|
||||
xla.apply_primitive(eig_p, operand,
|
||||
compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors))
|
||||
return dispatch.apply_primitive(
|
||||
eig_p,
|
||||
operand,
|
||||
compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors,
|
||||
)
|
||||
|
||||
def eig_lower(*args, **kw):
|
||||
raise NotImplementedError(
|
||||
@ -577,8 +579,8 @@ def eigh_jacobi(x: ArrayLike, *, lower: bool = True,
|
||||
return w, v
|
||||
|
||||
def _eigh_jacobi_impl(operand, *, lower, sort_eigenvalues):
|
||||
w, v = xla.apply_primitive(eigh_jacobi_p, operand, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
w, v = dispatch.apply_primitive(eigh_jacobi_p, operand, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
return w, v
|
||||
|
||||
def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
@ -634,8 +636,8 @@ mlir.register_lowering(eigh_jacobi_p, _eigh_jacobi_lowering_rule)
|
||||
|
||||
|
||||
def _eigh_impl(operand, *, lower, sort_eigenvalues):
|
||||
v, w = xla.apply_primitive(eigh_p, operand, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
v, w = dispatch.apply_primitive(eigh_p, operand, lower=lower,
|
||||
sort_eigenvalues=sort_eigenvalues)
|
||||
return v, w
|
||||
|
||||
def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
@ -1016,7 +1018,7 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
|
||||
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
|
||||
lu_pivots_to_permutation_p.multiple_results = False
|
||||
lu_pivots_to_permutation_p.def_impl(
|
||||
partial(xla.apply_primitive, lu_pivots_to_permutation_p))
|
||||
partial(dispatch.apply_primitive, lu_pivots_to_permutation_p))
|
||||
lu_pivots_to_permutation_p.def_abstract_eval(
|
||||
_lu_pivots_to_permutation_abstract_eval)
|
||||
batching.primitive_batchers[lu_pivots_to_permutation_p] = (
|
||||
@ -1111,7 +1113,7 @@ def _lu_python(x):
|
||||
return fn(x)
|
||||
|
||||
def _lu_impl(operand):
|
||||
lu, pivot, perm = xla.apply_primitive(lu_p, operand)
|
||||
lu, pivot, perm = dispatch.apply_primitive(lu_p, operand)
|
||||
return lu, pivot, perm
|
||||
|
||||
def _lu_abstract_eval(operand):
|
||||
@ -1385,7 +1387,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
|
||||
|
||||
geqrf_p = Primitive('geqrf')
|
||||
geqrf_p.multiple_results = True
|
||||
geqrf_p.def_impl(partial(xla.apply_primitive, geqrf_p))
|
||||
geqrf_p.def_impl(partial(dispatch.apply_primitive, geqrf_p))
|
||||
geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
|
||||
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
|
||||
mlir.register_lowering(geqrf_p, _geqrf_lowering_rule)
|
||||
@ -1474,7 +1476,7 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||
|
||||
|
||||
householder_product_p = Primitive('householder_product')
|
||||
householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p))
|
||||
householder_product_p.def_impl(partial(dispatch.apply_primitive, householder_product_p))
|
||||
householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
|
||||
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
|
||||
mlir.register_lowering(householder_product_p, _householder_product_lowering_rule)
|
||||
@ -1494,7 +1496,7 @@ mlir.register_lowering(
|
||||
|
||||
|
||||
def _qr_impl(operand, *, full_matrices):
|
||||
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
||||
q, r = dispatch.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
||||
return q, r
|
||||
|
||||
def _qr_abstract_eval(operand, *, full_matrices):
|
||||
@ -1572,7 +1574,7 @@ mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering));
|
||||
# Singular value decomposition
|
||||
|
||||
def _svd_impl(operand, *, full_matrices, compute_uv):
|
||||
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
||||
return dispatch.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
|
||||
def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
|
||||
@ -1762,7 +1764,7 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b, *, m, n, ldb, t
|
||||
tridiagonal_solve_p = Primitive('tridiagonal_solve')
|
||||
tridiagonal_solve_p.multiple_results = False
|
||||
tridiagonal_solve_p.def_impl(
|
||||
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
|
||||
functools.partial(dispatch.apply_primitive, tridiagonal_solve_p))
|
||||
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
|
||||
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
||||
|
||||
@ -1873,7 +1875,7 @@ def schur(x: ArrayLike, *,
|
||||
|
||||
def _schur_impl(operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
select_callable):
|
||||
return xla.apply_primitive(
|
||||
return dispatch.apply_primitive(
|
||||
schur_p,
|
||||
operand,
|
||||
compute_schur_vectors=compute_schur_vectors,
|
||||
@ -2000,7 +2002,7 @@ def _hessenberg_abstract_eval(a):
|
||||
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
|
||||
|
||||
hessenberg_p = Primitive("hessenberg")
|
||||
hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p))
|
||||
hessenberg_p.def_impl(partial(dispatch.apply_primitive, hessenberg_p))
|
||||
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
|
||||
hessenberg_p.multiple_results = True
|
||||
|
||||
@ -2098,7 +2100,7 @@ def _tridiagonal_abstract_eval(a, *, lower):
|
||||
]
|
||||
|
||||
tridiagonal_p = Primitive("tridiagonal")
|
||||
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
|
||||
tridiagonal_p.def_impl(partial(dispatch.apply_primitive, tridiagonal_p))
|
||||
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
|
||||
tridiagonal_p.multiple_results = True
|
||||
|
||||
|
@ -20,7 +20,6 @@ import weakref
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
@ -30,6 +29,7 @@ from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax.utils import (
|
||||
_argnum_weak_type,
|
||||
|
@ -23,6 +23,7 @@ from typing import Callable
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_client
|
||||
@ -44,7 +45,7 @@ def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None,
|
||||
weak_type_rule = weak_type_rule or _standard_weak_type_rule
|
||||
named_shape_rule = named_shape_rule or standard_named_shape_rule
|
||||
prim = core.Primitive(name)
|
||||
prim.def_impl(partial(xla.apply_primitive, prim))
|
||||
prim.def_impl(partial(dispatch.apply_primitive, prim))
|
||||
prim.def_abstract_eval(
|
||||
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
|
||||
weak_type_rule, named_shape_rule))
|
||||
|
@ -39,7 +39,7 @@ from jax.errors import JAXTypeError
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
@ -965,7 +965,7 @@ pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
|
||||
|
||||
# This is DynamicJaxprTrace.process_map with some very minor modifications
|
||||
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
from jax.interpreters.partial_eval import (
|
||||
from jax._src.interpreters.partial_eval import (
|
||||
trace_to_subjaxpr_dynamic, DynamicJaxprTracer,
|
||||
convert_constvars_jaxpr, new_jaxpr_eqn)
|
||||
assert primitive is xmap_p
|
||||
|
@ -28,9 +28,9 @@ import jax
|
||||
from jax._src import core
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
treedef_tuple, broadcast_prefix, all_leaves)
|
||||
|
@ -21,10 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
|
@ -19,13 +19,13 @@ from typing import Any, List, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.typing import Array
|
||||
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
||||
AccumEffect)
|
||||
|
@ -19,11 +19,11 @@ from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax.experimental import pjit
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import ir
|
||||
import jax.interpreters.pxla as pxla
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import custom_api_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
@ -370,7 +370,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
propagate_user_sharding, partition,
|
||||
infer_sharding_from_operands,
|
||||
static_args):
|
||||
mesh = pxla.thread_resources.env.physical_mesh
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
axis_context = ctx.module_context.axis_context
|
||||
|
||||
if isinstance(axis_context, mlir.ShardingContext):
|
||||
|
@ -511,9 +511,9 @@ from jax import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad, batching, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import ad, batching, pxla
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
|
@ -61,7 +61,6 @@ import numpy as np
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten,)
|
||||
|
||||
@ -70,6 +69,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.util import unzip2, weakref_lru_cache
|
||||
|
||||
|
@ -43,10 +43,10 @@ from jax._src.util import (HashableFunction, HashablePartial, unzip2,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
merge_lists)
|
||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
|
@ -36,7 +36,6 @@ from jax.experimental.sparse.util import (
|
||||
SparseEfficiencyError, SparseEfficiencyWarning, Shape,
|
||||
SparseInfo)
|
||||
from jax.experimental.sparse._lowerings import coo_spmv_p, coo_spmm_p
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
from jax.util import safe_zip, unzip2, split_list
|
||||
@ -45,6 +44,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax.lax import (
|
||||
_const, ranges_like, remaining, _dot_general_batch_dim_nums, DotDimensionNumbers)
|
||||
from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
|
||||
|
@ -62,9 +62,9 @@ from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_spar
|
||||
import jax.numpy as jnp
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.lib import pytree
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax.util import safe_map, safe_zip, split_list
|
||||
from jax._src.config import config
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,7 @@ per-file-ignores =
|
||||
jax/interpreters/ad.py:F401
|
||||
jax/interpreters/batching.py:F401
|
||||
jax/interpreters/mlir.py:F401
|
||||
jax/interpreters/partial_eval.py:F401
|
||||
jax/interpreters/pxla.py:F401
|
||||
jax/interpreters/xla.py:F401
|
||||
jax/linear_util.py:F401
|
||||
|
@ -55,7 +55,7 @@ from jax.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import array
|
||||
from jax.experimental import pjit
|
||||
|
@ -29,7 +29,6 @@ from jax import numpy as jnp
|
||||
from jax import jvp, linearize, vjp, jit, make_jaxpr
|
||||
from jax.api_util import flatten_fun_nokwargs
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
|
||||
tree_leaves)
|
||||
|
||||
@ -38,6 +37,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.core import UnshapedArray, ShapedArray, DBIdx
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
|
||||
|
@ -26,8 +26,8 @@ from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
|
@ -24,7 +24,7 @@ from jax._src import core
|
||||
from jax import lax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import tuple_insert
|
||||
import jax.numpy as jnp
|
||||
|
Loading…
x
Reference in New Issue
Block a user