George Necula
bf97e47929
Make infeed_test and host_callback_test independent. ( #3676 )
...
* Make infeed_test and host_callback_test independent.
* the infeed_test will stop the outfeed receiver
* Remove the use of --dist=loadfile.
* Prevent logging on exit
2020-07-07 11:03:30 +03:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. ( #3644 )
...
* Refactored host_callback to use the C++ runtime.
* The new runtime makes it unnecessary to start the outfeed_receiver
in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
I am trying to solve this by (a) making host_callback_test stop the
outfeed receiver on finish and infeed_test on start, and (b)
telling pytest-xdist to run all the tests from one file into
a single worker.
2020-07-04 18:12:58 +03:00
Matthew Johnson
75278309aa
refactor call primitives, simpler param processing ( #3491 )
2020-06-23 09:39:45 -07:00
Neil
046006e047
Fix typo: np.bool -> np.bool_ ( #3525 )
...
Replaced np.bool (which is just bool) with np.bool_, which is numpy's
Boolean type.
2020-06-22 19:43:25 -07:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. ( #3421 )
...
* Attach source info to Jaxpr equations.
Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
...: z = jax.numpy.cos(x)
...: z = z * jax.numpy.tanh(y)
...: return z + 2
...:
In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda ; a b.
let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)]
d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)]
e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)]
f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)]
g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)]
h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)]
i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)]
j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)]
in (f, j) }
In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15
ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
%constant.3 = pred[] constant(false)
%parameter.1 = f32[] parameter(0)
%cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%parameter.2 = f32[] parameter(1)
%tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
%constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
%negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
%multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) ( #3222 )
2020-06-08 17:50:14 +02:00
Jake Vanderplas
fb1717233a
Cleanup: deflake jax.experimental and jax.ops ( #3329 )
2020-06-05 19:00:04 -07:00
Adam Paszke
c5d870738c
Fix host_callback
2020-06-05 15:51:30 +00:00
Roy Frostig
dc4c9f0450
change cond primitive to an indexed conditional with multiple branch functions
...
in the core:
* bind and check cond primitive in indexed form
* rewrite abstract evaluation rule
* rewrite translation rule
* rewrite partial evaluation rule
* rewrite batching rule
* rewrite JVP rule
* rewrite transpose rule
* update jaxpr typechecker
* update pretty printer
* update outfeed-usage check
* update reference jaxpr in cond jaxpr test
* update reference regexes in HLO test
in experimental modules:
* update host_callback rewriter
* update loops expression builder
* generalize tf_impl rule
2020-06-03 22:19:15 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace ( #3287 )
2020-06-02 17:37:20 -07:00
Peter Hawkins
6ddac1d542
Disabled host_callback infrastructure for the HLO interpreter backend, which doesn't support infeed/outfeed. ( #3294 )
2020-06-02 13:16:13 -04:00
Peter Hawkins
34065df248
Add some type annotations to core and partial_eval. ( #3251 )
2020-06-01 21:45:36 -04:00
Roy Frostig
c5010cda47
use new gensym in host_callback jaxpr rewriter
2020-05-27 12:03:34 -07:00
James Bradbury
0eace80a6e
Fix experimental host callback on multi-host ( #3200 )
...
* Fix experimental host callback on multi-host
Hosts can only access the outfeed queue for local devices, while `api.devices` returns all devices in the system.
* Update host_callback.py
2020-05-25 08:12:58 +03:00
George Necula
afadb12b64
Improved tapping support for while: tap inside cond, vmap of while ( #3195 )
...
* Improved tapping support for while: tap inside cond, vmap of while
* Fix float64->float32 in tests
2020-05-24 10:50:07 +03:00
George Necula
b493a7e5df
Fix the handling of repeated vmap for id_tap ( #3132 )
...
* Fix the handling of repeated vmap for id_tap
* Updated the transforms to always be a tuple of tuples
* Changed the transforms to be dictionaries
2020-05-23 13:49:27 +03:00
Matthew Johnson
850f1afd95
improve errors for complex derivs, fixes #3121 ( #3149 )
2020-05-19 15:17:03 -07:00
Stephan Hoyer
77e31323f7
Fix indentation for docstrings in jax.experimental.host_callback ( #3119 )
2020-05-19 18:23:45 +03:00
Roy Frostig
cb5b1d10d9
handle single-operand cond in token-threading rewriter
2020-05-14 09:04:48 -07:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. ( #3046 )
...
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.
* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
George Necula
c171c33b1c
Update numpy references to use np. Added to Changelog ( #3029 )
2020-05-10 19:54:46 +03:00
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. ( #3006 )
...
This was already merged as #2791 but reverted due to XLA crashes.
This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a
Undo the id_print/id_tap feature (PR #2791 )
...
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
9f0795b8f1
Unified the eager and jit paths
...
Added error checking for outfeed_receiver not started to primitive computations
2020-05-07 16:24:13 +03:00
George Necula
d8b75e1913
Reimplemented the passing of tokens with a Jaxpr transform
2020-05-07 16:24:13 +03:00
George Necula
8fc96910c2
Improved documentation
2020-05-07 16:24:13 +03:00
George Necula
304009d772
Added error checking when starting compiled computations without starting
...
the outfeed receiver.
2020-05-07 16:24:13 +03:00
George Necula
cee5989a1d
Implemented pytree support for arg and result.
...
Enabled outfeed for all arrays as a tuple
2020-05-07 16:24:13 +03:00
George Necula
0515f8e42b
Added error handling for tap function errors
2020-05-07 16:24:13 +03:00
George Necula
0444f92795
Added support for sending all arrays in a single message
2020-05-07 16:24:13 +03:00
George Necula
653cad6302
Added support for multiple backends to outfeed receiver
...
Changed the encoding of the header to be uin32
2020-05-07 16:24:13 +03:00
George Necula
47cb5eaa86
Added masking transformation, added batch_dims to vmap
2020-05-07 16:24:13 +03:00
George Necula
931cb3f684
Ensure that we carry state only for control-flow conditionals that use print
2020-05-07 16:24:13 +03:00
George Necula
a16584d280
Fixed scan, and grad. Added multiplexing protocol.
2020-05-07 16:24:13 +03:00
George Necula
de685c9d5a
An experiment for id_print implemented with outfeed
...
* Added print descriptors, support multiple types
* Added a state-passing mechanism to XLA interpreter
2020-05-07 16:24:13 +03:00