23755 Commits

Author SHA1 Message Date
Ayaka
2b7b0742a4 [Pallas TPU] Add lowerings for bf16 jnp.ceil and jnp.floor in TPU v6+
This PR is similar to https://github.com/jax-ml/jax/pull/24284

Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support.

PiperOrigin-RevId: 688581914
2024-10-22 09:33:53 -07:00
Jake VanderPlas
48dd153e18 Better docs for jnp.insert 2024-10-22 09:20:48 -07:00
jax authors
2596a4059b Merge pull request #24412 from jakevdp:fromfunction-doc
PiperOrigin-RevId: 688576529
2024-10-22 09:18:42 -07:00
Jake VanderPlas
74fc7360f5 Require ml_dtypes >= 0.4.0
This is the minimum version that supports NumPy 2.0.
2024-10-22 09:04:28 -07:00
Jake VanderPlas
7e38cbd604 Better docs for jnp.fromfunction 2024-10-22 08:42:22 -07:00
jax authors
587832f295 Merge pull request #24442 from jakevdp:lexsort-doc
PiperOrigin-RevId: 688563766
2024-10-22 08:40:21 -07:00
jax authors
a2e4aff897 Merge pull request #24425 from dfm:rename-vmap-methods
PiperOrigin-RevId: 688547393
2024-10-22 07:51:29 -07:00
Christos Perivolaropoulos
4f9356361a [pallas] Support for setting explicit backends to pallas_call.
PiperOrigin-RevId: 688511303
2024-10-22 05:37:15 -07:00
Adam Paszke
2db03ba54b [Pallas:MGPU] Add support for grid dims in GPUMesh
Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.

PiperOrigin-RevId: 688488660
2024-10-22 04:10:46 -07:00
jax authors
0b3f0e11fb Reverts ebb75db8a523150c48376d15391f84380a2bb110
PiperOrigin-RevId: 688477769
2024-10-22 03:29:32 -07:00
Adam Paszke
84a303f32f [Pallas:MGPU] Allow allocating transformed refs in run_scoped
PiperOrigin-RevId: 688448592
2024-10-22 01:38:46 -07:00
Yash Katariya
ebb75db8a5 [sharding_in_types] Add out_type argument to einsum and dot_general to allow specifying for the output type. Right now, it only accept a NamedSharding but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout.
PiperOrigin-RevId: 688399552
2024-10-21 22:23:53 -07:00
Jake VanderPlas
8800fe2870 Better documentation for jnp.lexsort 2024-10-21 16:33:14 -07:00
Hernan Moraldo
5d3cac6603 Fix documentation.
PiperOrigin-RevId: 688293390
2024-10-21 15:29:59 -07:00
Praveen Narayanan
ad1aff098d Respect dot algorithm spec on TPU backends.
PiperOrigin-RevId: 688274131
2024-10-21 14:30:48 -07:00
jax authors
441aeebb29 Merge pull request #24420 from superbobry:maint-2
PiperOrigin-RevId: 688271404
2024-10-21 14:22:43 -07:00
Ezekiel Calubaquib
81bf626501 Move converter related tflite functions to tensorflow/lite repo
PiperOrigin-RevId: 688270228
2024-10-21 14:19:30 -07:00
jax authors
11eeff072f Merge pull request #22410 from garymm:patch-1
PiperOrigin-RevId: 688265373
2024-10-21 14:05:22 -07:00
Sergei Lebedev
3ad1985e1a Bumped mypy and ruff versions used by pre-commit 2024-10-21 21:58:41 +01:00
jax authors
5de878085a Merge pull request #24437 from jakevdp:fix-diff
PiperOrigin-RevId: 688261632
2024-10-21 13:56:14 -07:00
jax authors
1260951d57 Merge pull request #24434 from justinjfu:tutorial_mesh_update
PiperOrigin-RevId: 688259634
2024-10-21 13:50:23 -07:00
Jake VanderPlas
66971a2869 Fix jnp.diff for boolean inputs 2024-10-21 13:35:13 -07:00
Justin Fu
0b46a236c1 Update Pallas distributed tutorials with jax.make_mesh 2024-10-21 12:49:56 -07:00
jax authors
16fca386a3 Update XLA dependency to use revision
76da730179.

PiperOrigin-RevId: 688222632
2024-10-21 12:03:12 -07:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
jax authors
4a5ca2fd00 Merge pull request #24400 from jakevdp:subtract-ufunc
PiperOrigin-RevId: 688190106
2024-10-21 10:38:52 -07:00
jax authors
65307abd81 Merge pull request #24370 from dfm:ffi-call-to-callable
PiperOrigin-RevId: 688188390
2024-10-21 10:34:56 -07:00
Gary Miguel
dc908b4843 Update installation instructions
Apple GPUs and Mac x86_64 is a non-existent combination.
Mac x86_64 with AMD GPU is supported.

It's a bit of a confusing situation so hard to summarize, but hopefully this is more accurate and less confusing

Fixes: #24408
2024-10-21 10:20:09 -07:00
Jake VanderPlas
6467d03925 Make jnp.subtract a ufunc 2024-10-21 10:11:51 -07:00
Ezekiel Calubaquib
ad53addb74 Move out mnist py/Jax tensorflow lite tests to tensorflow lite repo
PiperOrigin-RevId: 688178268
2024-10-21 10:08:21 -07:00
jax authors
e29b93ff3e Merge pull request #24421 from jakevdp:cross-doc
PiperOrigin-RevId: 688175417
2024-10-21 10:01:45 -07:00
Dan Foreman-Mackey
0b651f0f45 Make ffi_call return a callable 2024-10-21 12:16:57 -04:00
jax authors
fe83d888b9 Merge pull request #24417 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 688159150
2024-10-21 09:10:45 -07:00
rajasekharporeddy
02f65bb11a Update warning message for jit of pmap 2024-10-21 21:17:59 +05:30
Yash Katariya
783285a71c FIx jax2tf breakge of iota
PiperOrigin-RevId: 688146581
2024-10-21 08:30:49 -07:00
Adam Paszke
f833891c87 [Pallas:MGPU] Add support for passing in WGMMA lhs from registers
PiperOrigin-RevId: 688117316
2024-10-21 06:42:18 -07:00
Adam Paszke
f08801b8d6 [Pallas:MGPU] Allow indexing to appear anywhere in the list of transforms
We only need to exchange the transforms preceding the indexer, while
the rest can remain unmodified.

PiperOrigin-RevId: 688112088
2024-10-21 06:22:16 -07:00
Jake VanderPlas
a1140e9246 Better docs for jnp.cross 2024-10-21 05:59:22 -07:00
Nitin Srinivasan
a2bc8c2e07 Remove temporary aliases from .bazelrc
These aliases were added to not break existing presubmit builds. Now that the presubmit builds have been updated, these aliases can be removed.

Also, corrects some comments.

PiperOrigin-RevId: 688096364
2024-10-21 05:20:13 -07:00
jax authors
f4b84e1c97 Merge pull request #24342 from gnecula:export_custom_types
PiperOrigin-RevId: 688093192
2024-10-21 05:08:04 -07:00
George Necula
2feea414ac [export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
2024-10-21 11:38:13 +02:00
jax authors
0d7ef9c9ca Merge pull request #24403 from jakevdp:load-doc
PiperOrigin-RevId: 688048891
2024-10-21 02:19:32 -07:00
jax authors
33a73852eb Update XLA dependency to use revision
8a7920d699.

PiperOrigin-RevId: 687898456
2024-10-20 13:03:47 -07:00
Yash Katariya
ca2d1584f8 Remove mesh_utils.create_device_mesh from docs
PiperOrigin-RevId: 687695419
2024-10-19 15:48:42 -07:00
jax authors
77fb1eee11 Update XLA dependency to use revision
d0d716fb63.

PiperOrigin-RevId: 687675747
2024-10-19 13:38:07 -07:00
jax authors
48bddc6f6c Adds arith.select to the op patters in order to canonicalize non 32 bit selects.
PiperOrigin-RevId: 687635492
2024-10-19 09:09:06 -07:00
Jake VanderPlas
0a85ba5f82 Better documentation for jnp.load 2024-10-19 06:20:20 -07:00
Ayaka
884f1dc3a1 [Pallas TPU] Use new MLIR op names
PiperOrigin-RevId: 687454709
2024-10-18 16:14:27 -07:00
jax authors
22426519b7 Update XLA dependency to use revision
7e3b0097bd.

PiperOrigin-RevId: 687427622
2024-10-18 14:35:02 -07:00
Adam Paszke
bbcc3eef3c [Pallas:MGPU] Fix the implementation of WGMMA with transposed RHS
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.

PiperOrigin-RevId: 687350270
2024-10-18 10:31:42 -07:00