Merge pull request #19403 from jakevdp:jnp-bool

PiperOrigin-RevId: 599312340
This commit is contained in:
jax authors 2024-01-17 15:48:46 -08:00
commit 83ad09c4e6
3 changed files with 6 additions and 1 deletions

View File

@ -53,6 +53,7 @@ from jax._src.numpy.lax_numpy import (
bincount as bincount,
blackman as blackman,
block as block,
bool_ as bool, # Array API alias for bool_
bool_ as bool_,
broadcast_arrays as broadcast_arrays,
broadcast_shapes as broadcast_shapes,

View File

@ -156,6 +156,7 @@ def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def blackman(M: int) -> Array: ...
def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ...
bool: Any
bool_: Any
def broadcast_arrays(*args: ArrayLike) -> list[Array]: ...

View File

@ -175,7 +175,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
return f
@parameterized.parameters(
[dtype for dtype in [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32,
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
jnp.complex64, jnp.complex128]
@ -191,6 +191,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
prims = [eqn.primitive for eqn in jaxpr.eqns]
self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated.
def testBoolDtypeAlias(self):
self.assertIs(jnp.bool, jnp.bool_)
@jtu.sample_product(
dtype=float_dtypes + [object],
allow_pickle=[True, False],