560 Commits

Author SHA1 Message Date
George Necula
d25bcac93d [shape_poly] Add better support for division, and working with strides
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
2023-01-25 07:37:54 -08:00
George Necula
8e931a82ff [shape_poly] Generalize binary operations with symbolic dimensions.
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".

Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:

`jnp.ones(np.arange(5) * x.shape[0])`

Instead you should write

`jnp.ones([i * x.shape[0] for i in range(5)])`
2023-01-21 04:26:59 -08:00
Qiao Zhang
650e1ef4c3 Expose fp8 types from jnp namespace.
PiperOrigin-RevId: 503353939
2023-01-19 22:23:33 -08:00
George Necula
1b04fcb4be [jax2tf] Improve handling of lax.pad and jnp.pad with polymorphic padding config
PiperOrigin-RevId: 498350702
2022-12-29 03:00:32 -08:00
George Necula
a51b174460 [jax2tf] Improved error checking for jnp.squeeze with shape polymorphism 2022-12-20 17:58:32 +02:00
George Necula
86a70ab811 [jax2tf] Fix for jnp.roll with shape polymorphism
There was a partial fix before, in #13470, but it was incomplete
and the x64 mode was not handled properly.

There are no tests added here; this was discovered by running the
tests with --jax2tf_default_experimental_native_lowering, which
will become default soon.
2022-12-09 08:08:28 +02:00
Jake VanderPlas
09d1b6d8d5 Deprecate jnp.msort following deprecation of numpy.msort 2022-12-07 10:08:18 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
George Necula
fcaf7f1169 [jax2tf] Fix the handling of jnp.roll for polymorphic shapes 2022-12-01 11:18:47 +01:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
jax authors
20f092c916 As explained in the JAX FAQ, jax.numpy.where has suprising behavior when
its gradient is taken and one of the inputs is NaN.  This CL adds
a short description of the behavior to the jax.numpy.where docs,
which is the logical place that users would look for it.

PiperOrigin-RevId: 488654036
2022-11-15 07:42:35 -08:00
Ishtiaq Hussain
09f62dec3c Moved abs to inputs of lcm and added specific test 2022-11-11 22:31:06 +00:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Jake VanderPlas
1627bc6e83 generate dynamic_slice rather than slice for simple indexing/slicing 2022-11-03 11:40:51 -07:00
Jake VanderPlas
8bde3a0a70 Point to ndarray.at from docstring of unimplemented jnp.put & jnp.place 2022-10-28 14:13:36 -07:00
Jake VanderPlas
15b489fb6f [typing] annotate next section of lax_numpy.py 2022-10-28 10:14:00 -07:00
Jake VanderPlas
0a79fa4f1c jax.numpy: implement window functions in terms of lax ops
Including blackman, bartlett, hamming, hanning, kaiser.

Why? Previously these were implemented by computing the output on host at trace-time and embedding the result as a large constant array. Computing the results via lax operations is more in the spirit of jax.numpy.
2022-10-27 15:47:04 -07:00
Jake VanderPlas
51242bcc26 jax.numpy: implement window functions in terms of lax ops
Including blackman, bartlett, hamming, hanning, kaiser.

Why? Previously these were implemented by embedding large constants; this should be more performant.
2022-10-27 15:08:16 -07:00
jax authors
994e0ac1bb Merge pull request #12979 from jakevdp:annotate-lax-numpy-3
PiperOrigin-RevId: 484271670
2022-10-27 09:20:28 -07:00
jax authors
d2df0faf41 Merge pull request #12996 from mattjj:tweak-jnp-canonicalize-shape
PiperOrigin-RevId: 484100902
2022-10-26 16:24:37 -07:00
Matthew Johnson
7e341817b4 [dynamic-shapes] tweak jnp.canonicalize_shape logic
The idea with jnp.canonicalize_shape is that it handles non-tuple shapes, i.e.
intended to be scalar-like arguments like Python builtin ints or numpy scalar
types or 0D arrays. To do that, it checks numpy.ndim(shape) == 0. But
numpy.ndim might attempt to convert its argument to a numpy.ndarray, which
breaks when the argument is a tuple with Tracers inside!

Instead, let's just check if the argument is one of the canonical sequence
types (list or tuple) and if so then not even call numpy.ndim.
2022-10-26 12:01:49 -07:00
Jake VanderPlas
9c0f876bcc [typing] annotate jnp.pad 2022-10-26 11:09:52 -07:00
Jake VanderPlas
a5bccc8bf9 [typing] annotate next chunk of lax_numpy.py 2022-10-25 14:03:43 -07:00
Jake VanderPlas
2f27d516d7 [typing] annotate next part of lax_numpy.py 2022-10-25 12:36:26 -07:00
Jake VanderPlas
2009e65a33 jnp.gradient: call check_arraylike on inputs & clean-up implementation 2022-10-24 15:27:33 -07:00
Jake VanderPlas
56d42c0edf [typing] annotate next batch of lax_numpy 2022-10-24 14:21:35 -07:00
Jake VanderPlas
97b17af5be [typing] add type annotations to the first several lax_numpy functions 2022-10-21 11:59:53 -07:00
Jake VanderPlas
6d308653e4 [typing] annotate jax.numpy ufuncs 2022-10-20 11:22:04 -07:00
Jake VanderPlas
eb2046800c [typing] annotate jax.numpy array creation routines 2022-10-18 13:46:07 -07:00
Matthew Johnson
df5f7cb8d3 Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.
PiperOrigin-RevId: 480229237
2022-10-10 18:51:37 -07:00
jax authors
9cabd227d7 Copybara import of the project:
--
6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>:

implement bint arrays (opaque dtypes), add padding rules

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
PiperOrigin-RevId: 479883102
2022-10-09 01:25:50 -07:00
Matthew Johnson
6d2aaac245 implement bint arrays (opaque dtypes), add padding rules
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-10-08 22:57:29 -07:00
jax authors
58cd8376ee Merge pull request #12675 from mattjj:device-put2
PiperOrigin-RevId: 479660808
2022-10-07 13:49:57 -07:00
George Necula
fb2141fc3b [jax2tf] Allow the use of DimPolynomial with jnp.array and binary operations
Prior to this the user had to explicitly call core.dimension_as_value whenever
using a potentially polymorphic shape in the computation, e.g., x +
core.dimension_as_value(x.shape[0]). Furthermore, jnp.array(x.shape[0])
would fail.

Now, these operations are allowed implicitly,
and the user can call `jnp.array(x.shape[0])`.

This uses an internal extensibility mechanism called __jax_array__
that is experimental and probably not fully implemented.
2022-10-07 17:58:41 +03:00
George Necula
7c7c94c8dd Expand support for __jax_array__ in jnp.array.
This relates to the long discussion in #4725 and #10065.
2022-10-07 14:25:07 +03:00
Matthew Johnson
ce95ebad94 make device_put work with Sharding 2nd arg
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-10-06 16:14:15 -07:00
Matthew Johnson
06a2c85d52 [dynamic-shapes] small fix to einsum (and indexing) 2022-10-04 12:29:41 -07:00
Jake VanderPlas
069866e07a Add types to jax/_src/numpy/util.py 2022-10-04 10:07:38 -07:00
Jake VanderPlas
fd45035b21 [typing] add full annotations for lax_numpy setops 2022-10-03 12:52:28 -07:00
Jake VanderPlas
d49c5c37ea jnp.take: add optional arguments forwarded to lax.gather 2022-09-29 09:33:38 -07:00
Matthew Johnson
a8826e672b [dynamic-shapes] Add basic slicing support
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-09-28 15:55:51 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
jax authors
ec15e83018 - Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable.
PiperOrigin-RevId: 476796185
2022-09-25 20:54:17 -07:00
Ke Wu
c823151771 Allow transpose axes to be negative to match (undocumented) NumPy behavior 2022-09-23 10:18:23 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
49a6034fa3 [dynamic-shapes] enable basic einsum support, following jax2tf shape polys 2022-09-13 16:06:33 -07:00
Yash Katariya
d7726e7b26 Make __getitem__ work for PmapSharding just like SDA works. DA is already covered with the current implementation.
Added TODOs to take fast path for indices wherever it is possible to do that. If a correct index is passed during getitem and if that index exists on `Array`, then the fast path is taken (see the test in this CL).

PiperOrigin-RevId: 473342504
2022-09-09 14:25:22 -07:00