8551 Commits

Author SHA1 Message Date
Peter Hawkins
0dfd76af97 Remove additional info return value from jax.scipy.linalg.polar(). 2021-07-20 13:13:31 -04:00
jax authors
c95ef8799d Merge pull request #7335 from hawkinsp:qdwh
PiperOrigin-RevId: 385796794
2021-07-20 08:54:16 -07:00
Adam Lewis
a2073ffcc2 Adds an implementation of a QR-based Dynamically Weighted Halley iteration. 2021-07-20 11:30:36 -04:00
jax authors
8f0ccb4e25 Merge pull request #7304 from j-towns:patch-3
PiperOrigin-RevId: 385720404
2021-07-19 23:37:25 -07:00
jax authors
63a152e72c Merge pull request #7331 from colemanliyah:fix_teardown
PiperOrigin-RevId: 385661922
2021-07-19 16:07:34 -07:00
jax authors
9398dd93be Merge pull request #7330 from colemanliyah:jit_integration
PiperOrigin-RevId: 385660510
2021-07-19 16:00:33 -07:00
colemanliyah
2eaadd86b6 jit integration with persistent compilation cache 2021-07-19 21:56:41 +00:00
colemanliyah
711c313026 added tearDown() method to test class 2021-07-19 21:53:27 +00:00
jax authors
29187a3317 Merge pull request #7315 from ROCmSoftwarePlatform:fix_pr_7306_rocm
PiperOrigin-RevId: 385566677
2021-07-19 09:01:05 -07:00
jax authors
21720373d4 Merge pull request #7319 from cloudhan:win-fix
PiperOrigin-RevId: 385562849
2021-07-19 08:43:06 -07:00
Adam Paszke
d25f4b34b8 Add an option to strictly enforce sharding implies by named axes
At the moment, xmap SPMD lowering only enforces sharding constraints for
computation inputs and outputs, while leaving sharding propagation in the
body entirely up to the XLA SPMD partitioner. This patch adds a new flag
`experimental_xmap_enforce_inferred_sharding` that inserts additional
sharding constraint between every JAX primitive in the xmapped function.
Assuming that the SPMD partitioner never overrides user-defined constraints,
this should restrict it sufficiently to generate a computation that is
partitioned exactly as implied by the evolution of intermediate named shapes.

PiperOrigin-RevId: 385562158
2021-07-19 08:39:27 -07:00
jax authors
277f250ffc Merge pull request #7325 from hawkinsp:views
PiperOrigin-RevId: 385557244
2021-07-19 08:14:11 -07:00
Peter Hawkins
5893b92048 Clarify documentation about array views. 2021-07-19 09:49:37 -04:00
jax authors
2ba686ca19 Merge pull request #7295 from apaszke:xmap-ad
PiperOrigin-RevId: 385510420
2021-07-19 02:37:43 -07:00
jax authors
ef2c2d9db3 Merge pull request #7320 from gnecula:shape_poly_error
PiperOrigin-RevId: 385507043
2021-07-19 02:12:04 -07:00
George Necula
0693c316e4 [jax2tf] Improved error checking for inconsistent use of a dimension variable
Previously the check was done only for multiple occurrences of a shape
variable in one argument. Now we check across all arguments.
2021-07-19 09:45:43 +02:00
George Necula
a21683605d [host_callback] Increase number of threads for callback processing.
Previously there was one thread per device for receiving the outfeed from
devices, but there was a single global thread that was calling into the Python
callbacks. This meant that if one of the callbacks was slow, it was blocking
processing of all other callbacks.

One situation when this created difficulties was if one wanted to break a host_callback into two operations: a quick one to enqueue work on a threadpool,
and a subsequent slow one to wait for and retreive the result. The first slow callback would block all other callbacks, including possibly some quick ones, thus missing the opportunity to start the slow work.

With this change there is a separate queue of outfeeds for each device and a
separate thread per device to call into Python. This allows for concurrency
between callbacks from different devices, although the callbacks from one
device are still sequential. If the programmer wants more concurrency, they can use a threadpool. Having more concurrency by default is tricky, because it may mean that the Python callbacks for one device may be seen out of order.

PiperOrigin-RevId: 385493070
2021-07-19 00:18:06 -07:00
George Necula
58522fd8a1 Change tfxla.variadic_reduce to point to XlaVariadicReduceV2.
PiperOrigin-RevId: 385492367
2021-07-19 00:12:08 -07:00
jax authors
0d069be51c Merge pull request #7321 from gnecula:fix_rtd
PiperOrigin-RevId: 385443835
2021-07-18 14:00:45 -07:00
George Necula
117d0d23ab Attempt to fix RTD build
It seems that the failure is for transformations.md
2021-07-18 19:35:14 +03:00
Cloud Han
6d84e02724 workaround compiling issue on Windows when cuda version < 11.0 2021-07-18 23:01:08 +08:00
Cloud Han
2d321c26e6 Use TF_CUDA_PATHS
CUDA_TOOLKIT_PATH and CUDNN_INSTALL_PATH are deprecated, see TF 2.0
release notes for more information
2021-07-18 22:55:34 +08:00
Cloud Han
4fa79ce1cb fix machine tag, on windows platforms.machine() returns AMD64 instread of x64_64 2021-07-18 22:55:02 +08:00
Reza Rahimi
ee08acd046 update rocblas because of PR-7306 2021-07-17 08:04:56 +00:00
jax authors
76bd9e176b Merge pull request #7277 from colemanliyah:pmap_integration
PiperOrigin-RevId: 385228055
2021-07-16 14:48:16 -07:00
jax authors
f4a475f5ae Merge pull request #7309 from gnecula:einsum_poly
PiperOrigin-RevId: 385178868
2021-07-16 10:48:05 -07:00
George Necula
41d46b2a91 [jax2tf] Expand the handling of shape-polymorphic einsum.
einsum supports an API where the arrays are interleaved with list
of indices.
2021-07-16 20:23:07 +03:00
Liyah Coleman
262b10ee59 pmap integration 2021-07-16 17:22:14 +00:00
jax authors
a37fbe8082 Merge pull request #7310 from gnecula:random_gamma_poly
PiperOrigin-RevId: 385173297
2021-07-16 10:22:04 -07:00
jax authors
7bd5fe54fd Merge pull request #7307 from cloudhan:missing_lib
PiperOrigin-RevId: 385170297
2021-07-16 10:08:32 -07:00
George Necula
bd77a61d31 [jax2tf] Fix the shape polymorphism for batched while, and random_gamma. 2021-07-16 20:04:42 +03:00
jax authors
1044401e50 Merge pull request #7306 from tomhennigan:changelist/385115540
PiperOrigin-RevId: 385169180
2021-07-16 10:03:42 -07:00
jax authors
8dc2de552e Merge pull request #7273 from tomhennigan:changelist/384520796
PiperOrigin-RevId: 385126655
2021-07-16 05:20:30 -07:00
Cloud Han
cf7298d238 cusparseGetErrorString is external symbol, without cusparse_lib as dependency, linker error 2021-07-16 19:55:55 +08:00
Tom Hennigan
afbd831ec3 Avoid sharing handles across streams.
When running across 8xV100 GPUs we observed the following error:

    libc++abi: terminating with uncaught exception of type std::runtime_error: third_party/py/jax/jaxlib/cusolver.cc:171: operation cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, static_cast<float*>(workspace), d.lwork, info) failed: cuSolver execution failed

I cannot find documentation to this effect, but I believe that it is unsafe to share cuSolver handles across streams, since keeping the handle pool stream local does solve the issue.
2021-07-16 11:11:21 +00:00
Jamie Townsend
6ca775b10a
Add TPU precision details to README gotchas 2021-07-16 12:24:19 +02:00
jax authors
b744a84fdc Merge pull request #7298 from LenaMartens:patch-4
PiperOrigin-RevId: 385075125
2021-07-15 22:09:18 -07:00
jax authors
1d390ae700 Merge pull request #7299 from hawkinsp:jaxlib
PiperOrigin-RevId: 385014189
2021-07-15 14:48:41 -07:00
jax authors
a91ad579c3 Merge pull request #7300 from VivekThazhathattil:fix-small-typo-in-docs
PiperOrigin-RevId: 385010028
2021-07-15 14:30:29 -07:00
Vivek Thazhathattil
61949b6f8b fix small typo in docs/developer.md 2021-07-16 02:48:45 +05:30
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
jax authors
6aa20d8f8f Merge pull request #7294 from hawkinsp:py36
PiperOrigin-RevId: 384994957
2021-07-15 13:19:23 -07:00
jax authors
e951ac8521 Merge pull request #7296 from apaszke:xmap-flag
PiperOrigin-RevId: 384993575
2021-07-15 13:12:38 -07:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Lena Martens
fcce9c3309
Unwrap function fully before getting debug_info
A small bug which surfaces in error messages: if a function has been `functools.wrapped`, some error messages (eg. UnexpectedTracerError) will point to the name of the wrapped function but to the filename of the wrapping function (because `functools.wrapped` does not update the code object of the function, but it _does_ update the name of the function).
2021-07-15 19:11:01 +01:00
Tom Hennigan
afa0d5725b Compute gtsv2 buffer size ahead of time and pass in to kernel.
A user reported that with their Quadro M4000 GPU (Driver: 460.56) tridiagonal_solve was throwing an "unsupported operation" error. I improved the logging (also included in this patch) and tracked it down to:

jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported

I had some challenges trying to figure out when async malloc was supported (it seems that for cards with compute <6 it fails) but have found an alternative approach where we compute the buffer size ahead of time and ask XLA to allocate. This is preferred for sure (although requires passing null pointers into cusparseSgtsv2_bufferSizeExt which seems to work today but I guess might change in future cuSPARSE releases).
2021-07-15 16:06:23 +00:00
jax authors
efd37e9099 Merge pull request #7280 from tomhennigan:changelist/384668789
PiperOrigin-RevId: 384938599
2021-07-15 09:04:41 -07:00
Tom Hennigan
d6e56f2df9 Add source location and expression to error messages for CUDA API calls.
Before:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: operation not supported

After:

    jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported
2021-07-15 15:42:46 +00:00
jax authors
ccfc6f6281 Merge pull request #7254 from hawkinsp:crosscompile
PiperOrigin-RevId: 384930242
2021-07-15 08:22:52 -07:00
Peter Hawkins
f5c61a892a Add support for cross-compiling jaxlib for Mac ARM. 2021-07-15 10:37:53 -04:00