602 Commits

Author SHA1 Message Date
Jake VanderPlas
6dd0e0153a jnp.ndarray.at: deprecate passing additional arguments by position 2023-03-13 10:04:39 -07:00
jax authors
03fc8a4766 Merge pull request #14912 from jakevdp:fail-on-out
PiperOrigin-RevId: 515702712
2023-03-10 12:05:38 -08:00
Jake VanderPlas
d579c3fcbb jnp.argmin/max: correctly handle out argument 2023-03-10 10:35:29 -08:00
Jake VanderPlas
5c2eefff38 Fix jnp.sort & jnp.vdot in no-jit mode 2023-03-10 10:22:34 -08:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
Stephan Hoyer
d4f70c8071 Add "compare_all" method to searchsorted 2023-03-07 16:34:24 -07:00
George Necula
afe4f8ed1a [shape_poly] Add support for shape polymorphism for jnp.{argsort,bincount,insert,nonzero} 2023-03-07 08:29:07 +01:00
jax authors
7d154103e3 Merge pull request #14789 from patrick-kidger:patch-2
PiperOrigin-RevId: 514490253
2023-03-06 12:27:53 -08:00
Patrick Kidger
17afaf67b9 Add _ScalarMeta(dtype=...) field for static type checkers 2023-03-06 10:49:30 -08:00
Peter Hawkins
eb286315fa Fix a TODO updating users of lax._check_user_dtype_supported to dtypes.check_user_dtype_supported.
PiperOrigin-RevId: 514435374
2023-03-06 09:33:01 -08:00
Peter Hawkins
a4412e2715 Remove internal ndarray type name. Use Array throughout.
jax.numpy.ndarray remains an exported alias for jax.Array.

PiperOrigin-RevId: 513046188
2023-02-28 14:51:08 -08:00
Johannes Reifferscheid
3ecff30b4e Don't create invalid bools in lax_numpy_test/testView.
Currently, JAX is generating random 8 bit ints for bools, which usually doesn't cause any issues, but in some special cases does. One example is the HLO snapshot dumping code, which surprisingly creates unparseable protos for such inputs.

PiperOrigin-RevId: 513032802
2023-02-28 14:03:08 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Jake VanderPlas
7f6826659e BUG: raise error when shaped_abstractify is called on JAX scalar types
PiperOrigin-RevId: 512163825
2023-02-24 14:27:57 -08:00
Jake VanderPlas
a283aa0cc3 Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
2023-02-23 16:15:09 -08:00
Jake VanderPlas
0913c5a009 jnp.ndarray.view: implement all dtypes
Re-land #14526 with fixes to scalar views
2023-02-17 10:54:37 -08:00
Jake VanderPlas
dc3e5f11af TMP 2023-02-17 09:49:11 -08:00
Jake VanderPlas
52bdbcf2dc BUG: avoid passing functions directly to abstractmethod
abstractmethod mutates its arguments, which causes problems
(see https://github.com/google/jax/discussions/14548)
2023-02-17 09:49:11 -08:00
Jake VanderPlas
e1333f3de0 Roll-back https://github.com/google/jax/pull/14526 because it breaks view() on scalar inputs
PiperOrigin-RevId: 510281592
2023-02-16 17:07:55 -08:00
Jake VanderPlas
b8994f5c3d jnp.ndarray.view: implement all dtypes 2023-02-16 10:07:24 -08:00
jax authors
6f1527f81a Merge pull request #14489 from jakevdp:copy-array
PiperOrigin-RevId: 509960582
2023-02-15 16:16:44 -08:00
Jake VanderPlas
a6d68581b4 DOC: add better documentation for array methods 2023-02-15 09:21:56 -08:00
Jake VanderPlas
58b800db84 jnp.copy: ensure inputs are array-like 2023-02-15 08:29:45 -08:00
jax authors
3bd6ca014c Merge pull request #14469 from gnecula:poly_percentile
PiperOrigin-RevId: 509828200
2023-02-15 07:33:26 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
jax authors
c2b7c5f132 Merge pull request #14474 from jakevdp:doc-array-methods
PiperOrigin-RevId: 509639140
2023-02-14 14:29:13 -08:00
Jake VanderPlas
5958bf0d2f DOC: improve documentation for jax.Array methods 2023-02-14 13:04:27 -08:00
Jake VanderPlas
967f2118bf DOC: improve documentation for jax.Array methods 2023-02-14 13:03:10 -08:00
jax authors
aa98c99d3a Merge pull request #14275 from xoiga123:fix-jax.numpy.hsplit
PiperOrigin-RevId: 509585801
2023-02-14 11:24:55 -08:00
George Necula
28efb6050a [shape_poly] Add support for shape polymorphism to jnp.{percentile,quantile,nanquantile} 2023-02-14 17:40:50 +01:00
jax authors
57900d7ef2 Merge pull request #14364 from jakevdp:fix-tril-indices
PiperOrigin-RevId: 508723970
2023-02-10 12:25:06 -08:00
Ngo Viet Hoai Bao
82e5767f77 update hsplit and testHVDSplit for 1D array 2023-02-10 14:27:37 +07:00
Jake VanderPlas
4fbaee5920 Implement jax.numpy.argpartition 2023-02-08 14:41:39 -08:00
Jake VanderPlas
794557d349 tril_indices/triu_indices: fix call signature & add type annotations 2023-02-08 11:19:06 -08:00
Jake VanderPlas
a76a024548 tril/triu_indices: compute arrays at runtime 2023-02-08 09:52:41 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
George Necula
15be538ebe [shape_poly] Fix the hashing and equality of symbolic dimensions 2023-02-04 08:30:44 +02:00
Peter Hawkins
b730ed4645 Remove placeholder functions for unimplemented NumPy functions.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Jake VanderPlas
0b5443c6e8 Clean up: remove unused helper functions 2023-02-01 09:55:58 -08:00
Jake VanderPlas
14a0fe08c8 DOC: improve documentation of OOB indices in jnp.take 2023-01-31 15:59:06 -08:00
Jake VanderPlas
217ca5db4b Add implementation of jnp.partition 2023-01-30 13:50:25 -08:00
Qiao Zhang
65ef487a82 Allow jnp.nan_to_num handle integer types like numpy.
See current behavior difference wrt np.nan_to_num
```
>>> np.nan_to_num(np.array(1, dtype=np.int32))
1
>>> jnp.nan_to_num(jnp.array(1, dtype=jnp.int32))
ValueError: data type <class 'numpy.int32'> not inexact
```
PiperOrigin-RevId: 505735212
2023-01-30 10:37:17 -08:00
George Necula
d25bcac93d [shape_poly] Add better support for division, and working with strides
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
2023-01-25 07:37:54 -08:00
George Necula
8e931a82ff [shape_poly] Generalize binary operations with symbolic dimensions.
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".

Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:

`jnp.ones(np.arange(5) * x.shape[0])`

Instead you should write

`jnp.ones([i * x.shape[0] for i in range(5)])`
2023-01-21 04:26:59 -08:00
Qiao Zhang
650e1ef4c3 Expose fp8 types from jnp namespace.
PiperOrigin-RevId: 503353939
2023-01-19 22:23:33 -08:00
George Necula
1b04fcb4be [jax2tf] Improve handling of lax.pad and jnp.pad with polymorphic padding config
PiperOrigin-RevId: 498350702
2022-12-29 03:00:32 -08:00
George Necula
a51b174460 [jax2tf] Improved error checking for jnp.squeeze with shape polymorphism 2022-12-20 17:58:32 +02:00
George Necula
86a70ab811 [jax2tf] Fix for jnp.roll with shape polymorphism
There was a partial fix before, in #13470, but it was incomplete
and the x64 mode was not handled properly.

There are no tests added here; this was discovered by running the
tests with --jax2tf_default_experimental_native_lowering, which
will become default soon.
2022-12-09 08:08:28 +02:00
Jake VanderPlas
09d1b6d8d5 Deprecate jnp.msort following deprecation of numpy.msort 2022-12-07 10:08:18 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00