16835 Commits

Author SHA1 Message Date
jax authors
88a60b808c Merge pull request #16870 from skye:version
PiperOrigin-RevId: 551636421
jax-v0.4.14 jax-v0.4.14-rc
2023-07-27 14:11:05 -07:00
Skye Wanderman-Milne
e132a0e5d5 Slightly downgrade xla version to avoid PJRT C API incompat 2023-07-27 14:05:14 -07:00
jax authors
c75e85da16 Merge pull request #16869 from skye:version
PiperOrigin-RevId: 551628701
2023-07-27 13:47:46 -07:00
Peter Hawkins
a480aa8dbd Work around pytype error.
An upcoming pytype release complains about unpacking a non-deterministic order iterable for this line of code. Work around pytype.

PiperOrigin-RevId: 551627521
2023-07-27 13:39:48 -07:00
Skye Wanderman-Milne
0b24b2ba6a Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.14 release 2023-07-27 13:35:04 -07:00
jax authors
7c0ef8660d Merge pull request #16433 from JGameCreation:patch-1
PiperOrigin-RevId: 551614086
2023-07-27 12:52:28 -07:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
jax authors
f35f226b44 Merge pull request #16865 from google:xla
PiperOrigin-RevId: 551565217
2023-07-27 10:07:24 -07:00
Peter Hawkins
0cd025348d Update XLA version to fix build failure. 2023-07-27 12:38:51 +00:00
Skye Wanderman-Milne
a03d6e6613 Move _tpu_ext.cc to jaxlib/mlir/_mlir_libs and set RPATH correctly
_tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in
jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in
its RPATH or similar on other platforms.

We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it
can use the same linkopts as other mlir targets that depend on
libjaxlib_mlir_capi.so. In particular, we want this to work correctly
across platforms, and it's not clear if Windows supports RPATH-like
functionality beyond the current directory.

PiperOrigin-RevId: 551372130
2023-07-26 18:25:17 -07:00
jax authors
bcddc504bf Merge pull request #16852 from hawkinsp:builddeps
PiperOrigin-RevId: 551333094
2023-07-26 15:33:01 -07:00
jax authors
5c39a0d39f Merge pull request #16844 from jakevdp:jnp-put
PiperOrigin-RevId: 551331519
2023-07-26 15:24:37 -07:00
jax authors
735637e313 Previously, using sparse.todense on a BCSR matrix with sparse.sparsify would raise NotImplementedError: sparse rule for todense is not implemented. By adding the sparse rule, it will resolve this issue.
PiperOrigin-RevId: 551291543
2023-07-26 13:01:02 -07:00
Peter Hawkins
3c4527b6b0 Check build and wheel are installed before building jaxlib. 2023-07-26 11:46:11 -07:00
jax authors
416814df2a Merge pull request #16826 from mattjj:issue16805
PiperOrigin-RevId: 551263673
2023-07-26 11:20:31 -07:00
jax authors
9e69277402 Merge pull request #16849 from skye:mosaic_test_build_fix
PiperOrigin-RevId: 551235536
2023-07-26 09:50:14 -07:00
Jake VanderPlas
88c42da7f4 Add implementation of jnp.put 2023-07-26 08:54:54 -07:00
Skye Wanderman-Milne
d0b65f2ab8 Make //jax:tpu_custom_call respect --//jax:build_jaxlib=false
Otherwise jaxlib is partially built and doesn't work properly.
2023-07-26 15:50:42 +00:00
jax authors
1054fe5a3b Merge pull request #16846 from gnecula:poly_dot
PiperOrigin-RevId: 551174686
2023-07-26 05:21:54 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
George Necula
c9f9f28b2c [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype
Add primitives tests for the case of dot_general with different lhs_dtype and
rhs_dtype. Then fix the lowering to work with dynamic shapes.
2023-07-26 12:29:12 +02:00
jax authors
f66d3cf016 Merge pull request #16842 from jakevdp:dynamic-slice-unsigned
PiperOrigin-RevId: 550981737
2023-07-25 13:37:31 -07:00
Jake VanderPlas
0dbda849ef lax.dynamic_slice: avoid negative index correction for unsigned indices 2023-07-25 13:09:09 -07:00
jax authors
def6190dc2 Merge pull request #16833 from gnecula:poly_v8_1
PiperOrigin-RevId: 550929048
2023-07-25 10:36:35 -07:00
jax authors
b11661696c Merge pull request #16839 from jakevdp:fix-core-deprecation
PiperOrigin-RevId: 550920851
2023-07-25 10:12:04 -07:00
Jake VanderPlas
3b6b988473 fix deprecations in core.py 2023-07-25 09:47:04 -07:00
George Necula
b6ed0568d8 [jax2tf] Add support for serialization version 8.
In this version the serialized module contain a StableHLO module
boolean attribute `jax.uses_shape_polymorphism` that specifies
whether the module uses shape polymorphism. If it doesn't then
we do not need to do shape refinement.

Note that we are still keeping the default serialization version to
6, for forward compatibility. However, the serialization unit tests
now run at version 8.

Made Exported.mlir_module a method instead of a propery, to make it
more obvious that it is a derived artifact.
2023-07-25 07:13:34 +02:00
jax authors
95b410c772 Merge pull request #16820 from gnecula:export_ver
PiperOrigin-RevId: 550769937
2023-07-24 21:58:07 -07:00
George Necula
40810353e9 [jax2tf] Added error for attempting to use wrong jax_serialization_version
Previously, the serialization would use the specified serialization version
without checking if it supported by the serialzier.
This could result in invalid serializations

Also add some compatibility tests for all supported versions.
2023-07-25 06:42:49 +02:00
Junwhan Ahn
14a6089e89 Change mhlo.is_same_data_across_replicas from unit attr to bool attr
Using bool attrs aligns better with StableHLO. Since [VHLO does not define unit attrs](https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/VhloAttrs.td), serializing StableHLO modules containing unit attrs fails. This becomes a problem when we want to serialize MHLO modules containing `mhlo.is_same_data_across_replicas` by converting them into StableHLO then VHLO.

JAX emits `mhlo.is_same_data_across_replicas` as a bool attr only after a new jaxlib version since this requires the jaxlib to understand the new attr type.

PiperOrigin-RevId: 550745955
2023-07-24 19:50:33 -07:00
Yash Katariya
7821516105 Make _remake internal and add return type hints
PiperOrigin-RevId: 550721261
2023-07-24 17:36:36 -07:00
jax authors
727af17cfd Merge pull request #16829 from jakevdp:has-opaque
PiperOrigin-RevId: 550686011
2023-07-24 15:11:55 -07:00
Jake VanderPlas
e1a1377cde replace use of has_opaque_dtype 2023-07-24 14:46:58 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
jax authors
c6fa3d93a6 Merge pull request #16822 from jakevdp:jupytext-version
PiperOrigin-RevId: 550673985
2023-07-24 14:29:37 -07:00
Matthew Johnson
9ddef5cf84 make _dot_general_batch_rule handle python builtin numeric types 2023-07-24 14:01:07 -07:00
Jake VanderPlas
7bb8312f82 CI: update jupytext to v0.14.7 2023-07-24 11:51:45 -07:00
jax authors
83d99bbb17 Merge pull request #16813 from gnecula:poly_err_msg
PiperOrigin-RevId: 550530238
2023-07-24 05:52:01 -07:00
George Necula
deb8fdfe0b [shape_poly] Improve error messages for shape assertions
Starting with serialization version 7 we introduce shape
assertions that are checked at runtime. In the process of
rolling out version 7 we encoutered projects with failed
shape assertions and it became clear that we need better
error messages.

See the changes here in tests and README.md for example of
the updated assertions.

To produce these assertions we now pass multiple operands to
the shape assertion, and we introduce a CachedShapeEvaluator
to reduce the amount of duplicate code generated.
2023-07-24 14:57:06 +03:00
jax authors
32cbc3678d Integrate LLVM at llvm/llvm-project@571c1292b6
Updates LLVM usage to match
[571c1292b693](https://github.com/llvm/llvm-project/commit/571c1292b693)

PiperOrigin-RevId: 550071080
2023-07-21 15:56:28 -07:00
jax authors
8a1a5fac8e Merge pull request #16781 from jakevdp:prng-dtypes
PiperOrigin-RevId: 550068690
2023-07-21 15:45:29 -07:00
Jake VanderPlas
7d7a536b55 custom prng: introduce mechanism to identify key arrays by dtype 2023-07-21 12:27:32 -07:00
jax authors
1b33a4eb05 Merge pull request #16815 from hawkinsp:py39
PiperOrigin-RevId: 550014612
2023-07-21 12:12:47 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
John QiangZhang
4114e6c428 Improve the default value of output_shape_dtype.
PiperOrigin-RevId: 549988693
2023-07-21 10:45:02 -07:00
jax authors
c8f4650933 Merge pull request #16721 from jakevdp:dot-mixed-precision
PiperOrigin-RevId: 549986744
2023-07-21 10:37:11 -07:00
Jake VanderPlas
561c9531ff Lower jax.numpy.dot to mixed-precision dot_general 2023-07-21 10:10:30 -07:00
George Necula
884bcf4efc Introduce version 8 of XlaCallModule.
Previously, XlaCallModule was running the shape refinement pass for all
compilations, even if the module did not use shape polymorphism.
Currently shape refinement changes the structure of the module,
through inlining and constant folding all integer operations.
This complicates debugging because the HLO dump is very different
than the one from JAX native executions.

Starting with version 8, we run shape refinement only
if the module contains a boolean module attribute
jax.uses_shape_polymorphism=true. I think it makes sense
to put this flag as a module attribute, rather than
as a TF op attribute, because the same processing will
be needed when the module is executed from JAX.

This attribute is not yet populated by the JAX exporter.

As part of this change we moved the error check for the
number of invocation arguments from RefineDynamicShapes
to LoadAndPreprocessModule. This required adding a couple
more arguments to the loader constructor.

PiperOrigin-RevId: 549973693
2023-07-21 09:51:19 -07:00
jax authors
90840e4ca9 Merge pull request #16795 from jakevdp:refactor-opaque
PiperOrigin-RevId: 549962048
2023-07-21 09:02:59 -07:00
Lena Martens
684228e832 Add back tracebacks to checkify's Error without leaking tracers.
The trick is to save the traceback as an XLA traceback, then turn it into a
python traceback only when throwing the error. No locals are leaked in the
process.

PiperOrigin-RevId: 549957746
2023-07-21 08:44:26 -07:00