PRNGKeyArrayImpl: add aval property

This makes it more readily compatible with jax.numpy routines.
This commit is contained in:
Jake VanderPlas 2023-04-21 16:44:32 -07:00
parent 035f585e43
commit e50138608a
2 changed files with 8 additions and 8 deletions

View File

@ -22,7 +22,7 @@ import jax
import weakref
from jax._src import core
from jax._src import linear_util as lu
from jax import config
from jax import config # type: ignore[no-redef]
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
tree_map, tree_flatten_with_path, keystr)

View File

@ -222,6 +222,10 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
_ = self._base_array.block_until_ready()
return self
@property
def aval(self):
return keys_shaped_array(self.impl, self.shape)
@property
def shape(self):
return base_arr_shape_to_keys_shape(self.impl, self._base_array.shape)
@ -239,9 +243,8 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
@property
def sharding(self):
aval = keys_shaped_array(self.impl, self.shape)
phys_sharding = self._base_array.sharding
return KeyTyRules.logical_op_sharding(aval, phys_sharding)
return KeyTyRules.logical_op_sharding(self.aval, phys_sharding)
def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
@ -544,11 +547,8 @@ class KeyTy:
core.opaque_dtypes.add(KeyTy)
core.pytype_aval_mappings[PRNGKeyArrayImpl] = (
lambda x: keys_shaped_array(x.impl, x.shape))
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = (
lambda x: keys_shaped_array(x.impl, x.shape))
core.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArrayImpl] = lambda x: x.aval
xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x