41 Commits

Author SHA1 Message Date
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Gunhyun Park
94440c74c8 Register acos primitive to lower to CHLO acos.
Related: https://github.com/openxla/stablehlo/pull/2496
PiperOrigin-RevId: 689890774
2024-10-25 13:20:36 -07:00
Dan Foreman-Mackey
dbc03cf8e5 Re-land #23261 with appropriate compatibility checks.
PiperOrigin-RevId: 676092618
2024-09-18 12:40:53 -07:00
Dan Foreman-Mackey
69ba060957 Reverts e15ec1e8abe3732d747731c15a36facf4169739e
PiperOrigin-RevId: 675987338
2024-09-18 07:41:52 -07:00
jax authors
e15ec1e8ab Merge pull request #23261 from joaospinto:stablehlo.tan
PiperOrigin-RevId: 675973798
2024-09-18 06:56:28 -07:00
Jake VanderPlas
0e6650e89d filecheck test: use lax.cumsum directly to prevent false-positive 2024-09-04 12:31:19 -07:00
Kevin Gleason
5e897c61f5 Integrate StableHLO at openxla/stablehlo@8817ff1d
PiperOrigin-RevId: 652528759
2024-07-15 10:38:09 -07:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Pearu Peterson
6d8b3e4cff Fix complex sin and cos on inputs with small absolute value or large pure imaginary part 2024-02-22 23:42:18 +02:00
Peter Hawkins
d0a6813ea2 Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call().
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.

Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.

This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).

Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.

PiperOrigin-RevId: 561042402
2023-08-29 08:50:07 -07:00
Peter Hawkins
e99ca460e1 Fix symbol collision when merging MLIR modules.
PiperOrigin-RevId: 542039479
2023-06-20 13:53:30 -07:00
Anish Tondwalkar
6842e98ca1 Migrate regularized_incomplete_beta_p off xla_fallback
PiperOrigin-RevId: 519244597
2023-03-24 14:53:20 -07:00
Anish Tondwalkar
ac44d2c2e3 Migrate besseli0e off xla_fallback
PiperOrigin-RevId: 519241252
2023-03-24 14:39:40 -07:00
Anish Tondwalkar
8c75e27f67 Migrate random_gamma_grad off xla_fallback
PiperOrigin-RevId: 519154537
2023-03-24 08:49:40 -07:00
Anish Tondwalkar
8d1d522618 Migrate igamma_grad_a_p off xla_fallback
PiperOrigin-RevId: 519148548
2023-03-24 08:21:22 -07:00
Anish Tondwalkar
4a9b09485e Migrate igammac_p off xla_fallback path
It is now decomposed into stablehlo ops.

PiperOrigin-RevId: 519122775
2023-03-24 05:58:38 -07:00
Anish Tondwalkar
f981243af5 Migrate igamma_p off xla_fallback
We decompose it into a series or a call to igammac.

PiperOrigin-RevId: 518993077
2023-03-23 16:26:59 -07:00
Anish Tondwalkar
3bad6fa223 [CHLO] Add erf_inv and lowering to mhlo
PiperOrigin-RevId: 513183138
2023-03-01 02:52:52 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
jax authors
978dcde8d6 MHLO Pretty Print - Enhance type printing for CopyOp, ClampOp, CstrReshapeOp, ComputeReshapeShapeOp, SelectOp.
Based on:
https://github.com/openxla/stablehlo/pull/37

PiperOrigin-RevId: 484271777
2022-10-27 09:27:00 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Bixia Zheng
0f089e1901 Lower bessel_i1e primitive to chlo.bessel_i1e.
PiperOrigin-RevId: 467996329
2022-08-16 12:39:28 -07:00
Bixia Zheng
bb92038b6f Change jax to lower the asin and atan primitives to their corresponding chlo
ops.

PiperOrigin-RevId: 466766999
2022-08-10 13:05:29 -07:00
jax authors
8e2c68cbe4 MHLO CompareOp pretty printing
PiperOrigin-RevId: 466051458
2022-08-08 08:43:56 -07:00
Benjamin Kramer
9e16efa98a Integrate LLVM at llvm/llvm-project@71c9757474
Updates LLVM usage to match
[71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3)

PiperOrigin-RevId: 460299215
2022-07-11 14:21:09 -07:00
Eugene Burmako
1fdc9a2b84 [MHLO] Add CHLO lowering for top_k
This follows the introduction of chlo.top_k in the diffbase.

PiperOrigin-RevId: 447581682
2022-05-09 15:45:04 -07:00
Eugene Burmako
0ed29b63f0 [MHLO] Add MHLO lowering for erf and erfc
erf implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=319-336;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443665435
2022-04-22 07:54:23 -07:00
Eugene Burmako
5f16873aad [MHLO] Switch tan to use CHLO lowering
Currently, it's desugared to sin(x)/cos(x) with upcast because CHLO_TanOp
legalization doesn't support complex numbers.

tan implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1175-1177;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443649394
2022-04-22 06:28:01 -07:00
Eugene Burmako
636345fd67 [MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO
The affected ops are: acosh, asinh and atanh
(in addition to cosh which was fixed a few days ago).

acosh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1181-1216;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

asinh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1218-1270;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

atanh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1272-1292;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443590210
2022-04-22 00:39:02 -07:00
Peter Hawkins
1bed70590a [MHLO] Switch call_tf to use an MHLO lowering (attempt 2).
In passing refactor and fix some bugs in the MHLO helper code:
* mlir.ir_constant() failed to propagate its canonicalize_types argument to its callee.
* Refactor the code to convert an XLA computation to an MHLO module and to merge two MHLO modules from the XLA fallback translation rule path.
* Fix symbol (alpha) renaming of call operator callees when merging MHLO modules.

Attempt 2: In this iteration of the merge_mhlo_modules function, move all the operators into the target module first before doing any symbol table manipulation.

PiperOrigin-RevId: 442904129
2022-04-19 14:00:00 -07:00
jax authors
bbb2bf1b9a [MHLO] Switch call_tf to use an MHLO lowering.
In passing refactor and fix some bugs in the MHLO helper code:
* mlir.ir_constant() failed to propagate its canonicalize_types argument to its callee.
* Refactor the code to convert an XLA computation to an MHLO module and to merge two MHLO modules from the XLA fallback translation rule path.
* Fix symbol (alpha) renaming of call operator callees when merging MHLO modules.

PiperOrigin-RevId: 442803807
2022-04-19 07:17:22 -07:00
Peter Hawkins
d717e3f514 [MHLO] Switch call_tf to use an MHLO lowering.
In passing refactor and fix some bugs in the MHLO helper code:
* mlir.ir_constant() failed to propagate its canonicalize_types argument to its callee.
* Refactor the code to convert an XLA computation to an MHLO module and to merge two MHLO modules from the XLA fallback translation rule path.
* Fix symbol (alpha) renaming of call operator callees when merging MHLO modules.

PiperOrigin-RevId: 442798170
2022-04-19 06:45:21 -07:00
Eugene Burmako
b267fd4336 [MHLO] Add MHLO lowering for cosh
Here's the corresponding old bridge lowering:
https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1294-1309;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

getConstantLike wasn't supporting complex numbers, and proper support required
non-trivial work, so in the meanwhile I've hacked up something that works
for static shapes to unblock the JAX use case (which currently only uses static shapes).

PiperOrigin-RevId: 442083258
2022-04-15 13:26:20 -07:00
Peter Hawkins
c2fe97ae01 Improve precision of chlo.sinh.
Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh

[JAX] Use chlo.sinh instead of the XLA client library HLO lowering.

PiperOrigin-RevId: 441851170
2022-04-14 14:10:26 -07:00
Peter Hawkins
1b8be90801 Remove the jax_enable_mlir flag. MLIR is now the only supported code path.
This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes.

PiperOrigin-RevId: 439324450
2022-04-04 08:40:09 -07:00
Sandeep Dasgupta
6cd9804163 Replace (deprecated) StrEnumAttr with EnumAttr.
ref: https://reviews.llvm.org/D120834
PiperOrigin-RevId: 435550738
2022-03-17 23:11:28 -07:00
Roy Frostig
6f519576f6 remove _reduce_sum from public jax.lax module 2022-03-08 16:34:26 -08:00
Peter Hawkins
18baa6e93b [MLIR] Add a @mlir.cache_lowering decorator that lowers a primitive out-of-line as a reusable function.
Some primitives have very large lowerings. This is particularly true for lowerings that use `mlir.lower_fun` (e.g., the threefry PRNG kernel) or some XLA fallback lowerings. In this case it makes sense to lower such computations once for each signature as an out of line function that we can call multiple times.

XLA will inline these functions early in compilation at the moment, but this avoids the need to repeatedly trace, e.g., the threefry kernel when emitting MHLO.

PiperOrigin-RevId: 416818325
2021-12-16 08:34:52 -08:00
Peter Hawkins
3969eec0e0 [MLIR] Keep MLIR IR longer as a Python ir.Module object rather than a string, until it is time to compile it.
Attach a meaningful module name, which is useful in logging, etc.

PiperOrigin-RevId: 415617591
2021-12-10 14:56:48 -08:00
Peter Hawkins
add967db88 [JAX] Add a dialect option to jit(...).lower(...).compiler_ir().
The dialect allows the user to select between HLO and MHLO output.

PiperOrigin-RevId: 415591372
2021-12-10 13:02:25 -08:00
Peter Hawkins
eafaafd624 Add some initial filecheck tests for JAX->MHLO lowering.
The coverage of this test suite is not complete, but it's a start.

PiperOrigin-RevId: 415560462
2021-12-10 10:59:24 -08:00