549 Commits

Author SHA1 Message Date
George Necula
c70de6deed [better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
2025-02-02 06:23:03 +02:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Jake VanderPlas
cfa38841ce Re-deprecate a number of symbols from jax.core 2024-12-16 13:04:06 -08:00
Jake VanderPlas
d3406768f0 temporarily un-deprecate several jax.core APIs.
These were causing excessive log-spam for some users; I'll work to migrate
them to jax.extend before re-deprecating these.
2024-12-12 13:15:58 -08:00
Fiona Lang
3f58337bbc Fix jax.core deprecation warnings for jax.extend.core.primitives symbols.
PiperOrigin-RevId: 705546724
2024-12-12 10:16:37 -08:00
jax authors
5e887b446b Merge pull request #25414 from jakevdp:finalize-deps
PiperOrigin-RevId: 705197214
2024-12-11 12:24:13 -08:00
Jake VanderPlas
0fe97f97c7 jax.core: remove private API
PiperOrigin-RevId: 705155279
2024-12-11 10:31:34 -08:00
Jake VanderPlas
f858a71461 Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client. 2024-12-11 09:50:33 -08:00
Jake VanderPlas
59b9eefd06 jax.core: more API deprecations 2024-12-10 20:27:28 -08:00
Jake VanderPlas
6541a62099 jax.core: deprecate a number of APIs 2024-12-10 11:11:32 -08:00
George Necula
c92507772c Cleanup more remnants of the jax.experimental.host_callback
Removes the outfeed rewriter mechanism and helper functions
`jaxpr_uses_outfeed`, which were used only by
`jax.experimental.host_callback`.
2024-11-12 03:27:10 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Jake VanderPlas
2b9c73d10d Remove a number of expired deprecations.
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Jake VanderPlas
e05c37c667 Finalize deprecation of pretty-printing utils in jax.core.pp_*
PiperOrigin-RevId: 678775782
2024-09-25 11:20:35 -07:00
Dougal Maclaurin
d2ac88c193 Expose some APIs for querying trace state. This will let us move users away from
depending on our internals. Prep work for "stackless".

PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dougal Maclaurin
018189491b Clean up and fix primal type to tangent type mapping
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 676115753
2024-09-18 13:43:54 -07:00
Jake VanderPlas
bd9698ec6d Deprecate several internal utilities in jax.core 2024-08-14 10:06:13 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Jake VanderPlas
fbcb157ad3 Finalize deprecation of several previously-deprecated jax.core functions:
- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`

These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.

PiperOrigin-RevId: 647671122
2024-06-28 07:28:28 -07:00
Jake VanderPlas
f63b94574a Deprecate internal pretty-printing APIs, jax.core.pp_* 2024-06-13 09:44:56 -07:00
Jake VanderPlas
bb5787da09 Finalize deprecations of several APIs
PiperOrigin-RevId: 633634215
2024-05-14 10:40:40 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
George Necula
cea77f5d17 Improve some deprecation error messages 2024-01-07 07:09:39 +02:00
Jake VanderPlas
adefbca642 jax.core: deprecate several private APIs 2023-12-15 13:37:09 -08:00
George Necula
0a02d83015 [shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0
Add core.max_dim and core.min_dim as nicer wrappers around the
core.non_negative_dim. Also improve the completeness of the
heuristics for deciding >= 0, and add more tests.
2023-12-07 09:41:47 +01:00
Jake VanderPlas
2edb66de8a jax.core: point deprecation to jax.extend 2023-10-13 12:49:05 -07:00
Jake VanderPlas
e0944c938f jax.core: deprecate some inadvertent exports 2023-10-11 15:22:19 -07:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Jake VanderPlas
3b6b988473 fix deprecations in core.py 2023-07-25 09:47:04 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
George Necula
aa85cd9a31 [shape_poly] Cleanup exported API symbols.
We remove the API functions related to shape polymorphism from the public API exported in jax.core. I could not remove a few API entry points because they
are referenced in Google. Will cleanup those uses next.

PiperOrigin-RevId: 545349000
2023-07-03 23:10:07 -07:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
eaf7eb2626 Break cycle between _src/core.py and _src/dtypes.py.
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
jax authors
59f33a4338 Expose JaxprDebugInfo so others can use it for pytyping.
PiperOrigin-RevId: 525749186
2023-04-20 08:09:26 -07:00
Peter Hawkins
1d4b7a3701 Hide accidental exports from jax.core.
PiperOrigin-RevId: 511350939
2023-02-21 17:48:40 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Roy Frostig
6b4de4f91c remove several more symbols from jax.core
* `DBIdx`
* `DConcreteArray`
* `DimensionHandler`
* `DuplicateAxisNameError`

PiperOrigin-RevId: 510503517
2023-02-17 13:07:00 -08:00
Roy Frostig
e276859d11 remove several symbols from jax.core
* `ClosedCallPrimitive`
* `CustomPpEqnRule`
* `DArray`
* `DArrayDimHandler`

PiperOrigin-RevId: 510343926
2023-02-16 22:55:16 -08:00
Matthew Johnson
ec1e513659 remove accidental re-export of __future__.annotations from jax/core.py
PiperOrigin-RevId: 510233347
2023-02-16 13:47:28 -08:00
Roy Frostig
591e2c8937 remove some exports from jax.core
Namely:
* `AvalMapHandlerPair`
* `AxisEnvFrame`
* `AxisName`
* `AxisPrimitive`
* `AxisSubst`
PiperOrigin-RevId: 510224417
2023-02-16 13:12:35 -08:00