Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Matthew Johnson
28672970bb
fix grad(..., argnums=-1), regressed in #10453
2022-05-11 11:19:22 -07:00
Jake VanderPlas
5d45458c7b
api_util: make shaped_abstractify respect raise_to_shaped
2022-05-05 17:20:00 -07:00
Matthew Johnson
e7acb82b14
[remove-units] remove units from api_util.py
2022-04-26 12:31:08 -07:00
Jake VanderPlas
df1ceaeeb1
Deprecate jax.tree_util.tree_multimap
2022-04-01 14:51:54 -07:00
Peter Hawkins
6bda1e5dd8
[JAX] Require exact type equality using is
for static arguments.
...
Fixes https://github.com/google/jax/issues/9273 .
PiperOrigin-RevId: 424182826
2022-01-25 14:36:02 -08:00
Peter Hawkins
4e21922055
Use imports relative to the jax
package consistently, rather than .
-relative imports.
...
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.
PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
a11d957e61
Disallow non-hashable static arguments in pmap().
...
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
e869e5e0f8
Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
...
One of many changes to codify the set of exported symbols in the jax.* namespace.
PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00