12969 Commits

Author SHA1 Message Date
Meekail Zain
79005c1e69 Deprecate newshape argument of jnp.reshape 2024-05-09 21:02:07 +00:00
jax authors
1a7a2aa555 Merge pull request #21106 from jakevdp:linalg-precision
PiperOrigin-RevId: 632217396
2024-05-09 11:33:54 -07:00
jax authors
0c4d81c8cd Merge pull request #21138 from jakevdp:einsum-doc
PiperOrigin-RevId: 632198113
2024-05-09 10:38:51 -07:00
Jake VanderPlas
2ddb7ff801 jnp.linalg: add precision & preferred_element_type to dot-like functions 2024-05-09 10:06:51 -07:00
Yash Katariya
671fb1265d Update the multi-process note in pjit's docstring
PiperOrigin-RevId: 632160561
2024-05-09 08:38:29 -07:00
Jake VanderPlas
5edfaa6782 jnp.linalg.multi_dot: use optimize='auto' 2024-05-09 06:47:30 -07:00
Yash Katariya
96f888bcfe Reverts 1956ff7d7b73794012fece2d8452e097196587fc
PiperOrigin-RevId: 631974751
2024-05-08 17:23:13 -07:00
Jake VanderPlas
e8700523d3 jnp.einsum: improve documentation 2024-05-08 14:30:59 -07:00
jax authors
962f084543 Merge pull request #21137 from superbobry:pallas
PiperOrigin-RevId: 631923082
2024-05-08 14:20:10 -07:00
jax authors
65d4c688e0 Generic reduce window jvp
The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly.

However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`).

For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values.

In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition

PiperOrigin-RevId: 631916764
2024-05-08 14:00:39 -07:00
Sergei Lebedev
4b62425b42 Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least
This makes it clear that the predicate is only supposed to be used for NVidia
GPUs at the moment.
2024-05-08 21:41:50 +01:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Anselm Levskaya
f768cb74b9 Refactor pipeline emitter API.
Building on enrique's work, this CL refactors the emit_pipeline abstraction:
1) factors out the VMEM double-buffering bookkeeping into a helper class.
2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper
while allowing manual overrides, callbacks don't control scheduling anymore, rather we have
explicit loop scheduling.
3) minimize callbacks and simplify the "defaults" for fusing pipelines together.

Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and
matmul-RSs are included in new tests.

PiperOrigin-RevId: 631865641
2024-05-08 11:22:47 -07:00
Parker Schuh
e652d62b85 Cleanup second registration of custom_partitioning callbacks now that
the jaxlib version has been bumped.

PiperOrigin-RevId: 631852273
2024-05-08 10:45:39 -07:00
Sergei Lebedev
0feeaa5999 Removed stale version guards and try/except blocks from Pallas GPU
They are unnecessary now that the minimum jaxlib version is 0.4.27.
2024-05-08 17:05:45 +01:00
jax authors
11da3df238 Merge pull request #21096 from gspschmid:gschmid/sourcemaps
PiperOrigin-RevId: 631769572
2024-05-08 05:44:08 -07:00
Georg Stefan Schmid
b0b322d486 Add sourcemap module to generate TC39-compliant source maps 2024-05-08 01:54:25 -07:00
jax authors
e9e4e5341e [jax:mosaic-gpu] FragmentedArray can do tiled load.
PiperOrigin-RevId: 631611060
2024-05-07 18:13:55 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
jax authors
78e10eea79 Merge pull request #21115 from jakevdp:multi-dot
PiperOrigin-RevId: 631545161
2024-05-07 14:14:02 -07:00
jax authors
2b3251ed46 Merge pull request #21092 from jakevdp:dot-doc
PiperOrigin-RevId: 631536980
2024-05-07 13:51:13 -07:00
Jake VanderPlas
09810be0cd Implement jnp.linalg.multi_dot using opt_einsum 2024-05-07 13:40:25 -07:00
jax authors
7153738551 Merge pull request #21104 from superbobry:triton-fixes
PiperOrigin-RevId: 631521767
2024-05-07 13:06:33 -07:00
Yash Katariya
5031a1ddc4 Finish jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631486157
2024-05-07 11:14:09 -07:00
Jake VanderPlas
034d843e8c jax.numpy: better docs for matmul-like functions 2024-05-07 11:01:54 -07:00
jax authors
3dd8af93b4 Merge pull request #21090 from jakevdp:extract
PiperOrigin-RevId: 631480474
2024-05-07 10:59:10 -07:00
jax authors
daab7a0329 Handle ellipsis ... in _attempt_rewriting_take_via_slice.
Previously `model['some_array'][:,0,0,:]` would generate a `slice`, while `model['some_array'][...,0,0,:]` would generate a `gather`. Now both of these generate `slice` eqns.

PiperOrigin-RevId: 631469837
2024-05-07 10:30:08 -07:00
Sergei Lebedev
2ca3264052 Updated the Pallas GPU lowering to work with older jaxlib versions
Triton changed the signatures of LoadOp and DotOp upstream, and the lowering
code was not ready to handle both old and new signatures.
2024-05-07 17:30:52 +01:00
Jake VanderPlas
0c26a34df2 Add optional size argument to jnp.compress & jnp.extract. 2024-05-07 08:47:34 -07:00
Jake VanderPlas
9b79f6520a Remove deprecated kind argument from jnp.sort and jnp.argsort.
PiperOrigin-RevId: 631429900
2024-05-07 08:18:59 -07:00
jax authors
500da57e91 Merge pull request #21077 from merrymercy:patch-1
PiperOrigin-RevId: 631409738
2024-05-07 07:07:04 -07:00
Adam Paszke
326adc01a5 [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args
PiperOrigin-RevId: 631404097
2024-05-07 06:36:36 -07:00
jax authors
cb0c49850c Merge pull request #21081 from hawkinsp:sourcemap
PiperOrigin-RevId: 631236806
2024-05-06 17:33:12 -07:00
jax authors
4de346485d Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model.
PiperOrigin-RevId: 631235764
2024-05-06 17:28:13 -07:00
jax authors
f6d88525a8 Merge pull request #20327 from selamw1:add_examples
PiperOrigin-RevId: 631186425
2024-05-06 14:30:06 -07:00
Selam Waktola
9caf59d68b improve documentation for ix_ 2024-05-06 13:43:55 -07:00
jax authors
3d3cb0bd2c Merge pull request #20842 from Micky774:array-api-default-promotion
PiperOrigin-RevId: 631168892
2024-05-06 13:39:56 -07:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
Jake VanderPlas
4a363156b9 jnp.linalg tensorinv & tensorsolve: improve implementation & docs 2024-05-06 11:08:36 -07:00
jax authors
7e9ef1e4d2 Merge pull request #21078 from jakevdp:numpy-linalg-doc
PiperOrigin-RevId: 631112228
2024-05-06 10:44:42 -07:00
jax authors
fb65ba4adf Add a config for using Clang on Windows.
PiperOrigin-RevId: 631112031
2024-05-06 10:39:28 -07:00
Jake VanderPlas
40b2d4852e jnp.linalg: improve API documentation 2024-05-06 09:22:59 -07:00
Meekail Zain
34c5163fd2 Refactored common upcast for integral-type accumulators 2024-05-06 15:13:10 +00:00
Lianmin Zheng
0eed28a010
Fix a typo in jax.jit docstring 2024-05-06 04:59:23 -07:00
jax authors
7681493760 Don't create temp directory when module is getting imported.
PiperOrigin-RevId: 630958402
2024-05-06 00:58:45 -07:00
jax authors
1b804a7720 Merge pull request #21056 from mattjj:vmap-grad-remat-shmap-bug
PiperOrigin-RevId: 630555588
2024-05-03 19:06:46 -07:00
Matthew Johnson
7a87010f84 [shard_map] better fix for spmd_axis_name issues with shmap residuals
The fix in #21032 was not correct because it assumed that the set of all mesh
axis names appearing in in_specs was an upper bound on the set of mesh axes
over which residuals could be device-varying. But collectives can introduce
device variance! So it's not an upper bound.

We track device variance when check_rep=True, but often people set
check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our
device variance tracking would be limiting. That may be a decent long term
solution, if we can make it easy to annotate pallas_calls with device variance
information. But it's not a great short term one to unblock things.

So instead I temporrarily went with context sensitivity: instead of making
residuals sharded over all mesh.axis_names (as we did before these patches), we
make them sharded over all mesh axis names _excluding_ any spmd_axis_names in
our dynamic context (by looking at the traces in our trace stack). It's illegal
to mention any spmd_axis_names in collectives (indeed anywhere in the body of
the function being vmapped), but I don't think we check it.

TODO(mattjj): add more testing (maybe in follow-ups)
2024-05-04 01:31:15 +00:00
Jake VanderPlas
e95173a4d3 Require arraylike input for several jax.numpy functions
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
Jake VanderPlas
88318e60d2 jnp.delete: better docs 2024-05-03 14:41:06 -07:00
Jake VanderPlas
ff67e51e7e Remove last scipy imports 2024-05-03 10:20:05 -07:00