Peter Hawkins 3fef74b2d0 [JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.

For example, one can now write things like:

```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
  func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
    %0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
    %1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
    %2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
    %3 = mhlo.add %2, %1 : tensor<1000xf32>
    return %3 : tensor<1000xf32>
  }
}
```

Fixes https://github.com/google/jax/issues/9226

PiperOrigin-RevId: 422855649
2022-01-19 11:04:48 -08:00
..
2021-10-04 17:54:46 -07:00
2021-10-04 17:54:46 -07:00
2021-10-20 22:15:35 +01:00
2021-11-18 11:01:50 +02:00
2021-10-04 17:54:46 -07:00
2021-11-30 15:47:50 -08:00
2022-01-18 09:59:26 -08:00