4026 Commits

Author SHA1 Message Date
Peter Hawkins
9455254b9f [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted.

PiperOrigin-RevId: 441214076
2022-04-12 10:30:52 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Tianjian Lu
a11b41f581 [sparse] Use sorted indices instead of sorted rows only.
PiperOrigin-RevId: 440579642
2022-04-09 08:33:48 -07:00
Peter Hawkins
94307a02c8 Revert: [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
PiperOrigin-RevId: 440452521
2022-04-08 14:22:15 -07:00
Matthew Johnson
272ed95858 remove experimental/djax
PiperOrigin-RevId: 440445082
2022-04-08 13:55:22 -07:00
jax authors
0bfb3efcd7 [JAX] Fix batch logic for approx_min/max_k
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.

PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
Peter Hawkins
0f15fa3b10 [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
PiperOrigin-RevId: 440433044
2022-04-08 12:57:59 -07:00
jax authors
b8602d018b Merge pull request #10198 from jakevdp:bcoo-duplicates
PiperOrigin-RevId: 440423326
2022-04-08 12:12:12 -07:00
Yash Katariya
654e5bd922 Roll forward again after the fix in the auto sharding pass.
PiperOrigin-RevId: 440412218
2022-04-08 11:25:07 -07:00
Jake VanderPlas
8b9efe79e7 [sparse] fix autodiff bug in spdot_general 2022-04-08 11:04:26 -07:00
Peter Hawkins
648a512488 [MHLO] Add direct MHLO lowerings for sparse primitives.
PiperOrigin-RevId: 440374054
2022-04-08 08:43:57 -07:00
Joan Puigcerver
0c02f7935a Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
PiperOrigin-RevId: 440300882
2022-04-08 01:08:55 -07:00
Jake VanderPlas
01e4fa8a78 [sparse] consolidate flavors of bcoo_dot_general 2022-04-07 11:28:12 -07:00
jax authors
8b3f039252 Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf
PiperOrigin-RevId: 439997145
2022-04-06 19:55:29 -07:00
jax authors
7ee6adb1a5 Merge pull request #10173 from jakevdp:bcoo-add-batchdim
PiperOrigin-RevId: 439955276
2022-04-06 15:50:10 -07:00
Yash Katariya
6a7a34603d Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
This move is because sharded_jit is being deprecated.

PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Jake VanderPlas
93a24f3b83 [sparse] add bcoo_add_batchdim 2022-04-06 14:44:29 -07:00
Alex Riley
869596fc2c Add jax.scipy.linalg.rsf2csf 2022-04-06 21:06:23 +01:00
Jake VanderPlas
5a96c0cb18 Skip test outside x64 2022-04-04 16:00:18 -07:00
Peter Hawkins
71a5eb263b [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
Fixes https://github.com/google/jax/issues/9946

PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Yash Katariya
6825f654b1 * Disallow any other type other than GDA and ShapedArray for auto sharding.
* Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.**
  * Auto sharding
    * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True.
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch
  * NO auto sharding
    * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params`
    * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch

PiperOrigin-RevId: 439413895
2022-04-04 14:33:51 -07:00
Peter Hawkins
1b8be90801 Remove the jax_enable_mlir flag. MLIR is now the only supported code path.
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.

PiperOrigin-RevId: 439324450
2022-04-04 08:40:09 -07:00
jax authors
e64a57d2c3 Merge pull request #10121 from hawkinsp:hcbcache
PiperOrigin-RevId: 439036780
2022-04-02 08:57:24 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
1c3edc811d Merge pull request #10110 from pschuh:weakref-bug
PiperOrigin-RevId: 438887762
2022-04-01 12:45:35 -07:00
Parker Schuh
df1c478ec5 Fix race condition for weakref destructor by catching rare exceptions. 2022-04-01 12:04:36 -07:00
Peter Hawkins
208e83ceb7 Avoid retracing when a host_callback.call is called multiple times with the same function.
If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality.

This issue was found in passing while debugging #9970.
2022-04-01 14:41:14 -04:00
jax authors
e766b96063 Merge pull request #10058 from yotarok:istft
PiperOrigin-RevId: 438832534
2022-04-01 08:43:27 -07:00
jax authors
4decbcb00e Merge pull request #10103 from LenaMartens:changelist/438319917
PiperOrigin-RevId: 438821559
2022-04-01 07:40:45 -07:00
Yotaro Kubo
a7fd751acf Add istft to jax.scipy.signal. 2022-04-01 14:28:53 +09:00
Yash Katariya
8ca8f74456 First attempt to enable auto-sharding. This CL adds support for GDA (no SDA support yet).
An example of using auto sharding with GDA:

```
f = pjit(lambda x: x, in_axis_resources=pjit.AUTO, out_axis_resources=pjit.AUTO)

sharding_info = pjit.get_sharding_from_xla(f, mesh, [(8, 2)], [np.int32])

inputs = [GlobalDeviceArray.from_callback(shape, mesh, ip, cb) for ip in sharding_info.in_pspec]

# Use the compiled function (which was compiled in get_sharding_from_xla)
out = sharding_info.compiled(*inputs) # Recommended way!
# OR
out = f(*inputs)
```
PiperOrigin-RevId: 438708483
2022-03-31 18:22:02 -07:00
jax authors
5181692b0e Merge pull request #10102 from gnecula:hcb_fix
PiperOrigin-RevId: 438638041
2022-03-31 12:52:40 -07:00
Lena Martens
15d2ccaeba Checkify: add axis and axis size to OOB error message. 2022-03-31 15:16:35 +01:00
George Necula
84e73598e1 [host_callback] Fix tests to ensure we use the correct platform
In host_callback_test, there are a few tests that inspect compiled HLO.
In some cases, we're explicitly creating a CPU XLA computation, but we're handing
it off the to the default backend. When we're on a TPU machine, we're asking a
TPU backend to compile a CPU XLA computation.

Fixes internal b/227521177.
2022-03-31 15:25:21 +03:00
Jake VanderPlas
34f116c0e0 vmap: preserve weak_type in batching tracer 2022-03-30 11:06:56 -07:00
jax authors
17fc5bd02e Merge pull request #9290 from fehiepsi:named
PiperOrigin-RevId: 438290209
2022-03-30 06:54:10 -07:00
jax authors
f8cddf0eca Merge pull request #9689 from fehiepsi:image
PiperOrigin-RevId: 438282642
2022-03-30 06:05:30 -07:00
jax authors
b31cf89e48 Merge pull request #10072 from jakevdp:fromiter
PiperOrigin-RevId: 438141629
2022-03-29 15:28:36 -07:00
Jake VanderPlas
fbfc3d8edf Better error messages for jnp.fromiter and jnp.fromfile 2022-03-29 14:30:32 -07:00
jax authors
085d3901fd Merge pull request #9369 from froystig:custom-vmap-outputs
PiperOrigin-RevId: 438111265
2022-03-29 13:35:49 -07:00
Roy Frostig
b2de101be7 require consistent output structure in custom vmap rules
... not always a sequence.
2022-03-29 12:28:04 -07:00
Jake VanderPlas
093b7032a8 Implement jnp.from* array creation functions 2022-03-29 10:52:47 -07:00
Tianjian Lu
e5d7f65b6a [sparse] Change call signature of bcoo primitive wrappers.
PiperOrigin-RevId: 437923482
2022-03-28 20:42:52 -07:00
jax authors
5dc068f54b Merge pull request #10042 from yotarok:fix_fft_helper
PiperOrigin-RevId: 437804929
2022-03-28 11:12:15 -07:00
Lena Martens
d72687c990 Checkify: add way to embed runtime info in error messages. 2022-03-28 11:55:34 +01:00
Yotaro Kubo
2e70177385 Fix a bug in fft helper appears when nperseg=1. 2022-03-28 14:51:54 +09:00
Roy Frostig
a6a43e2715 allow for recursive uses of custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-26 12:09:15 -07:00
Matthew Johnson
78cf4df21b improve remat transpose caching (cf. #9661) 2022-03-25 16:33:46 -07:00