Yash Katariya
13c34f9dc5
Move with_sharding_constraint
out of experimental into jax.lax
namespace.
...
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
Matthew Johnson
1185c895ca
in jax.Array notebook, polish beginning and tweak title and some wording
2022-12-10 22:16:54 -08:00
Jake VanderPlas
09d1b6d8d5
Deprecate jnp.msort following deprecation of numpy.msort
2022-12-07 10:08:18 -08:00
Peter Hawkins
33a1b8866a
Mark arguments to ufuncs as positional-only.
...
PiperOrigin-RevId: 493311821
2022-12-06 08:24:11 -08:00
Peter Hawkins
c2c3669c15
Remove long-deprecated method .block_host_until_ready().
...
PiperOrigin-RevId: 492571809
2022-12-02 15:18:11 -08:00
Peter Hawkins
f9b5312149
Do not mirror JAX config options back to ABSL flags.
...
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.
However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.
This gives a decent improvement on the device_put benchmark:
```
name old cpu/op new cpu/op delta
device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9)
name old time/op new time/op delta
device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9)
```
PiperOrigin-RevId: 492519085
2022-12-02 11:37:22 -08:00
Yash Katariya
934bc4e1b3
Move PartitionSpec
and Mesh
out of experimental and into the sharding
namespace. The new API endpoint is jax.sharding.PartitionSpec
and jax.sharding.Mesh
.
...
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
Skye Wanderman-Milne
82b442fa52
[docs] Replace one more jax_Array.html reference
...
I missed this in #13479 , thanks @yashk2810 for flagging!
2022-12-01 20:50:55 +00:00
George Necula
4ca05f428f
[call_tf] Use the same platform for TF lowering as the embedding JAX computation
...
This requires some changes for abstract evaluation, when
JAX does not use a specific platform.
Also attempt to fix the case when the TF lowering fails because the TF computation
uses a tf.Variable on another device as that used for lowering.
PiperOrigin-RevId: 492112847
2022-11-30 23:22:24 -08:00
Jake VanderPlas
cb62a31653
Drop support for Python 3.7
2022-11-29 15:01:47 -08:00
TJ
5fb0215d4d
updated jaxlib CHANGELOG
2022-11-28 10:37:42 -08:00
Yash Katariya
9799d5b139
Add the jax.Array
change to the changelog.
...
PiperOrigin-RevId: 488929264
2022-11-16 06:56:09 -08:00
Yash Katariya
8c42edfec1
Finish jax and jaxlib release 0.3.25. The next release will be 0.4.0 (since jax.Array will be enabled in that release)
...
PiperOrigin-RevId: 488672395
2022-11-15 09:02:53 -08:00
Peter Hawkins
ebd9840e1f
Add several recent changes to the CHANGELOG.
...
PiperOrigin-RevId: 488362198
2022-11-14 07:39:13 -08:00
Sharad Vikram
e15619ceab
Convert string axis name into tuple of strings in Mesh constructor
...
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Sharad Vikram
4bdfdd7363
Update changelog w/ info about deleting jax_experimental_name_stack
2022-11-11 14:02:30 -08:00
Peter Hawkins
1cead779a3
Add support for Hessenberg and tridiagonal matrix reductions on CPU.
...
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
ab8cde9ed4
Add support for the hermitian option on jnp.linalg.pinv.
...
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
Yash Katariya
1d48c93b0e
Finish the release of jax and jaxlib 0.3.24
...
PiperOrigin-RevId: 486162090
2022-11-04 09:43:12 -07:00
Yash Katariya
cc5af7ed98
Rename ReshapeableDevicesSharding
to PositionalSharding
and add an alias NamedSharding
for MeshPspecSharding
.
...
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.
PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
jax authors
ef63f75e39
Merge pull request #13039 from skye:cache_compile_time_heuristic
...
PiperOrigin-RevId: 485644419
2022-11-02 11:13:52 -07:00
Skye Wanderman-Milne
cc5171034f
Add new config jax_persistent_cache_min_compile_time_secs
.
...
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798 , since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).
I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt ) numbers from
name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344
Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
Jake VanderPlas
2416d15435
Call _check_arraylike for jnp.linalg & jnp.fft functions
2022-10-31 09:19:53 -07:00
Peter Hawkins
bf21391248
[JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
...
PiperOrigin-RevId: 484062717
2022-10-26 13:56:07 -07:00
Peter Hawkins
ce9e009c4c
[JAX:CPU] Enable buffer donation on CPU.
...
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.
PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
Jake VanderPlas
2009e65a33
jnp.gradient: call check_arraylike on inputs & clean-up implementation
2022-10-24 15:27:33 -07:00
Jake VanderPlas
4aceb81570
Add docs & changelog for jax.scipy.stats.mode
2022-10-20 15:55:57 -07:00
Skye Wanderman-Milne
81eb3fca55
Add new config jax_persistent_cache_min_instruction_count
.
...
This can be used to limit the number of entries written to the
persistent compilation cache.
I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt ) and logging
the instruction counts and complilation time:
name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344
Also adds new int config functionality.
Fixes #12583
2022-10-20 00:17:24 +00:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Skye Wanderman-Milne
1511439c01
Update version.py and CHANGELOG for jax 0.3.23 release
2022-10-12 11:09:15 -07:00
Skye Wanderman-Milne
e2aa939147
Update Colab TPU driver version
2022-10-11 17:49:10 -07:00
Skye Wanderman-Milne
a6e0e77624
Update version.py, setup.py and CHANGELOG post jax 0.3.22 release
2022-10-11 21:37:53 +00:00
jax authors
af98d8b00f
Merge pull request #12746 from jakevdp:fix-changelog
...
PiperOrigin-RevId: 480375907
2022-10-11 09:21:28 -07:00
Jake VanderPlas
05faf0f40d
Remove deprecated functionality from jax.test_util
...
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Jake VanderPlas
943129419c
changelog: add missing github commit links
2022-10-11 06:24:38 -07:00
Ran Ran
f3ded0fc1e
Address comments for change log
2022-10-05 18:16:49 +00:00
Ran Ran
4870fd3be8
Update message and change log
2022-10-05 04:39:04 +00:00
Skye Wanderman-Milne
2a3460ff8a
Update version and CHANGELOG for jax 0.3.21 release
2022-10-03 22:06:19 +00:00
Skye Wanderman-Milne
15e5f38a16
Make persistent compilation cache warn instead of raise an error on cache read/write failures
...
Fixes #12582 . Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.
Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
Peter Hawkins
b49e31a012
Update version numbers after release.
2022-09-28 18:49:22 +00:00
Peter Hawkins
8d8643664c
jax/jaxlib 0.3.20 release candidate.
2022-09-28 13:33:52 +00:00
Skye Wanderman-Milne
d028d93983
Update version and changelog for jax 0.3.19 release
2022-09-27 11:00:27 -07:00
Skye Wanderman-Milne
3c0d280bc0
Update version and changelog for jax 0.3.18 release
2022-09-26 12:43:39 -07:00
Jake VanderPlas
0cb233eec9
Add initial jax.Array base class for instance checks & annotation
2022-09-26 07:48:43 -07:00
Peter Hawkins
bcd36d8eb2
Jax and jaxlib 0.3.18 release candidate.
2022-09-26 14:10:57 +00:00
Yash Katariya
da90234cae
Delete soft_pmap as it has no users. Please use pjit
or xmap
if you do want soft_pmap.
...
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided.
PiperOrigin-RevId: 474145090
2022-09-13 15:52:10 -07:00
Peter Hawkins
40c80d7d0a
Remove jax._src from JAX namespace.
...
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX.
PiperOrigin-RevId: 473243243
2022-09-09 07:06:00 -07:00
Roy Frostig
a2ad414e7c
mention AOT readiness in changelog
2022-09-02 13:02:25 -07:00
Jake VanderPlas
ee4ea27c3e
update version and changelog for pypi
2022-08-31 11:37:09 -07:00
Peter Hawkins
5527966b27
[JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
...
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.
PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00