715 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dan Foreman-Mackey
56d0c695c9 Condition tan lowering on jaxlib version rather than forward compatibility mode.
PiperOrigin-RevId: 676436269
2024-09-19 09:03:51 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Dan Foreman-Mackey
dbc03cf8e5 Re-land #23261 with appropriate compatibility checks.
PiperOrigin-RevId: 676092618
2024-09-18 12:40:53 -07:00
Dan Foreman-Mackey
69ba060957 Reverts e15ec1e8abe3732d747731c15a36facf4169739e
PiperOrigin-RevId: 675987338
2024-09-18 07:41:52 -07:00
jax authors
e15ec1e8ab Merge pull request #23261 from joaospinto:stablehlo.tan
PiperOrigin-RevId: 675973798
2024-09-18 06:56:28 -07:00
Joao Sousa-Pinto
3f2bc9b608 Lower tan to StableHLO instead of CHLO.
Fixes #23259
2024-09-13 08:50:40 -07:00
Peter Hawkins
9c86fdec02 Make optimization_barrier a public lax API. 2024-09-06 00:18:57 +00:00
Jake VanderPlas
7b41583414 refactor jax.lax to not depend on jax.numpy 2024-09-01 07:49:49 -07:00
Yash Katariya
164b884f33 Fix failing tests in CI
PiperOrigin-RevId: 669357019
2024-08-30 09:49:58 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Peter Hawkins
ba5b081571 [numpy] Fix test failures under NumPy 2.0.
PiperOrigin-RevId: 664465687
2024-08-18 09:09:37 -07:00
jax authors
9785368c7f [Easy] Refactor ragged_dot transpose, combine ragged_to_dense
PiperOrigin-RevId: 663630185
2024-08-16 00:32:42 -07:00
Sergei Lebedev
c9142cbe75 Collapsed a few unnecessary `if TYPE_CHECKING` blocks 2024-08-12 13:08:55 +01:00
Roy Frostig
c54ffd41bc in dot docstring, format and link to dot_general 2024-08-11 12:44:50 -07:00
Sergei Lebedev
92b1f71314 Removed various ununsed functions
To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
George Necula
ffd2b00516 Add concretization error check in core.min_dim and core.max_dim
Fixes: #22751
2024-08-01 07:27:35 +02:00
jax authors
aeff5b61a9 Merge pull request #22080 from vfdev-5:add-device-kwarg-linspace-array
PiperOrigin-RevId: 656467191
2024-07-26 11:18:24 -07:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
vfdev-5
76d61f9d8f Added device kwargs to jnp.linspace, jnp.array, jnp.asarray 2024-07-26 00:36:34 +02:00
jax authors
76b4c70c23 Merge pull request #22628 from hawkinsp:broadcast2
PiperOrigin-RevId: 655779730
2024-07-24 19:17:25 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Peter Hawkins
52fa165d75 Simplify promote_shapes.
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.

Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
2024-07-24 19:42:16 -04:00
Sergei Lebedev
969431f1fc Removed unused `_broadcast_translate` 2024-07-22 22:47:49 +01:00
jax authors
9632a2d1a8 Add jvp and transpose rule for ragged dot.
The numerical accuracy test is perfect against the reference implementation, and somewhat loose against the alt grad implementation used for testing.

PiperOrigin-RevId: 654381378
2024-07-20 17:56:59 -07:00
jax authors
ac4ca35221 Merge pull request #22263 from hawkinsp:tuples
PiperOrigin-RevId: 653267867
2024-07-17 09:56:18 -07:00
Roy Frostig
e8d9a54b1b extend type annotation for lax.convert_element_type
... to also accept extended dtypes (as defined internally).

PiperOrigin-RevId: 651372438
2024-07-11 05:27:11 -07:00
Yash Katariya
0426388d31 Add sharding to convert_element_type_p primitive.
There are 2 reasons for doing this:

* Avoid an extra allocation by putting the output on the correct sharding that the user specified. If you device_put the output of `_convert_element_type`, then you pay the cost of 2 transfers which is not ideal at all since this path would be critical (when users use `device`) and we should avoid doing extra transfers at all costs.

* This will allow us to streamline `device` arguments being added to all `jnp` functions as we will have one place (`_convert_element_type`) which will handle the logic of putting things on the right device.

Also fixes: https://github.com/google/jax/issues/17422

PiperOrigin-RevId: 650621659
2024-07-09 07:33:29 -07:00
Peter Hawkins
3d5784a343 Don't wrap singleton ir.Types during HLO lowering.
This is similar to https://github.com/google/jax/pull/22211, but for MLIR types instead of MLIR values.
2024-07-08 12:24:45 -04:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Peter Hawkins
ac3cb6f954 Simplify mlir.dense_int_array.
The NumPy array conversion here is pointless and slightly slower than not doing it.

PiperOrigin-RevId: 647520922
2024-06-27 19:33:06 -07:00
jax authors
43dc4c1ff8 Fix the jax.lax.Precision documentation.
- Make it clear that this only impacts FP32 computations.
- Remove incorrect aliases, eg. 'bfloat16' for default. This does not do as advertised for GPU.
- explicitly specify GPU and TPU device-dependent behaviour.

PiperOrigin-RevId: 647342888
2024-06-27 09:22:07 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Yimei Sun
b37f51487d Remove the blocking for float16 dot on CPU platform to take advantage of CPU
platforms supporting float16 matmul computation for performance optimization.
With this PR change, JAX will allow dot float16 HLO being created. When the
HLO modules are processed during cpu compile stage in open xla, the
ChangeOpDataType pass will upcast the dot to float type if the CPU platform
does not support float16 computation, but for the platform supporting float16
computation, dot will stay as float16 type for execution.
2024-06-23 23:51:30 -07:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
jax authors
be1f4ba380 Merge pull request #21905 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 644068464
2024-06-17 11:04:28 -07:00
Sergei Lebedev
4913fff971 Rollback #21888, because it breaks multiple internal users
Reverts 193591b5c0b90ce498015b2e3d48950615253380

PiperOrigin-RevId: 643965549
2024-06-17 05:01:04 -07:00
rajasekharporeddy
b93da3873b Fix Typos 2024-06-17 13:55:46 +05:30
Jake VanderPlas
4f7cd03893 lax.mul: accept boolean inputs 2024-06-14 13:47:11 -07:00
Jake VanderPlas
6b8e2f3467 DOC: jax.lax.top_k: fix docstring rendering & add example 2024-06-10 13:57:21 -07:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
Yash Katariya
0591620932 Fix copy.deepcopy support for arrays in pinned_host memory.
PiperOrigin-RevId: 639145872
2024-05-31 14:04:02 -07:00
George Necula
87b81fc768 [shape_polyO] Add support for jnp.tril. 2024-05-30 02:53:00 +03:00
Michael Levesque-Dion
43f51d73ce Clean up version switches from dense array migration
PiperOrigin-RevId: 637955865
2024-05-28 10:58:51 -07:00
jax authors
720d2b8708 Merge pull request #21376 from ROCm:ci_f8
PiperOrigin-RevId: 637884483
2024-05-28 06:56:26 -07:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Ruturaj4
41e4c25dc1 [ROCm] Add float8_e4m3fnuz and float8_e5m2fnuz support for Rocm 2024-05-22 05:50:28 +00:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
Elfie Guo
43d19161ac Remove type promotion for mixed fp8 matmuls. 2024-05-13 16:50:52 +00:00