mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finish supporing bints and DArrays where the pile tests touch them.
This commit is contained in:
parent
e04db23651
commit
8774352f20
@ -143,7 +143,11 @@ if dtypes.int4 is not None:
|
||||
})
|
||||
|
||||
|
||||
def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
||||
def dtype_to_ir_type(dtype: Union[core.bint, np.dtype, np.generic]) -> ir.Type:
|
||||
if isinstance(dtype, core.bint):
|
||||
# TODO Support different-size underlying dtypes to take advantage of the
|
||||
# bound for packing?
|
||||
dtype = np.dtype(np.int32)
|
||||
assert isinstance(dtype, (np.dtype, np.generic)), type(dtype)
|
||||
dtype = np.dtype(dtype)
|
||||
try:
|
||||
|
@ -162,6 +162,10 @@ def _shard_array(x, devices, indices, sharding):
|
||||
for _t in array_types:
|
||||
shard_arg_handlers[_t] = _shard_array
|
||||
|
||||
def _shard_darray(x, devices, indices, sharding):
|
||||
return shard_arg(x._data, devices, indices, sharding)
|
||||
shard_arg_handlers[core.DArray] = _shard_darray
|
||||
|
||||
def shard_device_array(x, devices, indices, sharding):
|
||||
start_indices, limit_indices, removed_dims = unzip3(
|
||||
as_slice_indices(x, idx) for idx in indices)
|
||||
|
Loading…
x
Reference in New Issue
Block a user