39 Commits

Author SHA1 Message Date
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
Yash Katariya
b337c26c72 Add donate_argnames to jax.jit. This works similarly to static_argnames.
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.

I'll fix the TODOs, etc in follow up CLs.

Fixes https://github.com/google/jax/issues/10539

PiperOrigin-RevId: 547612861
2023-07-12 15:09:57 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Yash Katariya
01fdd91a5f Use _to_xla_hlo_sharding everywhere in JAX. Remove _to_xla_op_sharding in favor of _to_xla_hlo_sharding since constructing a C++ class is faster than protos and will help with further changes coming to HloSharding.
PiperOrigin-RevId: 537969500
2023-06-05 13:41:31 -07:00
Yash Katariya
42a8982d23 Smuggle _experimental_lowering_platform via kwargs to make it hidden and extremely private temporary.
PiperOrigin-RevId: 532644979
2023-05-16 19:47:58 -07:00
Yash Katariya
b196ad2e8c Remove the f-string evaluation during logging the elapsed time by passing in fun_name to log_elapsed_time
PiperOrigin-RevId: 532132574
2023-05-15 09:15:58 -07:00
Cristian Garcia
36c6fce9d5 improve serial_loop docstring 2023-05-05 15:27:33 +00:00
Yash Katariya
34d5a6259f Default jax_spmd_mode to allow_jit which will allow explicit jax.jit to not raise the multihost error (since jit and pjit have been merged).
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.

PiperOrigin-RevId: 527398394
2023-04-26 15:56:46 -07:00
George Necula
961b0655fa [shape_poly] Lowering sharding annotations in presence of dynamic shapes
Sharding annotations are lowered to custom calls, and in presence of dynamic shapes
we must use the `indices_of_shape_operands` attribute to hlo.CustomCall.
In order to be able to generate the code to compute the result shapes
we must pass the `LoweringRuleContext` and the result abstract value
to the lowering helpers that generate the custom calls.

The above is easy everywhere, except for the sharding annotations for
the inputs and outputs for a function, because we do not yet have
a LoweringRuleContext available.

This code is tested by tests that are still disabled in sharding_test.
They can be enabled once StableHLO improves the support for
dynamic shapes for custom calls: https://github.com/openxla/stablehlo/issues/1367
2023-04-17 14:27:00 +03:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
49e68dbe80 Add more return type annotations.
Fix a new pytype error by adding a checked cast.

PiperOrigin-RevId: 523780354
2023-04-12 12:54:07 -07:00
Jake VanderPlas
fbc1ee2ba3 Remove some dead code and unused imports 2023-04-12 12:15:15 -07:00
Yash Katariya
5d1abe1ba9 Make apply_primitive preserve shardings on outputs.
PiperOrigin-RevId: 523186148
2023-04-10 12:41:02 -07:00
Peter Hawkins
be1cf46a49 Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
2023-04-10 10:15:58 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
a9e48af260 Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00
Matthew Johnson
268456ef54 enable pjit recursive typechecking
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).

This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!

It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
2023-03-22 16:59:22 -07:00
Yash Katariya
23d3dfd834 Remove _PositionalSemantics class since it is not used anymore because jax.Array always has GLOBAL semantics
PiperOrigin-RevId: 517493710
2023-03-17 13:30:04 -07:00
Yash Katariya
7c7c60eabf Remove in_positional_semantics and out_positional_semantics from xmap
PiperOrigin-RevId: 517477866
2023-03-17 12:24:26 -07:00
Yash Katariya
d02f28199b Clean up pjit after jax.Array
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed

PiperOrigin-RevId: 517469390
2023-03-17 11:53:00 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
634035abd7 Remove GDA from JAX since jax.Array is the default type and cannot be disabled anymore as per https://jax.readthedocs.io/en/latest/jax_array_migration.html#how-can-i-disable-jax-array-for-now
PiperOrigin-RevId: 516905931
2023-03-15 13:00:00 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00
George Necula
9a424aabbd [jax2tf] Clean up the support for cross-lowering.
In a previous CL we introduced cross-lowering support without any
changes in JAX core, but at the expense of some overly complex code
in jax2tf, along with overriding a JAX core function. Plus, those
changes were not enough to handle some xmap and pmap cases.

Here we introduce a `_experimental_lowering_platform: Optional[str]` parameter
to the `.lower()` methods and then we thread the `lowering_platform`
all the way to the calls to `mlir.lower_jaxpr_to_module2`. That's it.

Note that this parameter to `.lower()` is experimental and not supposed
to be used outside jax2tf. It may also gobble user kwargs.
2023-03-01 09:53:22 +01:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
Ikko Eltociear Ashimine
28f89f6244
Fix typo in maps.py
conjuction -> conjunction
2023-02-21 17:22:11 +09:00
jax authors
c0107cc836 Merge pull request #14549 from sharadmv:dbidx-effects
PiperOrigin-RevId: 510608031
2023-02-17 23:43:38 -08:00
Yash Katariya
d93aa70801 Replace op_sharding_sharding with gspmd_sharding. This is purely an internal change.
PiperOrigin-RevId: 510562354
2023-02-17 17:53:13 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Peter Hawkins
2b9ad0d93e Move contents of jax.experimental.global_device_array to jax._src.global_device_array.
Make jax.experimental.global_device_array a shim around jax._src.global_device_array.

Change in preparation for deprecating global device arrays.

PiperOrigin-RevId: 510261140
2023-02-16 15:37:10 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Sharad Vikram
442aa028c2 Fix xmap staging rule to handle positional semantics
PiperOrigin-RevId: 509356614
2023-02-13 16:05:17 -08:00
Peter Hawkins
4a523e3d74 Minimize exported names from jax.experimental.maps.
Move implementation of maps to jax._src.maps.

PiperOrigin-RevId: 509309092
2023-02-13 12:57:54 -08:00