remove core.bint

PiperOrigin-RevId: 509932914
This commit is contained in:
Roy Frostig 2023-02-15 14:27:55 -08:00 committed by jax authors
parent 9b288e9ab9
commit 537372a637
2 changed files with 21 additions and 21 deletions

View File

@ -118,7 +118,6 @@ from jax._src.core import (
aval_property as aval_property,
axis_frame as axis_frame,
axis_substitution_rules as axis_substitution_rules,
bint as bint,
call as call,
call_bind_with_continuation as call_bind_with_continuation,
call_impl as call_impl,

View File

@ -22,14 +22,15 @@ from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import core
from jax import lax
from jax.interpreters import batching
import jax._src.lib
from jax._src import test_util as jtu
import jax._src.util
from jax.config import config
from jax.interpreters import batching
import jax._src.lib
import jax._src.util
from jax._src import core
from jax._src import test_util as jtu
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@ -1144,10 +1145,10 @@ class DynamicShapeTest(jtu.JaxTestCase):
(jnp.ones((256, 10)), jnp.ones( 10))]
# two different batch sizes *with bints*
bs1 = jax.lax.convert_element_type(128, jax.core.bint(128))
bs1 = jax.lax.convert_element_type(128, core.bint(128))
batch1 = (jnp.ones((bs1, 784)), jnp.ones((bs1, 10)))
bs2 = jax.lax.convert_element_type(32, jax.core.bint(128))
bs2 = jax.lax.convert_element_type(32, core.bint(128))
batch2 = (jnp.ones((bs2, 784)), jnp.ones((bs2, 10)))
# count retraces (and don't crash)
@ -1164,7 +1165,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
atol=1e-3, rtol=1e-3)
def test_bint_basic(self):
d = lax.convert_element_type(3, jax.core.bint(5))
d = lax.convert_element_type(3, core.bint(5))
self.assertEqual(str(d), '3{≤5}')
@jax.jit
@ -1174,7 +1175,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
f(d) # doesn't crash
def test_bint_broadcast(self):
d = lax.convert_element_type(3, jax.core.bint(5))
d = lax.convert_element_type(3, core.bint(5))
bint = lambda x, b: lax.convert_element_type(x, core.bint(b))
x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash
@ -1208,11 +1209,11 @@ class DynamicShapeTest(jtu.JaxTestCase):
def f(d):
return jnp.arange(d, dtype='int32')
y = f(lax.convert_element_type(3, jax.core.bint(5)))
y = f(lax.convert_element_type(3, core.bint(5)))
self.assertIsInstance(y, core.DArray)
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
d = lax.convert_element_type(3, jax.core.bint(5))
d = lax.convert_element_type(3, core.bint(5))
y = jax.jit(f)(d)
self.assertIsInstance(y, core.DArray)
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
@ -1225,8 +1226,8 @@ class DynamicShapeTest(jtu.JaxTestCase):
nonlocal count
count += 1
return jnp.zeros(n)
f(lax.convert_element_type(3, jax.core.bint(5)))
f(lax.convert_element_type(4, jax.core.bint(5)))
f(lax.convert_element_type(3, core.bint(5)))
f(lax.convert_element_type(4, core.bint(5)))
self.assertEqual(count, 1)
def test_bint_compilation_cache2(self):
@ -1238,19 +1239,19 @@ class DynamicShapeTest(jtu.JaxTestCase):
count += 1
return x.sum()
d = lax.convert_element_type(3, jax.core.bint(5))
d = lax.convert_element_type(3, core.bint(5))
x = jnp.arange(d)
y = f(x)
self.assertEqual(y, 3)
self.assertEqual(count, 1)
d = lax.convert_element_type(4, jax.core.bint(5))
d = lax.convert_element_type(4, core.bint(5))
x = jnp.arange(d)
y = f(x)
self.assertEqual(y, 6)
self.assertEqual(count, 1)
d = lax.convert_element_type(4, jax.core.bint(6))
d = lax.convert_element_type(4, core.bint(6))
x = jnp.arange(d)
y = f(x)
self.assertEqual(y, 6)
@ -1258,7 +1259,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
@unittest.skip('do we want to support this?')
def test_bint_add(self):
d = lax.convert_element_type(4, jax.core.bint(6))
d = lax.convert_element_type(4, core.bint(6))
x = jnp.arange(d)
@jax.jit
@ -1418,11 +1419,11 @@ class DynamicShapeTest(jtu.JaxTestCase):
def f(i):
return x[i]
sz = jax.lax.convert_element_type(2, jax.core.bint(3))
sz = jax.lax.convert_element_type(2, core.bint(3))
idx = jnp.arange(sz)
y = jax.jit(jax.vmap(f), abstracted_axes=('n',))(idx)
self.assertIsInstance(y, jax.core.DArray)
self.assertIsInstance(y, core.DArray)
self.assertEqual(y.shape, (sz, 4))
self.assertAllClose(y._data, x)