mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove redundant dtype canonicalization from jax.device_put().
Gives a small improvement to the included jax.device_put() benchmark on my VM: ``` name old cpu/op new cpu/op delta device_put 91.3µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5) name old time/op new time/op delta device_put 91.4µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5) ``` jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement. PiperOrigin-RevId: 491727173
This commit is contained in:
parent
7495a9e370
commit
d6c67c97db
@ -830,6 +830,11 @@ def host_local_array_to_global_array(state):
|
||||
multihost_utils.host_local_array_to_global_array(
|
||||
(input_data, input_data), global_mesh, (in_pspec, in_pspec))
|
||||
|
||||
@google_benchmark.register
|
||||
def device_put(state):
|
||||
x = np.array(1, np.int32)
|
||||
while state:
|
||||
_ = jax.device_put(x).block_until_ready()
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
@ -42,12 +42,15 @@ def zeros_like_array(x):
|
||||
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
return ad_util.zeros_like_aval(aval)
|
||||
|
||||
array_types = {np.ndarray, np.bool_,
|
||||
np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
dtypes.bfloat16, np.float16, np.float32, np.float64,
|
||||
np.complex64, np.complex128,
|
||||
np.longlong, np.intc}
|
||||
numpy_scalar_types = {
|
||||
np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
dtypes.bfloat16, np.float16, np.float32, np.float64,
|
||||
np.complex64, np.complex128,
|
||||
np.bool_, np.longlong, np.intc,
|
||||
}
|
||||
|
||||
array_types = {np.ndarray} | numpy_scalar_types
|
||||
|
||||
def canonical_concrete_aval(val, weak_type=None):
|
||||
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
|
||||
|
@ -32,7 +32,8 @@ from jax import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import (make_shaped_array, array_types)
|
||||
from jax._src.abstract_arrays import (make_shaped_array, array_types,
|
||||
numpy_scalar_types)
|
||||
from jax.core import (ConcreteArray, ShapedArray, str_eqn_compact)
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
|
||||
@ -247,8 +248,12 @@ def canonicalize_dtype(x):
|
||||
return canonicalize_dtype(x.__jax_array__())
|
||||
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
|
||||
|
||||
|
||||
def _canonicalize_ndarray_dtype(x):
|
||||
return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
|
||||
return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
|
||||
|
||||
def _canonicalize_numpy_scalar_dtype(x):
|
||||
return np.asarray(x, dtypes.canonicalize_dtype(np.dtype(x)))
|
||||
|
||||
def _canonicalize_python_scalar_dtype(typ, x):
|
||||
return np.asarray(
|
||||
@ -258,7 +263,8 @@ canonicalize_dtype_handlers: Dict[Any, Callable] = {}
|
||||
for t in device_array.device_array_types:
|
||||
canonicalize_dtype_handlers[t] = identity
|
||||
canonicalize_dtype_handlers.update(
|
||||
(t, _canonicalize_ndarray_dtype) for t in array_types)
|
||||
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
|
||||
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
|
||||
canonicalize_dtype_handlers.update(
|
||||
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
|
||||
canonicalize_dtype_handlers[core.Token] = identity
|
||||
|
Loading…
x
Reference in New Issue
Block a user