2123 Commits

Author SHA1 Message Date
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
jax authors
6929a97c0c Merge pull request #24968 from nireekshak:testingbranch
PiperOrigin-RevId: 698051658
2024-11-19 09:50:34 -08:00
jax authors
9d3eda17fd Merge pull request #24942 from jeertmans:patch-1
PiperOrigin-RevId: 698031586
2024-11-19 08:44:30 -08:00
Jérome Eertmans
d912034cb5
fix(docs): typos in macro name
chore(docs): sync .md file
2024-11-19 16:42:19 +01:00
nireekshak
1458d3dd56 Fix some typos 2024-11-19 15:04:55 +00:00
barnesjoseph
d4316b5760 Adds font fallbacks 2024-11-18 14:46:10 -08:00
jax authors
2de40e7dbf Merge pull request #24916 from jakevdp:update-lp
PiperOrigin-RevId: 697652214
2024-11-18 09:19:09 -08:00
jax authors
efd232762c Merge pull request #24917 from emilyfertig:emilyaf-sharp-bits
PiperOrigin-RevId: 697020253
2024-11-15 15:37:16 -08:00
Emily Fertig
225a2a5f8b Consolidate material on PRNGs and add a short summary to Key Concepts. 2024-11-15 14:44:57 -08:00
barnesjoseph
81cdc882ae DOC: update main landing page style
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2024-11-15 13:44:31 -08:00
Emily Fertig
5f1e3f5644 Add an example on logical operators to the tutorial. 2024-11-15 12:40:41 -08:00
jax authors
4511f0c66b Merge pull request #24862 from emilyfertig:emilyaf-control-flow-tutorial
PiperOrigin-RevId: 696692588
2024-11-14 16:50:14 -08:00
jax authors
4fe9164548 Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
PiperOrigin-RevId: 696679735
2024-11-14 16:00:51 -08:00
jax authors
8e292122b7 Merge pull request #24567 from Intel-tensorflow:minigoel/intel-plugin
PiperOrigin-RevId: 696677564
2024-11-14 15:52:38 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
Emily Fertig
e6f6a8af8d Move Control Flow text from Sharp Bits into its own tutorial. 2024-11-14 11:07:52 -08:00
jax authors
426e13a5aa Merge pull request #24886 from carlosgmartin:fix_typos
PiperOrigin-RevId: 696272953
2024-11-13 14:33:38 -08:00
jax authors
a792fb0618 Merge pull request #24883 from jakevdp:doc-installation
PiperOrigin-RevId: 696269924
2024-11-13 14:25:29 -08:00
carlosgmartin
307e88f280 Fix typos: Change 'arugments' to 'arguments'. 2024-11-13 15:58:45 -05:00
Jake VanderPlas
72a4692b94 doc: link directly to installation on the main page 2024-11-13 12:44:05 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00
Dan Foreman-Mackey
f757054267 Update some outdated syntax in FFI tutorial. 2024-11-12 08:34:24 -08:00
jax authors
ddce8670a5 Merge pull request #24759 from froystig:docs-about
PiperOrigin-RevId: 694361967
2024-11-07 22:11:01 -08:00
Roy Frostig
0a42bf12c3 add about page
This is an initial draft. There is more to come back and add/improve.
2024-11-07 21:47:50 -08:00
jax authors
60a6cd475b Add note on etils requirement for the Jax compilation cache.
The compilation cache has a dependency on etils.epath if the
cache is not on a local filesystem.

PiperOrigin-RevId: 694311585
2024-11-07 18:08:52 -08:00
jax authors
563ecdf2a2 Merge pull request #24704 from andportnoy:patch-3
PiperOrigin-RevId: 693443674
2024-11-05 12:41:54 -08:00
jax authors
c1af808c8c Merge pull request #24710 from rajasekharporeddy:typos
PiperOrigin-RevId: 693412112
2024-11-05 11:06:20 -08:00
Robert Dyro
04f2ef9e93 Adding JAX_LOGGING_LEVEL configuration option 2024-11-05 09:56:46 -08:00
rajasekharporeddy
a80d027dd7 Fix Typos 2024-11-05 12:29:20 +05:30
8bitmp3
9f0e6237a3 Update JAX landing page - Flax 2024-11-05 00:28:31 +00:00
Andrey Portnoy
74da736e0e
Update link to algebraic_simplifier.cc to point to OpenXLA instead of TF 2024-11-04 18:02:44 -05:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Li-Jesse-Jiaze
5e1366c4ce Fix #24661: Add zsh support to conda install documentation 2024-11-01 17:57:18 +01:00
Matthew Johnson
26f70c9c16 remove busted example from shmap jep 2024-11-01 16:37:46 +00:00
jax authors
5a3ed6c792 Merge pull request #24647 from emilyfertig:emilyaf-doc-pytree-dataclass
PiperOrigin-RevId: 691984161
2024-10-31 17:16:31 -07:00
Emily Fertig
467bd09f03 Add a register_dataclass example to the pytree tutorial. 2024-10-31 16:26:42 -07:00
Dan Foreman-Mackey
ce8dba98fb Move the CUDA end-to-end example to FFI examples workflow + hosted
runner.
2024-10-31 12:21:51 -04:00
Sergei Lebedev
85662f6dd8 [pallas:mosaic_gpu] plgpu.copy_smem_to_gmem no longer transparently commits SMEM
Users are expected to call `pltpu.commit_smem` manually instead.

PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Jake VanderPlas
abf14323dc Adjust copyright notice.
Previously we had been pulling-in NumPy and SciPy docs at runtime, but
after the work in #21461 this is no longer the case.
2024-10-28 18:53:38 -07:00
minigoel
68428488c8 Add a link to Intel plugin for JAX 2024-10-28 10:47:59 -07:00
Sergei Lebedev
dfa6fcd56b [pallas:mosaic_gpu] Extracted a basic emit_pipeline API from the in kernel pipelining test
PiperOrigin-RevId: 690619853
2024-10-28 08:25:47 -07:00
jax authors
6e06110e1e Merge pull request #24538 from jakevdp:cumulative-prod
PiperOrigin-RevId: 690606656
2024-10-28 07:45:15 -07:00
Jim Lin
e4eca9ec59 #jax Adds a missing comma to Pallas Quickstart
PiperOrigin-RevId: 689907976
2024-10-25 14:14:11 -07:00
Jake VanderPlas
02daf75f97 Add new jnp.cumulative_prod function.
This follows the API of the similar function added in NumPy 2.1.0
2024-10-25 13:45:54 -07:00
jax authors
3b42a6b413 Merge pull request #24391 from keshavb96:remat_documentation
PiperOrigin-RevId: 689888674
2024-10-25 13:13:01 -07:00
Sergei Lebedev
5a2128e44b [pallas] Removed deprecated aliases to CostEstimate and run_scoped
PiperOrigin-RevId: 689871787
2024-10-25 12:16:58 -07:00
jax authors
8c9dc21e30 Update hermetic CUDA docs.
PiperOrigin-RevId: 689463215
2024-10-24 11:51:02 -07:00
jax authors
7ad73e44ce Merge pull request #24446 from gnecula:export_doc
PiperOrigin-RevId: 688886756
2024-10-23 02:50:57 -07:00
Peter Hawkins
e4f3f8f064 Use libtpu releases rather than libtpu-nightly for jax[tpu].
PiperOrigin-RevId: 688632409
2024-10-22 11:47:07 -07:00