79 Commits

Author SHA1 Message Date
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Yash Katariya
1a62df1ac0 Rename sharding argument to out_sharding for lax.reshape, lax.broadcast_in_dim, lax.broadcast and lax.broadcasted_iota. .bind of these APIs still take sharding as a parameter though (but that's fine since it's internal and not public facing)
PiperOrigin-RevId: 726187934
2025-02-12 13:59:23 -08:00
Jake VanderPlas
7bacfbc658 refactor: move array creation routines out of lax_numpy.py 2025-02-06 15:47:30 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Jake VanderPlas
45a352041c internal: check integer overflow in lax.asarray 2025-01-17 14:38:13 -08:00
Jake VanderPlas
4c926c8d4c Add ensure_arraylike utility for lax.numpy implementations 2025-01-16 16:46:11 -08:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Yash Katariya
39e4f7f2ce [sharding_in_types] Make jnp.where broadcast shardings properly when a scalar exists
PiperOrigin-RevId: 705283318
2024-12-11 16:41:18 -08:00
Jake VanderPlas
14030801a5 Remove obsolete implements() decorator & fix tests 2024-10-28 15:22:09 -07:00
Yash Katariya
57a95a77ff [sharding_in_types] Support jnp.array with sharding_in_types. When the input array has a sharding, propagate it through without dropping the sharding.
PiperOrigin-RevId: 687089357
2024-10-17 16:51:41 -07:00
Jake VanderPlas
36d6bb9013 Better docs for jnp.gradient
Also remove skip_params option from util.implements, as this was its last usage.
2024-09-30 13:07:52 -07:00
Jake VanderPlas
ad6c3a7f64 Improve docs for jnp.pad 2024-09-25 14:41:13 -07:00
Jake VanderPlas
8bd84913a6 Better docs for array, asarray, linspace
This allows removal of extra_params handling from util.implements
2024-08-09 13:34:50 -07:00
Peter Hawkins
52fa165d75 Simplify promote_shapes.
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.

Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
2024-07-24 19:42:16 -04:00
jax authors
f1cfd99fe8 Merge pull request #22625 from hawkinsp:broadcast
PiperOrigin-RevId: 655738756
2024-07-24 16:29:13 -07:00
Peter Hawkins
7527101672 Don't broadcast scalar conditions in the jnp.where implementation().
The underlying lax primitive is perfectly happy to accept scalar conditions with the other arguments being non-scalar.
2024-07-24 12:06:51 -04:00
Peter Hawkins
34ce9f21db Simplify implementation of _broadcast_to.
_broadcast_to needlessly squeezes away size 1 dimensions before passing its input to broadcast_in_dim. But broadcast_in_dim is perfectly happy to broadcast size 1 dimensions, so we don't need this squeeze.
2024-07-24 10:57:54 -04:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Jake VanderPlas
6de6983d59 jnp.broadcast_to: better error for invalid shape 2024-04-02 08:38:51 -07:00
Jake VanderPlas
43a9faa06a Rename _wraps to implements 2024-01-24 14:14:19 -08:00
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Jake VanderPlas
d77cd9a0f4 Add jax.numpy.astype function 2023-11-30 15:50:22 -08:00
Jake VanderPlas
ae662be5ef Fix typo in deprecation warning 2023-11-28 13:56:49 -08:00
Jake VanderPlas
96d9f89415 [random] better errors for unsupported operations on prng keys 2023-11-03 19:23:18 -07:00
Jake VanderPlas
1815bc7632 [typing] allow scalar shape for jnp.broadcast_to 2023-10-13 13:37:20 -07:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Jake VanderPlas
e8ebe462d2 Deprecate non-array inputs to several jax.numpy functions 2023-09-22 14:21:23 -07:00
Matthew Johnson
69ad4df9a5 fix pow_p jvp rule at x=0. y=0
fixes #14397

For autodiff purposes (and eventually for evaluation implementation purposes)
we need to distinguish between pow :: inexact -> int -> inexact (which is
differentiable at (0., 0)) and pow :: inexact -> inexact -> inexact (which
isn't); see https://github.com/google/jax/issues/14397#issuecomment-1426386290.

Instead of making a new primitive, we made the old one polymorphic and switch
its behavior on the element type of its second argument.

There were also some other cases with special handling for algorithmic reasons
(e.g. doing binary exponentiation), so these autodiff cases had to be merged
with those algorithmic cases.

Co-authored-by: Roy Frostig <frostig@google.com>
2023-07-28 17:14:47 -07:00
Jake Vanderplas
b4132b4c50 Copybara import of the project:
--
b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>:

Rename opaque dtype to extended dtype.

This includes three deprecations:
 - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended)
 - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended)
 - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype
Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be
subject to a standard 3-month deprecation period.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b
PiperOrigin-RevId: 550674205
2023-07-24 14:38:20 -07:00
jax authors
1b33a4eb05 Merge pull request #16815 from hawkinsp:py39
PiperOrigin-RevId: 550014612
2023-07-21 12:12:47 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
2ffa9bd8df Refactor opaque dtype implementation.
This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype,
and dtypes.opaque analogous to np.numeric. This will let us replace the
dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque).
2023-07-20 19:51:52 -07:00
George Necula
9261edaf94 [shape_poly] Cleanups for the shape polymorphism APIs.
Shape polymorphism relies on a number of functions defined
in core.py. Overtime we have accumulated some duplicate functionality
in those functions. Here we do some cleanups:

  * remove symbolic_equal_dim and symbolic_equal_shape in favor of the
    newer definitely_equal and definitely_equal_shape
  * remove is_special_dim_size, which checks that a value is a
    dimension expression (not a constant). Some uses are replaced
    with `not is_constant_dim` and others with `is_dim`.
  * introduce concrete_dim_or_error to check that a value is
    a dimension
2023-06-30 15:56:57 +03:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Alexey Radul
703aa8f025 Propagate more definitely_equal checks to fix more cases of broadcast_to. 2023-06-13 17:04:43 -04:00
Jake VanderPlas
6374a77176 KeyArray: remove _stackable registration mechanism 2023-04-24 15:06:22 -07:00
Jake VanderPlas
a5737f82af custom prng: remove stackable override for jnp.concatenate 2023-04-24 12:26:58 -07:00
Jake VanderPlas
8f72454bdf Add internal jax.lax.asarray utility 2023-03-30 10:21:55 -07:00
Jake VanderPlas
b308312986 jnp.arange: better validation of inputs 2023-03-14 16:41:58 -07:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Peter Hawkins
a4412e2715 Remove internal ndarray type name. Use Array throughout.
jax.numpy.ndarray remains an exported alias for jax.Array.

PiperOrigin-RevId: 513046188
2023-02-28 14:51:08 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Jake VanderPlas
8562e76e94 jax.numpy docstrings: remove empty parameters section 2022-10-28 14:13:29 -07:00
Jake VanderPlas
2f27d516d7 [typing] annotate next part of lax_numpy.py 2022-10-25 12:36:26 -07:00
Jake VanderPlas
48e680c839 CI: avoid raising error when wrapped function is None 2022-10-24 08:57:53 -07:00
Jake VanderPlas
069866e07a Add types to jax/_src/numpy/util.py 2022-10-04 10:07:38 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
d20dcf4b50 Merge pull request #11857 from jakevdp:fix-bool-args
PiperOrigin-RevId: 467259299
2022-08-12 11:45:35 -07:00
Jake VanderPlas
3f06195994 jax.numpy: improve support for boolean inputs 2022-08-12 09:51:25 -07:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00