mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
PRNGKeyArrayImpl: add aval property
This makes it more readily compatible with jax.numpy routines.
This commit is contained in:
parent
035f585e43
commit
e50138608a
@ -22,7 +22,7 @@ import jax
|
||||
import weakref
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax import config
|
||||
from jax import config # type: ignore[no-redef]
|
||||
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
tree_map, tree_flatten_with_path, keystr)
|
||||
|
@ -222,6 +222,10 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
_ = self._base_array.block_until_ready()
|
||||
return self
|
||||
|
||||
@property
|
||||
def aval(self):
|
||||
return keys_shaped_array(self.impl, self.shape)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape)
|
||||
@ -239,9 +243,8 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
aval = keys_shaped_array(self.impl, self.shape)
|
||||
phys_sharding = self._base_array.sharding
|
||||
return KeyTyRules.logical_op_sharding(aval, phys_sharding)
|
||||
return KeyTyRules.logical_op_sharding(self.aval, phys_sharding)
|
||||
|
||||
def _is_scalar(self):
|
||||
base_ndim = len(self.impl.key_shape)
|
||||
@ -544,11 +547,8 @@ class KeyTy:
|
||||
core.opaque_dtypes.add(KeyTy)
|
||||
|
||||
|
||||
core.pytype_aval_mappings[PRNGKeyArrayImpl] = (
|
||||
lambda x: keys_shaped_array(x.impl, x.shape))
|
||||
|
||||
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = (
|
||||
lambda x: keys_shaped_array(x.impl, x.shape))
|
||||
core.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
|
||||
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
|
||||
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user