14 Commits

Author SHA1 Message Date
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
jax authors
6095af050f Merge pull request #26427 from mattjj:direct-linearize-fixes
PiperOrigin-RevId: 734687601
2025-03-07 14:22:16 -08:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Matthew Johnson
0e30a3ace9 [mutable-arrays] read values should have the same explicit sharding as ref
fixes #26936
2025-03-07 20:53:29 +00:00
George Necula
1be801bac8 [better_errors] Cleanup use of DebugInfo.arg_names and result_paths
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Matthew Johnson
719031c1fd [mutable-arrays] persist shardings through xla computations 2025-02-07 18:33:24 +00:00
Matthew Johnson
b6482f126e add mutable array ref error checks to cond and custom_vjp 2024-12-20 01:44:50 +00:00
Matthew Johnson
e52856261f add mutable array ref error checks to scan 2024-12-19 01:33:39 +00:00
Matthew Johnson
42ac4ca357 ref errors 2024-12-18 07:46:14 +00:00
Dougal
fc2edbfac8 Add a freeze primitive to delimit ref lifetimes for AD.
Also some basic AD through mutable_array/freeze.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-12-09 20:57:07 -05:00
Sharad Vikram
80f963c003 Fix mutable array effects not being tracked properly
PiperOrigin-RevId: 680801564
2024-09-30 18:55:15 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00