deprecate PRNGKeyArray.unsafe_raw_array in favor of jax.random.key_data

The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.

PiperOrigin-RevId: 565195381
This commit is contained in:
Roy Frostig 2023-09-13 16:33:21 -07:00 committed by jax authors
parent cd2d419f6f
commit 1f8cc44f4e
4 changed files with 30 additions and 30 deletions

View File

@ -17,6 +17,7 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.numpy.sometrue`: use `jax.numpy.any`.
* `jax.numpy.product`: use `jax.numpy.prod`.
* `jax.numpy.cumproduct`: use `jax.numpy.cumprod`.
* Internal deprecations/removals:
* The internal submodule `jax.prng` is now deprecated. Its contents are available at
{mod}`jax.extend.random`.
@ -25,6 +26,8 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.random.PRNGKeyArray` and `jax.random.KeyArray` are deprecated. Use {class}`jax.Array`
for type annotations, and `jax.dtypes.issubdtype(arr, jax.dtypes.prng_key)`` for runtime
detection of typed prng keys.
* The method `PRNGKeyArray.unsafe_raw_array` is deprecated. Use
{func}`jax.random.key_data` instead.
## jaxlib 0.4.16

View File

@ -19,6 +19,7 @@ from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, NamedTuple
import warnings
import numpy as np
@ -141,9 +142,6 @@ class PRNGKeyArrayMeta(abc.ABCMeta):
class PRNGKeyArray(jax.Array, metaclass=PRNGKeyArrayMeta):
"""An array whose elements are PRNG keys"""
@abc.abstractmethod # TODO(frostig): rename
def unsafe_raw_array(self) -> PRNGKeyArray: ...
@abc.abstractmethod
def unsafe_buffer_pointer(self) -> int: ...
@ -255,15 +253,6 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
self.impl = impl
self._base_array = key_data
# TODO(frostig): rename to unsafe_base_array, or just offer base_array attr?
def unsafe_raw_array(self):
"""Access the raw numerical array that carries underlying key data.
Returns:
A uint32 JAX array whose leading dimensions are ``self.shape``.
"""
return self._base_array
def block_until_ready(self):
_ = self._base_array.block_until_ready()
return self
@ -302,6 +291,13 @@ class PRNGKeyArrayImpl(PRNGKeyArray):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]
def unsafe_raw_array(self):
# deprecated on 13 Sept 2023
raise warnings.warn(
'The `unsafe_raw_array` method of PRNG key arrays is deprecated. '
'Use `jax.random.key_data` instead.', DeprecationWarning, stacklevel=2)
return self._base_array
def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self.impl, self._base_array.addressable_data(index))
@ -472,7 +468,7 @@ class KeyTyRules:
@staticmethod
def physical_const(val) -> Array:
return val.unsafe_raw_array()
return val._base_array
@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
@ -624,7 +620,7 @@ xla.canonicalize_dtype_handlers[PRNGKeyArrayImpl] = lambda x: x
def key_array_shard_arg_handler(x: PRNGKeyArrayImpl, devices, indices, sharding):
aval = x.aval
key_shape = aval.dtype.impl.key_shape
arr = x.unsafe_raw_array()
arr = x._base_array
# TODO(yashkatariya,frostig): This assumes that the last dimensions are not
# sharded. This is only true when enable_custom_prng is True.
@ -641,7 +637,7 @@ pxla.shard_arg_handlers[PRNGKeyArrayImpl] = key_array_shard_arg_handler
def key_array_constant_handler(x):
arr = x.unsafe_raw_array()
arr = x._base_array
return mlir.get_constant_handler(type(arr))(arr)
mlir.register_constant_handler(PRNGKeyArrayImpl, key_array_constant_handler)
@ -740,7 +736,7 @@ def random_split_abstract_eval(keys_aval, *, shape):
@random_split_p.def_impl
def random_split_impl(keys, *, shape):
base_arr = random_split_impl_base(
keys.impl, keys.unsafe_raw_array(), keys.ndim, shape=shape)
keys.impl, keys._base_array, keys.ndim, shape=shape)
return PRNGKeyArrayImpl(keys.impl, base_arr)
def random_split_impl_base(impl, base_arr, keys_ndim, *, shape):
@ -777,7 +773,7 @@ def random_fold_in_abstract_eval(keys_aval, msgs_aval):
@random_fold_in_p.def_impl
def random_fold_in_impl(keys, msgs):
base_arr = random_fold_in_impl_base(
keys.impl, keys.unsafe_raw_array(), msgs, keys.shape)
keys.impl, keys._base_array, msgs, keys.shape)
return PRNGKeyArrayImpl(keys.impl, base_arr)
def random_fold_in_impl_base(impl, base_arr, msgs, keys_shape):
@ -826,7 +822,7 @@ def random_bits_abstract_eval(keys_aval, *, bit_width, shape):
@random_bits_p.def_impl
def random_bits_impl(keys, *, bit_width, shape):
return random_bits_impl_base(keys.impl, keys.unsafe_raw_array(), keys.ndim,
return random_bits_impl_base(keys.impl, keys._base_array, keys.ndim,
bit_width=bit_width, shape=shape)
def random_bits_impl_base(impl, base_arr, keys_ndim, *, bit_width, shape):
@ -912,7 +908,7 @@ def random_unwrap_abstract_eval(keys_aval):
@random_unwrap_p.def_impl
def random_unwrap_impl(keys):
return keys.unsafe_raw_array()
return keys._base_array
def random_unwrap_lowering(ctx, keys):
return [keys]

View File

@ -1210,7 +1210,7 @@ class PJitTest(jtu.BufferDonationTestCase):
out = f(seeds)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
jax.random.key_data(out) # doesn't crash
def test_with_sharding_constraint_is_compatible_error(self):
mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl'))
@ -1843,7 +1843,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = make_keys(seeds)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
jax.random.key_data(out) # doesn't crash
def test_globally_sharded_key_array_8x4_multi_device_with_out_sharding(self):
input_shape = (8, 4)
@ -1860,7 +1860,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = make_keys(seeds)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
jax.random.key_data(out) # doesn't crash
def test_globally_sharded_key_array_8x4_multi_device(self):
input_shape = (8, 4)
@ -1877,7 +1877,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = make_keys(seeds)
self.assertTrue(jax.dtypes.issubdtype(out.dtype, jax.dtypes.prng_key))
self.assertEqual(out.shape, input_shape)
out.unsafe_raw_array() # doesn't crash
jax.random.key_data(out) # doesn't crash
def test_array_device_assignment_mismatch_out_shardings(self):
input_shape = (8, 2)

View File

@ -55,7 +55,7 @@ uint_dtypes = jtu.dtypes.all_unsigned
def _prng_key_as_array(key):
# TODO(frostig): remove some day when we deprecate "raw" key arrays
if jnp.issubdtype(key.dtype, dtypes.prng_key):
return key.unsafe_raw_array()
return random.key_data(key)
else:
return key
@ -1788,7 +1788,7 @@ class KeyArrayTest(jtu.JaxTestCase):
def test_key_dtype_attributes(self):
key = self.make_keys()
key_raw = key.unsafe_raw_array()
key_raw = random.key_data(key)
self.assertStartsWith(key.dtype.name, "key")
self.assertEqual(key.size * key.dtype.itemsize,
@ -2251,8 +2251,8 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
vmapped_keys = vmap(random.split)(mapped_keys)
self.assertEqual(vmapped_keys.shape, (3, 2))
for fk, vk in zip(forloop_keys, vmapped_keys):
self.assertArraysEqual(fk.unsafe_raw_array(),
vk.unsafe_raw_array())
self.assertArraysEqual(random.key_data(fk),
random.key_data(vk))
def test_cannot_add(self):
key = self.make_key(73)
@ -2379,16 +2379,17 @@ class JnpWithKeyArrayTest(jtu.JaxTestCase):
self.assertEqual(out_key.shape, out_like_key.shape)
def check_against_reference(self, key_func, arr_func, *key_args):
out_arr = arr_func(*tree_util.tree_map(lambda x: x.unsafe_raw_array(), key_args))
out_arr = arr_func(*tree_util.tree_map(lambda x: random.key_data(x),
key_args))
self.assertIsInstance(out_arr, jax.Array)
out_key = key_func(*key_args)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
self.assertArraysEqual(random.key_data(out_key), out_arr)
out_key = jax.jit(key_func)(*key_args)
self.assertIsInstance(out_key, jax_random.PRNGKeyArray)
self.assertArraysEqual(out_key.unsafe_raw_array(), out_arr)
self.assertArraysEqual(random.key_data(out_key), out_arr)
@parameterized.parameters([
[(2, 3), 'shape', (2, 3)],