14902 Commits

Author SHA1 Message Date
David Pizzuto
a8f2d9a186 deprecation_module: Move to new internal_test_util directory.
Now we no longer need to mess with sys.path in deprecation_test.
2023-02-17 10:55:04 -08:00
John QiangZhang
fdc8864d9b Add tf.convert_to_tensor for call_tf gradient outputs.
PiperOrigin-RevId: 510453014
2023-02-17 09:39:16 -08:00
Lena Martens
307288b943 Checkify: Remove stray raise_as_much_as_possible.
We switched to initial style, and no longer need to rely on raising all tracers
to an active Checkify trace.

PiperOrigin-RevId: 510450683
2023-02-17 09:29:06 -08:00
jax authors
edff87eb07 Merge pull request #13613 from ROCmSoftwarePlatform:rocm_rt_build
PiperOrigin-RevId: 510440289
2023-02-17 08:40:28 -08:00
Chao
0dde7a0fb1
Update Dockerfile.ms
update to ROCm5.4
2023-02-17 14:33:33 +00:00
jax authors
5ff462d438 Merge pull request #14556 from jakevdp:doc-pytree
PiperOrigin-RevId: 510416966
2023-02-17 06:33:28 -08:00
Jake VanderPlas
47ec553c40 DOC: add alternative for pytree initialization 2023-02-17 06:04:33 -08:00
jax authors
51182258bb Merge pull request #14529 from jakevdp:lax-bitcast-validation
PiperOrigin-RevId: 510410676
2023-02-17 06:01:15 -08:00
jax authors
c62462afda Merge pull request #14554 from hawkinsp:mlirapi
PiperOrigin-RevId: 510410342
2023-02-17 05:53:33 -08:00
Peter Hawkins
9cf3cb4486 Reexport jax.interpreters.mlir.token_type.
Fixes https://github.com/google/jax/issues/14551
2023-02-17 13:26:44 +00:00
Roy Frostig
e276859d11 remove several symbols from jax.core
* `ClosedCallPrimitive`
* `CustomPpEqnRule`
* `DArray`
* `DArrayDimHandler`

PiperOrigin-RevId: 510343926
2023-02-16 22:55:16 -08:00
jax authors
8962d2f701 Merge pull request #14513 from mattjj:shmap-test
PiperOrigin-RevId: 510330159
2023-02-16 21:21:20 -08:00
Matthew Johnson
ab881cb720 [shard-map] add systematic tests for eager, jit, autodiff 2023-02-16 20:40:09 -08:00
Jake VanderPlas
e1333f3de0 Roll-back https://github.com/google/jax/pull/14526 because it breaks view() on scalar inputs
PiperOrigin-RevId: 510281592
2023-02-16 17:07:55 -08:00
jax authors
c467d84eea Merge pull request #14536 from jakevdp:coo-oob
PiperOrigin-RevId: 510281491
2023-02-16 17:00:33 -08:00
Jake VanderPlas
df358242ff [sparse] test coo/csr extra nse 2023-02-16 16:27:31 -08:00
Yash Katariya
eea1fef6e5 Return jax.Array from GDA's callback APIs if jax.Array is True.
PiperOrigin-RevId: 510268071
2023-02-16 16:02:05 -08:00
Peter Hawkins
2b9ad0d93e Move contents of jax.experimental.global_device_array to jax._src.global_device_array.
Make jax.experimental.global_device_array a shim around jax._src.global_device_array.

Change in preparation for deprecating global device arrays.

PiperOrigin-RevId: 510261140
2023-02-16 15:37:10 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Yash Katariya
47dc01637f Create a jax.Array from make_sharded_device_array since SDA is deprecated.
PiperOrigin-RevId: 510251301
2023-02-16 14:52:56 -08:00
Tianjian Lu
4fa69e60a0 [sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
PiperOrigin-RevId: 510248091
2023-02-16 14:40:18 -08:00
Peter Hawkins
c368562529 Add keep_dep tag to :global_device_array build target to hint that it should be kept.
PiperOrigin-RevId: 510241400
2023-02-16 14:15:21 -08:00
Yash Katariya
941722f7db Finish jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510234171
2023-02-16 13:54:56 -08:00
Matthew Johnson
ec1e513659 remove accidental re-export of __future__.annotations from jax/core.py
PiperOrigin-RevId: 510233347
2023-02-16 13:47:28 -08:00
Peter Hawkins
43b615c0a0 Move global_device_array into its own BUILD target.
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
Roy Frostig
591e2c8937 remove some exports from jax.core
Namely:
* `AvalMapHandlerPair`
* `AxisEnvFrame`
* `AxisName`
* `AxisPrimitive`
* `AxisSubst`
PiperOrigin-RevId: 510224417
2023-02-16 13:12:35 -08:00
jax authors
fd6174651c Merge pull request #14535 from jakevdp:csr-api
PiperOrigin-RevId: 510221845
2023-02-16 13:02:08 -08:00
Jake VanderPlas
d1334c80d2 [sparse] bring sparse.csr API in line with sparse.coo 2023-02-16 12:47:38 -08:00
Yash Katariya
34324f80e9 Catch ImportError when importing tf instead of a broad exception catch. If not, this leads to weird errors in the other tests down the line.
PiperOrigin-RevId: 510206006
2023-02-16 12:03:58 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Roy Frostig
6b545a2ddc remove several exported symbols from jax.core
All of these are prefixed by an underscore.

PiperOrigin-RevId: 510194304
2023-02-16 11:20:36 -08:00
jax authors
66e7c0cdce Merge pull request #14526 from jakevdp:ndarray-view
PiperOrigin-RevId: 510194136
2023-02-16 11:12:46 -08:00
Jake VanderPlas
11acec03c3 lax.bitcast_convert_type: better input validation 2023-02-16 10:56:06 -08:00
jax authors
bb04686a98 Merge pull request #14503 from sharadmv:fstring-docs
PiperOrigin-RevId: 510181634
2023-02-16 10:28:53 -08:00
jax authors
4eeca92e54 Merge pull request #14482 from gijskoning:patch-1
PiperOrigin-RevId: 510176609
2023-02-16 10:10:34 -08:00
Jake VanderPlas
b8994f5c3d jnp.ndarray.view: implement all dtypes 2023-02-16 10:07:24 -08:00
jax authors
68276a9a3e Merge pull request #14518 from Schoyen:patch-1
PiperOrigin-RevId: 510165801
2023-02-16 09:30:35 -08:00
jax authors
f323952400 Merge pull request #14501 from jakevdp:lax-bitcast-convert
PiperOrigin-RevId: 510153280
2023-02-16 08:45:32 -08:00
Yash Katariya
58e46b48e6 Prepare for jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510152471
jax-v0.4.4 jaxlib-v0.4.4 jax-v0.4.4-rc
2023-02-16 08:37:15 -08:00
Jake VanderPlas
b18cbbe101 lax.bitcast_convert_type: support casting between types of different width 2023-02-16 08:21:18 -08:00
Peter Hawkins
c6a99b699e Remove jax.interpreters.xla.lower_fun.
This function has been a stub that does nothing useful for a long time, and the only user I can find is Equinox which already guards this with a hasattr(xla, 'lower_fun') guard.

PiperOrigin-RevId: 510142446
2023-02-16 07:51:15 -08:00
George Necula
a9e886f956 [jax2tf] Enable all native lowering jax2tf tests
Filed bugs for the few remaining tests, and disabled them.
Fixed the logging of the compiled HLO on test failure.

PiperOrigin-RevId: 510135651
2023-02-16 07:15:57 -08:00
Øyvind Sigmundson Schøyen
c7ddd2a7fa
DOC: fix typo in sph_harm
:math:\theta` -> :math:`\theta`
2023-02-16 15:34:18 +01:00
George Necula
454e4de524 [shape_poly] Fix the lowering for symbolic dimension expressions for division
The symbolic dimension expression use the Python semantics for division and remainder, while StableHLO is slightly different.

PiperOrigin-RevId: 510056597
2023-02-15 23:51:23 -08:00
John QiangZhang
d0b42f2ce8 Fix the simple bug on call_tf.replace_non_float and add unittest for floating and complex data type.
PiperOrigin-RevId: 510055139
2023-02-15 23:40:54 -08:00
Roy Frostig
26045c49e7 remove core.{aval_method,aval_property}
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6 Merge pull request #14500 from jakevdp:bcsr-matmul-test
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00
Peter Hawkins
0af9fff5ca Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 510027595
2023-02-15 21:03:03 -08:00
Roy Frostig
1b2a318fd1 remove core.axis_substitution_rules
PiperOrigin-RevId: 509989925
2023-02-15 18:42:13 -08:00
Peter Hawkins
768960b4e4 Fix pytype errors.
PiperOrigin-RevId: 509984207
2023-02-15 18:12:42 -08:00