mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
171 lines
6.2 KiB
Python
171 lines
6.2 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
from absl.testing import absltest, parameterized
|
|
|
|
import jax
|
|
from jax import lax
|
|
import jax.extend as jex
|
|
import jax.numpy as jnp
|
|
|
|
from jax._src import abstract_arrays
|
|
from jax._src import api
|
|
from jax._src import core
|
|
from jax._src import linear_util
|
|
from jax._src import prng
|
|
from jax._src import test_util as jtu
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.extend import ffi
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
class ExtendTest(jtu.JaxTestCase):
|
|
|
|
def test_symbols(self):
|
|
# Assume these are tested in random_test.py, only check equivalence
|
|
self.assertIs(jex.random.seed_with_impl, prng.seed_with_impl)
|
|
self.assertIs(jex.random.threefry2x32_p, prng.threefry2x32_p)
|
|
self.assertIs(jex.random.threefry_2x32, prng.threefry_2x32)
|
|
self.assertIs(jex.random.threefry_prng_impl, prng.threefry_prng_impl)
|
|
self.assertIs(jex.random.rbg_prng_impl, prng.rbg_prng_impl)
|
|
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
|
|
|
|
# Assume these are tested elsewhere, only check equivalence
|
|
self.assertIs(jex.backend.clear_backends, api.clear_backends)
|
|
self.assertIs(jex.core.array_types, abstract_arrays.array_types)
|
|
self.assertIs(jex.linear_util.StoreException, linear_util.StoreException)
|
|
self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun)
|
|
self.assertIs(jex.linear_util.cache, linear_util.cache)
|
|
self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux)
|
|
self.assertIs(jex.linear_util.transformation, linear_util.transformation)
|
|
self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux)
|
|
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)
|
|
|
|
|
|
class RandomTest(jtu.JaxTestCase):
|
|
|
|
def test_key_make_with_custom_impl(self):
|
|
shape = (4, 2, 7)
|
|
|
|
def seed_rule(_):
|
|
return jnp.ones(shape, dtype=jnp.dtype('uint32'))
|
|
|
|
def no_rule(*args, **kwargs):
|
|
assert False, 'unreachable'
|
|
|
|
impl = jex.random.define_prng_impl(
|
|
key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule,
|
|
random_bits=no_rule)
|
|
k = jax.random.key(42, impl=impl)
|
|
self.assertEqual(k.shape, ())
|
|
self.assertEqual(impl, jax.random.key_impl(k))
|
|
|
|
def test_key_wrap_with_custom_impl(self):
|
|
def no_rule(*args, **kwargs):
|
|
assert False, 'unreachable'
|
|
|
|
shape = (4, 2, 7)
|
|
impl = jex.random.define_prng_impl(
|
|
key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule,
|
|
random_bits=no_rule)
|
|
data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32'))
|
|
k = jax.random.wrap_key_data(data, impl=impl)
|
|
self.assertEqual(k.shape, (3,))
|
|
self.assertEqual(impl, jax.random.key_impl(k))
|
|
|
|
|
|
class FfiTest(jtu.JaxTestCase):
|
|
|
|
def testHeadersExist(self):
|
|
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
|
|
for header in ["c_api.h", "api.h", "ffi.h"]:
|
|
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
|
|
|
|
@parameterized.parameters(
|
|
[True, int(1), float(5.0),
|
|
np.int32(-5), np.float32(0.5)])
|
|
def testIrAttribute(self, value):
|
|
with mlir.make_ir_context(), ir.Location.unknown():
|
|
const = mlir.ir_constant(value)
|
|
attr = ffi._ir_attribute(value)
|
|
assert const.type.element_type == attr.type
|
|
|
|
@parameterized.parameters([True, 1, 5.0, "param", np.float32(0.5)])
|
|
def testParams(self, param):
|
|
prim = core.Primitive("test_ffi")
|
|
prim.def_abstract_eval(lambda *args, **kwargs: args[0])
|
|
mlir.register_lowering(prim, jex.ffi.ffi_lowering("test_ffi"))
|
|
|
|
# TODO(dfm): Currently testing that lowering works with different types of
|
|
# parameters, but we should probably actually check the emitted HLO.
|
|
func = jax.jit(lambda *args: prim.bind(*args, param=param))
|
|
func.lower(jnp.linspace(0, 5, 10))
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1,), (4,), (5,)],
|
|
dtype=(np.int32,),
|
|
)
|
|
@jtu.run_on_devices("gpu")
|
|
def testFfiCall(self, shape, dtype):
|
|
pivots_size = shape[-1]
|
|
permutation_size = 2 * pivots_size
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
|
actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size)
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
@jtu.sample_product(
|
|
shape=[(1,), (4,), (5,)],
|
|
dtype=(np.int32,),
|
|
vectorized=(False, True),
|
|
)
|
|
@jtu.run_on_devices("gpu")
|
|
def testFfiCallBatching(self, shape, dtype, vectorized):
|
|
shape = (10,) + shape
|
|
pivots_size = shape[-1]
|
|
permutation_size = 2 * pivots_size
|
|
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
|
pivots = jnp.broadcast_to(pivots, shape)
|
|
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
|
actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation(
|
|
x, permutation_size, vectorized=vectorized))(pivots)
|
|
self.assertArraysEqual(actual, expected)
|
|
|
|
|
|
# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
|
|
# custom call target because that's the only one in jaxlib that uses the
|
|
# new FFI interface. Once more are available, consider using something that
|
|
# can be run on multiple platforms.
|
|
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True):
|
|
return jex.ffi.ffi_call(
|
|
"cu_lu_pivots_to_permutation",
|
|
jax.ShapeDtypeStruct(
|
|
shape=pivots.shape[:-1] + (permutation_size,),
|
|
dtype=pivots.dtype,
|
|
),
|
|
pivots,
|
|
permutation_size=np.int32(permutation_size),
|
|
vectorized=vectorized,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|