mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19403 from jakevdp:jnp-bool
PiperOrigin-RevId: 599312340
This commit is contained in:
commit
83ad09c4e6
@ -53,6 +53,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
bincount as bincount,
|
||||
blackman as blackman,
|
||||
block as block,
|
||||
bool_ as bool, # Array API alias for bool_
|
||||
bool_ as bool_,
|
||||
broadcast_arrays as broadcast_arrays,
|
||||
broadcast_shapes as broadcast_shapes,
|
||||
|
@ -156,6 +156,7 @@ def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def blackman(M: int) -> Array: ...
|
||||
def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ...
|
||||
bool: Any
|
||||
bool_: Any
|
||||
def broadcast_arrays(*args: ArrayLike) -> list[Array]: ...
|
||||
|
||||
|
@ -175,7 +175,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
return f
|
||||
|
||||
@parameterized.parameters(
|
||||
[dtype for dtype in [jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32,
|
||||
[dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32,
|
||||
jnp.uint64, jnp.int8, jnp.int16, jnp.int32, jnp.int64,
|
||||
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
|
||||
jnp.complex64, jnp.complex128]
|
||||
@ -191,6 +191,9 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
prims = [eqn.primitive for eqn in jaxpr.eqns]
|
||||
self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated.
|
||||
|
||||
def testBoolDtypeAlias(self):
|
||||
self.assertIs(jnp.bool, jnp.bool_)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=float_dtypes + [object],
|
||||
allow_pickle=[True, False],
|
||||
|
Loading…
x
Reference in New Issue
Block a user