14081 Commits

Author SHA1 Message Date
jax authors
9f5a6312c5 Merge pull request #13518 from jakevdp:multiv-normal
PiperOrigin-RevId: 493311895
2022-12-06 08:31:32 -08:00
Peter Hawkins
33a1b8866a Mark arguments to ufuncs as positional-only.
PiperOrigin-RevId: 493311821
2022-12-06 08:24:11 -08:00
jax authors
a7900166d1 Merge pull request #12962 from hawkinsp:rocm
PiperOrigin-RevId: 493308550
2022-12-06 08:09:45 -08:00
George Necula
938b625b2d Remove unreliable call_tf_test
PiperOrigin-RevId: 493225546
2022-12-06 00:31:19 -08:00
Yash Katariya
e7e9687161 Allow pjit's C++ dispatch path to operate on uncommitted array only if it belongs on a single device. This will bring pjit's dispatch performance in line with jit to prepare for jit/pjit frontend merge.
PiperOrigin-RevId: 493164446
2022-12-05 18:09:59 -08:00
John QiangZhang
2fc7bbce49 Cache the exec time tf.get_current_name_scope into _thread_local_state.exec_time_tf_name_scope and add it as prefix during tracing.
Also move the test cases to right code location.

PiperOrigin-RevId: 493163018
2022-12-05 18:02:44 -08:00
Peter Hawkins
516f0d0d0a Support negative axes in all_gather.
Previously we didn't check for these and they caused crashes during MHLO verification.

PiperOrigin-RevId: 493160581
2022-12-05 17:48:50 -08:00
jax authors
189d0c0456 Merge pull request #13525 from jakevdp:typing-extensions
PiperOrigin-RevId: 493158064
2022-12-05 17:35:04 -08:00
jax authors
9400555b66 Merge pull request #13526 from jakevdp:isinstance-array
PiperOrigin-RevId: 493156840
2022-12-05 17:27:55 -08:00
Jake VanderPlas
88bfe90a23 Replace utility function with isinstance check 2022-12-05 16:13:33 -08:00
jax authors
382a248f4b Merge pull request #13512 from jakevdp:fix-optimizers
PiperOrigin-RevId: 493139751
2022-12-05 16:11:59 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
a42e9e2363 [sparse] delete _bcoo_unbatch utility 2022-12-05 14:35:34 -08:00
jax authors
23261d78da Merge pull request #13515 from jakevdp:sparse-typing
PiperOrigin-RevId: 493113475
2022-12-05 14:28:18 -08:00
Peter Hawkins
1f9c988e63 Use _thread_local_state.__dict__.get() instead of getattr(_thread_local_state, ...).
`getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present.

Improves the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  51.4µs ± 1%  48.9µs ± 3%  -4.87%  (p=0.000 n=8+9)

name        old time/op             new time/op             delta
device_put  51.4µs ± 1%             48.9µs ± 3%  -4.87%          (p=0.000 n=8+9)
```

PiperOrigin-RevId: 493108288
2022-12-05 14:09:47 -08:00
Jake VanderPlas
c376836721 [typing] annotate jax.experimental.sparse 2022-12-05 14:05:25 -08:00
jax authors
45707a2729 Merge pull request #13499 from jakevdp:error-docs
PiperOrigin-RevId: 493092719
2022-12-05 13:14:37 -08:00
Yash Katariya
25d1a0b4c6 Add cudnn 86 (for cuda 11.8) so that I can release cuda 11.8 nightlies.
PiperOrigin-RevId: 493086060
2022-12-05 12:50:09 -08:00
Jake VanderPlas
58d6a3b164 random.multivariate_normal: add note about singular covariance 2022-12-05 12:43:05 -08:00
Jake VanderPlas
29942e312b docs: add another example to the ConcretizationTypeError docs 2022-12-05 11:24:54 -08:00
Roy Frostig
431c51a3eb rename iota_32x2_shape to iota_2x32_shape
... for consistency with other internal Threefry primitive names.
2022-12-05 11:09:56 -08:00
jax authors
dad21d3d79 Merge pull request #13506 from simonbutt:bugfix/jax101-pytrees
PiperOrigin-RevId: 493052164
2022-12-05 10:45:21 -08:00
Yash Katariya
b8b6e272d3 Add typehints and point to the correct endpoint of Mesh and PartitionSpec in the args section.
PiperOrigin-RevId: 493035898
2022-12-05 09:49:18 -08:00
Jake VanderPlas
d317cfa37b Revert part of #13498 2022-12-05 09:21:09 -08:00
Roy Frostig
75af6b58d9 add a jax2tf translation rule for the shaped-iota primitive
This allows for jax2tf conversion of the partitionable Threefry RNG.
2022-12-05 09:19:25 -08:00
Roy Frostig
a3483dbe32 docstring for shaped iota primitive 2022-12-05 09:15:27 -08:00
Peter Hawkins
401fbb61a9 Disable xmap_test on TPU under asan due to CI timeouts.
PiperOrigin-RevId: 492994226
2022-12-05 06:52:09 -08:00
Qiao Zhang
55d6daacfa Skip test_lstm on CPU and TPU for jax OSS build.
PiperOrigin-RevId: 492722650
2022-12-03 14:16:07 -08:00
Simon Butt
542b38a1c2 Updated jax.tree_leaves --> jax.tree_util.tree_leaves to remove deprecation notice in jax101-pytrees tutorial
Signed-off-by: Simon Butt <simonbutt123@gmail.com>
2022-12-03 21:22:25 +00:00
Yash Katariya
e814f70547 Raise an error when a numpy input is passed with a non-trivial sharding. This can lead to weird behavior with pjit and XLA since host-local inputs are not allowed with pjit anymore.
PiperOrigin-RevId: 492621424
2022-12-02 20:47:45 -08:00
Hyeontaek Lim
02fab525a7 Add tests to check if pjit handles deleted array inputs gracefully and consistently
pjit dispatch paths should check deleted array inputs when attempting to use
them. These new tests ensure that various pjit dispatch paths detect and handle
them gracefully and consistently.

Add a check to the PyArray argument handling to make the tests pass.

PiperOrigin-RevId: 492605524
2022-12-02 18:41:31 -08:00
jax authors
693047a14b Merge pull request #13498 from jakevdp:x64-other-tests
PiperOrigin-RevId: 492593760
2022-12-02 17:10:04 -08:00
jax authors
f22d4a84cf Merge pull request #13490 from jakevdp:x64-check-grads
PiperOrigin-RevId: 492592519
2022-12-02 17:01:56 -08:00
Hyeontaek Lim
06755ad249 Reduce the buffer size used in ShardedDeviceArrayTest.testThreadsafeIndexing
testThreadsafeIndexing uses a fairly large buffer size. When overlapping many
executions under a constraint host memory for testing using an alternative
backend, this test may hit the maximum allowed memory use.

This change reduces the buffer size by half, which is likely still interesting
and runs more reliably on an alternative backend.

PiperOrigin-RevId: 492588538
2022-12-02 16:38:47 -08:00
Peter Hawkins
c2c3669c15 Remove long-deprecated method .block_host_until_ready().
PiperOrigin-RevId: 492571809
2022-12-02 15:18:11 -08:00
Peter Hawkins
5e102c17d6 Implement .on_device_size_in_bytes() on jax.Array.
This is an array present in DeviceArray that is missing from Array.

PiperOrigin-RevId: 492571171
2022-12-02 15:11:27 -08:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
Jake VanderPlas
9e53de888a [x64] make chack_grads() more type-safe 2022-12-02 12:51:41 -08:00
jax authors
8a28ccd6fd Merge pull request #13491 from jakevdp:x64-stax
PiperOrigin-RevId: 492526309
2022-12-02 12:06:11 -08:00
Peter Hawkins
f9b5312149 Do not mirror JAX config options back to ABSL flags.
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.

However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.

This gives a decent improvement on the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  79.5µs ± 6%  69.4µs ± 7%  -12.73%  (p=0.000 n=10+9)

name        old time/op             new time/op             delta
device_put  79.5µs ± 6%             69.4µs ± 7%  -12.73%         (p=0.000 n=10+9)
```

PiperOrigin-RevId: 492519085
2022-12-02 11:37:22 -08:00
Jake VanderPlas
8431e43fe5 [x64] more type safety in stax_test.py 2022-12-02 10:02:25 -08:00
jax authors
1027d55b8c Optimize core.find_top_trace
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.

PiperOrigin-RevId: 492481837
2022-12-02 09:00:50 -08:00
Adam Paszke
bbf22db08b Optimize core.find_top_trace
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.

PiperOrigin-RevId: 492460102
2022-12-02 07:04:52 -08:00
jax authors
01377bc9a6 Merge pull request #13485 from jakevdp:x64-random
PiperOrigin-RevId: 492367153
2022-12-01 20:43:18 -08:00
jax authors
5927032664 Merge pull request #13482 from jakevdp:x64-signal
PiperOrigin-RevId: 492367133
2022-12-01 20:36:41 -08:00
jax authors
e1d118c38d Merge pull request #13476 from jakevdp:x64-lax-numpy
PiperOrigin-RevId: 492367125
2022-12-01 20:29:46 -08:00
jax authors
ff6f215ceb Merge pull request #13481 from jakevdp:x64-line-search
PiperOrigin-RevId: 492359730
2022-12-01 19:40:58 -08:00
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
jax authors
ed9519dadf Merge pull request #13484 from google:index
PiperOrigin-RevId: 492323132
2022-12-01 16:02:50 -08:00
Jake VanderPlas
fdf5894c75 [x64] make random_test more type-safe 2022-12-01 15:51:37 -08:00