12 Commits

Author SHA1 Message Date
Peter Hawkins
3969eec0e0 [MLIR] Keep MLIR IR longer as a Python ir.Module object rather than a string, until it is time to compile it.
Attach a meaningful module name, which is useful in logging, etc.

PiperOrigin-RevId: 415617591
2021-12-10 14:56:48 -08:00
Peter Hawkins
add967db88 [JAX] Add a dialect option to jit(...).lower(...).compiler_ir().
The dialect allows the user to select between HLO and MHLO output.

PiperOrigin-RevId: 415591372
2021-12-10 13:02:25 -08:00
Trevor Cai
56f029f7f0 [jax] Add computation name to cache hit logging.
PiperOrigin-RevId: 414697336
2021-12-07 05:34:41 -08:00
Ryan Sepassi
ae0f4f6c55 Add profiler annotations
PiperOrigin-RevId: 414552750
2021-12-06 15:13:35 -08:00
Yash Katariya
14bc95fe1b Internal change
PiperOrigin-RevId: 414405773
2021-12-06 04:34:39 -08:00
Ryan Sepassi
ac0d02743c Use profiler.annotate_function 2021-12-03 10:39:34 -08:00
Ryan Sepassi
6f9c5abd38 Add profiler.annotate_function on some internals 2021-12-03 10:38:44 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Peter Hawkins
12512cc96a Merge most of the MLIR JIT dispatch logic into the common primitive and JIT computation path.
Change the representation of both units and tokens at the runtime level to be a single buffer with shape pred[0]. While the MLIR lowering is happy to have a non 1:1 mapping between avals and IR values, the XLA lowering is not, so until we remove the XLA lowering it's easiest just to keep the mapping 1:1.

PiperOrigin-RevId: 412957231
2021-11-29 12:40:05 -08:00
Peter Hawkins
52fe821719 Merge xla._partition_outputs and util.unflatten.
PiperOrigin-RevId: 412117736
2021-11-24 12:52:40 -08:00
Peter Hawkins
839d410de0 [MLIR] Move most MLIR translation rules into lax.
PiperOrigin-RevId: 411942327
2021-11-23 18:58:28 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00