Peter Hawkins
7f4ef63cd8
Run pyupgrade --py310-plus
.
...
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Sergei Lebedev
2bb80d540c
Removed unused which_linear= param from pallas_call_p
...
As far as I can tell, it was threaded through everywhere, but never actually used.
PiperOrigin-RevId: 644457293
2024-06-18 11:31:54 -07:00
jax authors
c8cdf303fb
Merge pull request #20761 from superbobry:config
...
PiperOrigin-RevId: 644435642
2024-06-18 10:33:44 -07:00
Chris Jones
8ce0b55c86
[jax:pallas] Fix Pallas kernel batching rule where an input is aliased with an output and the input is batched on a non-zero axis.
...
PiperOrigin-RevId: 644348136
2024-06-18 05:29:43 -07:00
Sergei Lebedev
ce0d9e9b9f
Changed the naming of internal config APIs
...
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.
The renames are
* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
2024-06-18 11:48:57 +01:00
Justin Fu
fb68f3449b
[Pallas] Add checkify support for pallas_call in interpret mode.
...
PiperOrigin-RevId: 644181742
2024-06-17 17:15:42 -07:00
Justin Fu
1d77720e9a
[Pallas] Add initial DMA interpret mode rules. Currently this only supports LOGICAL device ids with one sharding axis.
...
PiperOrigin-RevId: 644171210
2024-06-17 16:36:21 -07:00
Sharad Vikram
9499de4358
[Pallas] Make num_programs return an int if the grid is not dynamic
...
PiperOrigin-RevId: 644149441
2024-06-17 15:18:40 -07:00
jax authors
c9b23f0f5a
Merge pull request #21912 from superbobry:pallas
...
PiperOrigin-RevId: 644093480
2024-06-17 12:16:49 -07:00
Sergei Lebedev
01f182e772
Use `unitialized_value
` for allocating outputs for interpreted Pallas kernels
...
PiperOrigin-RevId: 644057616
2024-06-17 10:34:38 -07:00
Sergei Lebedev
550862f8c1
Added some docs to `_hoist_consts_to_refs
`
...
I also restructured the implementation slightly, because most list allocations
were in fact unnecessary.
2024-06-17 15:33:05 +01:00
Sergei Lebedev
f67f2e06ce
Fixed a `ValueError
` when a Pallas GPU kernel closed over array constants
...
The fix idea is based on the investigation by @zhixuan-lin in #21557 .
PiperOrigin-RevId: 643965836
2024-06-17 05:05:01 -07:00
Sharad Vikram
e12656002f
[Pallas] Don't actually vmap if we're vmapping over axis size 1
...
PiperOrigin-RevId: 643209848
2024-06-13 20:24:43 -07:00
Sergei Lebedev
2466ae3e93
Added docstrings to pl.num_programs() and pl.program_id()
2024-06-13 21:57:52 +01:00
George Necula
97db0e758d
[pallas] Add support for cross-platform lowering
...
When implementing this I have discovered that the
multi-platform lowering support does not handle the case when
the lowering rule for a platform invoke tracing (via `mlir.lower_fun`)
and that tracing encounters a primitive that has lowering rules
only for a particular platform. To support this, I have added
the `LoweringRuleContext.platforms` to override
`ModuleContext.platforms` with a potentially narrower set
of lowering platforms. Added a test for this scenario.
2024-06-12 08:48:58 +02:00
Sergei Lebedev
70f6ab3128
Updated the type annotations of *_spec= parameters of pl.pallas_call
...
The previous type did not work for nested pytrees and for some reason neither
pytype nor mypy flagged that.
I also re-enabled type checking for most pallas/*.py files.
2024-06-11 12:22:00 +01:00
jax authors
7d913f763a
Merge pull request #21298 from oliverdutton:pallas_interpreter_indexing_fix
...
PiperOrigin-RevId: 641325047
2024-06-07 12:29:31 -07:00
jax authors
621814bd7d
Add loop-based vmap lowering for pallas calls
...
Loop-based vmap is used for cases in which a pipeline-based vmap is currently not feasible:
* Dynamic grid dimensions
* Batched scalar prefetch arguments
PiperOrigin-RevId: 640530524
2024-06-05 08:15:27 -07:00
Oliver Dutton
fa29f1865b
[pallas] align interpreter load/store for masked OOB slicing
2024-06-04 23:19:55 +01:00
Sergei Lebedev
fe35acd413
Fixed a typo in _pallas_call_impl
...
PiperOrigin-RevId: 639747084
2024-06-03 05:33:47 -07:00
Sergei Lebedev
daa99025b9
Updated the JVP rule for pallas_call_p to propagate new invar indices to effects
...
Prior to this change some of the tests in PallasTest were failing under
JAX_ENABLE_CHECKS=1, because the effects in the JVP jaxpr did not type check.
PiperOrigin-RevId: 638652928
2024-05-30 07:58:59 -07:00
Sergei Lebedev
f5617d7323
Removed noop # type: ignore comments
...
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
e2918ca138
Added a very rough sketch of Mosaic GPU lowering for Pallas
...
Almost nothing is supported, including
* PyTree inputs/outputs
* indexers
* non-trivial grids
* block specs
* any primitives beyond the ones added here
* etc etc
PiperOrigin-RevId: 633713366
2024-05-14 14:48:09 -07:00
Justin Fu
1e48adc698
[Pallas] Pad input/outputs in interpret mode to fix errors in OOB memory accesses.
...
PiperOrigin-RevId: 633283991
2024-05-13 11:50:21 -07:00
Justin Fu
5d2e8615af
Reverts 7844bac5d220b41253495cacf719f61905f46925
...
PiperOrigin-RevId: 629123629
2024-04-29 11:13:43 -07:00
Justin Fu
7844bac5d2
Add proper handling of OOB array accesses in pallas interpret mode.
...
PiperOrigin-RevId: 628202600
2024-04-25 15:05:52 -07:00
Sergei Lebedev
a205c9120a
pallas_call now has only one way to pass compiler_params=
...
Previously, it was possible to do
pallas_call(..., foo=42)
and also
pallas_call(..., compiler_params=dict(foo=42))
PiperOrigin-RevId: 623277572
2024-04-09 14:23:20 -07:00
Sergei Lebedev
f74f4ed48b
Removed unnecessary BUILD dependencies from :ops_test
...
I also re-added the accidentally removed JAX_TRITON_COMPILE_VIA_XLA variable
to :pallas_test.
PiperOrigin-RevId: 621299158
2024-04-02 14:36:41 -07:00
Sergei Lebedev
2ee4c0f644
Added installation instructions to the error in _pallas_call_lowering
...
PiperOrigin-RevId: 621168804
2024-04-02 07:36:28 -07:00
Sergei Lebedev
16b3f00e42
Register GPU/TPU lowering for pallas_call_p lazily
...
Prior to this change we had to import jax.experimental.pallas.{gpu,tpu} in
jax.experimental.pallas only to get the lowering rules registered.
PiperOrigin-RevId: 620957622
2024-04-01 14:40:33 -07:00
Sharad Vikram
30973a9474
[Pallas] Pass in compiler params via explicit compiler_params argument instead of passing via **kwargs
...
This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result.
This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this.
PiperOrigin-RevId: 613239439
2024-03-06 09:16:22 -08:00
Adam Paszke
516b75dc24
Add pl.num_programs to make it easier to query the dynamic grid size
...
The new function can be used both in the kernel body and in the block specs.
PiperOrigin-RevId: 610391119
2024-02-26 06:39:03 -08:00
Adam Paszke
0b04ff1241
Add support for non-disjoint windows in Pallas/Mosaic
...
This enables the index function to select a window starting from
any element. However, the Mosaic implementation still requires it
to be at least tile aligned.
PiperOrigin-RevId: 605254616
2024-02-08 02:48:43 -08:00
Sharad Vikram
a7a6b40b55
[Pallas] Add interpret mode support for dynamic grid
...
PiperOrigin-RevId: 603818776
2024-02-02 16:37:47 -08:00
Sharad Vikram
a41385c860
[Pallas/TPU] Allow 1-sized batch dim in vmap of dynamic grid
...
PiperOrigin-RevId: 603518847
2024-02-01 16:43:18 -08:00
Sharad Vikram
d76705da94
[Pallas/TPU] Add vmap support for dynamic grid
...
PiperOrigin-RevId: 603502393
2024-02-01 15:39:32 -08:00
Adam Paszke
21070b24d7
Add support for dynamically computed grid bounds in Pallas kernels.
...
PiperOrigin-RevId: 603389883
2024-02-01 09:15:19 -08:00
Matthew Johnson
4a8babb101
integrate attrs in jax.jit
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2024-01-27 17:44:43 -08:00
Sharad Vikram
037bc5edbc
[Pallas] Add support for trivial vmap of scalar prefetch
...
PiperOrigin-RevId: 600898742
2024-01-23 13:57:12 -08:00
Sharad Vikram
3990a0571e
[Pallas/TPU] Add pallas call tests
...
PiperOrigin-RevId: 599681509
2024-01-18 18:16:54 -08:00
Sharad Vikram
edef6d17fa
[Pallas] Use AbstractMemoryRefs for all Pallas tracing.
...
This simplifies a lot of the Pallas tracing and lowering logic because memory spaces are passed through the Ref type instead of through the BlockMapping.
PiperOrigin-RevId: 599670626
2024-01-18 17:20:11 -08:00
Sergei Lebedev
36f6b52e42
Upgrade most .py sources to 3.9
...
This commit was generated by running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
jax authors
1189d61bc0
[Pallas] Fix batching rule for kernels with scratch inputs
...
Scratch inputs do not need a batching dimension.
PiperOrigin-RevId: 588921137
2023-12-07 15:10:12 -08:00
Sharad Vikram
3403631d99
[Pallas/Mosaic] Fixes for interpret mode on TPU
...
* scratch space support
* trivial lowering for trace_start/end
PiperOrigin-RevId: 588482689
2023-12-06 11:03:05 -08:00
Sharad Vikram
6299ff8023
[Pallas] Allow interpret mode on non-CPU backends if backend-specific lowerings are not registered
...
PiperOrigin-RevId: 583132671
2023-11-16 12:46:43 -08:00
Neil Girdhar
3c920c0120
Switch from flake8 to Ruff
2023-11-15 22:35:52 -05:00
Sharad Vikram
8fbcfce2dd
[Pallas] Enable interpreter mode as default lowering for CPU
...
PiperOrigin-RevId: 580700740
2023-11-08 16:35:31 -08:00
Sharad Vikram
fdc2f9cab7
[Pallas] Add async_copy_to
and async_remote_copy_to
for doing DMAs.
...
Also add `.at` view syntax for `Ref`s
PiperOrigin-RevId: 565478936
2023-09-14 14:32:08 -07:00
Sharad Vikram
cb114f247a
[Pallas] Refactor memory space handling
...
PiperOrigin-RevId: 563586933
2023-09-07 17:08:57 -07:00
Sharad Vikram
d0c4c9b3fe
[Pallas] Add support for scoped allocations to Pallas TPU
...
PiperOrigin-RevId: 563580548
2023-09-07 16:41:01 -07:00