217 Commits

Author SHA1 Message Date
George Necula
b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00
Peter Hawkins
278ff13b66 Improve implementation of cbrt() in JAX.
Lower to XLA cbrt() operator in sufficiently new jaxlibs.
On TPU, use a Newton-Raphson step to improve the cube root.

Remove support for complex cbrt() in jax.numpy; the existing lowering was wrong and it is not entirely clear to me that we actually want to support complex `jnp.cbrt()`. NumPy itself does not support complex numbers in this case.

Add testing for `sqrt`/`rsqrt` for more types.

[XLA:Python] Add cbrt to XLA:Python bindings.

PiperOrigin-RevId: 386316949
2021-07-22 14:01:28 -07:00
Peter Hawkins
5893b92048 Clarify documentation about array views. 2021-07-19 09:49:37 -04:00
Jake VanderPlas
3902515ef2 Fix rank promotion error in jnp.cov 2021-07-12 13:35:19 -07:00
Jake VanderPlas
17e562df13 Fix rank promotion warning in polyint/polyder 2021-07-12 11:39:25 -07:00
Jake VanderPlas
9ee12207ab Fix rank promotion error for jnp.packbits 2021-07-12 09:24:42 -07:00
Jake VanderPlas
7f8b81b1e6 Fix rank promotion warnings for jnp.in1d 2021-07-09 08:36:48 -07:00
Jake VanderPlas
591c315b0f Fix rank promotion warning in jnp.vander 2021-07-09 06:36:17 -07:00
jax authors
d2861b120e Merge pull request #7114 from charmerDark:s_implementation
PiperOrigin-RevId: 383835003
2021-07-09 06:35:42 -07:00
Jake VanderPlas
f3234b5642 jnp.expand_dims: accept any sequence of dims, not just tuples 2021-07-08 14:45:57 -07:00
jax authors
4e34914484 Merge pull request #7219 from jakevdp:unpack-bits-rank-promotion
PiperOrigin-RevId: 383716713
2021-07-08 14:35:38 -07:00
vishnu
c60115ab9d adding s_ and index_exp 2021-07-09 02:35:28 +05:30
Jake VanderPlas
61be1e57d4 Fix rank promotion warning in DeviceArray.view() 2021-07-08 13:01:51 -07:00
Jake VanderPlas
6d26f6e980 Fix & test rank promotion in jnp.unpackbits 2021-07-08 12:05:47 -07:00
jax authors
78a689bb09 Merge pull request #6933 from lgeiger:multi-slice
PiperOrigin-RevId: 383608401
2021-07-08 04:54:03 -07:00
jax authors
da48776188 Merge pull request #7141 from wdphy16:complex_logaddexp
PiperOrigin-RevId: 383489570
2021-07-07 14:07:42 -07:00
Peter Hawkins
2168483a62 Add x.at[idx].get().
This allows the sorted/unique keyword arguments to be passed to indexed gather operations.
2021-07-07 08:51:09 -04:00
Dian Wu
7b10965794 Implement complex logaddexp 2021-07-03 18:09:58 +02:00
Nicholas Junge
cba5b13a39 Improve concretization errors for jnp indexing routines 2021-06-17 20:26:16 +02:00
jax authors
31e9c65f2a Merge pull request #6952 from jakevdp:hstack-reshape
PiperOrigin-RevId: 378916971
2021-06-11 11:43:38 -07:00
jax authors
8fcdb85f09 Merge pull request #6940 from jakevdp:fix-sinc
PiperOrigin-RevId: 378904804
2021-06-11 10:47:57 -07:00
Jake VanderPlas
17710c0711 add efficient path for array input to jnp.stack, jnp.[hvd]stack, jnp.column_stack 2021-06-11 10:42:06 -07:00
Jake VanderPlas
0470f4f368 jnp.concatenate: add fast path based on lax.reshape for array inputs 2021-06-10 13:25:33 -07:00
Jake VanderPlas
80d8f2d56c jnp.sinc: fix NaNs at x=0 2021-06-10 09:14:07 -07:00
Peter Hawkins
1ff12f05b3 Add unique/sorted annotations to gather().
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.

Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Lukas Geiger
8e8cbac33d Simplify multi slice 2021-06-10 00:37:14 +01:00
Jake VanderPlas
79d0852145 Add optional size argument to jnp.union1d for JIT compatibility 2021-06-09 11:36:34 -07:00
Jake VanderPlas
f97e2f945f jnp.argwhere: add optional size parameter for JIT compatibility 2021-06-08 16:17:37 -07:00
Jake VanderPlas
022464b91b jnp.where: add optional size argument 2021-06-08 15:53:12 -07:00
Jake VanderPlas
1296dc3f1e jnp.flatnonzero: add optional size argument for JIT compatibility 2021-06-08 13:16:51 -07:00
Jake VanderPlas
d198ad0ac1 jnp.unique: add optional size argument for JIT compatibility 2021-06-08 11:31:42 -07:00
Jake VanderPlas
21dbe30fbb BUG: return JAX arrays rather than NumPy arrays in jnp.unravel_index 2021-06-03 09:15:01 -07:00
George Necula
2ccda70d83 [jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.

Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.

Division  is supported only in the cases when either there is no remainder,
or the divisor is a constant.

This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:

```
   y = x.reshape((2, -1))
   z = ... y ...
   return z.reshape(x.shape)
```
2021-06-03 10:58:06 +03:00
George Necula
c07d54aab0 [jax2tf] Add shape polymorphism support for jnp.einsum.
The main problem was that jnp.einsum uses opt_einsum.contract_path
to parse the specification string and compute the order or the
contractions. This function wants to compute the sizes of operands
and intermediate results, and will fail if some dimensions are
polymorphic.

The (partial) solution here is to replace the operands with
jax.ShapeDtypeStruct with a fixed size for all dimension variables,
then call opt_einsum.contract_path and use that result if there
is only one contraction. We abort if there are multiple contractions.
This behavior is clearly sound. If there were multiple contractions,
perhaps their order would be different with different dimension sizes.
2021-05-31 19:06:15 +03:00
jax authors
e0f285fd21 Merge pull request #6839 from jakevdp:reshape-doc
PiperOrigin-RevId: 376645137
2021-05-31 02:30:04 -07:00
jax authors
44b1791b7a Copybara import of the project:
--
8226dfc8a4974b4c8031ee267fa5327e778140ee by Nicholas Junge <nicholas.junge@web.de>:

Handle negative values for list-like sections in jnp.split

PiperOrigin-RevId: 376302305
2021-05-27 20:33:49 -07:00
Jake VanderPlas
dded0e38b3 DOC: add notes to jax.numpy docstrings about returning copies rather than views 2021-05-27 18:05:45 -07:00
Nicholas Junge
8226dfc8a4 Handle negative values for list-like sections in jnp.split 2021-05-27 18:25:18 +02:00
Peter Hawkins
c2b0f72d66 Fix handling of empty dimensions in jnp.take(). 2021-05-24 11:59:41 -04:00
Lukas Geiger
b4b02cb7c3 Add jnp.resize 2021-05-21 09:25:49 +01:00
Lukas Geiger
a5a1f62c5c Improve docs for jnp.round and export jnp.round_ 2021-05-21 01:43:56 +01:00
Paul Nguyen
bbcaec4a3a Initial implementation of jax.numpy.poly
This is an initial jax.numpy.poly implementation. It is tested by testPoly in the tests/lax_numpy_test.py test file.
2021-05-20 13:51:14 -05:00
Peter Hawkins
2f7ef94562 Support complex numbers in jnp.convolve and jnp.correlate. 2021-05-13 09:09:46 -04:00
Peter Hawkins
6d2344d5b8 Change jnp scalar types to consider numpy scalars as instances. 2021-05-12 20:31:49 -04:00
Jake VanderPlas
5943594041 Improve error message for shape 2021-05-11 15:44:09 -07:00
Peter Hawkins
d005e38f78 Promote the x.at[idx].set(y) operators as the preferred way to do indexed updates.
Mark the index_update() etc. operators as deprecated in the documentation.

Add new .divide and .power operators. Fixes #2694.
Add .multiply as an alias for .mul. To be more numpy-like we should probably prefer the longer names.
2021-05-10 20:32:00 -04:00
jax authors
3af76ff013 Merge pull request #6704 from jakevdp:at-doc
PiperOrigin-RevId: 372999519
2021-05-10 13:30:10 -07:00
Jake VanderPlas
26f74e64a6 DOC: add documentation of DeviceArray object properties & methods 2021-05-10 12:37:59 -07:00
Peter Hawkins
d96bf0bef6 Allow non-inexact dtypes for jnp.nan...() reductions.
Fixes #2349.
2021-05-10 13:21:12 -04:00
George Necula
244fc7b11c [jax2tf] Expand shape polymorphism support for some instances of lax.gather 2021-05-08 07:38:50 +03:00