mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
remove core.bint
PiperOrigin-RevId: 509932914
This commit is contained in:
parent
9b288e9ab9
commit
537372a637
@ -118,7 +118,6 @@ from jax._src.core import (
|
||||
aval_property as aval_property,
|
||||
axis_frame as axis_frame,
|
||||
axis_substitution_rules as axis_substitution_rules,
|
||||
bint as bint,
|
||||
call as call,
|
||||
call_bind_with_continuation as call_bind_with_continuation,
|
||||
call_impl as call_impl,
|
||||
|
@ -22,14 +22,15 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.interpreters import batching
|
||||
import jax._src.lib
|
||||
from jax._src import test_util as jtu
|
||||
import jax._src.util
|
||||
|
||||
from jax.config import config
|
||||
from jax.interpreters import batching
|
||||
|
||||
import jax._src.lib
|
||||
import jax._src.util
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
@ -1144,10 +1145,10 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
(jnp.ones((256, 10)), jnp.ones( 10))]
|
||||
|
||||
# two different batch sizes *with bints*
|
||||
bs1 = jax.lax.convert_element_type(128, jax.core.bint(128))
|
||||
bs1 = jax.lax.convert_element_type(128, core.bint(128))
|
||||
batch1 = (jnp.ones((bs1, 784)), jnp.ones((bs1, 10)))
|
||||
|
||||
bs2 = jax.lax.convert_element_type(32, jax.core.bint(128))
|
||||
bs2 = jax.lax.convert_element_type(32, core.bint(128))
|
||||
batch2 = (jnp.ones((bs2, 784)), jnp.ones((bs2, 10)))
|
||||
|
||||
# count retraces (and don't crash)
|
||||
@ -1164,7 +1165,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_bint_basic(self):
|
||||
d = lax.convert_element_type(3, jax.core.bint(5))
|
||||
d = lax.convert_element_type(3, core.bint(5))
|
||||
self.assertEqual(str(d), '3{≤5}')
|
||||
|
||||
@jax.jit
|
||||
@ -1174,7 +1175,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
f(d) # doesn't crash
|
||||
|
||||
def test_bint_broadcast(self):
|
||||
d = lax.convert_element_type(3, jax.core.bint(5))
|
||||
d = lax.convert_element_type(3, core.bint(5))
|
||||
bint = lambda x, b: lax.convert_element_type(x, core.bint(b))
|
||||
|
||||
x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash
|
||||
@ -1208,11 +1209,11 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
def f(d):
|
||||
return jnp.arange(d, dtype='int32')
|
||||
|
||||
y = f(lax.convert_element_type(3, jax.core.bint(5)))
|
||||
y = f(lax.convert_element_type(3, core.bint(5)))
|
||||
self.assertIsInstance(y, core.DArray)
|
||||
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
|
||||
|
||||
d = lax.convert_element_type(3, jax.core.bint(5))
|
||||
d = lax.convert_element_type(3, core.bint(5))
|
||||
y = jax.jit(f)(d)
|
||||
self.assertIsInstance(y, core.DArray)
|
||||
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
|
||||
@ -1225,8 +1226,8 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
nonlocal count
|
||||
count += 1
|
||||
return jnp.zeros(n)
|
||||
f(lax.convert_element_type(3, jax.core.bint(5)))
|
||||
f(lax.convert_element_type(4, jax.core.bint(5)))
|
||||
f(lax.convert_element_type(3, core.bint(5)))
|
||||
f(lax.convert_element_type(4, core.bint(5)))
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
def test_bint_compilation_cache2(self):
|
||||
@ -1238,19 +1239,19 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
count += 1
|
||||
return x.sum()
|
||||
|
||||
d = lax.convert_element_type(3, jax.core.bint(5))
|
||||
d = lax.convert_element_type(3, core.bint(5))
|
||||
x = jnp.arange(d)
|
||||
y = f(x)
|
||||
self.assertEqual(y, 3)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
d = lax.convert_element_type(4, jax.core.bint(5))
|
||||
d = lax.convert_element_type(4, core.bint(5))
|
||||
x = jnp.arange(d)
|
||||
y = f(x)
|
||||
self.assertEqual(y, 6)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
d = lax.convert_element_type(4, jax.core.bint(6))
|
||||
d = lax.convert_element_type(4, core.bint(6))
|
||||
x = jnp.arange(d)
|
||||
y = f(x)
|
||||
self.assertEqual(y, 6)
|
||||
@ -1258,7 +1259,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skip('do we want to support this?')
|
||||
def test_bint_add(self):
|
||||
d = lax.convert_element_type(4, jax.core.bint(6))
|
||||
d = lax.convert_element_type(4, core.bint(6))
|
||||
x = jnp.arange(d)
|
||||
|
||||
@jax.jit
|
||||
@ -1418,11 +1419,11 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
def f(i):
|
||||
return x[i]
|
||||
|
||||
sz = jax.lax.convert_element_type(2, jax.core.bint(3))
|
||||
sz = jax.lax.convert_element_type(2, core.bint(3))
|
||||
idx = jnp.arange(sz)
|
||||
y = jax.jit(jax.vmap(f), abstracted_axes=('n',))(idx)
|
||||
|
||||
self.assertIsInstance(y, jax.core.DArray)
|
||||
self.assertIsInstance(y, core.DArray)
|
||||
self.assertEqual(y.shape, (sz, 4))
|
||||
self.assertAllClose(y._data, x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user