20 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Chris Jones
15c542228a [jax] Make the SupportsDType protocol runtime checkable.
This allows `DTypeLike` to be used as a type annotation for type-checked functions without triggering a warning.

PiperOrigin-RevId: 652905699
2024-07-16 10:57:15 -07:00
Jake VanderPlas
0ff0d7b95d jnp.take: fix annotation for fill_value 2024-05-25 14:20:55 -07:00
Meekail Zain
a2feff2e54 Add support for max_version, dl_device, copy kwargs in __dlpack__ 2024-04-11 16:44:19 +00:00
Meekail Zain
8b7aae586b Update jnp.clip to Array API 2023 standard 2024-04-04 22:55:10 +00:00
George Necula
c6afdfd8d6 [shape_poly] Simplify the API for processing polymorphic_shape specifications
Before, we had `export.poly_spec` to create a jax.ShapedDtypeStruct`
given a polymorphic shape specification. This function was
invoked `poly_spec(arg_shape, arg_dtype, polymorphic_shape)`.
The `arg_shape` was only needed when the polymorphic shape spec
contained placeholders.

We break out an `export.symbolic_shape` that is just a parser
of polymorphic shape specs and we ask the user to invoke
`jax.ShapeDtypeStruct` directly:

`jax.ShapeDtypeStruct(export.symbolic_shape(polymorphic_shape, like=arg_shape), arg_dtype)`.

We also rename the `export.poly_specs` to `export.arg_specs`.
2023-11-28 12:45:59 +02:00
Jake VanderPlas
911f745775 Make jax._src.typing.DTypeLike more strictly defined
This is in preparation for exporting this to `jax.typing.DTypeLike`. Currently this is effectively just Any, and we want to make certain it's a meaningful type before exporting.

PiperOrigin-RevId: 572260744
2023-10-10 09:01:19 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
a59a4d3bd2 Allow duck-typed inputs to zeros_like, ones_like, etc. 2023-06-06 00:59:51 -07:00
Peter Hawkins
e4b154b660 Split basearray into separate Bazel module.
Move the definition of ArrayLike into basearray to avoid a cyclic dependency between array.py and basearray.

PiperOrigin-RevId: 516264828
2023-03-13 11:14:41 -07:00
Jake VanderPlas
de673ce297 DOC: improve usage recommendation in jax.typing 2023-02-21 04:58:21 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
1ed18fa500 add allow_opaque_dtype to dtypes.canonicalize_dtype utility 2022-10-17 13:47:42 -07:00
Jake VanderPlas
8196a6a9f0 [typing] clarify jax._src.typing 2022-10-14 11:52:04 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5829c6ae9d Change case of typing.Dtype -> typing.DType
This follows the convention used in numpy.typing.DType.
2022-09-14 15:03:55 -07:00
Jake VanderPlas
b3c31ebe7d Add typing_test.py 2022-09-13 12:43:51 -07:00
Jake VanderPlas
4fed097b1f jax._src/typing: add basic types 2022-09-12 09:07:56 -07:00