diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 86ebad744..791f33c1b 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -20,6 +20,7 @@ import google_benchmark import jax from jax import lax from jax.experimental import sparse +from jax._src.api_util import shaped_abstractify # technically not an api fn import jax.numpy as jnp import numpy as np @@ -461,6 +462,15 @@ def sparse_bcoo_matvec_compile(state): return _sparse_bcoo_matvec(state, compile=True) +@google_benchmark.register +@google_benchmark.option.unit(google_benchmark.kMillisecond) +def bench_shaped_abstractify(state): + device, *_ = jax.devices() + args = [jax.device_put_replicated(1, [device])] * 1000 + while state: + _ = [shaped_abstractify(x) for x in args] + + def swap(a, b): return b, a diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 6f22cc7c1..c27e0d8c6 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -15,7 +15,8 @@ import inspect import operator from functools import partial -from typing import Any, Dict, Iterable, Sequence, Set, Tuple, Union, Optional +from typing import (Any, Dict, Iterable, Sequence, Set, Tuple, Union, Optional, + Callable) import warnings import numpy as np @@ -425,7 +426,7 @@ def _dtype(x): except ValueError: return dtypes.result_type(getattr(x, 'dtype')) -def shaped_abstractify(x): +def _shaped_abstractify_slow(x): try: return core.raise_to_shaped( x if isinstance(x, core.AbstractValue) else core.get_aval(x)) @@ -437,6 +438,14 @@ def shaped_abstractify(x): return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type, named_shape=named_shape) +# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior +def shaped_abstractify(x): + try: + return _shaped_abstractify_handlers[type(x)](x) + except KeyError: + return _shaped_abstractify_slow(x) +_shaped_abstractify_handlers: Dict[Any, Callable[[Any], core.ShapedArray]] = {} + # This decorator exists to make it easier to monkey-patch APIs in JAX. # By default it does nothing, but it can be monkey-patched to do other things. def api_hook(fun, tag: str): diff --git a/jax/experimental/array.py b/jax/experimental/array.py index 01a6718fd..a8f82b729 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -18,6 +18,7 @@ import numpy as np from typing import Sequence, Tuple, Callable, Union, Optional, cast, List from jax import core +from jax._src import api_util from jax._src import dispatch from jax._src.config import config from jax._src.util import prod @@ -242,6 +243,8 @@ def make_array_from_callback(shape: Shape, sharding: Sharding, core.pytype_aval_mappings[Array] = lambda x: core.ShapedArray(x.shape, x.dtype) xla.pytype_aval_mappings[Array] = lambda x: core.ShapedArray(x.shape, x.dtype) xla.canonicalize_dtype_handlers[Array] = pxla.identity +api_util._shaped_abstractify_handlers[Array] = \ + lambda x: core.ShapedArray(x.shape, x.dtype) def _device_put_array(x, device: Optional[Device]): diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 136e66cb2..ed5898d25 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -19,6 +19,7 @@ import numpy as np from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple from jax import core +from jax._src import api_util from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.config import config @@ -557,6 +558,8 @@ core.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray( xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray( x.shape, x.dtype) xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity +api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \ + lambda x: core.ShapedArray(x.shape, x.dtype) def _gda_shard_arg(x, devices, indices): return x._device_buffers diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 688fb7efb..2d2208cae 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -27,8 +27,9 @@ from weakref import ref import numpy as np from jax import core -from jax._src import dtypes from jax import linear_util as lu +from jax._src import api_util +from jax._src import dtypes from jax._src import profiler from jax._src.ad_util import Zero from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs @@ -1578,6 +1579,7 @@ class DynamicJaxprTracer(core.Tracer): frame = self._trace.frame val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self))) return self if val is None else get_referent(val) +api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval") class JaxprStackFrame: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 6ca40172c..f650d2656 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -59,6 +59,7 @@ from jax.interpreters import xla from jax.tree_util import tree_flatten, tree_map from jax._src import abstract_arrays +from jax._src import api_util from jax._src import device_array from jax._src import source_info_util from jax._src import util @@ -808,6 +809,7 @@ def _register_handlers_for_sharded_device_array(sda): dispatch.device_put_handlers[sda] = dispatch._device_put_array xla.pytype_aval_mappings[sda] = op.attrgetter("aval") xla.canonicalize_dtype_handlers[sda] = identity + api_util._shaped_abstractify_handlers[sda] = op.attrgetter("aval") _register_handlers_for_sharded_device_array(_ShardedDeviceArray) _register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)