698 Commits

Author SHA1 Message Date
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
Jake VanderPlas
1815bc7632 [typing] allow scalar shape for jnp.broadcast_to 2023-10-13 13:37:20 -07:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Sergei Lebedev
5d9c39f4b0 MAINT Use a generator expression with all() and any()
There is no reason to allocate a list only for the purpose of iteration.
2023-10-10 22:33:03 +01:00
Jake VanderPlas
911f745775 Make jax._src.typing.DTypeLike more strictly defined
This is in preparation for exporting this to `jax.typing.DTypeLike`. Currently this is effectively just Any, and we want to make certain it's a meaningful type before exporting.

PiperOrigin-RevId: 572260744
2023-10-10 09:01:19 -07:00
Jake VanderPlas
2902b32e33 [typing] allow Sequence inputs in several jax.numpy functions 2023-10-02 11:48:36 -07:00
Jake VanderPlas
e8ebe462d2 Deprecate non-array inputs to several jax.numpy functions 2023-09-22 14:21:23 -07:00
Yash Katariya
426970591b If an input to jnp.asarray is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
Do a similar thing for jax.Array too if dtypes match.

Fixes https://github.com/google/jax/issues/17702

PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Jake VanderPlas
4edb74ba7b Fix some numpy 2.0 incompatibilities 2023-09-21 10:24:52 -07:00
Jake VanderPlas
505f03b40f Avoid references to symbols removed in numpy 2.0 2023-09-19 11:50:21 -07:00
Jake VanderPlas
3386e54fe0 jnp.inner: add preferred_element_type argument 2023-09-14 16:40:19 -07:00
Brennan Saeta
1cef3e85f4 Fix error message for zeros_like which was referencing ones_like.
PiperOrigin-RevId: 565413589
2023-09-14 10:43:57 -07:00
Brian Patton
ed955ea7bf Fully unroll the scan in jnp.searchsorted, when method 'scan_unrolled' is specified. On GPU, XLA's 'scan' (fori_loop) implementation launches multiple calls to the body_fun GPU kernel, whereas a fully unrolled scan can be fused into a single kernel launch.
Since we only require log-many steps, this is often quite practical, and can be a nice speedup. (from 4.5ms down to 1.5ms in my scenario.)

PiperOrigin-RevId: 565371859
2023-09-14 08:10:49 -07:00
Jake VanderPlas
9289f3250b Add missing preferred_element_type tests
Followup to https://github.com/google/jax/pull/17506
2023-09-08 13:07:37 -07:00
Jake VanderPlas
2451f34233 jax.numpy: add preferred_element_type argument to matmul functions 2023-09-07 15:16:22 -07:00
Adam Paszke
bb8d5a0121 Rewrite simple slicing to the static slicing primitive whenever possible
This makes it a lot easier to handle within Pallas and Mosaic.

PiperOrigin-RevId: 563128943
2023-09-06 09:43:00 -07:00
Jake VanderPlas
7d29ed6bdd Lower jax.numpy matmul functions to mixed-precision dot_general 2023-09-05 08:37:51 -07:00
Miha Zgubic
992e5e4479
Fix typo in jnp.interp docstring. 2023-08-25 22:39:15 +01:00
Jake VanderPlas
0da3a7ffb5 jnp.einsum: lower to mixed-precision dot_general when possible.
This is a re-landing of https://github.com/google/jax/pull/16733. The downstream issues should be fixed by https://github.com/google/jax/pull/17152.

Reverts c6f40e202c7f5724b9be61afa33541a8f4abfdd0

PiperOrigin-RevId: 559794120
2023-08-24 10:31:39 -07:00
Jake VanderPlas
19a57e1a01 Deprecate jax.numpy.row_stack 2023-08-22 13:12:49 -07:00
Peter Hawkins
9f5999d545 Improve type annotations for jax.numpy.
* Allow sequences of axes to jnp.flip, rather than mandating tuples. Users sometimes pass lists here.
* Allow array-like pad_width values to pad().

PiperOrigin-RevId: 558923802
2023-08-21 15:56:14 -07:00
Jake VanderPlas
8bba992f9a deprecate jax.numpy.issubsctype 2023-08-17 12:27:52 -07:00
Parker Schuh
c6f40e202c Reverts 75c3457264f9cc117ff09551ce3174d72689fa3d
PiperOrigin-RevId: 557628297
2023-08-16 16:06:28 -07:00
Jake VanderPlas
14d52fca55 jnp.einsum: lower to mixed-precision dot_general when possible 2023-08-15 15:57:19 -07:00
Peter Hawkins
78cfdd1b35 Add some more type annotations to lax_numpy.py.
These type annotations are of course mostly ignored because the pytype: skip-file comment, but they help readers if nothing else.

PiperOrigin-RevId: 555955257
2023-08-11 08:07:24 -07:00
Jake VanderPlas
4df58052aa jnp.unpackbits: fix handling of count & add tests 2023-08-10 14:34:11 -07:00
Peter Hawkins
0e80d959c8 Mark jnp.{NINF,NZERO,PZERO} as deprecated.
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357).

PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Mateusz Sokół
1fedf04ed5 API: Remove NINF and PINF usages 2023-08-09 14:16:33 +02:00
jax authors
e21945661f Merge pull request #16972 from mtsokol:update-np-exceptions-imports
PiperOrigin-RevId: 554548376
2023-08-07 11:58:59 -07:00
Mateusz Sokół
d183a2c02f ENH: Update numpy exceptions imports 2023-08-07 19:08:41 +02:00
Jake Hall
85f124c18d Add support for float8_e4m3fnuz and float8_e5m2fnuz. 2023-08-07 11:48:53 +01:00
Jake VanderPlas
bd5a4571d1 Implement jax.numpy.place with required inplace parameter 2023-08-02 14:29:26 -07:00
Jake VanderPlas
5a5730d9fc Fix type annotations for jnp.where 2023-08-02 13:42:20 -07:00
Jake VanderPlas
88c42da7f4 Add implementation of jnp.put 2023-07-26 08:54:54 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
jax authors
1b33a4eb05 Merge pull request #16815 from hawkinsp:py39
PiperOrigin-RevId: 550014612
2023-07-21 12:12:47 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
561c9531ff Lower jax.numpy.dot to mixed-precision dot_general 2023-07-21 10:10:30 -07:00
Jake VanderPlas
2ffa9bd8df Refactor opaque dtype implementation.
This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype,
and dtypes.opaque analogous to np.numeric. This will let us replace the
dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque).
2023-07-20 19:51:52 -07:00
George Necula
4fdc134543 [shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.

Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.

We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.

Note that this fixes a couple of bugs

  * for core.dilated_dim we had the code "if d == 0 then 0 else ..."
  but this works only if we can tell statically that `d == 0`, and
  it produced wrong results when `d` was symbolic and could take
  the value 0.
  * for core.stride_dim we did not handle correctly the case when
  `d < window_size`.

Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-19 16:15:04 +03:00
Jake VanderPlas
74159132b6 support np.array(x) where x is a custom pytree with __jax_array__ 2023-07-17 13:33:17 -07:00
Patrick Kidger
8bce54e5cb Add type annotation to jnp.tensordot
Just stopping pyright from complaining at me.
2023-07-17 11:30:16 -07:00
George Necula
71ac0bb446 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 16:37:53 +03:00
George Necula
58d6c4c1ec Roll back #16689
PiperOrigin-RevId: 547773322
2023-07-13 06:05:50 -07:00
George Necula
d21a667235 [shape_poly] More cleanup for the internal APIs for shape polymorphism.
Previously we had a number of APIs in core.py that operated on dimensions
and shapes and delegated to instances of DimensionHandler. We remove most
of those APIs because by now they ended up doing very little, e.g.,
`core.sum_dim` was the same as `operator.add`, and `core.sum_shape` was
the same as `tuple(map(operator.add))`.

We also remove the whole `DimensionHandler` machinery because by now
the only other use of non-constant dimensions using this mechanism
are the symbolic dimensions used for shape polymorphism, and those
support now full operator overloading. (When we introduced `DimensionHandler`
we had the masking transformation around that needed it also.)
2023-07-13 09:59:41 +03:00
Matthew Johnson
6bdb5821c3 einsum: inf inputs could cause superfluous nan outputs 2023-07-11 17:19:40 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
01a16f5914 Merge pull request #16487 from jakevdp:convolve-dtype
PiperOrigin-RevId: 542929304
2023-06-23 12:32:36 -07:00