300 Commits

Author SHA1 Message Date
Yash Katariya
13c34f9dc5 Move with_sharding_constraint out of experimental into jax.lax namespace.
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
Matthew Johnson
1185c895ca in jax.Array notebook, polish beginning and tweak title and some wording 2022-12-10 22:16:54 -08:00
Jake VanderPlas
09d1b6d8d5 Deprecate jnp.msort following deprecation of numpy.msort 2022-12-07 10:08:18 -08:00
Peter Hawkins
33a1b8866a Mark arguments to ufuncs as positional-only.
PiperOrigin-RevId: 493311821
2022-12-06 08:24:11 -08:00
Peter Hawkins
c2c3669c15 Remove long-deprecated method .block_host_until_ready().
PiperOrigin-RevId: 492571809
2022-12-02 15:18:11 -08:00
Peter Hawkins
f9b5312149 Do not mirror JAX config options back to ABSL flags.
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.

However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.

This gives a decent improvement on the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  79.5µs ± 6%  69.4µs ± 7%  -12.73%  (p=0.000 n=10+9)

name        old time/op             new time/op             delta
device_put  79.5µs ± 6%             69.4µs ± 7%  -12.73%         (p=0.000 n=10+9)
```

PiperOrigin-RevId: 492519085
2022-12-02 11:37:22 -08:00
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
Skye Wanderman-Milne
82b442fa52 [docs] Replace one more jax_Array.html reference
I missed this in #13479, thanks @yashk2810 for flagging!
2022-12-01 20:50:55 +00:00
George Necula
4ca05f428f [call_tf] Use the same platform for TF lowering as the embedding JAX computation
This requires some changes for abstract evaluation, when
JAX does not use a specific platform.

Also attempt to fix the case when the TF lowering fails because the TF computation
uses a tf.Variable on another device as that used for lowering.

PiperOrigin-RevId: 492112847
2022-11-30 23:22:24 -08:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
TJ
5fb0215d4d updated jaxlib CHANGELOG 2022-11-28 10:37:42 -08:00
Yash Katariya
9799d5b139 Add the jax.Array change to the changelog.
PiperOrigin-RevId: 488929264
2022-11-16 06:56:09 -08:00
Yash Katariya
8c42edfec1 Finish jax and jaxlib release 0.3.25. The next release will be 0.4.0 (since jax.Array will be enabled in that release)
PiperOrigin-RevId: 488672395
2022-11-15 09:02:53 -08:00
Peter Hawkins
ebd9840e1f Add several recent changes to the CHANGELOG.
PiperOrigin-RevId: 488362198
2022-11-14 07:39:13 -08:00
Sharad Vikram
e15619ceab Convert string axis name into tuple of strings in Mesh constructor
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Sharad Vikram
4bdfdd7363 Update changelog w/ info about deleting jax_experimental_name_stack 2022-11-11 14:02:30 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
ab8cde9ed4 Add support for the hermitian option on jnp.linalg.pinv.
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
Yash Katariya
1d48c93b0e Finish the release of jax and jaxlib 0.3.24
PiperOrigin-RevId: 486162090
2022-11-04 09:43:12 -07:00
Yash Katariya
cc5af7ed98 Rename ReshapeableDevicesSharding to PositionalSharding and add an alias NamedSharding for MeshPspecSharding.
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.

PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
jax authors
ef63f75e39 Merge pull request #13039 from skye:cache_compile_time_heuristic
PiperOrigin-RevId: 485644419
2022-11-02 11:13:52 -07:00
Skye Wanderman-Milne
cc5171034f Add new config jax_persistent_cache_min_compile_time_secs.
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798, since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).

I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) numbers from

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
Jake VanderPlas
2416d15435 Call _check_arraylike for jnp.linalg & jnp.fft functions 2022-10-31 09:19:53 -07:00
Peter Hawkins
bf21391248 [JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
PiperOrigin-RevId: 484062717
2022-10-26 13:56:07 -07:00
Peter Hawkins
ce9e009c4c [JAX:CPU] Enable buffer donation on CPU.
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.

PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
Jake VanderPlas
2009e65a33 jnp.gradient: call check_arraylike on inputs & clean-up implementation 2022-10-24 15:27:33 -07:00
Jake VanderPlas
4aceb81570 Add docs & changelog for jax.scipy.stats.mode 2022-10-20 15:55:57 -07:00
Skye Wanderman-Milne
81eb3fca55 Add new config jax_persistent_cache_min_instruction_count.
This can be used to limit the number of entries written to the
persistent compilation cache.

I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) and logging
the instruction counts and complilation time:

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new int config functionality.

Fixes #12583
2022-10-20 00:17:24 +00:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Skye Wanderman-Milne
1511439c01 Update version.py and CHANGELOG for jax 0.3.23 release 2022-10-12 11:09:15 -07:00
Skye Wanderman-Milne
e2aa939147 Update Colab TPU driver version 2022-10-11 17:49:10 -07:00
Skye Wanderman-Milne
a6e0e77624 Update version.py, setup.py and CHANGELOG post jax 0.3.22 release 2022-10-11 21:37:53 +00:00
jax authors
af98d8b00f Merge pull request #12746 from jakevdp:fix-changelog
PiperOrigin-RevId: 480375907
2022-10-11 09:21:28 -07:00
Jake VanderPlas
05faf0f40d Remove deprecated functionality from jax.test_util
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Jake VanderPlas
943129419c changelog: add missing github commit links 2022-10-11 06:24:38 -07:00
Ran Ran
f3ded0fc1e Address comments for change log 2022-10-05 18:16:49 +00:00
Ran Ran
4870fd3be8 Update message and change log 2022-10-05 04:39:04 +00:00
Skye Wanderman-Milne
2a3460ff8a Update version and CHANGELOG for jax 0.3.21 release 2022-10-03 22:06:19 +00:00
Skye Wanderman-Milne
15e5f38a16 Make persistent compilation cache warn instead of raise an error on cache read/write failures
Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.

Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
Peter Hawkins
b49e31a012 Update version numbers after release. 2022-09-28 18:49:22 +00:00
Peter Hawkins
8d8643664c jax/jaxlib 0.3.20 release candidate. 2022-09-28 13:33:52 +00:00
Skye Wanderman-Milne
d028d93983 Update version and changelog for jax 0.3.19 release 2022-09-27 11:00:27 -07:00
Skye Wanderman-Milne
3c0d280bc0 Update version and changelog for jax 0.3.18 release 2022-09-26 12:43:39 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
Peter Hawkins
bcd36d8eb2 Jax and jaxlib 0.3.18 release candidate. 2022-09-26 14:10:57 +00:00
Yash Katariya
da90234cae Delete soft_pmap as it has no users. Please use pjit or xmap if you do want soft_pmap.
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided.

PiperOrigin-RevId: 474145090
2022-09-13 15:52:10 -07:00
Peter Hawkins
40c80d7d0a Remove jax._src from JAX namespace.
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX.

PiperOrigin-RevId: 473243243
2022-09-09 07:06:00 -07:00
Roy Frostig
a2ad414e7c mention AOT readiness in changelog 2022-09-02 13:02:25 -07:00
Jake VanderPlas
ee4ea27c3e update version and changelog for pypi 2022-08-31 11:37:09 -07:00
Peter Hawkins
5527966b27 [JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.

PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00