Remove redundant dtype canonicalization from jax.device_put().

Gives a small improvement to the included jax.device_put() benchmark on my VM:

```
name        old cpu/op  new cpu/op  delta
device_put  91.3µs ± 5%  80.1µs ± 3%  -12.29%  (p=0.008 n=5+5)

name        old time/op             new time/op             delta
device_put  91.4µs ± 5%             80.1µs ± 3%  -12.29%          (p=0.008 n=5+5)
```

jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement.

PiperOrigin-RevId: 491727173
This commit is contained in:
Peter Hawkins 2022-11-29 13:46:51 -08:00 committed by jax authors
parent 7495a9e370
commit d6c67c97db
3 changed files with 23 additions and 9 deletions

View File

@ -830,6 +830,11 @@ def host_local_array_to_global_array(state):
multihost_utils.host_local_array_to_global_array(
(input_data, input_data), global_mesh, (in_pspec, in_pspec))
@google_benchmark.register
def device_put(state):
x = np.array(1, np.int32)
while state:
_ = jax.device_put(x).block_until_ready()
if __name__ == "__main__":
google_benchmark.main()

View File

@ -42,12 +42,15 @@ def zeros_like_array(x):
aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
return ad_util.zeros_like_aval(aval)
array_types = {np.ndarray, np.bool_,
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
dtypes.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128,
np.longlong, np.intc}
numpy_scalar_types = {
np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
dtypes.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128,
np.bool_, np.longlong, np.intc,
}
array_types = {np.ndarray} | numpy_scalar_types
def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,

View File

@ -32,7 +32,8 @@ from jax import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import source_info_util
from jax._src.abstract_arrays import (make_shaped_array, array_types)
from jax._src.abstract_arrays import (make_shaped_array, array_types,
numpy_scalar_types)
from jax.core import (ConcreteArray, ShapedArray, str_eqn_compact)
import jax._src.pretty_printer as pp
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
@ -247,8 +248,12 @@ def canonicalize_dtype(x):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
def _canonicalize_numpy_scalar_dtype(x):
return np.asarray(x, dtypes.canonicalize_dtype(np.dtype(x)))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(
@ -258,7 +263,8 @@ canonicalize_dtype_handlers: Dict[Any, Callable] = {}
for t in device_array.device_array_types:
canonicalize_dtype_handlers[t] = identity
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in array_types)
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = identity