21020 Commits

Author SHA1 Message Date
Sergei Lebedev
befa10c1d7 Slightly rearranged NDIndexer.from_indices_shape and added missing tests 2024-06-03 12:33:14 +01:00
Tomás Longeri
7c471e2533 [Pallas] Clean up MLIR compatibility check
PiperOrigin-RevId: 639718864
2024-06-03 03:33:50 -07:00
jax authors
a5e7613092 Update XLA dependency to use revision
e351b8186b.

PiperOrigin-RevId: 639620308
2024-06-02 19:02:06 -07:00
jax authors
e81c82605f Update XLA dependency to use revision
4d355f39c0.

PiperOrigin-RevId: 639431808
2024-06-01 19:43:25 -07:00
Dateng Lin
20379c636d Fixed the logging due to a recent change.
PiperOrigin-RevId: 639392396
2024-06-01 14:28:55 -07:00
George Necula
be1e40dc2e Copybara import of the project:
--
f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 by George Necula <gcnecula@gmail.com>:

[export] Fix

A user reported an error when trying to export a function
that has a "lower" attribute (to impersonate a jitted function)
but does not have a "__name__" attribute.
The solution is to use the default name "<unnamed function>".

While I was at it I have added a `util.fun_name` to get
the name of a Callable, and I use it in several places.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21572 from gnecula:exp_fix_name f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8
PiperOrigin-RevId: 639236990
2024-05-31 20:40:42 -07:00
jax authors
432159a9d3 Update XLA dependency to use revision
ea480cecc0.

PiperOrigin-RevId: 639227230
2024-05-31 19:55:16 -07:00
Zixuan Jiang
dfd7d17c1d [JAX] Use iota_reshape_dims and iota_transpose_perm in pxla, which is more efficient than tile_assignment_devices.
HloSharding V1 -> HloSharding V2.

PiperOrigin-RevId: 639210975
2024-05-31 18:15:00 -07:00
jax authors
1edb94ec46 [XLA][Mosaic] Add support for fp8 matmuls in TPUv5+
Needed a little more backfill for TPU load

PiperOrigin-RevId: 639206243
2024-05-31 17:51:25 -07:00
Parker Schuh
5eb679ee34 Hypothesis isn't compatible with subTest and jax tests sometimes report these
warnings as errors. Disable to allow these tests to pass.

PiperOrigin-RevId: 639189411
2024-05-31 16:35:49 -07:00
jax authors
c8207ee831 Merge pull request #21568 from mattjj:custom-vjp-bwd-produce-list-or-tuple
PiperOrigin-RevId: 639170346
2024-05-31 15:28:09 -07:00
Matthew Johnson
2f4cc6e9cb [custom_vjp] allow bwd rule to produce top-level list (not just tuple) 2024-05-31 21:49:06 +00:00
Yash Katariya
0591620932 Fix copy.deepcopy support for arrays in pinned_host memory.
PiperOrigin-RevId: 639145872
2024-05-31 14:04:02 -07:00
jax authors
4afd32ee32 Merge pull request #21563 from dfm:fix-ffi-test
PiperOrigin-RevId: 639125422
2024-05-31 12:57:02 -07:00
Jevin Jiang
389bf93abf [XLA:Mosaic] Fix infer/apply vector layout rule for terminators (scf::yieldOp, scf::conditionOp).
We should infer layout for each terminator inside its own region and find a compatible layout for a final result if the result is based on terminators from multiple regions like scf::ifOp, scf::whileOp, scf::forOp. If no compatible layout is found, we will fall back to a normalized layout. Finally we also need to ensure the layouts in input, terminator and output are consistent across loops.

PiperOrigin-RevId: 639122434
2024-05-31 12:47:33 -07:00
Dan Foreman-Mackey
690fa1d90c Remove failing ffi test
The FFI headers aren't properly exposed during a bazel build, so these
tests are failing. I'll re-enable them when I get a chance to get that
working properly.
2024-05-31 15:36:33 -04:00
jax authors
d9f07d0350 Merge pull request #21531 from dfm:move-ffi-submodule
PiperOrigin-RevId: 639077602
2024-05-31 10:31:32 -07:00
jax authors
29632fb3dc Merge pull request #21544 from marcoselvi:patch-2
PiperOrigin-RevId: 639077520
2024-05-31 10:28:24 -07:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Sergei Lebedev
d6a84cc5f3 Pallas GPU no longer assumes that all slices have stride 1
Fixes #20895.

PiperOrigin-RevId: 639031975
2024-05-31 07:44:11 -07:00
George Necula
d26819d6cd [jax2tf] Bump tolerance for FFT tests on CPU.
PiperOrigin-RevId: 639029683
2024-05-31 07:33:44 -07:00
Adam Paszke
d01496a09a [Mosaic GPU] Restore the PTX/PTXAS/SASS dump flags
They're very useful while prototyping the kernels.

PiperOrigin-RevId: 639027506
2024-05-31 07:27:36 -07:00
Emilio Cota
ccfe01c6bc Test the loading of parameters from host memory
PiperOrigin-RevId: 639027224
2024-05-31 07:24:39 -07:00
Christos Perivolaropoulos
8eaea2b13d [Mosaic GPU] Add a simple benchmark.
PiperOrigin-RevId: 639023867
2024-05-31 07:08:28 -07:00
Sergei Lebedev
d2a39bc61b Updated the layer norm implementation in Mosaic GPU tests
jnp.var now needs lax.gt_p, which we don't currently support.

PiperOrigin-RevId: 639011383
2024-05-31 06:11:48 -07:00
Adam Paszke
41685db0cb Wrap wgmma.fence in llvm.inline_asm to constrain LLVM scheduling
wgmma.fence.aligned is a weird PTX instruction in that it is one of the
few (if not the only one?) that disallows sinking ALU ops on registers
below it. But, LLVM assumes that all operations on registers are pure
and will often happily sink them below this instruction. By wrapping
the fence in an inline assembly block that simply copies over the
registers, we can force LLVM to construct the registers before the fence.
And ptxas should be able to eliminate the unnecessary register copies.

PiperOrigin-RevId: 639011288
2024-05-31 06:08:29 -07:00
jax authors
33c7c8d30e Disable PgleTest.testPGLEProfilerGetFDOProfile.
This is a new test that is failing in CI.

PiperOrigin-RevId: 639005238
2024-05-31 05:40:04 -07:00
Adam Paszke
3fb6817ffd Decrease tile sizes in Pallas tests
Otherwise ptxas might fail at register allocation due to WGMMA having a large
footprint.

PiperOrigin-RevId: 639003292
2024-05-31 05:29:54 -07:00
Chris Jones
9a572d23d2 [mosaic:gpu] Minor cleanup in matmul example.
PiperOrigin-RevId: 638996405
2024-05-31 04:57:23 -07:00
Chris Jones
3c64066097 [mosaic:gpu] Relax constraint on stages in matmul example.
PiperOrigin-RevId: 638993045
2024-05-31 04:40:14 -07:00
Marco Selvi
7a8bcf6ee5
Fix parenthesis in "Gradients contain NaN where using where" 2024-05-31 12:03:39 +01:00
jax authors
8deed95c7f Merge pull request #21430 from Cjkkkk:remove_is_training
PiperOrigin-RevId: 638974548
2024-05-31 03:14:39 -07:00
jax authors
38b34b2378 Update XLA dependency to use revision
7f0ee78fca.

PiperOrigin-RevId: 638871244
2024-05-30 19:49:27 -07:00
Yash Katariya
54eaef9f62 Make sure that the sharding and unconstrained_dims in with_sharding_constraint are correct when wsc is vmapped.
In other words, if unconstrained_dims is specified, then the sharding should also contain P.UNCONSTRAINED under vmap.

PiperOrigin-RevId: 638843222
2024-05-30 17:44:51 -07:00
Yash Katariya
7aaf29bf82 Make UNCONSTRAINED's __str__ and __repr__ so it prints nicely in all cases
Before:

```
NamedSharding(mesh=Mesh('x': 2, 'y': 1), spec=PartitionSpec(<jax._src.partition_spec._UnconstrainedPartitionSingleton object at 0x13047cc3b850>, 'x'))
ParsedPartitionSpec(partitions=(None, ('x',)), unsafe_user_spec=PartitionSpec(<jax._src.partition_spec._UnconstrainedPartitionSingleton object at 0x13047cc3b850>, 'x'), sync=2)
```

After:

```
NamedSharding(mesh=Mesh('x': 2, 'y': 1), spec=PartitionSpec(UNCONSTRAINED, 'x'))
ParsedPartitionSpec(partitions=(None, ('x',)), unsafe_user_spec=PartitionSpec(UNCONSTRAINED, 'x'), sync=2)
```

PiperOrigin-RevId: 638842878
2024-05-30 17:41:32 -07:00
Yash Katariya
bfaf0b74e8 Improve the error message when users pass DeviceLocalLayout.AUTO to jax.jit and a jax.Array as an argument.
PiperOrigin-RevId: 638797194
2024-05-30 15:07:01 -07:00
jax authors
bca9882d20 Merge pull request #19538 from mattjj:vjp-mismatch-error-message
PiperOrigin-RevId: 638791209
2024-05-30 14:51:14 -07:00
jax authors
b1d37d1d20 Merge pull request #21071 from mattjj:vmap-spmd-axis-name-errors
PiperOrigin-RevId: 638783978
2024-05-30 14:30:21 -07:00
Matthew Johnson
10d285dea7 fix error message for vjp arguments 2024-05-30 21:22:35 +00:00
Matthew Johnson
3984d822ba add error checks for vmap spmd_axis_name 2024-05-30 20:48:11 +00:00
jax authors
51e743139b Merge pull request #21440 from vfdev-5:add-op-name-dtype-in-testname-lax-numpy-ops-test
PiperOrigin-RevId: 638700042
2024-05-30 10:30:03 -07:00
Chris Jones
8bbb8983b1 [mosaic:gpu] Partially revert previous change as types.py not present in jaxlib.
Reverts a5fc31e42582ca8e2de5dea6936460795ae2c4af

PiperOrigin-RevId: 638665968
2024-05-30 08:43:51 -07:00
Sergei Lebedev
daa99025b9 Updated the JVP rule for pallas_call_p to propagate new invar indices to effects
Prior to this change some of the tests in PallasTest were failing under
JAX_ENABLE_CHECKS=1, because the effects in the JVP jaxpr did not type check.
PiperOrigin-RevId: 638652928
2024-05-30 07:58:59 -07:00
Sergei Lebedev
8729952d82 Added a missing return to MosaicGPUCustomCall
PiperOrigin-RevId: 638627696
2024-05-30 06:13:01 -07:00
Peter Hawkins
f24d2a71bb Disable PgleTest.testAutoPgleWithPersistentCache.
This is a new test that is failing in CI.

PiperOrigin-RevId: 638619001
2024-05-30 05:36:20 -07:00
Peter Hawkins
01b4cb6de0 Bump memory limits for array_test and layout_test on TPU CI.
These use more than our CI's default memory limit (12GB) when run under tsan.

PiperOrigin-RevId: 638618718
2024-05-30 05:33:14 -07:00
Chris Jones
a5fc31e425 [mosaic:gpu] Minor cleanup in FragmentedArray.transfer_tile.
- Remove redundant line.
- Use `ConstantOp.create_index`.
- Use `BoolAttr`.

PiperOrigin-RevId: 638616982
2024-05-30 05:24:38 -07:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Sergei Lebedev
f04800f80d Call setUp only if the test is not skipped in Pallas tests
unittest does not call tearDown if setUp raised unittest.SkipTest.

PiperOrigin-RevId: 638565553
2024-05-30 01:29:09 -07:00
Michael Levesque-Dion
9309592ac3 Integrate StableHLO at openxla/stablehlo@c44d9af8
PiperOrigin-RevId: 638559828
2024-05-30 01:04:35 -07:00