12306 Commits

Author SHA1 Message Date
Matthew Johnson
3ab8637a73 Update README discussion of Windows support:
* right next to the pip installation instructions, mention they don't work for Windows;
* add a link to #5795 for an unofficial discussion of Windows native support
2022-07-19 17:12:40 -07:00
jax authors
6ea9e4d4dd Merge pull request #11546 from jakevdp:fix-scipy-sym-pos
PiperOrigin-RevId: 461976166
2022-07-19 14:26:51 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Yash Katariya
9f914a93d6 Replace input sharding_specs with in_shardings in InputsHandler
PiperOrigin-RevId: 461963206
2022-07-19 13:43:28 -07:00
jax authors
226cef08bf Merge pull request #11544 from jakevdp:conv-general-pad
PiperOrigin-RevId: 461937001
2022-07-19 11:46:08 -07:00
Jake VanderPlas
489596c0e2 lax.conv_general_dilated: validate negative paddings 2022-07-19 11:15:18 -07:00
jax authors
ac731bb4cc Merge pull request #11518 from ROCmSoftwarePlatform:rocm_disable_dlpack_unit_tests
PiperOrigin-RevId: 461904181
2022-07-19 09:48:36 -07:00
jax authors
97fe2262d3 Merge pull request #11526 from atgctg:patch-2
PiperOrigin-RevId: 461903614
2022-07-19 09:42:30 -07:00
jax authors
388733b533 Merge pull request #11460 from jakevdp:profiler-deprecation
PiperOrigin-RevId: 461903612
2022-07-19 09:36:36 -07:00
Benjamin Kramer
2c72858928 Integrate LLVM at llvm/llvm-project@8aff88fd3a
Updates LLVM usage to match
[8aff88fd3a5f](https://github.com/llvm/llvm-project/commit/8aff88fd3a5f)

PiperOrigin-RevId: 461889195
2022-07-19 08:31:24 -07:00
Jake VanderPlas
2543542fa8 jax.profiler: remove deprecated functions 2022-07-19 08:13:44 -07:00
jax authors
12ce369c94 Merge pull request #11506 from gnecula:jax2tf_call_module
PiperOrigin-RevId: 461834766
2022-07-19 02:42:57 -07:00
George Necula
ee50140701 [jax2tf] A new experimental version with JAX native lowering.
In the future JAX will be able to use a serialization format
based on a variant of MHLO. This is not yet ready, but in this PR
we are starting to get jax2tf ready for this. As a temporary
step, we had introduced a TF op called XlaCallModule which carries
a serialized MHLO module and which e can use to wrap the JAX native
MHLO as a TF op. We still reuse parts of jax2tf, in particular
the gradient machinery.

This functionality can be enabled locally with a
`experimental_native_lowering` flag for `jax2tf.convert`, or
globally with the flag `--jax2tf_default_experimental_native_lowering`.
2022-07-19 10:50:04 +02:00
Yash Katariya
ea627b807b Replace out_specs with out_shardings and remove out_indices in ResultsHandler.
PiperOrigin-RevId: 461788039
2022-07-18 20:57:02 -07:00
jax authors
e1fdd57ff3 Merge pull request #11532 from jakevdp:fix-upstream-ci
PiperOrigin-RevId: 461706572
2022-07-18 13:57:56 -07:00
jax authors
b0805a8a31 Fixes the JAX implementation of CELU returning NaN gradients for input
values >= 88.7229.

When a JAX where() op is used to avoid a NaN or undefined value, reverse
differentiation can still return NaN even though the NaN input is not selected
by the conditional:

https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

This change uses jnp.maximum and jnp.minimum to compute CELU without producing an undefined value.

PiperOrigin-RevId: 461678140
2022-07-18 11:58:05 -07:00
Jake VanderPlas
2eaa44e8d8 [CI] upstream-dev: run all tests 2022-07-18 11:06:56 -07:00
jax authors
d98d5ddce5 [JAX] Add jax_unique_mhlo_module_names flag to control if MHLO should be made unique.
Some clients of JAX expect module names to not be altered so that they can cache XLA compilations.

PiperOrigin-RevId: 461648129
2022-07-18 10:05:44 -07:00
atgctg
1d5fbeadd2
Fix typo 2022-07-18 11:23:21 +02:00
jax authors
ae4aee762a [jax2tf] Fix conv1d padding; it's already normalized before the _pad_spatial_dims call. Enable non-XLA tests of conv1d.
PiperOrigin-RevId: 461556553
2022-07-18 01:28:18 -07:00
Rohit Santhanam
235ea7059c [ROCm] Disable new array_interoperability dlpack tests. 2022-07-17 04:48:11 +00:00
jax authors
a08a1f284a Merge pull request #11504 from gnecula:shape_poly_conv2
PiperOrigin-RevId: 461368469
2022-07-16 11:33:59 -07:00
jax authors
a2371925fd Merge pull request #11513 from jakevdp:from-dlpack
PiperOrigin-RevId: 461274756
2022-07-15 17:46:32 -07:00
Jake VanderPlas
2f4c485a54 Add dlpack support to device_array and jax.numpy 2022-07-15 17:31:11 -07:00
jax authors
7d7aa467f4 Merge pull request #11514 from jakevdp:stdout-atty
PiperOrigin-RevId: 461266411
2022-07-15 16:53:33 -07:00
jax authors
0da5657ac7 Merge pull request #11507 from jakevdp:tree-util-warning-level
PiperOrigin-RevId: 461266381
2022-07-15 16:48:13 -07:00
Yash Katariya
90687cc1ff Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation except of AUTO and UNSPECIFIED (see below for these 2).
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
  * `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.

* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.

* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.

* Rest of the changes are to make all other functions play nicely with sharding instances.

PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00
jax authors
005fcf333f Merge pull request #11515 from google:yashk2810-patch-13
PiperOrigin-RevId: 461260081
2022-07-15 16:10:37 -07:00
Yash Katariya
90433c0518
Update ci-build.yaml 2022-07-15 16:03:16 -07:00
Jake VanderPlas
b41f33b0d7 pretty_printing: handle case where stdout is patched by a logger 2022-07-15 14:50:17 -07:00
jax authors
f12b7fb0bb Merge pull request #11509 from jakevdp:sparse-lower
PiperOrigin-RevId: 461242221
2022-07-15 14:50:02 -07:00
jax authors
5a10c1af3c Merge pull request #11487 from google:upstream-dev-logs
PiperOrigin-RevId: 461241333
2022-07-15 14:43:37 -07:00
Jake VanderPlas
c1549a0a16 [sparse] make sparse objects compatible with jax.jit.lower() 2022-07-15 09:58:31 -07:00
Jake VanderPlas
6907dfad00 tree_util: fix warning category and stacklevel 2022-07-15 09:24:22 -07:00
Tom Hennigan
10720258ea Reduce the verbosity of treedef printing for custom nodes.
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.

Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:

    >>> params = { .. some params .. }
    >>> f = jax.jit(..).lower(params).compile()
    >>> f(params)  # fine
    >>> params['some_new_thing'] = something
    >>> f(params)
    TypeError: function compiled for {treedef}, called with {treedef}.

PiperOrigin-RevId: 461190971
2022-07-15 07:14:28 -07:00
jax authors
023e6f5955 Copybara import of the project:
--
e1f1e93e0c8b53e62a064b06b56c84a2bfedb911 by Roy Frostig <frostig@google.com>:

maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 461146464
2022-07-15 01:23:51 -07:00
jax authors
e19f6973c4 Merge pull request #11498 from froystig:tree-map-top-level
PiperOrigin-RevId: 461131693
2022-07-14 23:23:41 -07:00
Tianjian Lu
b421e24bb0 [sparse] Update _validate_coo_mhlo in gpu_sparse.
PiperOrigin-RevId: 461111317
2022-07-14 20:35:09 -07:00
George Necula
777c129dfb [dynamic-shapes] Split dynamic_api_test.py
PiperOrigin-RevId: 461109288
2022-07-14 20:18:53 -07:00
George Necula
e6f93bcdc0 [shape-poly] Improve the error reporting for division
Added a section to README to explain the division errors
and to show a workaround. Changed the division errors
to include more detail as to what the error is,
and to include a link to the new section in the README
2022-07-15 05:37:55 +03:00
jax authors
a35f9acbff Merge pull request #11476 from gnecula:shape_poly_conv
PiperOrigin-RevId: 461102579
2022-07-14 19:27:12 -07:00
jax authors
e40f9c2af2 Merge pull request #11482 from jakevdp:numpy-api-tests
PiperOrigin-RevId: 461056275
2022-07-14 14:52:45 -07:00
Jake VanderPlas
0f14943524 lax_numpy_test: make compatible with numpy 1.24-dev 2022-07-14 14:35:10 -07:00
Jake VanderPlas
b0cd7de999 CI: use minimum jaxlib in upstream-ci build 2022-07-14 14:34:47 -07:00
jax authors
cf61646ad8 Merge pull request #11491 from shoyer:user-dtype-stacklevel
PiperOrigin-RevId: 461039368
2022-07-14 13:37:20 -07:00
Roy Frostig
e1f1e93e0c maintain an alias to jax.tree_util.tree_map in the top level jax module 2022-07-14 11:00:54 -07:00
jax authors
881d16c8fe Merge pull request #11497 from skye:workspace
PiperOrigin-RevId: 460993471
2022-07-14 10:19:28 -07:00
Skye Wanderman-Milne
9149c38e1e Update WORKSPACE and setup.py in preparation for 0.3.15 jax/jaxlib release 2022-07-14 10:12:59 -07:00
George Necula
63f8ee85d9 Address review comments 2022-07-14 13:20:44 +03:00
George Necula
b22121c0c1 [jax2tf] Fixes for handling of convolutions with shape_polymorphism and enable_xla=False
Issue: #11402

Due to a typo we were running no tests for convolutions with shape
polymorphism and enable_xla=False.

Added a few more tests from #11402 (Thanks @sdenton4).

The main issue was that in presence of shape polymorphism we cannot
just use `x.shape` for a TF value `x` because it will contain `None`
in the place of unknown dimensions. We must use instead the JAX
abstract values.

This does not fix all issues reported in #11402, there is still the
computation of padding or padding="SAME". Commented out the
corresponding test.
2022-07-14 13:20:41 +03:00