Merge pull request #11634 from mattjj:fastpath-for-shaped-abstractify

PiperOrigin-RevId: 463718000
This commit is contained in:
jax authors 2022-07-27 17:33:58 -07:00
commit 27655af6b9
6 changed files with 32 additions and 3 deletions

View File

@ -20,6 +20,7 @@ import google_benchmark
import jax
from jax import lax
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
import jax.numpy as jnp
import numpy as np
@ -461,6 +462,15 @@ def sparse_bcoo_matvec_compile(state):
return _sparse_bcoo_matvec(state, compile=True)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def bench_shaped_abstractify(state):
device, *_ = jax.devices()
args = [jax.device_put_replicated(1, [device])] * 1000
while state:
_ = [shaped_abstractify(x) for x in args]
def swap(a, b):
return b, a

View File

@ -15,7 +15,8 @@
import inspect
import operator
from functools import partial
from typing import Any, Dict, Iterable, Sequence, Set, Tuple, Union, Optional
from typing import (Any, Dict, Iterable, Sequence, Set, Tuple, Union, Optional,
Callable)
import warnings
import numpy as np
@ -425,7 +426,7 @@ def _dtype(x):
except ValueError:
return dtypes.result_type(getattr(x, 'dtype'))
def shaped_abstractify(x):
def _shaped_abstractify_slow(x):
try:
return core.raise_to_shaped(
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
@ -437,6 +438,14 @@ def shaped_abstractify(x):
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
named_shape=named_shape)
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
def shaped_abstractify(x):
try:
return _shaped_abstractify_handlers[type(x)](x)
except KeyError:
return _shaped_abstractify_slow(x)
_shaped_abstractify_handlers: Dict[Any, Callable[[Any], core.ShapedArray]] = {}
# This decorator exists to make it easier to monkey-patch APIs in JAX.
# By default it does nothing, but it can be monkey-patched to do other things.
def api_hook(fun, tag: str):

View File

@ -18,6 +18,7 @@ import numpy as np
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
from jax import core
from jax._src import api_util
from jax._src import dispatch
from jax._src.config import config
from jax._src.util import prod
@ -242,6 +243,8 @@ def make_array_from_callback(shape: Shape, sharding: Sharding,
core.pytype_aval_mappings[Array] = lambda x: core.ShapedArray(x.shape, x.dtype)
xla.pytype_aval_mappings[Array] = lambda x: core.ShapedArray(x.shape, x.dtype)
xla.canonicalize_dtype_handlers[Array] = pxla.identity
api_util._shaped_abstractify_handlers[Array] = \
lambda x: core.ShapedArray(x.shape, x.dtype)
def _device_put_array(x, device: Optional[Device]):

View File

@ -19,6 +19,7 @@ import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
from jax import core
from jax._src import api_util
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.config import config
@ -557,6 +558,8 @@ core.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
x.shape, x.dtype)
xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \
lambda x: core.ShapedArray(x.shape, x.dtype)
def _gda_shard_arg(x, devices, indices):
return x._device_buffers

View File

@ -27,8 +27,9 @@ from weakref import ref
import numpy as np
from jax import core
from jax._src import dtypes
from jax import linear_util as lu
from jax._src import api_util
from jax._src import dtypes
from jax._src import profiler
from jax._src.ad_util import Zero
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
@ -1578,6 +1579,7 @@ class DynamicJaxprTracer(core.Tracer):
frame = self._trace.frame
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
return self if val is None else get_referent(val)
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
class JaxprStackFrame:

View File

@ -59,6 +59,7 @@ from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_map
from jax._src import abstract_arrays
from jax._src import api_util
from jax._src import device_array
from jax._src import source_info_util
from jax._src import util
@ -808,6 +809,7 @@ def _register_handlers_for_sharded_device_array(sda):
dispatch.device_put_handlers[sda] = dispatch._device_put_array
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
xla.canonicalize_dtype_handlers[sda] = identity
api_util._shaped_abstractify_handlers[sda] = op.attrgetter("aval")
_register_handlers_for_sharded_device_array(_ShardedDeviceArray)
_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)