55 Commits

Author SHA1 Message Date
George Necula
dd8ab85121 [jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
2021-06-12 21:08:37 +03:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
George Necula
1994f6df4a [jax2tf] Fix the round-trip call_tf(convert)
Also cleaned the handling of global state in jax2tf.
2021-06-11 11:57:27 +03:00
Skye Wanderman-Milne
063401f3ef Update jax version to 0.2.14 2021-06-10 13:15:53 -07:00
George Necula
59ae45a83c [jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.

In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.

For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.

The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.

I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.

For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-06-09 08:08:42 +02:00
George Necula
d243258b86 [jax2tf] Implement inequalities and friends for complex numbers.
This requires re-using JAX's lowering rule for comparisons of
complex numbers to use lexicographic comparison.
2021-06-04 17:56:44 +03:00
jax authors
ecab743e5c Merge pull request #6877 from hawkinsp:tracebacks
PiperOrigin-RevId: 377247694
2021-06-03 02:47:21 -07:00
George Necula
d03d849a19 [jax2tf] Fix the 32/64-bit behavior to follow JAX rules
JAX and TensorFlow have different behavior w.r.t. 32-64 bit
computations. This PR cleans up the handling of types in jax2tf
to ensure that we follow the same behavior in jax2tf and in JAX.

This means that f_jax(args) always does the computation with the
same precision as jax2tf.convert(f_jax)(args). This may mean that
the result of the conversion depends on the value of JAX_ENABLE_x64.

See README.md for more details.
2021-06-03 10:12:58 +03:00
Peter Hawkins
2882286b50 Add a --jax_traceback_filtering flag to control the traceback filtering mode.
Add a new traceback filtering mode that uses __tracebackhide__, and use it in IPython.
2021-06-02 16:25:37 -04:00
jax authors
8e6101c6a1 Merge pull request #6866 from gnecula:tf_pjit
PiperOrigin-RevId: 376989780
2021-06-01 22:50:12 -07:00
George Necula
2ad9c0c34c [jax2tf] Fix the scoping of the enable_xla conversion parameter
Previously, the global enable_xla flag was set upon entry to
`jax.convert`. It should instead be set only for the duration
of the just-in-time conversion, which may happen later when
the converted function is invoked.
2021-05-21 11:22:21 +03:00
Peter Hawkins
f83e309fe7 Update changelog. 2021-05-12 09:46:17 -04:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
Qiao Zhang
528d5bbb11 Update README etc for jaxlib 0.1.66 release. 2021-05-11 16:49:32 -07:00
jax authors
c31943cfe5 Merge pull request #6622 from hawkinsp:eightr
PiperOrigin-RevId: 372035283
2021-05-04 18:17:56 -07:00
Peter Hawkins
97e89bde18 Add a tridiagonal eigh solver. 2021-05-04 20:43:41 -04:00
Skye Wanderman-Milne
c7485b7a19 Bump jax version and changelog for jax 0.1.13 release 2021-05-03 16:32:00 -07:00
George Necula
d762ec1d21 [host_callback] Minor fix to use the new xla_shape.is_token 2021-04-28 12:22:32 +03:00
Peter Hawkins
79a7f7bca8 Don't build CUDA 11.2 wheels.
Update XLA.

CUDA 11.1 wheels are compatible with CUDA versions 11.1+, since NVidia now promises enhanced version compatibility between CUDA minor releases starting with CUDA 11.1
2021-04-26 09:43:29 -04:00
Jake VanderPlas
bb543f2b5b jnp.unique: add support for axis argument 2021-04-21 16:00:14 -07:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
Jake VanderPlas
8d17cce80e Add JIT-compatible version of jnp.nonzero 2021-04-20 09:18:49 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Peter Hawkins
fb2824bdbb [JAX] Add static_argnames support to jax.jit.
Requires a new jaxlib build.

Add support for static_argnames in C++ JIT implementation.

PiperOrigin-RevId: 367627359
2021-04-09 07:11:04 -07:00
Peter Hawkins
9fad2441a2 Mark arguments to jax.jit() other than the function as keyword-only.
This change is to prevent breakage when options are added or removed.
2021-04-08 10:32:35 -04:00
Skye Wanderman-Milne
f8f373466c Update README, etc. for jaxlib 0.1.65 release 2021-04-07 17:51:20 -07:00
George Necula
dce31e9631 [jax2tf] Fix handling of float0 2021-04-07 13:48:39 +03:00
Skye Wanderman-Milne
7b42011f7c Update jax version and changelog 2021-04-01 10:11:52 -07:00
Jake VanderPlas
2a091d2629 Update changelog for #5868 2021-04-01 09:29:22 -07:00
Peter Hawkins
3fc1fdb148 Add a JVP rule for the general case of lax.reduce. 2021-03-30 17:31:47 -04:00
Jake VanderPlas
c11e725ecb X32 mode: raise OverflowError for large integers 2021-03-30 10:05:03 -07:00
George Necula
d323ad0f2b [host_callback] Add support for tapping empty arrays
We make sure that both the inputs and the outputs of
callbacks can contain empty arrays.
Most platforms do not support empty infeed, so we ensure
we do not send those.
2021-03-30 10:48:58 +03:00
Jake VanderPlas
9790232556 Python integer conversion: always return int64 or OverflowError 2021-03-29 09:26:19 -07:00
Jake VanderPlas
40dac9425c pre-release omnistaging cleanup 2021-03-25 16:44:58 -07:00
Skye Wanderman-Milne
b68a08adf1 Add programmatic profiling APIs, and rename some existing APIs.
This change provides aliases for the renamed APIs so existing code
won't break. We should remove these aliases after the next release.
2021-03-25 11:50:55 -07:00
Peter Hawkins
7052a87ab6 Increase minimum jaxlib version to 0.1.64. 2021-03-24 13:49:48 -04:00
Matthew Johnson
c4a099093c update version and changelog for pypi 2021-03-23 19:33:04 -07:00
Skye Wanderman-Milne
0cbe2c1c05 Update README, etc. for jaxlib 0.1.64 release 2021-03-18 16:11:40 -07:00
Skye Wanderman-Milne
757247b791 Update README, etc. for jaxlib 0.1.63 release 2021-03-17 10:14:52 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Skye Wanderman-Milne
f06bb9a7f4 Update jaxlib version etc. 2021-03-09 17:55:40 -08:00
Roy Frostig
9c420653c3 move changelog to top level 2021-03-08 10:44:52 -08:00
Jake VanderPlas
94484d85aa Migrate CHANGELOG.rst -> CHANGELOG.md 2021-02-12 17:03:53 -08:00
George Necula
89514f9278
Moved CHANGELOG to docs (#2252)
* Moved CHANGELOG to docs

This puts the documentation also on RTD, with TOC.
Also changed its format to .rst, for consistency.
Added GitHub links to the change log.

* Actually add the CHANGELOG.rst

* Added reminder comments to the CHANGELOG.rst
2020-02-23 19:18:06 +01:00
Alexander Botev
43ee917511
Adding broadcast_argnums to pmap for allowing similar behaviour t… (#1786)
* Adding `static_argnums` to `pmap` for similar behaviour to `static_argnums` of `jit`.

* Removed check for ShardedDeviceArray

* Final clean up and rename.
2020-02-14 07:45:26 -08:00
Matthew Johnson
9e6fe64a66 bump version and update changelog for pypi 2020-02-11 07:22:17 -08:00
George Necula
20f9230f6e Simplify Jaxpr: remove the bound_subjaxpr field, all subjaxprs are in params.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).

For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.
2020-02-11 10:06:08 +01:00
George Necula
b18a4d8583 Disabled tests known to fail on Mac, and optionally slow tests.
Issue: #2166

Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known
to be slow.
2020-02-05 18:02:56 +01:00
George Necula
272620e66c Added note to CHANGELOG.md 2020-02-04 10:22:54 +01:00