mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
deprecate PRNGKeyArray.unsafe_raw_array
in favor of jax.random.key_data
The latter function is also better in that its behavior is invariant to `jit`, whereas the `unsafe_raw_array` method only works in eager mode. PiperOrigin-RevId: 565195381
This commit is contained in:
parent
cd2d419f6f
commit
1f8cc44f4e
@ -17,6 +17,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `jax.numpy.sometrue`: use `jax.numpy.any`.
|
||||
* `jax.numpy.product`: use `jax.numpy.prod`.
|
||||
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.
|
||||
|
||||
* Internal deprecations/removals:
|
||||
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
|
||||
{mod}`jax.extend.random`.
|
||||
@ -25,6 +26,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
|
||||
for type annotations, and `jax.dtypes.issubdtype(arr, jax.dtypes.prng_key)`` for runtime
|
||||
detection of typed prng keys.
|
||||
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
|
||||
{func}`jax.random.key_data` instead.
|
||||
|
||||
## jaxlib 0.4.16
|
||||
|
||||
|
@ -19,6 +19,7 @@ from functools import partial, reduce
|
||||
import math
|
||||
import operator as op
|
||||
from typing import Any, Callable, NamedTuple
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -141,9 +142,6 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
|
||||
class PRNGKeyArray(jax.Array, metaclass=PRNGKeyArrayMeta):
|
||||
"""An array whose elements are PRNG keys"""
|
||||
|
||||
@abc.abstractmethod # TODO(frostig): rename
|
||||
def unsafe_raw_array(self) -> PRNGKeyArray: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def unsafe_buffer_pointer(self) -> int: ...
|
||||
|
||||
@ -255,15 +253,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
self.impl = impl
|
||||
self._base_array = key_data
|
||||
|
||||
# TODO(frostig): rename to unsafe_base_array, or just offer base_array attr?
|
||||
def unsafe_raw_array(self):
|
||||
"""Access the raw numerical array that carries underlying key data.
|
||||
|
||||
Returns:
|
||||
A uint32 JAX array whose leading dimensions are ``self.shape``.
|
||||
"""
|
||||
return self._base_array
|
||||
|
||||
def block_until_ready(self):
|
||||
_ = self._base_array.block_until_ready()
|
||||
return self
|
||||
@ -302,6 +291,13 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
|
||||
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
|
||||
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
|
||||
|
||||
def unsafe_raw_array(self):
|
||||
# deprecated on 13 Sept 2023
|
||||
raise warnings.warn(
|
||||
'The `unsafe_raw_array` method of PRNG key arrays is deprecated. '
|
||||
'Use `jax.random.key_data` instead.', DeprecationWarning, stacklevel=2)
|
||||
return self._base_array
|
||||
|
||||
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
|
||||
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))
|
||||
|
||||
@ -472,7 +468,7 @@ class KeyTyRules:
|
||||
|
||||
@staticmethod
|
||||
def physical_const(val) -> Array:
|
||||
return val.unsafe_raw_array()
|
||||
return val._base_array
|
||||
|
||||
@staticmethod
|
||||
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
|
||||
@ -624,7 +620,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
|
||||
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
|
||||
aval = x.aval
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
arr = x.unsafe_raw_array()
|
||||
arr = x._base_array
|
||||
|
||||
# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
|
||||
# sharded. This is only true when enable_custom_prng is True.
|
||||
@ -641,7 +637,7 @@ pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
|
||||
|
||||
|
||||
def key_array_constant_handler(x):
|
||||
arr = x.unsafe_raw_array()
|
||||
arr = x._base_array
|
||||
return mlir.get_constant_handler(type(arr))(arr)
|
||||
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
|
||||
|
||||
@ -740,7 +736,7 @@ def random_split_abstract_eval(keys_aval, *, shape):
|
||||
@random_split_p.def_impl
|
||||
def random_split_impl(keys, *, shape):
|
||||
base_arr = random_split_impl_base(
|
||||
keys.impl, keys.unsafe_raw_array(), keys.ndim, shape=shape)
|
||||
keys.impl, keys._base_array, keys.ndim, shape=shape)
|
||||
return PRNGKeyArrayImpl(keys.impl, base_arr)
|
||||
|
||||
def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
|
||||
@ -777,7 +773,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
|
||||
@random_fold_in_p.def_impl
|
||||
def random_fold_in_impl(keys, msgs):
|
||||
base_arr = random_fold_in_impl_base(
|
||||
keys.impl, keys.unsafe_raw_array(), msgs, keys.shape)
|
||||
keys.impl, keys._base_array, msgs, keys.shape)
|
||||
return PRNGKeyArrayImpl(keys.impl, base_arr)
|
||||
|
||||
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
|
||||
@ -826,7 +822,7 @@ def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
|
||||
|
||||
@random_bits_p.def_impl
|
||||
def random_bits_impl(keys, *, bit_width, shape):
|
||||
return random_bits_impl_base(keys.impl, keys.unsafe_raw_array(), keys.ndim,
|
||||
return random_bits_impl_base(keys.impl, keys._base_array, keys.ndim,
|
||||
bit_width=bit_width, shape=shape)
|
||||
|
||||
def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):
|
||||
@ -912,7 +908,7 @@ def random_unwrap_abstract_eval(keys_aval):
|
||||
|
||||
@random_unwrap_p.def_impl
|
||||
def random_unwrap_impl(keys):
|
||||
return keys.unsafe_raw_array()
|
||||
return keys._base_array
|
||||
|
||||
def random_unwrap_lowering(ctx, keys):
|
||||
return [keys]
|
||||
|
@ -1210,7 +1210,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
out = f(seeds)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
jax.random.key_data(out) # doesn't crash
|
||||
|
||||
def test_with_sharding_constraint_is_compatible_error(self):
|
||||
mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
|
||||
@ -1843,7 +1843,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = make_keys(seeds)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
jax.random.key_data(out) # doesn't crash
|
||||
|
||||
def test_globally_sharded_key_array_8x4_multi_device_with_out_sharding(self):
|
||||
input_shape = (8, 4)
|
||||
@ -1860,7 +1860,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = make_keys(seeds)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
jax.random.key_data(out) # doesn't crash
|
||||
|
||||
def test_globally_sharded_key_array_8x4_multi_device(self):
|
||||
input_shape = (8, 4)
|
||||
@ -1877,7 +1877,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
out = make_keys(seeds)
|
||||
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
|
||||
self.assertEqual(out.shape, input_shape)
|
||||
out.unsafe_raw_array() # doesn't crash
|
||||
jax.random.key_data(out) # doesn't crash
|
||||
|
||||
def test_array_device_assignment_mismatch_out_shardings(self):
|
||||
input_shape = (8, 2)
|
||||
|
@ -55,7 +55,7 @@ uint_dtypes = jtu.dtypes.all_unsigned
|
||||
def _prng_key_as_array(key):
|
||||
# TODO(frostig): remove some day when we deprecate "raw" key arrays
|
||||
if jnp.issubdtype(key.dtype, dtypes.prng_key):
|
||||
return key.unsafe_raw_array()
|
||||
return random.key_data(key)
|
||||
else:
|
||||
return key
|
||||
|
||||
@ -1788,7 +1788,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_key_dtype_attributes(self):
|
||||
key = self.make_keys()
|
||||
key_raw = key.unsafe_raw_array()
|
||||
key_raw = random.key_data(key)
|
||||
|
||||
self.assertStartsWith(key.dtype.name, "key")
|
||||
self.assertEqual(key.size * key.dtype.itemsize,
|
||||
@ -2251,8 +2251,8 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
|
||||
vmapped_keys = vmap(random.split)(mapped_keys)
|
||||
self.assertEqual(vmapped_keys.shape, (3, 2))
|
||||
for fk, vk in zip(forloop_keys, vmapped_keys):
|
||||
self.assertArraysEqual(fk.unsafe_raw_array(),
|
||||
vk.unsafe_raw_array())
|
||||
self.assertArraysEqual(random.key_data(fk),
|
||||
random.key_data(vk))
|
||||
|
||||
def test_cannot_add(self):
|
||||
key = self.make_key(73)
|
||||
@ -2379,16 +2379,17 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_key.shape, out_like_key.shape)
|
||||
|
||||
def check_against_reference(self, key_func, arr_func, *key_args):
|
||||
out_arr = arr_func(*tree_util.tree_map(lambda x: x.unsafe_raw_array(), key_args))
|
||||
out_arr = arr_func(*tree_util.tree_map(lambda x: random.key_data(x),
|
||||
key_args))
|
||||
self.assertIsInstance(out_arr, jax.Array)
|
||||
|
||||
out_key = key_func(*key_args)
|
||||
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
|
||||
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
|
||||
self.assertArraysEqual(random.key_data(out_key), out_arr)
|
||||
|
||||
out_key = jax.jit(key_func)(*key_args)
|
||||
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
|
||||
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
|
||||
self.assertArraysEqual(random.key_data(out_key), out_arr)
|
||||
|
||||
@parameterized.parameters([
|
||||
[(2, 3), 'shape', (2, 3)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user