22 Commits

Author SHA1 Message Date
Peter Hawkins
a4412e2715 Remove internal ndarray type name. Use Array throughout.
jax.numpy.ndarray remains an exported alias for jax.Array.

PiperOrigin-RevId: 513046188
2023-02-28 14:51:08 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
Jake VanderPlas
58323d5b40 jax.numpy reductions: better validation of initial value 2023-02-13 08:43:25 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
b037feb105 [x64] more type safety for lax_numpy-related tests 2022-12-01 11:18:02 -08:00
Jake VanderPlas
709ffd7e77 [typing] annotate jax.numpy reduction operations 2022-10-26 13:33:15 -07:00
Jake VanderPlas
32ef3ba37b jnp.average: support tuple axis 2022-10-06 10:20:46 -07:00
Jake VanderPlas
3be2087424 jnp.prod & jnp.sum: promote to default integer type rather than int64/uint64 2022-10-04 10:08:30 -07:00
Jake VanderPlas
1860f6d839 [x64] add promote_integers argument to jnp.prod & jnp.sum 2022-09-26 13:31:43 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
Jake VanderPlas
df800f39d3 jnp.average: support keepdims argument 2022-06-28 10:55:55 -07:00
Pavel Sountsov
ff637e12f1 Allow doing reductions on empty arrays in some cases.
Namely, when the reduction axis is not over the zero-sized dimension.
2022-06-14 21:57:56 +00:00
Jake VanderPlas
010e490128 [x64] make jax.numpy reductions respect input dtypes
Also make then compatible with strict dtype promotion mode.
2022-06-01 16:24:36 -07:00
Jake VanderPlas
1c555dc956 [x64] make jnp.average compatible with strict promotion 2022-06-01 14:25:35 -07:00
Jake VanderPlas
97a80ecb1d [x64] jax.numpy reductions: avoid binary promotion for upcast_bf16 2022-05-27 11:08:47 -07:00
Jake VanderPlas
9ab42ed2c6 [x64] handle strict promotion for jnp.var 2022-05-27 09:27:57 -07:00
Sergei Lebedev
c5d3ece6f5 MAINT Fixed new mypy errors
mypy seems to handle lambdas and named functions differently. So, I had to
promote a few helpers to named functions to get them to type check.
2022-05-23 20:21:00 +01:00
Jake VanderPlas
0e17c5acea maint: remove numpy reductions dependence on ufuncs 2022-03-23 15:56:59 -07:00
Jake VanderPlas
121d8d6320 Factor-out reductions from lax_numpy.py 2022-03-18 11:47:22 -07:00