981 Commits

Author SHA1 Message Date
che-shr-cat
d2c6c06546
Fix DeviceArray class reference 2022-01-10 17:34:09 +03:00
che-shr-cat
78977d6f5a fix broken links and update texts in thinking_in_jax.ipynb 2022-01-10 16:19:57 +03:00
Roman Novak
b9b759d4ff
Merge branch 'main' into conv_local 2022-01-07 09:51:46 -08:00
Jake VanderPlas
eba2ed2fd6 Update sphinx-related packages 2022-01-04 14:16:57 -08:00
jax authors
2e60850192 Merge pull request #9058 from che-shr-cat:main
PiperOrigin-RevId: 418917696
2021-12-30 01:39:40 -08:00
Grigory Sapunov
504728d8b6 link directly to the documentation for the jnp.ndarray.at property 2021-12-29 12:29:16 +03:00
Jake VanderPlas
b889282f6d docs: add FAQ section about jit compilation & numerics 2021-12-28 08:57:51 -08:00
Grigory Sapunov
f93531b020 replace deprecated jax.ops.index_* functions with the new index update operators 2021-12-27 20:23:29 +03:00
Vlad Feinberg
cd333f0f5b Fix straight-through estimator example in docs (#9032) 2021-12-21 22:25:12 +00:00
Jake VanderPlas
1f7d6316c2 doc: move stub section to bottom of FAQ 2021-12-15 16:19:14 -08:00
Matthew Johnson
0c68605bf1 add jax.block_until_ready to docs and changelog
also unrelatedly fix a couple of the uses of rst in changelog.md (though
many others remain)
2021-12-14 13:39:47 -08:00
jax authors
404c3c7d25 Merge pull request #8718 from jakevdp:config-doc
PiperOrigin-RevId: 413630185
2021-12-02 03:14:31 -08:00
jax authors
800aac8fd3 Merge pull request #8681 from jakevdp:numpy-faq
PiperOrigin-RevId: 413316336
2021-11-30 21:33:37 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Jake VanderPlas
0e4e30f4e5 DOC: add documentation for configuration functionality 2021-11-29 10:44:54 -08:00
Jake VanderPlas
4a72e57ce0 DOC: add FAQ section on JAX vs. Numpy performance 2021-11-24 12:04:02 -08:00
Matthew Johnson
8430deda3e custom pp_eqn rules, simpler xla_call print 2021-11-23 15:52:52 -08:00
Peter Hawkins
f3aa5fa92f Document lax.GatherScatterMode.
Recommend the .at[...] property in the docstrings for lax.scatter_ operators.

Add several missing lax.scatter_ operators to the index.
2021-11-22 15:43:02 -05:00
jax authors
f08a5a07a8 Merge pull request #8552 from mattjj:elide-more-convert-element-types
PiperOrigin-RevId: 411082070
2021-11-19 09:44:30 -08:00
Matthew Johnson
abbf78b5c3 generalize jaxpr simplification machinery
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts

This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.

But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.

In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.

These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
2021-11-19 09:00:59 -08:00
Peter Hawkins
58199b4b9a Delete the XLA in Python notebook.
Its tests are failing, and it describes a non-public API that we are phasing out.
2021-11-18 09:45:06 -05:00
Tianjian Lu
c5f73b3d8e [JAX] Added jax.lax.linalg.qdwh.
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00
Peter Hawkins
9ea55468ab [JAX] Update users of jax.ops.index... functions, which are deprecated.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 406162068
2021-10-28 09:54:26 -07:00
Jake VanderPlas
cfe0156f5c readthedocs: use new build configuration & update to Python 3.9 2021-10-27 20:44:25 -07:00
Jake VanderPlas
bae93ed9b1 DOC: pin jupyter-core to fix RTD build 2021-10-27 09:02:24 -07:00
Yash Katariya
821fcaa750 Make the pjit docs clear about who does local and global communication
PiperOrigin-RevId: 405421833
2021-10-25 09:37:15 -07:00
jax authors
d9ae5a1696 Merge pull request #8298 from jakevdp:lax-doc
PiperOrigin-RevId: 404618408
2021-10-20 12:53:23 -07:00
jax authors
09c2c9a24b [JAX] Export ann documentation.
PiperOrigin-RevId: 404615254
2021-10-20 12:39:42 -07:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Jake VanderPlas
94169b96a8 DOC: add conv_dimension_numbers and ConvGeneralDilatedDimensionNumbers to docs 2021-10-19 17:18:15 -07:00
Peter Hawkins
e783cbcb72 Port remaining translation rules inside JAX to new style.
PiperOrigin-RevId: 404288551
2021-10-19 09:48:37 -07:00
Peter Hawkins
1a73743610 Move xla_bridge.constant to jax.interpreter.xla.pyval_to_ir_constant.
This is a more descriptive name and a better location (next to other facilities for building XLA IR).

Quite a few users of the former xla_bridge.constant() didn't need anything other than uncanonicalized array constants. Change these users to use xla_client.ops.Constant instead; no need for the fancy utility in these cases.

PiperOrigin-RevId: 404270649
2021-10-19 08:40:51 -07:00
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
95f47074da Remove xla_bridge.{constant, register_constant_handler, _python_scalar_constant} from API.
An upcoming change will move and rename these functions, and it's not clear they should have been public in the first place.

PiperOrigin-RevId: 404051961
2021-10-18 13:56:58 -07:00
Peter Hawkins
714e19a794 Remove xla_bridge.make_computation_builder().
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.

Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
2021-10-18 13:20:34 -04:00
Basil
6fac5d36a1 Remove unused variable 2021-10-13 14:38:33 +01:00
Peter Hawkins
e81f5cdf39 Remove documentation for --xla_hlo_profile.
It doesn't work on CPU and GPU, and it's probably hurting more than it's helping to have it documented.
2021-10-06 09:43:32 -04:00
Peter Hawkins
104a46594b Add DeprecationWarnings to jax.ops.index_... operators.
Remove uses of index_... in Common Gotchas notebook.
2021-10-05 20:47:22 -04:00
Jake VanderPlas
198cc5ee4f chore: update jupytext to v0.1.13 & re-sync notebooks 2021-10-05 14:30:16 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Jake VanderPlas
c35b2f2485 DOC: move index update API docs to jnp.ndarray.at
- Add docstring to abstract  property
- Add explicit HTML documentation of this property
- Mark index update functions as deprecated, linking to this documentation
2021-10-01 14:06:08 -07:00
Peter Hawkins
d4023508a4 Uniquify variable names globally within a jaxpr.
It is confusing when the same name is shadowed within an inner lambda expression. Use globally unique variable names in each pretty-printed jaxpr.
2021-10-01 12:49:47 -04:00
jax authors
ef696a0b43 Merge pull request #8019 from hawkinsp:pprint
PiperOrigin-RevId: 399424971
2021-09-28 06:26:24 -07:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Jake VanderPlas
ff2bfc0e87 DOC: fix docstring and add docs. 2021-09-27 09:48:27 -07:00
Peter Hawkins
2a6f836d30 Fix rendering of lax.gather docs.
Remove spurious extra argument to `get()` in documentation.
2021-09-24 09:59:10 -04:00
jax authors
45ce1d3489 Merge pull request #7595 from hawkinsp:gather
PiperOrigin-RevId: 398623114
2021-09-23 18:18:41 -07:00
Skye Wanderman-Milne
2fcf3f7270 Remove .[minimum-jaxlib] from test-requirements.txt
This means that jax and its dependencies (e.g. jaxlib) must be
manually installed before running the tests. This is useful for
testing an existing jax install, e.g. a later version of jaxlib, GPU
jaxlib, etc.
2021-09-23 12:24:24 -07:00
Peter Hawkins
867068821e Drop out-of-bounds indexes in gather. 2021-09-23 10:35:03 -04:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00