From d6c67c97db118ea7d65bd3109a04357b47cfcba5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 Nov 2022 13:46:51 -0800 Subject: [PATCH] Remove redundant dtype canonicalization from jax.device_put(). MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Gives a small improvement to the included jax.device_put() benchmark on my VM: ``` name old cpu/op new cpu/op delta device_put 91.3µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5) name old time/op new time/op delta device_put 91.4µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5) ``` jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement. PiperOrigin-RevId: 491727173 --- benchmarks/api_benchmark.py | 5 +++++ jax/_src/abstract_arrays.py | 15 +++++++++------ jax/interpreters/xla.py | 12 +++++++++--- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index 18970eccc..768cb4b1f 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -830,6 +830,11 @@ def host_local_array_to_global_array(state): multihost_utils.host_local_array_to_global_array( (input_data, input_data), global_mesh, (in_pspec, in_pspec)) +@google_benchmark.register +def device_put(state): + x = np.array(1, np.int32) + while state: + _ = jax.device_put(x).block_until_ready() if __name__ == "__main__": google_benchmark.main() diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 15d7049f1..e38d58e36 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -42,12 +42,15 @@ def zeros_like_array(x): aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type) return ad_util.zeros_like_aval(aval) -array_types = {np.ndarray, np.bool_, - np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64, - dtypes.bfloat16, np.float16, np.float32, np.float64, - np.complex64, np.complex128, - np.longlong, np.intc} +numpy_scalar_types = { + np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + dtypes.bfloat16, np.float16, np.float32, np.float64, + np.complex64, np.complex128, + np.bool_, np.longlong, np.intc, +} + +array_types = {np.ndarray} | numpy_scalar_types def canonical_concrete_aval(val, weak_type=None): return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val, diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index aefc882a0..b4c25074d 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -32,7 +32,8 @@ from jax import core from jax._src import device_array from jax._src import dtypes from jax._src import source_info_util -from jax._src.abstract_arrays import (make_shaped_array, array_types) +from jax._src.abstract_arrays import (make_shaped_array, array_types, + numpy_scalar_types) from jax.core import (ConcreteArray, ShapedArray, str_eqn_compact) import jax._src.pretty_printer as pp from jax._src.util import (prod, new_name_stack, safe_zip, safe_map, @@ -247,8 +248,12 @@ def canonicalize_dtype(x): return canonicalize_dtype(x.__jax_array__()) raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}") + def _canonicalize_ndarray_dtype(x): - return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x))) + return np.asarray(x, dtypes.canonicalize_dtype(x.dtype)) + +def _canonicalize_numpy_scalar_dtype(x): + return np.asarray(x, dtypes.canonicalize_dtype(np.dtype(x))) def _canonicalize_python_scalar_dtype(typ, x): return np.asarray( @@ -258,7 +263,8 @@ canonicalize_dtype_handlers: Dict[Any, Callable] = {} for t in device_array.device_array_types: canonicalize_dtype_handlers[t] = identity canonicalize_dtype_handlers.update( - (t, _canonicalize_ndarray_dtype) for t in array_types) + (t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types) +canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype canonicalize_dtype_handlers.update( (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types) canonicalize_dtype_handlers[core.Token] = identity