15765 Commits

Author SHA1 Message Date
Yash Katariya
febd339742 [Micro-optimization] Only log the avals and shardings if logging is enabled for that level.
PiperOrigin-RevId: 524845969
2023-04-17 07:53:37 -07:00
jax authors
8ce19eea4f Merge pull request #15408 from gnecula:poly_sharding
PiperOrigin-RevId: 524823660
2023-04-17 05:52:12 -07:00
George Necula
961b0655fa [shape_poly] Lowering sharding annotations in presence of dynamic shapes
Sharding annotations are lowered to custom calls, and in presence of dynamic shapes
we must use the `indices_of_shape_operands` attribute to hlo.CustomCall.
In order to be able to generate the code to compute the result shapes
we must pass the `LoweringRuleContext` and the result abstract value
to the lowering helpers that generate the custom calls.

The above is easy everywhere, except for the sharding annotations for
the inputs and outputs for a function, because we do not yet have
a LoweringRuleContext available.

This code is tested by tests that are still disabled in sharding_test.
They can be enabled once StableHLO improves the support for
dynamic shapes for custom calls: https://github.com/openxla/stablehlo/issues/1367
2023-04-17 14:27:00 +03:00
jax authors
55fbe1c7b5 Merge pull request #15621 from froystig:issue14856
PiperOrigin-RevId: 524595421
2023-04-15 20:36:47 -07:00
Roy Frostig
e9061953f6 fix opaque dtype case of dtypes.dtype 2023-04-15 20:06:37 -07:00
jax authors
20896c1b2d Merge pull request #15619 from froystig:jit-threefry-seed
PiperOrigin-RevId: 524585035
2023-04-15 19:11:35 -07:00
jax authors
153de2684f Merge pull request #15620 from mattjj:custom-jvp-docs-typo
PiperOrigin-RevId: 524573729
2023-04-15 16:49:55 -07:00
Matthew Johnson
690071f1de fix custom_jvp docs typo 2023-04-15 14:51:33 -07:00
Roy Frostig
bf55dc947d jit the threefry seed function
The Threefry PRNG's seeding function involves operations with small
constants, such as `lax.shift_right_logical(seed, 32)`. This causes to
host-to-device transfers of small scalars (e.g. `32`) every time that
one seeds outside of a `jit`. To avoid these transfers, and any
inflexibility under JAX's transfer guard, we `jit` the seeding
function.

This shifts costs around a bit. Whereas previously we were moving
scalars to device on every (eager) seed call, we are now tracing and
compiling the seed function. The latter will only happen once per
input shape.
2023-04-15 10:38:46 -07:00
Mark Sandler
849e47f79a Makes deserializer put tensors on the device before releasing inflight memory, as well as avoids allocating memory before memory is available.
This makes in-flight memory limiter both reflective of the actual peak usage, as well as reduces peak usage since we no longer try to fully materialize sharded tensors on the host.

PiperOrigin-RevId: 524456216
2023-04-14 21:20:38 -07:00
Yash Katariya
9bb971f3ba Fix a bug in converting GSPMDSharding to PositionalSharding. Also assert that we are creating correct OpShardings (like the check in hlo_sharding.cc).
PiperOrigin-RevId: 524405474
2023-04-14 15:51:58 -07:00
Yash Katariya
673730c065 Add is_fully_replicated method to Shardings. This allows to scrub the usage of is_op_sharding_replicated from JAX because we can just query it on Shardings and save an expensive round trip to OpSharding creation.
PiperOrigin-RevId: 524379122
2023-04-14 13:56:33 -07:00
jax authors
88a5ffb2e8 Merge pull request #15563 from nouiz:multi-node-nightly
PiperOrigin-RevId: 524368556
2023-04-14 13:13:12 -07:00
jax authors
dc79db606f Merge pull request #15602 from mattjj:shmap-error-message
PiperOrigin-RevId: 524349769
2023-04-14 11:57:34 -07:00
Matthew Johnson
866e325582 [shard-map] tweak error message to suggest P() for scalars 2023-04-14 09:39:38 -07:00
Yash Katariya
d887fa0b56 [Micro-optimization] Add a fast path for single controller runtimes while calculating addressable_devices and addressable_devices_indices_map
PiperOrigin-RevId: 524301175
2023-04-14 08:46:54 -07:00
Yash Katariya
30c6871618 Deprecate and raise an exception for instantiate_const_outputs argument of jax.xla_computation since it has been unused for a very long time.
PiperOrigin-RevId: 524295738
2023-04-14 08:20:20 -07:00
Yash Katariya
10c4766f6c Remove unordered_effects from lower_jaxpr_to_module since it is unused
PiperOrigin-RevId: 524139972
2023-04-13 16:57:39 -07:00
jax authors
468d7720bf Merge pull request #15592 from mattjj:shmap-post-process-scalar-res
PiperOrigin-RevId: 524121962
2023-04-13 15:40:07 -07:00
Yash Katariya
c235f214d0 Create same Sharding objects wherever possible to get maximum cache hits
PiperOrigin-RevId: 524116574
2023-04-13 15:22:17 -07:00
jax authors
0fd5b2ca61 Remove use of int casting in STFT collapse of batch dimensions.
PiperOrigin-RevId: 524115535
2023-04-13 15:15:11 -07:00
Yash Katariya
3e93833ed8 Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
Also remove instantiate_const_outputs since that is unused

PiperOrigin-RevId: 524113088
2023-04-13 15:05:21 -07:00
jax authors
0fbaedf45e Merge pull request #15589 from jakevdp:math-prod
PiperOrigin-RevId: 524110280
2023-04-13 14:54:29 -07:00
Matthew Johnson
734a1d797e [shard-map] handle scalar residuals in post_process partial eval
This is the third time fixing one bug! Well, it was three different instances of the same conceptual idea: `shard_map`ped functions can't have scalar outputs (unless they're the same value for every function instance) because out_specs only let us express concatenation not stacking. So when autodiff of the body produces scalar residuals, we need to add a new axis to those scalars to be able to concatenate them in the output of the forward pass, and then remove that extra axis on the backward pass.

We had fixed that bug before in shard_map's partial eval rule, and then in its custom-policy partial eval rule. Here finally we're fixing it in the third and last place: in the post-process partial eval rule, which gets called when we differentiate with respect to a variable closed over by the `shard_map`ped function and none of its arguments.
2023-04-13 14:05:08 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
fb46d3d084 Add an optional devices option to PmapSharding.default so that we can provide a public API to create PmapShardings without having users to create sharding_specs.
PiperOrigin-RevId: 524034034
2023-04-13 10:14:37 -07:00
Peter Hawkins
2e524411db Add unregistered mhlo.num_replicas and mhlo.num_partitions attributes to HLO output.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.

PiperOrigin-RevId: 524013922
2023-04-13 08:55:44 -07:00
Yash Katariya
fdbad53b15 Make _device_assignment a Tuple[Device] so that we don't convert a list to a tuple and vice-versa everywhere
PiperOrigin-RevId: 524002310
2023-04-13 08:03:27 -07:00
Peter Hawkins
6c0aa9b18b Small optimization: Add a C++ helper for testing if an OpSharding is fully replicated.
This saves parsing an OpSharding into an HloSharding just to see if it is replicated.

We have ideas on how to not use OpShardings at all for this kind of thing, but they require larger refactorings.

PiperOrigin-RevId: 523981705
2023-04-13 06:13:35 -07:00
jax authors
5bb816c69b Merge pull request #15569 from LenaMartens:caching
PiperOrigin-RevId: 523974305
2023-04-13 05:29:06 -07:00
Yash Katariya
b06d627c05 Remove _allow_propagation_to_outputs from compile in MeshComputation since after jax.Array it is not required and can just default to being set to True if a sharding is unspecified.
PiperOrigin-RevId: 523851611
2023-04-12 17:38:18 -07:00
Roy Frostig
39f7e16c33 add CompilerOptions type alias for AOT compiler options
PiperOrigin-RevId: 523840981
2023-04-12 16:49:13 -07:00
Matthew Johnson
03e72e3b77 Copybara import of the project:
--
75a7e7a07d58e14de73190d060414fd3a1ba3d52 by Matthew Johnson <mattjj@google.com>:

Handle jaxpr-round-tripping of custom jvp rules w/ sym zero

fixes #14833

Co-authored-by: Roy Frostig <frostig@google.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15426 from mattjj:custom-jvp-symbolic-zeros-3 75a7e7a07d58e14de73190d060414fd3a1ba3d52
PiperOrigin-RevId: 523817551
2023-04-12 15:11:55 -07:00
jax authors
62e2860040 Merge pull request #15574 from skye:sharding
PiperOrigin-RevId: 523806105
2023-04-12 14:28:47 -07:00
Skye Wanderman-Milne
e2f1e7d28e Fix ndarray comparision in sharding_impls.py 2023-04-12 20:43:57 +00:00
Peter Hawkins
49e68dbe80 Add more return type annotations.
Fix a new pytype error by adding a checked cast.

PiperOrigin-RevId: 523780354
2023-04-12 12:54:07 -07:00
jax authors
bed81eb013 Merge pull request #15565 from jakevdp:unused-code
PiperOrigin-RevId: 523774097
2023-04-12 12:30:03 -07:00
Jake VanderPlas
fbc1ee2ba3 Remove some dead code and unused imports 2023-04-12 12:15:15 -07:00
George Necula
56aecb9559 [jax2tf] Fix backwards compatibility test to not use goldens for eigh
We now check the expected property for the eigenvectors, instead
of comparing them to the goldens. We still compare the eigenvalues
to goldens. The check is a copy of the similar check in linag_test.py:
6b7ae36f10/tests/linalg_test.py (L1653)

Also discovered a fixed that we were not passing the `rtol` parameter
correctly to tests.

PiperOrigin-RevId: 523762205
2023-04-12 11:45:21 -07:00
Lena Martens
d1438a1205 Checkify: close over all arguments.
This means you don't have to worry about passing in non-jax-types (like
strings) or marking arguments as static.

Fixes #15504.
2023-04-12 18:32:37 +01:00
jax authors
777480c257 Merge pull request #15545 from jakevdp:io-callback-doc
PiperOrigin-RevId: 523734211
2023-04-12 10:15:00 -07:00
jax authors
32f010c4a5 Merge pull request #15561 from gnecula:poly_top_k
PiperOrigin-RevId: 523724157
2023-04-12 09:42:23 -07:00
Peter Hawkins
199c63d6f7 Remove a stale comment about unimplemented MHLO lowerings.
PiperOrigin-RevId: 523723546
2023-04-12 09:35:29 -07:00
George Necula
536f9ce44b [shape_poly] Add support for shape polymorphism for lax.top_k
For now, only for graph serialization.
2023-04-12 19:20:36 +03:00
jax authors
51949502dc Merge pull request #15560 from gnecula:poly_one_hot
PiperOrigin-RevId: 523718306
2023-04-12 09:15:40 -07:00
Peter Hawkins
33acdc0e40 Delete some dead code that pertained to sharded_jit.
sharded_jit is long gone.

PiperOrigin-RevId: 523711890
2023-04-12 08:49:54 -07:00
Jake VanderPlas
3ca7d67e8d Fully implement and test axes argument to jax.scipy.signal.fftconvolve
PiperOrigin-RevId: 523707411
2023-04-12 08:31:30 -07:00
Frederic Bastien
0a55822c87 The date trick doesn't work, so try to sort instead. 2023-04-12 07:43:13 -07:00
Jake VanderPlas
cc7fc2e0af fftconvolve: adjust test tolerances
PiperOrigin-RevId: 523695809
2023-04-12 07:38:19 -07:00
Jake VanderPlas
055edf4a08 DOC: add docstrings for callback functions 2023-04-12 07:33:09 -07:00