16775 Commits

Author SHA1 Message Date
Peter Hawkins
fe30d3fd4b Move _array_shard_arg helpers from pxla into array.
Refactoring only which fixes a TODO.

Add a canonicalize argument to pxla.shard_arg so we can call that API from array yet  avoid double-canonicalization.

PiperOrigin-RevId: 549658117
2023-07-20 09:48:10 -07:00
jax authors
08366b21a1 Merge pull request #15679 from mattjj:issue15676
PiperOrigin-RevId: 549656910
2023-07-20 09:40:03 -07:00
Jake VanderPlas
65751bb328 make jvp(asarray, (1.,), (2.,)) produce Arrays
fixes #15676

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-20 09:21:55 -07:00
Peter Hawkins
fa5915b34d Delete pxla.make_sharded_device_array.
This function is unused and not exported from JAX.

PiperOrigin-RevId: 549606907
2023-07-20 05:58:35 -07:00
jax authors
1ceddfc98a Merge pull request #16710 from gnecula:poly_max0
PiperOrigin-RevId: 549515427
2023-07-19 21:40:17 -07:00
jax authors
e2a49ee297 Tweaks the utility function _get_ppspec_from_executable to get the shardings directly from the executable (instead of from its HLO modules).
PiperOrigin-RevId: 549473458
2023-07-19 17:38:59 -07:00
jax authors
c006e52f1a Merge pull request #16779 from jakevdp:random-gamma
PiperOrigin-RevId: 549460066
2023-07-19 16:39:05 -07:00
jax authors
54adc744df Merge pull request #16794 from lgeiger:py39-version-checks
PiperOrigin-RevId: 549457819
2023-07-19 16:29:18 -07:00
Jake VanderPlas
7205160095 Re-parameterize jax.random.gamma for better behavior at endpoints 2023-07-19 16:15:03 -07:00
jax authors
cd951f4917 Merge pull request #16793 from lgeiger:numpy-version-checks
PiperOrigin-RevId: 549454320
2023-07-19 16:13:55 -07:00
Lukas Geiger
6812d5c0ca Remove unneeded Python 3.9+ version checks 2023-07-19 23:37:30 +01:00
Lukas Geiger
de2c8541be Remove obsolete numpy version checks 2023-07-19 23:33:47 +01:00
jax authors
0c4c020716 Include compile time along with executable in cache entry.
In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit.

Testing: updated tests.
PiperOrigin-RevId: 549439628
2023-07-19 15:17:45 -07:00
jax authors
5ae3ac28cd Add deprecation of jax.stages.Compiled.compiler_ir to the change log
PiperOrigin-RevId: 549415191
2023-07-19 13:48:55 -07:00
Yash Katariya
cd2dc2f2fa Error if memory_kind is not correct for the devices in Shardings during initialization.
PiperOrigin-RevId: 549410478
2023-07-19 13:32:39 -07:00
Peter Hawkins
7df3477926 [JAX] Use MLIR argument locations instead of a bespoke jax.arg_info attribute.
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.

Example of new output:

```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
    return %0 : tensor<i32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```

Note debug information must be enabled.

PiperOrigin-RevId: 549325621
2023-07-19 08:39:16 -07:00
Yash Katariya
f94104f71a Skip PgleTest.testPassingFDOProfile if xla_extension_version < 169
PiperOrigin-RevId: 549322105
2023-07-19 08:22:32 -07:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
George Necula
4fdc134543 [shape_poly] Add support for max0 for symbolic dimensions.
There are a few cases when JAX computes `max(v, 0)`, most
notably when computing the sizes of strided access,
dilated convolutions and padding, and for the size
of jnp.arange.

Until now these cases were supported
for shape polymorphism only when we can tell statically
that the size is >= 0. Here we add support to the
symbolic expressions for a `non_negative` operator,
which essentially implements `max(v, 0)` and with this
we can now support the general case for `jnp.arange`, with
simpler code.

We could add a general `max` operator, and we may do so in the
future, but for now `non_negative` suffices.

Note that this fixes a couple of bugs

  * for core.dilated_dim we had the code "if d == 0 then 0 else ..."
  but this works only if we can tell statically that `d == 0`, and
  it produced wrong results when `d` was symbolic and could take
  the value 0.
  * for core.stride_dim we did not handle correctly the case when
  `d < window_size`.

Handling the above fundamentally requires a `max(d, 0)` operation.
2023-07-19 16:15:04 +03:00
George Necula
e643f98558 [shape_poly] Reimplement the shape constraint checking using shape assertions.
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.

As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
2023-07-19 09:56:33 +03:00
jax authors
f97dca79a2 Merge pull request #16752 from gnecula:bc_schur
PiperOrigin-RevId: 549219190
2023-07-18 23:48:32 -07:00
George Necula
f29544bfb2 [jax2tf] Add backwards compatibility for lax.linalg.schur on CPU 2023-07-19 09:39:50 +03:00
jax authors
3f2bff5182 Merge pull request #16751 from gnecula:bc_triangular
PiperOrigin-RevId: 549216345
2023-07-18 23:31:08 -07:00
George Necula
0dd45ddcdd [jax2tf] Add backwards compatibility test for lax.triangular_solve on CPU 2023-07-19 08:40:22 +03:00
jax authors
4b72163423 Merge pull request #16775 from froystig:random-api-policy
PiperOrigin-RevId: 549122781
2023-07-18 15:06:56 -07:00
Tao Wang
b7686f41aa Enable passing fdo_profile in compiler_options in pxla.py
PiperOrigin-RevId: 549109629
2023-07-18 14:18:28 -07:00
Roy Frostig
9150b239ff add jax.prng to uncovered modules list in API policy 2023-07-18 14:13:25 -07:00
Roy Frostig
9aa5307e2f API compatibility policy: expand on numerics and randomness 2023-07-18 14:13:25 -07:00
Yash Katariya
579808d986 Add memory_kind to NamedSharding, SingleDeviceSharding, PositionalSharding and GSPMDSharding.
PiperOrigin-RevId: 548997870
2023-07-18 07:39:36 -07:00
Peter Hawkins
59509dc2b3 Remove the jax_array config option, which does nothing.
PiperOrigin-RevId: 548981491
2023-07-18 06:16:06 -07:00
jax authors
8016fb3b66 Merge pull request #16769 from jakevdp:fix-jax-array
PiperOrigin-RevId: 548835682
2023-07-17 16:51:13 -07:00
Jake VanderPlas
74159132b6 support np.array(x) where x is a custom pytree with __jax_array__ 2023-07-17 13:33:17 -07:00
jax authors
68ea651ae4 Merge pull request #16740 from jakevdp:spdot-general-args
PiperOrigin-RevId: 548773744
2023-07-17 12:52:33 -07:00
jax authors
909df91b2b Merge pull request #16754 from patrick-kidger:patch-5
PiperOrigin-RevId: 548759179
2023-07-17 12:00:00 -07:00
jax authors
1b1e74f944 Merge pull request #16767 from jakevdp:mypy-fix
PiperOrigin-RevId: 548752096
2023-07-17 11:35:07 -07:00
Patrick Kidger
8bce54e5cb Add type annotation to jnp.tensordot
Just stopping pyright from complaining at me.
2023-07-17 11:30:16 -07:00
Jake VanderPlas
4bb54d32d8 mypy: suppress annotation-unchecked notes 2023-07-17 11:18:48 -07:00
Yash Katariya
f123ba9730 Make _parsed_pspec a kw_only argument. This should be backwards compatible since you pass an arg as a kwarg too.
PiperOrigin-RevId: 548741701
2023-07-17 10:59:48 -07:00
jax authors
d8bc033bd2 Merge pull request #16759 from patrick-kidger:patch-6
PiperOrigin-RevId: 548737104
2023-07-17 10:45:13 -07:00
Artem Belevich
3a7857130f Disable tests triggering a known bug in cuda-12.
PiperOrigin-RevId: 548727901
2023-07-17 10:26:12 -07:00
Artem Belevich
d49b67a73a Disable tests that trigger a known bug in cublasDtrsmBatched in cuda-12 on sm_60.
PiperOrigin-RevId: 548727690
2023-07-17 10:17:21 -07:00
Parker Schuh
e9873c4683 Change visibility for jax_export to public.
PiperOrigin-RevId: 548713622
2023-07-17 09:25:21 -07:00
George Necula
b1985db27b [jax2tf] Remove unnecessary test
The `jax2tf_test.XlaCallModuleTest` was added in the early days of native serialization. Now we have much better testing through `xla_call_module_test.py` (near the `XlaCallModule` definition in TF) and through all the native serialization tests in jax2tf test suite.

PiperOrigin-RevId: 548701049
2023-07-17 08:42:57 -07:00
George Necula
ac7246902d [jax2tf] Re-enable some tests for native serialization
The custom calls for CPU have been fixed already.
We also fix a bug that effectively disabled all shape
polymorphism tests on GPU.

PiperOrigin-RevId: 548700405
2023-07-17 08:34:06 -07:00
Patrick Kidger
997d60ed1a
Fix docstring for jax.debug.{print,callback} 2023-07-17 15:58:58 +01:00
Yash Katariya
f0ce0d8c6a Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
2023-07-17 06:35:49 -07:00
George Necula
603eeb1901 Copybara import of the project:
--
06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula <gcnecula@gmail.com>:

[jax2tf] Added a flag and environment variable to control the serialization version.

This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c
PiperOrigin-RevId: 548504243
2023-07-16 09:27:12 -07:00
Alexey Radul
cd39128c09 Fix silly type error involving dims_out sometimes being a thunk and sometimes not.
PiperOrigin-RevId: 548343565
2023-07-15 05:13:43 -07:00
Jake VanderPlas
7986ba75c6 [sparse] support preferred_element_type in dot_general 2023-07-14 18:23:34 -07:00
Peter Hawkins
651f87733b Remove jax_jit_pjit_api_merge.
PiperOrigin-RevId: 548236671
2023-07-14 15:25:00 -07:00