9 Commits

Author SHA1 Message Date
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
28672970bb fix grad(..., argnums=-1), regressed in #10453 2022-05-11 11:19:22 -07:00
Jake VanderPlas
5d45458c7b api_util: make shaped_abstractify respect raise_to_shaped 2022-05-05 17:20:00 -07:00
Matthew Johnson
e7acb82b14 [remove-units] remove units from api_util.py 2022-04-26 12:31:08 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Peter Hawkins
6bda1e5dd8 [JAX] Require exact type equality using is for static arguments.
Fixes https://github.com/google/jax/issues/9273.

PiperOrigin-RevId: 424182826
2022-01-25 14:36:02 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00