38 Commits

Author SHA1 Message Date
Jake VanderPlas
c4169a0c76 make tests compatible with recent pillow versions 2022-07-22 13:09:52 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
Jake VanderPlas
617df70135 Unpin numpy to ensure most recent version is tested 2022-06-23 12:23:14 -07:00
Yash Katariya
1908da33af Only initialize GPU backends if they are not already initialized
PiperOrigin-RevId: 456664792
2022-06-22 19:39:52 -07:00
Jake VanderPlas
1f300e729b CI: pin pillow<9.1 to prevent deprecation warnings 2022-04-01 09:23:27 -07:00
Peter Hawkins
901d459e0d Add cloudpickle as a test requirement.
We have at least one test that tests pickling JAX objects.
2022-02-16 15:04:56 -05:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Skye Wanderman-Milne
2fcf3f7270 Remove .[minimum-jaxlib] from test-requirements.txt
This means that jax and its dependencies (e.g. jaxlib) must be
manually installed before running the tests. This is useful for
testing an existing jax install, e.g. a later version of jaxlib, GPU
jaxlib, etc.
2021-09-23 12:24:24 -07:00
Jake VanderPlas
a5b6a4e6a9 CI: remove flake8 from test requirements. 2021-08-25 11:07:09 -07:00
dependabot[bot]
9f2863c66b Copybara import of the project:
--
57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f by dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>:

Bump flatbuffers from 1.12 to 2.0

Bumps [flatbuffers](https://github.com/google/flatbuffers) from 1.12 to 2.0.
- [Release notes](https://github.com/google/flatbuffers/releases)
- [Commits](https://github.com/google/flatbuffers/compare/v1.12.0...v2.0.0)

---
updated-dependencies:
- dependency-name: flatbuffers
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7686 from google:dependabot/pip/flatbuffers-2.0 57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f
PiperOrigin-RevId: 392097862
2021-08-20 17:13:26 -07:00
Jake VanderPlas
cbcd6eeadb CI: bump mypy & flake8 versions to newest 2021-08-20 14:35:37 -07:00
Jake VanderPlas
7fa151c5c3 cleanup: remove redundant entry from test-requirements 2021-08-20 10:09:14 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
0de4a60834 Update pillow pin to >= 8.3.1.
8.3.1 fixed the issue from https://github.com/google/jax/pull/7166.
2021-07-07 08:33:29 -04:00
Jake VanderPlas
4ba343aa83 CI: pin pillow dependency to 8.2 to avoid failures under 8.3 2021-07-01 16:32:35 -07:00
Jake VanderPlas
0c91be7b46 CI: temporarily pin numpy to <1.21 2021-06-22 11:15:16 -07:00
Peter Hawkins
07277f0785 Bump mypy version to 0.902. 2021-06-14 10:05:34 -04:00
Peter Hawkins
40c5e376d8 Pin flatbuffers 1.12 for CI tests. 2021-05-10 18:21:25 -04:00
Jake VanderPlas
f9a4162551 Specify minimum jaxlib version in a single location 2021-03-22 16:14:41 -07:00
Peter Hawkins
328930b917 Increase minimum jaxlib version to 0.1.62. 2021-03-16 15:11:36 -04:00
Skye Wanderman-Milne
7a67b974ac jaxlib version bump etc. 2021-02-12 09:42:04 -08:00
Peter Hawkins
13f3819054 Update README.md for jaxlib 0.1.60.
Bump jaxlib version to 0.1.61 and update changelog.

Change jaxlib numpy version limit to >=1.16 for next release. Releases older than 1.16 are deprecated per NEP 00029. Reenable NumPy 1.20.

Bump minimum jaxlib version to 0.1.60.
2021-02-03 20:44:01 -05:00
George Necula
a145e3d414 Pin numpy to max version 1.19, to avoid errors with 1.20
Will fix the numpy errors separately.
2021-01-31 15:18:54 +02:00
Skye Wanderman-Milne
7c2454e969 Update jaxlib version, minimum jaxlib version, readme, and changelog.
Bumping the min jaxlib version to support https://github.com/google/jax/pull/5213.
2021-01-15 12:56:08 -08:00
Peter Hawkins
1312d793ed Update mypy version to 0.790.
This appears to be necessary for Python 3.9.
2020-12-07 09:42:19 -05:00
Benjamin Chetioui
58a117fe0d
Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266)
* Add options to compute L/R eigenvectors in geev.

The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.

Reformulate eig-related operations based on the new geev API.

* Addressed hawkinsp's comments from google/jax#3882.

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-09-15 11:45:15 +03:00
George Necula
0b99ca896d
[jax2tf] Disable the CI tests for jax2tf. (#4019)
We do this to see if this reduces the incidence of errors fetching
the tf-nightly package. These tests are being run when we import
the code in Google.
2020-08-11 12:39:54 +03:00
Matthew Johnson
64ec4443be another attempt at fixing github ci 2020-07-30 10:17:00 -07:00
Matthew Johnson
420e8422a4 attempt to fix CI by updating jax2tf test dep 2020-07-30 08:08:48 -07:00
Peter Hawkins
b943b31b22
Add jax.image.resize. (#3703)
* Add jax.image.resize.

This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.

While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
2020-07-10 09:57:59 -04:00
George Necula
4f3011f320
Refactored host_callback to use the C++ runtime. (#3644)
* Refactored host_callback to use the C++ runtime.

* The new runtime makes it unnecessary to start the outfeed_receiver
  in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
  I am trying to solve this by (a) making host_callback_test stop the
  outfeed receiver on finish and infeed_test on start, and (b)
  telling pytest-xdist to run all the tests from one file into
  a single worker.
2020-07-04 18:12:58 +03:00
George Necula
166e795d63
Updated minimum jaxlib to 0.1.51 (#3642) 2020-07-02 18:53:58 +03:00
George Necula
448c635a9e
[jax2tf] Update the tf-nightly version to 2020-07-01 (#3635) 2020-07-02 10:21:23 +03:00
Peter Hawkins
f036f5ddb0
Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation war… (#3543)
* Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation warnings.

* Pin a newer tf-nightly to fix jax2tf tests for NumPy 1.19.0
2020-06-24 15:19:00 -04:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Peter Hawkins
58f8329a40
Switch CI builds from Travis to Github actions (#3409) 2020-06-11 17:10:56 -04:00