5 Commits

Author SHA1 Message Date
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
George Necula
1e7642279b [shape_poly] Add another micro-benchmark
A very common operation is to compute linear combinations of expressions.

PiperOrigin-RevId: 605644781
2024-02-09 09:02:59 -08:00
George Necula
313caba4cb [shape_poly] Add micro-benchmarks for symbolic shape manipulation.
In some of the cases when we use many symbolic expressions for shapes, the operations with
symbolic expressions are becoming somewhat costly. These benchmarks help with picking
good internal representations.

PiperOrigin-RevId: 604289958
2024-02-05 05:38:25 -08:00