1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-25 23:06:06 +00:00

17 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
45b71e5d83 ensure_index: raise better error for traced inputs 2022-09-12 12:28:14 -07:00
Neil Girdhar
ad38a6bb28 Fix common typo: Tuple[X] -> Tuple[X, ...] 2022-08-16 11:47:22 -04:00
Matthew Johnson
81b6263ed0 Rolling forward after test failures caused roll-back (from use of np.empty).
PiperOrigin-RevId: 465712458
2022-08-05 22:19:33 -07:00
jax authors
6b0c0dc321 Internal change
PiperOrigin-RevId: 465705931
2022-08-05 21:08:43 -07:00
Matthew Johnson
348da51dc6 prototype unfettered element types in jaxpr arrays
From where comes the set of element types in jaxprs? Historically, from NumPy
and XLA element types. But why would jaxprs be constrained to those? After all,
jaxprs are just symbols, my friends. Those symbols need to be grounded when we
translate to another compiler's IR, or when we have input or output values with
a jaxpr evaluation. So if we're lowering we need ways to map jaxpr types to
lowered IR types, and also ways to map any operations allowed on these types to
lowered IR operations. And we may want Python objects representing values of
these types. But once we have those mappings we don't need to be limited by
NumPy/XLA element types.

Within jaxprs, we also need to handle transformations with these types.

In this change we started unfettering jaxpr element types from their vestigial
NumPy/XLA constraints. Concretely, that means:
  * allowing ShapedArray to have any object for its 'dtype' attribute
  * added core.custom_eltype set
  * extended existing handlers for ShapedArray to call the corresponding custom
    element type handlers
  * mlir lowerings of some fully-element-type-polymorphic primitives
  * tests

In this PR, we only actually use these new extension points in tests.

The applications to come that we have in mind are:
  * arrays of prngkeys (and even custom prngs, as well as reuse error checking)
  * arrays of bounded int type for dynamic shapes (and especially raggedness)
  * float0 arrays
We do *not* have in mind opening these mechanisms up to users. Think of these
as yet another JAX-internal extension point, like all our existing 'handler'
tables.

Jargon-wise, we may want to distinguish:
  * 'eltype' meaning jaxpr element types
  * 'dtype' meaning numpy dtypes (an existing convention)
  * 'etype' meaning hlo/mhlo element types (an existing convention)
But the code doesn't model this jargon at the moment, since we left a lot of
attributes and helper functions referring to 'dtype'.

We haven't yet handled all the element-type-polymorphic primitives. Here's the
list we've thought of so far:
  * [x] broadcast
  * [ ] reshape
  * [x] transpose
  * [ ] pad
  * [x] slice, dynamic_slice, dynamic_update_slice
  * [ ] concatenate
  * [ ] all_to_all, gather, scatter, all_gather, collective_permute
  * [x] make empty scalar (only appears in internal-about-to-lower-jaxpr dialect)
That last one is interesting: we introduced it so that the scan lowering rule,
which lowers first to a "lowered jaxpr dialect" involving only those eltypes
which correspond to etypes and involving only while_loop, ds/dus, etc, can be
made simpler. Otherwise we'd need scan, itself a fully-eltype-polymorphic
primitive, have a more complicated lowering rule.

We also haven't handled AD. Our main applications (at least the first two
listed above) don't involve AD types, so it seemed good to skip for now.

Co-authored-by: Roy Frostig <frostig@google.com>
2022-08-05 19:23:55 -07:00
Matthew Johnson
148173630f add an optional fastpath for api_util.shaped_abstractify
also add a benchmark for it, 8.7ms -> 0.2ms on my machine

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
28672970bb fix grad(..., argnums=-1), regressed in 2022-05-11 11:19:22 -07:00
Jake VanderPlas
5d45458c7b api_util: make shaped_abstractify respect raise_to_shaped 2022-05-05 17:20:00 -07:00
Matthew Johnson
e7acb82b14 [remove-units] remove units from api_util.py 2022-04-26 12:31:08 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Peter Hawkins
6bda1e5dd8 [JAX] Require exact type equality using is for static arguments.
Fixes https://github.com/google/jax/issues/9273.

PiperOrigin-RevId: 424182826
2022-01-25 14:36:02 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00