170 Commits

Author SHA1 Message Date
Matthew Johnson
be3ca507db del add_any_p and zeros_like_p, replace aval-dispatched traceable 2023-12-21 17:04:21 -08:00
Alexey Radul
fbb587232c Rename Piles to Jumbles, to avoid unfortunate Imperial entanglements. 2023-07-14 15:19:49 -04:00
Alexey Radul
2daeec83ce Redefine the pile representation from concatenated to stacked-and-padded.
The advantage (already being realized) is that the batching rules
become much simpler: we just batch along the stacked axis as always,
and when a reduction is about to occur, also mask out the padding
elements, replacing them with the identity element of the reduction.

This commit

- Changes the intended representation of data for piles and the
  corresponding BatchTracers.
- Re-defines ConcatAxis as RaggedAxis to represent the metadata.
- Updates `defreducer` to require the identity function (in case
  masking is needed), and supplies it everywhere.
- Flushes batching.segment_sum, as it is dead code now.
- Deletes unpack_concat_axes and reassemble_concat_axes, because they
  are irrelevant to the padded representation.
2023-05-19 13:13:15 -07:00
Roy Frostig
a262314934 prune unintended exports from jax.interpreters.batching
PiperOrigin-RevId: 508784928
2023-02-10 16:47:28 -08:00
Roy Frostig
55c2b6dad6 move jax.interpreters.batching to jax._src.interpreters.batching
Re-export roughly all of the same symbols via `jax.interpreters.batching` for now.

PiperOrigin-RevId: 508107044
2023-02-08 09:51:00 -08:00
Yash Katariya
e21c29476d Add batch_jaxpr2 which tells the caller where batch dims are.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Jake VanderPlas
0241567c3a remove dead code 2022-11-30 12:02:53 -08:00
jax authors
d1e26d9c5d Merge pull request #13139 from mattjj:djax-vmap4
PiperOrigin-RevId: 488458141
2022-11-14 13:48:28 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Matthew Johnson
0b463efb70 tighten up vmap w/ piles: require pile_axis in_axes/out_axes 2022-11-08 10:27:55 -08:00
Matthew Johnson
f2f2faa4fa add a basic prototype of piles, behind jax_dynamic_shapes
Co-authored-by: Adam Paszke <apaszke@google.com>
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-11-06 17:03:04 -08:00
Matthew Johnson
6ebf44a681 make leak checker errors explain why objects are alive
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
2022-10-28 14:12:17 -07:00
Matthew Johnson
8e8ae8441f fix 2022-10-20 22:23:29 -07:00
Matthew Johnson
f76fc010e4 put back some unsafe_maps 2022-10-20 21:56:00 -07:00
Matthew Johnson
a1d303b081 [dynamic-shapes] fix nested vmap callable annotation logic 2022-10-19 17:01:53 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Parker Schuh
8fb957350c Add spmd_axis_name to vmap to allow constraining mapped PartitionSpecs. 2022-08-08 19:41:42 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Lena Martens
8efeb3e297 Fix getting aval of BatchTracers that are not mapped. 2022-06-23 17:28:45 +01:00
Matthew Johnson
b92c6b1e4d fix ad_checkpoint.checkpoint vmap rule 2022-05-05 13:31:27 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Lukas Geiger
50e8bc4514 Replace reshape with expand_dims if possible 2022-03-31 01:34:26 +01:00
Roy Frostig
0ada0a105e avoid batching units in cond partial eval
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-22 17:42:38 -07:00
Matthew Johnson
24a7afdbf4 improve batch_jaxpr caching from #9196
In #9196 I missed a related utilty function which needed memoization.

Co-authored-by: Adam Paszke <apaszke@google.com>
2022-03-07 12:46:28 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Matthew Johnson
e0fb424d81 use singleton dims in broadcasting binop batchers 2022-02-16 23:11:22 -08:00
Matthew Johnson
13ede5b2eb add origin info to leaked tracer error
add origin info to vmap tracers (BatchTracer)
2022-01-19 12:25:04 -08:00
Matthew Johnson
dc484bf450 Copybara import of the project:
--
06deb73c9be01cedc000efe7b3eb72d68615471a by Matthew Johnson <mattjj@google.com>:

cache initial-style jaxpr transformations

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/9196 from mattjj:issue3847 06deb73c9be01cedc000efe7b3eb72d68615471a
PiperOrigin-RevId: 422604879
2022-01-18 11:25:13 -08:00
Matthew Johnson
08aec823fd fix a custom_vjp post_process bug, related cleanups
related to #8783, doesn't completely fix it
2022-01-12 07:51:50 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Jake VanderPlas
e14eaf0664 cleanup: remove stray debugging breakpoint 2021-11-23 12:17:08 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Matthew Johnson
50e7e952bd add internal vmappable interface (part 1)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2021-11-04 15:01:54 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Sharad Vikram
cc3e197991 Combine initial_style_batchers with collective_rules 2021-09-09 11:23:51 -07:00
Adam Paszke
1158530faa Remove axis name from named_shape when unmapping avals
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.

PiperOrigin-RevId: 395423006
2021-09-08 01:42:15 -07:00
Sharad Vikram
d693324dab change while loop batching fixed point condition
Fixes #7063

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-09-02 19:39:03 -07:00
Adam Paszke
e410085392 Revert of "change while loop batching fixed point condition"
This change seems to have broken some other projects, so reverting.

PiperOrigin-RevId: 393869074
2021-08-30 15:40:40 -07:00
jax authors
bbfd8f7cfc Merge pull request #7206 from sharadmv:collective-cf
PiperOrigin-RevId: 393748817
2021-08-30 05:21:23 -07:00
Matthew Johnson
83f95a5dae custom_jvp/vjp tweaks and fixes 2021-08-17 17:51:35 -07:00
Matthew Johnson
2e6a30a595 always use same object for vmap temp axis name 2021-08-13 14:54:17 -07:00
Sharad Vikram
49f7ac22cc change while loop batching fixed point condition
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-07-28 14:37:40 -07:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00
Adam Paszke
bca3d61b3b Insert xmap SPMD axes into pjit sharding annotations
This should let us emit good XLA annotations for `xmap(pjit)`. Previously
we might have been overestimating the set of replicated mesh dimensions.

PiperOrigin-RevId: 377259226
2021-06-03 04:13:29 -07:00
Adam Paszke
8df502aeb2 Use the axis names attached to a primitive when selecting the top trace
This is useful e.g. for handling psums of values that are not sharded,
but are also not statically known constants that we can fold.
2021-04-28 09:46:24 +00:00
Peter Choy
eb9d6e4d21 Pass axis name to _match_axes and add to error message. 2021-04-22 13:34:04 +00:00