Ruturaj4
ce7347a52b
Fix circular import in pallas core file
2025-04-09 19:57:24 -05:00
Dragan Mladjenovic
a8c11ba79e
[pallas:triton] Fix atomic min/max lowering for unsigned integers and float types ( #263 )
2025-03-10 10:39:22 -05:00
Ruturaj Vaidya
fd1e518e44
Fix RNN patch for 35-qa branch ( #245 )
2025-02-27 10:03:47 -06:00
Ruturaj Vaidya
dceb5310fe
[ROCm] Implement RNN support ( #217 )
2025-02-07 11:08:12 -06:00
Peter Hawkins
81991d87c8
JAX release 0.4.35
2024-10-22 15:00:23 -04:00
jax authors
1c6b0a9193
Merge pull request #24465 from jakevdp:fix-mypy
...
PiperOrigin-RevId: 688632024
2024-10-22 11:45:27 -07:00
jax authors
9a2dd19a92
Merge pull request #21524 from andportnoy:aportnoy/unknown-platform-lowering-warning
...
PiperOrigin-RevId: 688630259
2024-10-22 11:40:39 -07:00
jax authors
1e41d5ef6f
Merge pull request #24452 from jakevdp:insert-doc
...
PiperOrigin-RevId: 688624762
2024-10-22 11:26:38 -07:00
jax authors
1a2737b72b
Merge pull request #24467 from andportnoy:patch-2
...
PiperOrigin-RevId: 688620752
2024-10-22 11:17:15 -07:00
Jake VanderPlas
849850216d
fix mypy error
2024-10-22 11:10:10 -07:00
Ayaka
c60bafcc33
[Pallas TPU] Fix lowering for jnp.remainder
...
Fixes https://github.com/jax-ml/jax/issues/24027
PiperOrigin-RevId: 688614799
2024-10-22 11:01:58 -07:00
Andrey Portnoy
637898493e
Add back the import of jtu
in flash_attention.py
...
This was erroneously removed in de3191fab.
2024-10-22 13:37:35 -04:00
Andrey Portnoy
2aaa108f06
Raise an error when registering a lowering for an unknown platform
2024-10-22 13:29:48 -04:00
Jake VanderPlas
48dd153e18
Better docs for jnp.insert
2024-10-22 09:20:48 -07:00
Jake VanderPlas
7e38cbd604
Better docs for jnp.fromfunction
2024-10-22 08:42:22 -07:00
jax authors
587832f295
Merge pull request #24442 from jakevdp:lexsort-doc
...
PiperOrigin-RevId: 688563766
2024-10-22 08:40:21 -07:00
jax authors
a2e4aff897
Merge pull request #24425 from dfm:rename-vmap-methods
...
PiperOrigin-RevId: 688547393
2024-10-22 07:51:29 -07:00
Christos Perivolaropoulos
4f9356361a
[pallas] Support for setting explicit backends to pallas_call.
...
PiperOrigin-RevId: 688511303
2024-10-22 05:37:15 -07:00
Adam Paszke
2db03ba54b
[Pallas:MGPU] Add support for grid dims in GPUMesh
...
Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.
PiperOrigin-RevId: 688488660
2024-10-22 04:10:46 -07:00
jax authors
0b3f0e11fb
Reverts ebb75db8a523150c48376d15391f84380a2bb110
...
PiperOrigin-RevId: 688477769
2024-10-22 03:29:32 -07:00
Adam Paszke
84a303f32f
[Pallas:MGPU] Allow allocating transformed refs in run_scoped
...
PiperOrigin-RevId: 688448592
2024-10-22 01:38:46 -07:00
Yash Katariya
ebb75db8a5
[sharding_in_types] Add out_type
argument to einsum
and dot_general
to allow specifying for the output type. Right now, it only accept a NamedSharding
but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout
.
...
PiperOrigin-RevId: 688399552
2024-10-21 22:23:53 -07:00
Jake VanderPlas
8800fe2870
Better documentation for jnp.lexsort
2024-10-21 16:33:14 -07:00
jax authors
441aeebb29
Merge pull request #24420 from superbobry:maint-2
...
PiperOrigin-RevId: 688271404
2024-10-21 14:22:43 -07:00
Ezekiel Calubaquib
81bf626501
Move converter related tflite functions to tensorflow/lite repo
...
PiperOrigin-RevId: 688270228
2024-10-21 14:19:30 -07:00
Sergei Lebedev
3ad1985e1a
Bumped mypy and ruff versions used by pre-commit
2024-10-21 21:58:41 +01:00
Jake VanderPlas
66971a2869
Fix jnp.diff for boolean inputs
2024-10-21 13:35:13 -07:00
Dan Foreman-Mackey
61701af4a2
Rename vmap methods for callbacks.
2024-10-21 15:03:04 -04:00
jax authors
4a5ca2fd00
Merge pull request #24400 from jakevdp:subtract-ufunc
...
PiperOrigin-RevId: 688190106
2024-10-21 10:38:52 -07:00
jax authors
65307abd81
Merge pull request #24370 from dfm:ffi-call-to-callable
...
PiperOrigin-RevId: 688188390
2024-10-21 10:34:56 -07:00
Jake VanderPlas
6467d03925
Make jnp.subtract a ufunc
2024-10-21 10:11:51 -07:00
Ezekiel Calubaquib
ad53addb74
Move out mnist py/Jax tensorflow lite tests to tensorflow lite repo
...
PiperOrigin-RevId: 688178268
2024-10-21 10:08:21 -07:00
jax authors
e29b93ff3e
Merge pull request #24421 from jakevdp:cross-doc
...
PiperOrigin-RevId: 688175417
2024-10-21 10:01:45 -07:00
Dan Foreman-Mackey
0b651f0f45
Make ffi_call return a callable
2024-10-21 12:16:57 -04:00
jax authors
fe83d888b9
Merge pull request #24417 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 688159150
2024-10-21 09:10:45 -07:00
rajasekharporeddy
02f65bb11a
Update warning message for jit of pmap
2024-10-21 21:17:59 +05:30
Yash Katariya
783285a71c
FIx jax2tf breakge of iota
...
PiperOrigin-RevId: 688146581
2024-10-21 08:30:49 -07:00
Adam Paszke
f833891c87
[Pallas:MGPU] Add support for passing in WGMMA lhs from registers
...
PiperOrigin-RevId: 688117316
2024-10-21 06:42:18 -07:00
Adam Paszke
f08801b8d6
[Pallas:MGPU] Allow indexing to appear anywhere in the list of transforms
...
We only need to exchange the transforms preceding the indexer, while
the rest can remain unmodified.
PiperOrigin-RevId: 688112088
2024-10-21 06:22:16 -07:00
Jake VanderPlas
a1140e9246
Better docs for jnp.cross
2024-10-21 05:59:22 -07:00
jax authors
f4b84e1c97
Merge pull request #24342 from gnecula:export_custom_types
...
PiperOrigin-RevId: 688093192
2024-10-21 05:08:04 -07:00
George Necula
2feea414ac
[export] Add support for serialization for some custom PyTree nodes
...
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.
Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).
When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.
The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.
Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.
Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
2024-10-21 11:38:13 +02:00
jax authors
0d7ef9c9ca
Merge pull request #24403 from jakevdp:load-doc
...
PiperOrigin-RevId: 688048891
2024-10-21 02:19:32 -07:00
Yash Katariya
ca2d1584f8
Remove mesh_utils.create_device_mesh
from docs
...
PiperOrigin-RevId: 687695419
2024-10-19 15:48:42 -07:00
Jake VanderPlas
0a85ba5f82
Better documentation for jnp.load
2024-10-19 06:20:20 -07:00
Ayaka
884f1dc3a1
[Pallas TPU] Use new MLIR op names
...
PiperOrigin-RevId: 687454709
2024-10-18 16:14:27 -07:00
Adam Paszke
bbcc3eef3c
[Pallas:MGPU] Fix the implementation of WGMMA with transposed RHS
...
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.
PiperOrigin-RevId: 687350270
2024-10-18 10:31:42 -07:00
Yash Katariya
2153de4ce0
[sharding_in_types] If out_aval.sharding is not None and the user specified out_sharding is None, concretize it with the device assignment available and add it to the final out_shardings that's used for lowering and compilation.
...
This will allow us to return the exact sharding spec that sharding propagation rules figured out.
PiperOrigin-RevId: 687349015
2024-10-18 10:27:58 -07:00
Christos Perivolaropoulos
f8a3c0366b
[pallas] run_scoped now supports partial discharge.
...
PiperOrigin-RevId: 687347284
2024-10-18 10:22:31 -07:00
Adam Paszke
e138e8e49d
[Pallas:MGPU] Fix docstring for commit_shared
...
PiperOrigin-RevId: 687308732
2024-10-18 08:16:55 -07:00