David Pizzuto
a8f2d9a186
deprecation_module: Move to new internal_test_util directory.
...
Now we no longer need to mess with sys.path in deprecation_test.
2023-02-17 10:55:04 -08:00
John QiangZhang
fdc8864d9b
Add tf.convert_to_tensor for call_tf gradient outputs.
...
PiperOrigin-RevId: 510453014
2023-02-17 09:39:16 -08:00
Lena Martens
307288b943
Checkify: Remove stray raise_as_much_as_possible
.
...
We switched to initial style, and no longer need to rely on raising all tracers
to an active Checkify trace.
PiperOrigin-RevId: 510450683
2023-02-17 09:29:06 -08:00
jax authors
edff87eb07
Merge pull request #13613 from ROCmSoftwarePlatform:rocm_rt_build
...
PiperOrigin-RevId: 510440289
2023-02-17 08:40:28 -08:00
Chao
0dde7a0fb1
Update Dockerfile.ms
...
update to ROCm5.4
2023-02-17 14:33:33 +00:00
jax authors
5ff462d438
Merge pull request #14556 from jakevdp:doc-pytree
...
PiperOrigin-RevId: 510416966
2023-02-17 06:33:28 -08:00
Jake VanderPlas
47ec553c40
DOC: add alternative for pytree initialization
2023-02-17 06:04:33 -08:00
jax authors
51182258bb
Merge pull request #14529 from jakevdp:lax-bitcast-validation
...
PiperOrigin-RevId: 510410676
2023-02-17 06:01:15 -08:00
jax authors
c62462afda
Merge pull request #14554 from hawkinsp:mlirapi
...
PiperOrigin-RevId: 510410342
2023-02-17 05:53:33 -08:00
Peter Hawkins
9cf3cb4486
Reexport jax.interpreters.mlir.token_type.
...
Fixes https://github.com/google/jax/issues/14551
2023-02-17 13:26:44 +00:00
Roy Frostig
e276859d11
remove several symbols from jax.core
...
* `ClosedCallPrimitive`
* `CustomPpEqnRule`
* `DArray`
* `DArrayDimHandler`
PiperOrigin-RevId: 510343926
2023-02-16 22:55:16 -08:00
jax authors
8962d2f701
Merge pull request #14513 from mattjj:shmap-test
...
PiperOrigin-RevId: 510330159
2023-02-16 21:21:20 -08:00
Matthew Johnson
ab881cb720
[shard-map] add systematic tests for eager, jit, autodiff
2023-02-16 20:40:09 -08:00
Jake VanderPlas
e1333f3de0
Roll-back https://github.com/google/jax/pull/14526 because it breaks view()
on scalar inputs
...
PiperOrigin-RevId: 510281592
2023-02-16 17:07:55 -08:00
jax authors
c467d84eea
Merge pull request #14536 from jakevdp:coo-oob
...
PiperOrigin-RevId: 510281491
2023-02-16 17:00:33 -08:00
Jake VanderPlas
df358242ff
[sparse] test coo/csr extra nse
2023-02-16 16:27:31 -08:00
Yash Katariya
eea1fef6e5
Return jax.Array from GDA's callback APIs if jax.Array is True.
...
PiperOrigin-RevId: 510268071
2023-02-16 16:02:05 -08:00
Peter Hawkins
2b9ad0d93e
Move contents of jax.experimental.global_device_array to jax._src.global_device_array.
...
Make jax.experimental.global_device_array a shim around jax._src.global_device_array.
Change in preparation for deprecating global device arrays.
PiperOrigin-RevId: 510261140
2023-02-16 15:37:10 -08:00
pizzud
631e4ed7e0
lax_test: Create a separate module for lax-specific test utils in a new package.
...
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.
The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.
Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.
PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Yash Katariya
47dc01637f
Create a jax.Array from make_sharded_device_array since SDA is deprecated.
...
PiperOrigin-RevId: 510251301
2023-02-16 14:52:56 -08:00
Tianjian Lu
4fa69e60a0
[sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
...
PiperOrigin-RevId: 510248091
2023-02-16 14:40:18 -08:00
Peter Hawkins
c368562529
Add keep_dep tag to :global_device_array build target to hint that it should be kept.
...
PiperOrigin-RevId: 510241400
2023-02-16 14:15:21 -08:00
Yash Katariya
941722f7db
Finish jax and jaxlib 0.4.4 release
...
PiperOrigin-RevId: 510234171
2023-02-16 13:54:56 -08:00
Matthew Johnson
ec1e513659
remove accidental re-export of __future__.annotations from jax/core.py
...
PiperOrigin-RevId: 510233347
2023-02-16 13:47:28 -08:00
Peter Hawkins
43b615c0a0
Move global_device_array into its own BUILD target.
...
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
Roy Frostig
591e2c8937
remove some exports from jax.core
...
Namely:
* `AvalMapHandlerPair`
* `AxisEnvFrame`
* `AxisName`
* `AxisPrimitive`
* `AxisSubst`
PiperOrigin-RevId: 510224417
2023-02-16 13:12:35 -08:00
jax authors
fd6174651c
Merge pull request #14535 from jakevdp:csr-api
...
PiperOrigin-RevId: 510221845
2023-02-16 13:02:08 -08:00
Jake VanderPlas
d1334c80d2
[sparse] bring sparse.csr API in line with sparse.coo
2023-02-16 12:47:38 -08:00
Yash Katariya
34324f80e9
Catch ImportError when importing tf
instead of a broad exception catch. If not, this leads to weird errors in the other tests down the line.
...
PiperOrigin-RevId: 510206006
2023-02-16 12:03:58 -08:00
Peter Hawkins
54269c1145
Remove more exported names from jax.interpreters.xla.
...
None of these appear to have public users, and this module is not included in the deprecation policy.
Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Roy Frostig
6b545a2ddc
remove several exported symbols from jax.core
...
All of these are prefixed by an underscore.
PiperOrigin-RevId: 510194304
2023-02-16 11:20:36 -08:00
jax authors
66e7c0cdce
Merge pull request #14526 from jakevdp:ndarray-view
...
PiperOrigin-RevId: 510194136
2023-02-16 11:12:46 -08:00
Jake VanderPlas
11acec03c3
lax.bitcast_convert_type: better input validation
2023-02-16 10:56:06 -08:00
jax authors
bb04686a98
Merge pull request #14503 from sharadmv:fstring-docs
...
PiperOrigin-RevId: 510181634
2023-02-16 10:28:53 -08:00
jax authors
4eeca92e54
Merge pull request #14482 from gijskoning:patch-1
...
PiperOrigin-RevId: 510176609
2023-02-16 10:10:34 -08:00
Jake VanderPlas
b8994f5c3d
jnp.ndarray.view: implement all dtypes
2023-02-16 10:07:24 -08:00
jax authors
68276a9a3e
Merge pull request #14518 from Schoyen:patch-1
...
PiperOrigin-RevId: 510165801
2023-02-16 09:30:35 -08:00
jax authors
f323952400
Merge pull request #14501 from jakevdp:lax-bitcast-convert
...
PiperOrigin-RevId: 510153280
2023-02-16 08:45:32 -08:00
Yash Katariya
58e46b48e6
Prepare for jax and jaxlib 0.4.4 release
...
PiperOrigin-RevId: 510152471
jax-v0.4.4
jaxlib-v0.4.4
jax-v0.4.4-rc
2023-02-16 08:37:15 -08:00
Jake VanderPlas
b18cbbe101
lax.bitcast_convert_type: support casting between types of different width
2023-02-16 08:21:18 -08:00
Peter Hawkins
c6a99b699e
Remove jax.interpreters.xla.lower_fun.
...
This function has been a stub that does nothing useful for a long time, and the only user I can find is Equinox which already guards this with a hasattr(xla, 'lower_fun') guard.
PiperOrigin-RevId: 510142446
2023-02-16 07:51:15 -08:00
George Necula
a9e886f956
[jax2tf] Enable all native lowering jax2tf tests
...
Filed bugs for the few remaining tests, and disabled them.
Fixed the logging of the compiled HLO on test failure.
PiperOrigin-RevId: 510135651
2023-02-16 07:15:57 -08:00
Øyvind Sigmundson Schøyen
c7ddd2a7fa
DOC: fix typo in sph_harm
...
:math:\theta` -> :math:`\theta`
2023-02-16 15:34:18 +01:00
George Necula
454e4de524
[shape_poly] Fix the lowering for symbolic dimension expressions for division
...
The symbolic dimension expression use the Python semantics for division and remainder, while StableHLO is slightly different.
PiperOrigin-RevId: 510056597
2023-02-15 23:51:23 -08:00
John QiangZhang
d0b42f2ce8
Fix the simple bug on call_tf.replace_non_float and add unittest for floating and complex data type.
...
PiperOrigin-RevId: 510055139
2023-02-15 23:40:54 -08:00
Roy Frostig
26045c49e7
remove core.{aval_method,aval_property}
...
PiperOrigin-RevId: 510043837
2023-02-15 22:22:09 -08:00
jax authors
d8514d0ec6
Merge pull request #14500 from jakevdp:bcsr-matmul-test
...
PiperOrigin-RevId: 510034750
2023-02-15 21:26:06 -08:00
Peter Hawkins
0af9fff5ca
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 510027595
2023-02-15 21:03:03 -08:00
Roy Frostig
1b2a318fd1
remove core.axis_substitution_rules
...
PiperOrigin-RevId: 509989925
2023-02-15 18:42:13 -08:00
Peter Hawkins
768960b4e4
Fix pytype errors.
...
PiperOrigin-RevId: 509984207
2023-02-15 18:12:42 -08:00