jax authors
ce5f9a6da9
Merge pull request #22530 from superbobry:maint
...
PiperOrigin-RevId: 654881710
2024-07-22 13:46:58 -07:00
jax authors
20231e1d98
Update XLA dependency to use revision
...
5e2fc1f94a
.
PiperOrigin-RevId: 654878074
2024-07-22 13:36:16 -07:00
jax authors
57d8dde65d
Merge pull request #22571 from jakevdp:dct-norm
...
PiperOrigin-RevId: 654862753
2024-07-22 12:56:41 -07:00
Robert Dyro
eb3f538c7e
Correctly counting cache miss logs
...
PiperOrigin-RevId: 654860872
2024-07-22 12:53:09 -07:00
jax authors
71422c6272
Merge pull request #22512 from dfm:gh22501
...
PiperOrigin-RevId: 654858612
2024-07-22 12:49:07 -07:00
jax authors
5ecd1965d1
Merge pull request #22544 from mattjj:19175
...
PiperOrigin-RevId: 654858318
2024-07-22 12:45:22 -07:00
jax authors
75bbf4019d
Merge pull request #22514 from dfm:gh22493
...
PiperOrigin-RevId: 654858304
2024-07-22 12:41:18 -07:00
Vladimir Belitskiy
a1f2a50cfa
Increase shard count under TPU for //third_party/py/jax/tests:lax_numpy_test.
...
PiperOrigin-RevId: 654847718
2024-07-22 12:08:04 -07:00
Dan Foreman-Mackey
991187aaa8
Fix dtype canonicalization in jnp.indices
.
...
`jnp.indices` was hard coded to default to `dtype = np.int32`, but it
should default to the canonicalized `np.int64`.
Fixes https://github.com/google/jax/issues/22501
2024-07-22 15:02:48 -04:00
Dan Foreman-Mackey
705eed3388
Fixing dtype canonicalization in sharp edges tutorial.
...
As reported in https://github.com/google/jax/issues/22493 , the sharp
edges tutorial doesn't seem to actually enable x64 when it says it does.
Fixes https://github.com/google/jax/issues/22493
2024-07-22 15:02:02 -04:00
Vladimir Belitskiy
d7b821b04d
The newly added test class is failing, and blocking presubmits
...
Reverts 09523adf7dd5b5b1099780785a73a12bf6664c53
PiperOrigin-RevId: 654842341
2024-07-22 11:52:24 -07:00
Jake VanderPlas
2efd1ec011
jax.scipy.fft.dct: implement & test norm='backward'
2024-07-22 11:18:35 -07:00
jax authors
0d7531b4f1
Merge pull request #22567 from jakevdp:fft-norm-validation
...
PiperOrigin-RevId: 654828825
2024-07-22 11:15:43 -07:00
Jake VanderPlas
326559ca47
jax.scipy.fft: error for unsupported norm argument
2024-07-22 10:32:03 -07:00
Matthew Johnson
f7cef92ed7
[shard_map] fix psum rewrite rule's pbroadcast logic
...
We want to allow `psum(x, axes)` regardless of how `x` is replicated. That
means when we rewrite it into the stricter `psum2`, which can only sum over
non-replicated axes, we need to insert a pbroadcast like this:
```
psum(x, axes) == psum2(pbroadcast(x, axes & input_replicated_axes), axes)
```
In words, we need to first `pbroadcast` over all those axes we're about to sum
over but that the input is already replicated over.
We write it as a comprehension over mesh.axis_names, rather than just that set
intersection, just to ensure deterministic ordering, since Python set
operations are not guaranteed to be deterministic. There are other places in
the file where we don't ensure deterministic ordering; someday I'll come back
and fix those.
fixes #19175
2024-07-22 17:16:30 +00:00
jax authors
db05734041
Merge pull request #22515 from dfm:pre-commit-filter
...
PiperOrigin-RevId: 654799238
2024-07-22 10:07:01 -07:00
jax authors
83f0c979fa
Merge pull request #22456 from google:dependabot/github_actions/actions/setup-python-5.1.1
...
PiperOrigin-RevId: 654787031
2024-07-22 09:36:33 -07:00
jax authors
48a0f9c3f0
Merge pull request #22540 from ppwwyyxx:patch-1
...
PiperOrigin-RevId: 654747206
2024-07-22 07:50:58 -07:00
jax authors
8ec0cc2c82
Merge pull request #22547 from sameer-dudeja:dev-fix-broken-link
...
PiperOrigin-RevId: 654724736
2024-07-22 06:43:28 -07:00
Sergei Lebedev
7157839853
Fixed pl.BlockSpec argument ordering in the Pallas TPU matmul tutorial
2024-07-22 12:28:40 +01:00
jax authors
433f66ad02
Merge pull request #22550 from gnecula:pallas_consts
...
PiperOrigin-RevId: 654686876
2024-07-22 04:04:08 -07:00
George Necula
b7105ccd19
[pallas] Fix the handling of captured consts
...
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-22 13:34:32 +03:00
Ram Rachum
00bd6ddf95
Show cache_key when logging compilation cache hits/misses
2024-07-22 11:56:02 +03:00
Sergei Lebedev
4fa93cff35
Documented a few more Pallas APIs and added them to the API docs
2024-07-21 22:32:51 +01:00
jax authors
b3469a61d1
Update XLA dependency to use revision
...
1d9074cac0
.
PiperOrigin-RevId: 654531486
2024-07-21 12:55:15 -07:00
jax authors
946145e752
Merge pull request #22543 from mattjj:20162
...
PiperOrigin-RevId: 654523567
2024-07-21 11:48:01 -07:00
jax authors
a8ef5eaf11
Merge pull request #22519 from cool-RR:patch-profiler
...
PiperOrigin-RevId: 654522898
2024-07-21 11:43:27 -07:00
jax authors
2a2aa612be
Merge pull request #22541 from mattjj:21343
...
PiperOrigin-RevId: 654522436
2024-07-21 11:37:39 -07:00
jax authors
09523adf7d
Merge pull request #22442 from gnecula:pallas_interpret_splash
...
PiperOrigin-RevId: 654520725
2024-07-21 11:23:43 -07:00
George Necula
28d4caefb0
[pallas] Thread the interpret= parmeter to pallas_call.
...
This aligns more tests to use the same testing structure: tests
can run on CPU (in interpreter mode) or TPU/GPU, and for each
test class MyTest we have a sibling test class MyInterpreterTest.
This is useful when developing on a machine without accelerators.
2024-07-21 20:17:12 +03:00
Sameer Dudeja
993a1e74ba
Fix broken export links
2024-07-21 11:37:01 +05:30
jax authors
9632a2d1a8
Add jvp and transpose rule for ragged dot.
...
The numerical accuracy test is perfect against the reference implementation, and somewhat loose against the alt grad implementation used for testing.
PiperOrigin-RevId: 654381378
2024-07-20 17:56:59 -07:00
Tomás Longeri
f4b09234a0
[Mosaic TPU] Set in_bounds for transfer_read used in replicated loads
...
This is in preparation for integrating changes from MLIR:
2ee5586ac7 (diff-3cbcc8f6c740f2d6e16f5a0c19daf4bb8224ad92d9e430fc10c935587a67dcce)
Also don't pass in `padding` since there is a builder that uses `padding` of zero as default.
PiperOrigin-RevId: 654370142
2024-07-20 16:26:18 -07:00
jax authors
ff36ea5de3
Merge pull request #21567 from mattjj:skip-invar-origin-msg-if-malformed
...
PiperOrigin-RevId: 654356735
2024-07-20 14:30:17 -07:00
jax authors
da9b24f833
Update XLA dependency to use revision
...
a01a3db2e7
.
PiperOrigin-RevId: 654347848
2024-07-20 13:19:01 -07:00
Matthew Johnson
173794bcef
[shard_map] shard_map check_rep=True rules for custom_linear_solve
...
fixes #21855
2024-07-20 18:06:55 +00:00
Matthew Johnson
c5fd3b0ced
skip _origin_msg invar debug info if invar_pos/arg_info is malformed
...
cf #20397 , #20396
2024-07-20 17:22:49 +00:00
Matthew Johnson
83dfed1c02
make core.as_named_shape treat int
like tuple[int]
...
fixes #21343
2024-07-20 17:14:38 +00:00
Yash Katariya
82c608674a
Fix the efficient reshard path in device_put when you want to go from 1 mesh to another with different device assignments.
...
The old code lead to the wrong answer as shown in the test added in this PR.
PiperOrigin-RevId: 654318251
2024-07-20 09:09:05 -07:00
Yuxin Wu
cbb3d99279
Fix docstring of jnp.nonzero
2024-07-19 18:39:12 -07:00
jax authors
e9c40467d7
Merge pull request #22526 from sharadmv:pallas-matmul-docs
...
PiperOrigin-RevId: 654112294
2024-07-19 13:44:48 -07:00
Ram Rachum
799c79ca94
Support pathlib.Path
input to start_trace
2024-07-19 23:41:02 +03:00
jax authors
3dd0d74f2b
Update XLA dependency to use revision
...
e608ef43d7
.
PiperOrigin-RevId: 654089480
2024-07-19 12:28:12 -07:00
Jevin Jiang
faf89ab0da
[XLA:Mosaic] Simplify the logic in converting dynamic roll to Log(N) static ops.
...
PiperOrigin-RevId: 654065156
2024-07-19 11:11:22 -07:00
jax authors
d2d2eaebc2
Merge pull request #22529 from dfm:ffi-export-registration
...
PiperOrigin-RevId: 654058232
2024-07-19 10:54:08 -07:00
Enrique Piqueras
ca6d6341f5
look mum I can still can edit Pallas, AGI! Optimize pipeline emitter scheduler by omitting copies of accumulators during iteration in which they are going to be zeroed out.
...
Also, add some clarifying comments and set fixed RHS schedules of matmul reduce scatter implementations.
PiperOrigin-RevId: 654015498
2024-07-19 08:30:14 -07:00
Adam Paszke
7d6e76f8df
[Mosaic GPU] Add reverse arithmetic functions for FragmentedArrary for convenience
...
PiperOrigin-RevId: 654011040
2024-07-19 08:13:18 -07:00
Adam Paszke
d621eb1e31
[Mosaic GPU] Fix a tiny bug in the profiler
...
It would previously reallocate the event name that was interned as 0.
PiperOrigin-RevId: 654010678
2024-07-19 08:09:14 -07:00
Tom Cobley
c27f3cf187
Allow implementations of XLA:CPU collectives to be passed directly into make_cpu_client
.
...
PiperOrigin-RevId: 653998378
2024-07-19 07:14:45 -07:00
Dan Foreman-Mackey
b308c64936
Export jaxlib.xla_client.register_custom_call_target
as jax.extend.ffi.register_ffi_target
.
...
This means that users of the FFI interface won't need to directly
interact with `jaxlib.xla_client` at all.
I've expanded the doctring a little and changed one default: the default
`api_version` is `1` instead of `0` to be consistent with the new name.
2024-07-19 08:12:25 -04:00