Yash Katariya
934bc4e1b3
Move PartitionSpec
and Mesh
out of experimental and into the sharding
namespace. The new API endpoint is jax.sharding.PartitionSpec
and jax.sharding.Mesh
.
...
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
Skye Wanderman-Milne
82b442fa52
[docs] Replace one more jax_Array.html reference
...
I missed this in #13479 , thanks @yashk2810 for flagging!
2022-12-01 20:50:55 +00:00
George Necula
4ca05f428f
[call_tf] Use the same platform for TF lowering as the embedding JAX computation
...
This requires some changes for abstract evaluation, when
JAX does not use a specific platform.
Also attempt to fix the case when the TF lowering fails because the TF computation
uses a tf.Variable on another device as that used for lowering.
PiperOrigin-RevId: 492112847
2022-11-30 23:22:24 -08:00
Jake VanderPlas
cb62a31653
Drop support for Python 3.7
2022-11-29 15:01:47 -08:00
TJ
5fb0215d4d
updated jaxlib CHANGELOG
2022-11-28 10:37:42 -08:00
Yash Katariya
9799d5b139
Add the jax.Array
change to the changelog.
...
PiperOrigin-RevId: 488929264
2022-11-16 06:56:09 -08:00
Yash Katariya
8c42edfec1
Finish jax and jaxlib release 0.3.25. The next release will be 0.4.0 (since jax.Array will be enabled in that release)
...
PiperOrigin-RevId: 488672395
2022-11-15 09:02:53 -08:00
Peter Hawkins
ebd9840e1f
Add several recent changes to the CHANGELOG.
...
PiperOrigin-RevId: 488362198
2022-11-14 07:39:13 -08:00
Sharad Vikram
e15619ceab
Convert string axis name into tuple of strings in Mesh constructor
...
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Sharad Vikram
4bdfdd7363
Update changelog w/ info about deleting jax_experimental_name_stack
2022-11-11 14:02:30 -08:00
Peter Hawkins
1cead779a3
Add support for Hessenberg and tridiagonal matrix reductions on CPU.
...
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.
None of these primitives are differentiable at the moment.
PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
ab8cde9ed4
Add support for the hermitian option on jnp.linalg.pinv.
...
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
Yash Katariya
1d48c93b0e
Finish the release of jax and jaxlib 0.3.24
...
PiperOrigin-RevId: 486162090
2022-11-04 09:43:12 -07:00
Yash Katariya
cc5af7ed98
Rename ReshapeableDevicesSharding
to PositionalSharding
and add an alias NamedSharding
for MeshPspecSharding
.
...
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.
PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
jax authors
ef63f75e39
Merge pull request #13039 from skye:cache_compile_time_heuristic
...
PiperOrigin-RevId: 485644419
2022-11-02 11:13:52 -07:00
Skye Wanderman-Milne
cc5171034f
Add new config jax_persistent_cache_min_compile_time_secs
.
...
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798 , since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).
I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt ) numbers from
name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344
Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
Jake VanderPlas
2416d15435
Call _check_arraylike for jnp.linalg & jnp.fft functions
2022-10-31 09:19:53 -07:00
Peter Hawkins
bf21391248
[JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
...
PiperOrigin-RevId: 484062717
2022-10-26 13:56:07 -07:00
Peter Hawkins
ce9e009c4c
[JAX:CPU] Enable buffer donation on CPU.
...
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.
PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
Jake VanderPlas
2009e65a33
jnp.gradient: call check_arraylike on inputs & clean-up implementation
2022-10-24 15:27:33 -07:00
Jake VanderPlas
4aceb81570
Add docs & changelog for jax.scipy.stats.mode
2022-10-20 15:55:57 -07:00
Skye Wanderman-Milne
81eb3fca55
Add new config jax_persistent_cache_min_instruction_count
.
...
This can be used to limit the number of entries written to the
persistent compilation cache.
I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt ) and logging
the instruction counts and complilation time:
name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344
Also adds new int config functionality.
Fixes #12583
2022-10-20 00:17:24 +00:00
Peter Hawkins
9ab88071a7
Avoid loading scipy eagerly.
...
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Skye Wanderman-Milne
1511439c01
Update version.py and CHANGELOG for jax 0.3.23 release
2022-10-12 11:09:15 -07:00
Skye Wanderman-Milne
e2aa939147
Update Colab TPU driver version
2022-10-11 17:49:10 -07:00
Skye Wanderman-Milne
a6e0e77624
Update version.py, setup.py and CHANGELOG post jax 0.3.22 release
2022-10-11 21:37:53 +00:00
jax authors
af98d8b00f
Merge pull request #12746 from jakevdp:fix-changelog
...
PiperOrigin-RevId: 480375907
2022-10-11 09:21:28 -07:00
Jake VanderPlas
05faf0f40d
Remove deprecated functionality from jax.test_util
...
PiperOrigin-RevId: 480360504
2022-10-11 08:16:34 -07:00
Jake VanderPlas
943129419c
changelog: add missing github commit links
2022-10-11 06:24:38 -07:00
Ran Ran
f3ded0fc1e
Address comments for change log
2022-10-05 18:16:49 +00:00
Ran Ran
4870fd3be8
Update message and change log
2022-10-05 04:39:04 +00:00
Skye Wanderman-Milne
2a3460ff8a
Update version and CHANGELOG for jax 0.3.21 release
2022-10-03 22:06:19 +00:00
Skye Wanderman-Milne
15e5f38a16
Make persistent compilation cache warn instead of raise an error on cache read/write failures
...
Fixes #12582 . Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.
Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.
2022-09-30 18:38:22 +00:00
Peter Hawkins
b49e31a012
Update version numbers after release.
2022-09-28 18:49:22 +00:00
Peter Hawkins
8d8643664c
jax/jaxlib 0.3.20 release candidate.
2022-09-28 13:33:52 +00:00
Skye Wanderman-Milne
d028d93983
Update version and changelog for jax 0.3.19 release
2022-09-27 11:00:27 -07:00
Skye Wanderman-Milne
3c0d280bc0
Update version and changelog for jax 0.3.18 release
2022-09-26 12:43:39 -07:00
Jake VanderPlas
0cb233eec9
Add initial jax.Array base class for instance checks & annotation
2022-09-26 07:48:43 -07:00
Peter Hawkins
bcd36d8eb2
Jax and jaxlib 0.3.18 release candidate.
2022-09-26 14:10:57 +00:00
Yash Katariya
da90234cae
Delete soft_pmap as it has no users. Please use pjit
or xmap
if you do want soft_pmap.
...
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided.
PiperOrigin-RevId: 474145090
2022-09-13 15:52:10 -07:00
Peter Hawkins
40c80d7d0a
Remove jax._src from JAX namespace.
...
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX.
PiperOrigin-RevId: 473243243
2022-09-09 07:06:00 -07:00
Roy Frostig
a2ad414e7c
mention AOT readiness in changelog
2022-09-02 13:02:25 -07:00
Jake VanderPlas
ee4ea27c3e
update version and changelog for pypi
2022-08-31 11:37:09 -07:00
Peter Hawkins
5527966b27
[JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
...
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.
PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00
jax authors
0a51a5a2ba
Merge pull request #11944 from jakevdp:rm-tile
...
PiperOrigin-RevId: 469539007
2022-08-23 13:16:50 -07:00
Jake VanderPlas
b8fe0ab8b1
Fix JVP rule for lax.pow()
2022-08-23 11:18:43 -07:00
Jake VanderPlas
8378d08fcd
Remove deprecated array.tile() method
2022-08-23 07:36:59 -07:00
Sharad Vikram
b0fdf10a63
Apply suggestions from code review
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 10:50:50 -07:00
Sharad Vikram
393bca122d
Expose pure callback and enable rank polymorphic callbacks
2022-08-17 10:56:42 -07:00
Matthew Johnson
d19e34fa4a
delete old remat implementation
...
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00