22860 Commits

Author SHA1 Message Date
Peter Hawkins
95f38d95d7 Update TPU test configuration tags.
PiperOrigin-RevId: 672562923
2024-09-09 09:02:51 -07:00
jax authors
c0dacbf724 Merge pull request #23484 from justinjfu:pallas_prefetch_docs
PiperOrigin-RevId: 672538687
2024-09-09 07:33:57 -07:00
jax authors
c28b3de599 Merge pull request #23459 from jakevdp:split-doc
PiperOrigin-RevId: 672532876
2024-09-09 07:09:52 -07:00
jax authors
f069aebfa9 Merge pull request #23469 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 672530119
2024-09-09 06:57:49 -07:00
Sergei Lebedev
05cdcb8ce5 Slightly re-arranged Pallas Mosaic GPU pipelining logic
This change prepares a few pipelining optimizations which will be done in a
follow up.

PiperOrigin-RevId: 672530087
2024-09-09 06:56:40 -07:00
Jake VanderPlas
0320a792ba Improve docs for jnp.split & related APIs 2024-09-09 05:34:45 -07:00
Peter Hawkins
fe63b991dd Disable cudnn_fusion_test from CI.
This test isn't passing in our internal CI.

PiperOrigin-RevId: 672507574
2024-09-09 05:16:13 -07:00
Peter Hawkins
aa16abe511 [pallas] Fix test failures on Windows.
Avoid importing Triton modules on Windows, since we don't build it.
Also avoid using an unescaped `\` in a regular expression.

PiperOrigin-RevId: 672507555
2024-09-09 05:14:58 -07:00
rajasekharporeddy
d37c8501ea Better dosc for jax.numpy: minimum and maximum 2024-09-09 10:25:23 +05:30
jax authors
201d3ff8f1 Update XLA dependency to use revision
492151fc19.

PiperOrigin-RevId: 672324700
2024-09-08 13:38:28 -07:00
Peter Hawkins
b6abd738d9 Relax some test tolerances in for_loop_test.py.
This PR attempts to fix some CI failures on Mac ARM.

PiperOrigin-RevId: 672312564
2024-09-08 12:09:45 -07:00
Peter Hawkins
5af1efb285 Skip symmetric product test on older jaxlibs.
The new symmetric product operator will appear to jaxlib 0.4.32.

PiperOrigin-RevId: 672311569
2024-09-08 12:04:20 -07:00
jax authors
79dabe530c Merge pull request #23462 from hawkinsp:mlir
PiperOrigin-RevId: 672188226
2024-09-07 21:04:56 -07:00
Keith Rush
265bb7bf4c Adds failing test for https://github.com/google/jax/issues/23476.
PiperOrigin-RevId: 672183133
2024-09-07 20:30:18 -07:00
jax authors
cd782643a1 Update XLA dependency to use revision
8c6dafbe7e.

PiperOrigin-RevId: 672126938
2024-09-07 14:16:22 -07:00
jax authors
703a8a6c2b Merge pull request #23490 from mattjj:readme-title
PiperOrigin-RevId: 672104872
2024-09-07 11:34:37 -07:00
Sergei Lebedev
3e1c2b3ee9 Removed dead code from add_jaxvals
PiperOrigin-RevId: 672103395
2024-09-07 11:26:33 -07:00
Matthew Johnson
8ab503158b tweak readme title to be more about what jax can do for you, dear user
we should rewrite this whole readme...
2024-09-07 18:09:19 +00:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
jax authors
671acef5ab Update XLA dependency to use revision
b0368b065c.

PiperOrigin-RevId: 671893460
2024-09-06 14:35:59 -07:00
Justin Fu
51a666fb8c [Pallas] Update Pallas docs with new figures and TPUCompilerParams 2024-09-06 14:30:29 -07:00
Dan Foreman-Mackey
2ce0fc25e0 Fix tolerances for failing linalg tests.
PiperOrigin-RevId: 671881600
2024-09-06 13:58:20 -07:00
Dan Foreman-Mackey
7266e338c8 Update FFI target name for syrk operation to be consistent with other kernels.
PiperOrigin-RevId: 671870569
2024-09-06 13:21:38 -07:00
jax authors
b6213aaa85 Make pltpu key derivation more robust.
PiperOrigin-RevId: 671857080
2024-09-06 12:35:08 -07:00
Dan Foreman-Mackey
1d12a9934c Port GPU kernel for symmetric eigendecomposition to GPU.
Of note, I moved the logic about which algorithm to use, and when to use the batched algorithm into the kernel in order to support shape polymorphism and export.

PiperOrigin-RevId: 671853879
2024-09-06 12:23:04 -07:00
jax authors
f97bfc85a3 Implement symmetric_product() to produce a symmetric matrix: C = alpha * X @ X.T + beta * C
PiperOrigin-RevId: 671845818
2024-09-06 11:58:20 -07:00
George Necula
fc6b22e2e4 [host_callback] Fix type promotion error
Fix a type error that arises when we try to run the host callback tests with JAX_HOST_CALLBACK_LEGACY=False (in the process of deprecating jax.experimental.host_callback).

PiperOrigin-RevId: 671825020
2024-09-06 10:56:51 -07:00
jax authors
878b6b5743 Merge pull request #23369 from sergachev:cleanup_nsys_converter
PiperOrigin-RevId: 671801177
2024-09-06 09:54:59 -07:00
jax authors
7326db7791 Update XLA dependency to use revision
53fd000440.

PiperOrigin-RevId: 671741772
2024-09-06 07:00:52 -07:00
jax authors
d776f1da76 Merge pull request #23470 from gnecula:poly_fix_eq_constraints
PiperOrigin-RevId: 671727351
2024-09-06 05:53:53 -07:00
Sergei Lebedev
ef947a0ce6 Added a bit more error checking to Pallas Mosaic GPU pipelining logic
PiperOrigin-RevId: 671711873
2024-09-06 04:39:32 -07:00
George Necula
0d8ffd33ab [shape_polyO] Improve handling of equality shape constraints
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.

First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.

Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.

Fixes: #23437
Fixes: #23456
2024-09-06 13:55:38 +03:00
jax authors
b6031a9c82 Merge pull request #23454 from carlosgmartin:shape_polymorphism_docs_fix_division_parentheses
PiperOrigin-RevId: 671697208
2024-09-06 03:33:33 -07:00
Sebastian Bodenstein
e3b8177af3 Internal change.
PiperOrigin-RevId: 671583042
2024-09-05 18:42:22 -07:00
jax authors
7d438601ae Merge pull request #23419 from pkgoogle:better_bitwise_right_shift_doc
PiperOrigin-RevId: 671580626
2024-09-05 18:32:46 -07:00
jax authors
25ef8ac015 Merge pull request #23107 from ykirpichev:add_xla_flags_md
PiperOrigin-RevId: 671580619
2024-09-05 18:31:25 -07:00
Peter Hawkins
27e19239ca Fix triton capi_objects target to depend on MLIR CAPIIRObjects bazel
target.

"...Objects" targets should only depend on other "...Objects" targets in
MLIR land. Don't mix them.
2024-09-06 01:06:27 +00:00
jax authors
107cacafbc Merge pull request #23466 from hawkinsp:optbarrier
PiperOrigin-RevId: 671567949
2024-09-05 17:35:00 -07:00
Peter Hawkins
9c86fdec02 Make optimization_barrier a public lax API. 2024-09-06 00:18:57 +00:00
jax authors
2904180ae0 Merge pull request #23451 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 671546105
2024-09-05 16:10:28 -07:00
jax authors
ae400f8d2a Merge pull request #23465 from frederikwilde:typos
PiperOrigin-RevId: 671545152
2024-09-05 16:07:00 -07:00
jax authors
125bb4f158 Update XLA dependency to use revision
3c4102f71c.

PiperOrigin-RevId: 671522199
2024-09-05 14:59:32 -07:00
jax authors
45dc05eaa6 Delete remote python repository rule calls from TF configs.
Remote configurations of python repositories are removed because hermetic Python repository rules install and configure python modules in Bazel cache on the host machine. The cache is shared across host and remote machines.

PiperOrigin-RevId: 671512134
2024-09-05 14:31:20 -07:00
Piseth Ky
02334cdaa5 updating bitwise_right_shift_doc as an alias
simpler bitwise_right_shift implementation

to match previous PR

updating bitwise_right_shift_doc as an alias

readded jnp.bitwise_left_shift, jnp.bitwise_right_shift

Update sharded-computation doc to use make_mesh()

Rename `jtu.create_global_mesh` to `jtu.create_mesh` and use `jax.make_mesh` inside `jtu.create_mesh` to get maximum test coverage of the new API.

PiperOrigin-RevId: 670744047

better true_divide and divide docs

doc wording update

[Mosaic TPU] Fix mosaic alignment check in concatenate rule.

PiperOrigin-RevId: 670837792

Fix pytype errors and args for jax.Array methods

Add docker builds for ubu22 and 24

Better docs for jax.numpy: log and log1p

random.key_impl: improve repr of output

Remove unused docstring addition: _PRECISION_DOC

update example optimizers library docstring

* JAXopt is being merged into Optax, so point only to Optax
* Update Optax's github repository URL

fixing merge duplication

updating tests to skip bitwise shift if numpy major version < 2

removed whitespace 659

keep non-bitwise tests for numpy < 2.0.0

more readable edit
2024-09-05 14:24:11 -07:00
Yash Katariya
a144eb234b Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits
PiperOrigin-RevId: 671507042
2024-09-05 14:16:13 -07:00
Yash Katariya
4c8bed9270 Don't add a sharding property to ShapedArray if sharding_in_types flag is not switched on.
PiperOrigin-RevId: 671475186
2024-09-05 12:48:10 -07:00
Frederik Wilde
d08b68996a
Update jax-primitives.md 2024-09-05 21:29:56 +02:00
jax authors
0cfb9ac35a Merge pull request #23458 from dfm:ffi-layouts
PiperOrigin-RevId: 671465163
2024-09-05 12:19:07 -07:00
Frederik Wilde
be4383b3f9 Typo 2024-09-05 20:47:40 +02:00
Dan Foreman-Mackey
86f48a85b4 Add support for the DeviceLocalLayout API when lowering FFI calls.
This PR updates the FFI lowering rule to support a DeviceLoweringLayout
object as input when specifying the input and output layouts. For now,
this just converts the DLL object to its appropriate list of
minor-to-major integers because that's what the custom call op expects.
2024-09-05 14:30:06 -04:00