mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Merge pull request #9955 from nicholasjng:add-itemsize
PiperOrigin-RevId: 435650867
This commit is contained in:
commit
e9f59aed84
@ -5159,6 +5159,10 @@ def _nbytes(arr):
|
||||
return size(arr) * _dtype(arr).itemsize
|
||||
|
||||
|
||||
def _itemsize(arr):
|
||||
return _dtype(arr).itemsize
|
||||
|
||||
|
||||
def _clip(number, min=None, max=None, out=None, *, a_min=None, a_max=None):
|
||||
# ndarray.clip has a slightly different API from clip (min -> a_min, max -> a_max)
|
||||
# TODO: remove after deprecation window
|
||||
@ -5655,6 +5659,7 @@ def _set_shaped_array_attributes(shaped_array):
|
||||
setattr(shaped_array, "astype", core.aval_method(_astype))
|
||||
setattr(shaped_array, "view", core.aval_method(_view))
|
||||
setattr(shaped_array, "nbytes", core.aval_property(_nbytes))
|
||||
setattr(shaped_array, "itemsize", core.aval_property(_itemsize))
|
||||
setattr(shaped_array, "clip", core.aval_method(_clip))
|
||||
|
||||
setattr(shaped_array, "_array_module", staticmethod(__array_module__))
|
||||
@ -5685,6 +5690,7 @@ def _set_device_array_base_attributes(device_array):
|
||||
setattr(device_array, "astype", _astype)
|
||||
setattr(device_array, "view", _view)
|
||||
setattr(device_array, "nbytes", property(_nbytes))
|
||||
setattr(device_array, "itemsize", property(_itemsize))
|
||||
setattr(device_array, "clip", _clip)
|
||||
|
||||
_set_device_array_base_attributes(device_array.DeviceArray)
|
||||
|
@ -4203,6 +4203,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype)),
|
||||
"shape": shape, "dtype": dtype}
|
||||
for shape in array_shapes
|
||||
for dtype in all_dtypes))
|
||||
def testItemsize(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
np_op = lambda x: np.asarray(x).itemsize
|
||||
jnp_op = lambda x: jnp.asarray(x).itemsize
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_{}_dtype={}".format(
|
||||
jtu.format_shape_dtype_string(shape, a_dtype), dtype),
|
||||
|
Loading…
x
Reference in New Issue
Block a user