mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Merge pull request #11634 from mattjj:fastpath-for-shaped-abstractify
PiperOrigin-RevId: 463718000
This commit is contained in:
commit
27655af6b9
@ -20,6 +20,7 @@ import google_benchmark
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import sparse
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -461,6 +462,15 @@ def sparse_bcoo_matvec_compile(state):
|
||||
return _sparse_bcoo_matvec(state, compile=True)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def bench_shaped_abstractify(state):
|
||||
device, *_ = jax.devices()
|
||||
args = [jax.device_put_replicated(1, [device])] * 1000
|
||||
while state:
|
||||
_ = [shaped_abstractify(x) for x in args]
|
||||
|
||||
|
||||
def swap(a, b):
|
||||
return b, a
|
||||
|
||||
|
@ -15,7 +15,8 @@
|
||||
import inspect
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, Sequence, Set, Tuple, Union, Optional
|
||||
from typing import (Any, Dict, Iterable, Sequence, Set, Tuple, Union, Optional,
|
||||
Callable)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -425,7 +426,7 @@ def _dtype(x):
|
||||
except ValueError:
|
||||
return dtypes.result_type(getattr(x, 'dtype'))
|
||||
|
||||
def shaped_abstractify(x):
|
||||
def _shaped_abstractify_slow(x):
|
||||
try:
|
||||
return core.raise_to_shaped(
|
||||
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
|
||||
@ -437,6 +438,14 @@ def shaped_abstractify(x):
|
||||
return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
|
||||
named_shape=named_shape)
|
||||
|
||||
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
|
||||
def shaped_abstractify(x):
|
||||
try:
|
||||
return _shaped_abstractify_handlers[type(x)](x)
|
||||
except KeyError:
|
||||
return _shaped_abstractify_slow(x)
|
||||
_shaped_abstractify_handlers: Dict[Any, Callable[[Any], core.ShapedArray]] = {}
|
||||
|
||||
# This decorator exists to make it easier to monkey-patch APIs in JAX.
|
||||
# By default it does nothing, but it can be monkey-patched to do other things.
|
||||
def api_hook(fun, tag: str):
|
||||
|
@ -18,6 +18,7 @@ import numpy as np
|
||||
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
|
||||
|
||||
from jax import core
|
||||
from jax._src import api_util
|
||||
from jax._src import dispatch
|
||||
from jax._src.config import config
|
||||
from jax._src.util import prod
|
||||
@ -242,6 +243,8 @@ def make_array_from_callback(shape: Shape, sharding: Sharding,
|
||||
core.pytype_aval_mappings[Array] = lambda x: core.ShapedArray(x.shape, x.dtype)
|
||||
xla.pytype_aval_mappings[Array] = lambda x: core.ShapedArray(x.shape, x.dtype)
|
||||
xla.canonicalize_dtype_handlers[Array] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[Array] = \
|
||||
lambda x: core.ShapedArray(x.shape, x.dtype)
|
||||
|
||||
|
||||
def _device_put_array(x, device: Optional[Device]):
|
||||
|
@ -19,6 +19,7 @@ import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
|
||||
|
||||
from jax import core
|
||||
from jax._src import api_util
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.config import config
|
||||
@ -557,6 +558,8 @@ core.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
|
||||
xla.pytype_aval_mappings[GlobalDeviceArray] = lambda x: core.ShapedArray(
|
||||
x.shape, x.dtype)
|
||||
xla.canonicalize_dtype_handlers[GlobalDeviceArray] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[GlobalDeviceArray] = \
|
||||
lambda x: core.ShapedArray(x.shape, x.dtype)
|
||||
|
||||
def _gda_shard_arg(x, devices, indices):
|
||||
return x._device_buffers
|
||||
|
@ -27,8 +27,9 @@ from weakref import ref
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax import linear_util as lu
|
||||
from jax._src import api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src.ad_util import Zero
|
||||
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
|
||||
@ -1578,6 +1579,7 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
frame = self._trace.frame
|
||||
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
|
||||
return self if val is None else get_referent(val)
|
||||
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
|
||||
|
||||
|
||||
class JaxprStackFrame:
|
||||
|
@ -59,6 +59,7 @@ from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api_util
|
||||
from jax._src import device_array
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -808,6 +809,7 @@ def _register_handlers_for_sharded_device_array(sda):
|
||||
dispatch.device_put_handlers[sda] = dispatch._device_put_array
|
||||
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
|
||||
xla.canonicalize_dtype_handlers[sda] = identity
|
||||
api_util._shaped_abstractify_handlers[sda] = op.attrgetter("aval")
|
||||
|
||||
_register_handlers_for_sharded_device_array(_ShardedDeviceArray)
|
||||
_register_handlers_for_sharded_device_array(pmap_lib.ShardedDeviceArray)
|
||||
|
Loading…
x
Reference in New Issue
Block a user