135 Commits

Author SHA1 Message Date
George Necula
bf97e47929
Make infeed_test and host_callback_test independent. (#3676)
* Make infeed_test and host_callback_test independent.

* the infeed_test will stop the outfeed receiver
* Remove the use of --dist=loadfile.
* Prevent logging on exit
2020-07-07 11:03:30 +03:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
Matthew Johnson
75278309aa
refactor call primitives, simpler param processing (#3491) 2020-06-23 09:39:45 -07:00
Neil
046006e047
Fix typo: np.bool -> np.bool_ (#3525)
Replaced np.bool (which is just bool) with np.bool_, which is numpy's
Boolean type.
2020-06-22 19:43:25 -07:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 2020-06-08 17:50:14 +02:00
Jake Vanderplas
fb1717233a
Cleanup: deflake jax.experimental and jax.ops (#3329) 2020-06-05 19:00:04 -07:00
Adam Paszke
c5d870738c Fix host_callback 2020-06-05 15:51:30 +00:00
Roy Frostig
dc4c9f0450 change cond primitive to an indexed conditional with multiple branch functions
in the core:

* bind and check cond primitive in indexed form
* rewrite abstract evaluation rule
* rewrite translation rule
* rewrite partial evaluation rule
* rewrite batching rule
* rewrite JVP rule
* rewrite transpose rule
* update jaxpr typechecker
* update pretty printer
* update outfeed-usage check
* update reference jaxpr in cond jaxpr test
* update reference regexes in HLO test

in experimental modules:

* update host_callback rewriter
* update loops expression builder
* generalize tf_impl rule
2020-06-03 22:19:15 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Peter Hawkins
6ddac1d542
Disabled host_callback infrastructure for the HLO interpreter backend, which doesn't support infeed/outfeed. (#3294) 2020-06-02 13:16:13 -04:00
Peter Hawkins
34065df248
Add some type annotations to core and partial_eval. (#3251) 2020-06-01 21:45:36 -04:00
Roy Frostig
c5010cda47 use new gensym in host_callback jaxpr rewriter 2020-05-27 12:03:34 -07:00
James Bradbury
0eace80a6e
Fix experimental host callback on multi-host (#3200)
* Fix experimental host callback on multi-host

Hosts can only access the outfeed queue for local devices, while `api.devices` returns all devices in the system.

* Update host_callback.py
2020-05-25 08:12:58 +03:00
George Necula
afadb12b64
Improved tapping support for while: tap inside cond, vmap of while (#3195)
* Improved tapping support for while: tap inside cond, vmap of while

* Fix float64->float32 in tests
2020-05-24 10:50:07 +03:00
George Necula
b493a7e5df
Fix the handling of repeated vmap for id_tap (#3132)
* Fix the handling of repeated vmap for id_tap

* Updated the transforms to always be a tuple of tuples

* Changed the transforms to be dictionaries
2020-05-23 13:49:27 +03:00
Matthew Johnson
850f1afd95
improve errors for complex derivs, fixes #3121 (#3149) 2020-05-19 15:17:03 -07:00
Stephan Hoyer
77e31323f7
Fix indentation for docstrings in jax.experimental.host_callback (#3119) 2020-05-19 18:23:45 +03:00
Roy Frostig
cb5b1d10d9 handle single-operand cond in token-threading rewriter 2020-05-14 09:04:48 -07:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
George Necula
c171c33b1c
Update numpy references to use np. Added to Changelog (#3029) 2020-05-10 19:54:46 +03:00
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. (#3006)
This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
9f0795b8f1 Unified the eager and jit paths
Added error checking for outfeed_receiver not started to primitive computations
2020-05-07 16:24:13 +03:00
George Necula
d8b75e1913 Reimplemented the passing of tokens with a Jaxpr transform 2020-05-07 16:24:13 +03:00
George Necula
8fc96910c2 Improved documentation 2020-05-07 16:24:13 +03:00
George Necula
304009d772 Added error checking when starting compiled computations without starting
the outfeed receiver.
2020-05-07 16:24:13 +03:00
George Necula
cee5989a1d Implemented pytree support for arg and result.
Enabled outfeed for all arrays as a tuple
2020-05-07 16:24:13 +03:00
George Necula
0515f8e42b Added error handling for tap function errors 2020-05-07 16:24:13 +03:00
George Necula
0444f92795 Added support for sending all arrays in a single message 2020-05-07 16:24:13 +03:00
George Necula
653cad6302 Added support for multiple backends to outfeed receiver
Changed the encoding of the header to be uin32
2020-05-07 16:24:13 +03:00
George Necula
47cb5eaa86 Added masking transformation, added batch_dims to vmap 2020-05-07 16:24:13 +03:00
George Necula
931cb3f684 Ensure that we carry state only for control-flow conditionals that use print 2020-05-07 16:24:13 +03:00
George Necula
a16584d280 Fixed scan, and grad. Added multiplexing protocol. 2020-05-07 16:24:13 +03:00
George Necula
de685c9d5a An experiment for id_print implemented with outfeed
* Added print descriptors, support multiple types
* Added a state-passing mechanism to XLA interpreter
2020-05-07 16:24:13 +03:00