Merge pull request #9955 from nicholasjng:add-itemsize

PiperOrigin-RevId: 435650867
This commit is contained in:
jax authors 2022-03-18 08:55:19 -07:00
commit e9f59aed84
2 changed files with 20 additions and 0 deletions

View File

@ -5159,6 +5159,10 @@ def _nbytes(arr):
return size(arr) * _dtype(arr).itemsize
def _itemsize(arr):
return _dtype(arr).itemsize
def _clip(number, min=None, max=None, out=None, *, a_min=None, a_max=None):
# ndarray.clip has a slightly different API from clip (min -> a_min, max -> a_max)
# TODO: remove after deprecation window
@ -5655,6 +5659,7 @@ def _set_shaped_array_attributes(shaped_array):
setattr(shaped_array, "astype", core.aval_method(_astype))
setattr(shaped_array, "view", core.aval_method(_view))
setattr(shaped_array, "nbytes", core.aval_property(_nbytes))
setattr(shaped_array, "itemsize", core.aval_property(_itemsize))
setattr(shaped_array, "clip", core.aval_method(_clip))
setattr(shaped_array, "_array_module", staticmethod(__array_module__))
@ -5685,6 +5690,7 @@ def _set_device_array_base_attributes(device_array):
setattr(device_array, "astype", _astype)
setattr(device_array, "view", _view)
setattr(device_array, "nbytes", property(_nbytes))
setattr(device_array, "itemsize", property(_itemsize))
setattr(device_array, "clip", _clip)
_set_device_array_base_attributes(device_array.DeviceArray)

View File

@ -4203,6 +4203,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype}
for shape in array_shapes
for dtype in all_dtypes))
def testItemsize(self, shape, dtype):
rng = jtu.rand_default(self.rng())
np_op = lambda x: np.asarray(x).itemsize
jnp_op = lambda x: jnp.asarray(x).itemsize
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_{}_dtype={}".format(
jtu.format_shape_dtype_string(shape, a_dtype), dtype),