29 Commits

Author SHA1 Message Date
Yash Katariya
60fccc2aac Disable test_source_file_prefix_removal test because there is cross-contamination of metadata information from different call sites because of cached jitted functions
PiperOrigin-RevId: 651847347
2024-07-12 12:07:24 -07:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Junwhan Ahn
0f760ee545 Avoid using lambda as the reducer fn
Lambdas are represented by their ids in the metadata of lowered HLO (see example below) and they change every time. This makes the compilation cache less effective as it causes the computation's fingerprint to change every time.

```
get-tuple-element.41724 = bf16[8]{0} get-tuple-element(reduce.41723), index=0, metadata={op_name="pjit(_wrapped_fn)/jit(main)/.../reduce[computation=<function _compute_argminmax.<locals>.reducer_fn at 0x7fa6ecfb2200> dimensions=(1,)]" source_file="..." source_line=...}
```

PiperOrigin-RevId: 601910715
2024-01-26 17:43:57 -08:00
Peter Hawkins
c787b3da07 Change metadata_test to tolerate paths with backslashes.
Fixes a test failure under Windows.

The backslashes end up doubled in the MLIR string because of escaping.
2023-06-12 21:57:41 -04:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
37d4ad910a Remove uses of jax.xla_computation from metadata_test.py
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.

PiperOrigin-RevId: 509975754
2023-02-15 17:26:21 -08:00
Yash Katariya
7b1128fdc4 Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Sharad Vikram
74b136e62c Delete jax_experimental_name_stack flag
PiperOrigin-RevId: 487601864
2022-11-10 11:59:50 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Sharad Vikram
5ff2e8eb4c Fix name stack bugs 2022-04-19 11:14:41 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Skye Wanderman-Milne
350045c60c Add jax_hlo_source_file_canonicalization_regex config.
This is meant to be used with @colemanliyah's persistent compilation
cache, since the serialized HLO computation (including the source_file
metadata) is used in the cache key. The config can be used to remove
bits of the source file path that vary between program invocations, to
avoid spurious cache misses.
2021-07-30 14:42:41 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Roy Frostig
dc4c9f0450 change cond primitive to an indexed conditional with multiple branch functions
in the core:

* bind and check cond primitive in indexed form
* rewrite abstract evaluation rule
* rewrite translation rule
* rewrite partial evaluation rule
* rewrite batching rule
* rewrite JVP rule
* rewrite transpose rule
* update jaxpr typechecker
* update pretty printer
* update outfeed-usage check
* update reference jaxpr in cond jaxpr test
* update reference regexes in HLO test

in experimental modules:

* update host_callback rewriter
* update loops expression builder
* generalize tf_impl rule
2020-06-03 22:19:15 -07:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Lena Martens
1cc471928b
Remove pe from name_stack and test. (#3209) 2020-05-27 00:59:31 -07:00
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
Roy Frostig
664a4e123d
VJP of cond, via partial eval + transpose (#2091)
VJP (grad) of lax.cond, via partial eval + transpose


Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-01-30 15:03:00 -08:00
Matthew Johnson
96102dc727
simplify cond by removing consts (#2102)
Some higher-order primitives, like 'scan' and 'while', benefit from
distinguishing constants from other inputs to their closure-converted
function arguments; the reason is that for those primitives constants
act differently from the other inputs, which are loop carries or
scanned-over values, and are handled differently by transformations. For
example, they're used differently than loop carries in lattice
fixed-point computations. As another example, in scan the constants in
the forward computation are fanned out, so when transposing scan we
generate an accumulate-add.

However, these considerations don't hold true for cond: since there's no
looping going on (and hence no lattice fixed-points), constants are
treated just like the other operands. So we don't need to carry around
the distinction. That simplifies the cond rules a bit.

Co-authored-by: Roy Frostig <frostig@google.com>
2020-01-29 13:17:39 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
1a5d9c531a
clear compilation cache before metadata tests (#2103) 2020-01-28 18:45:45 -08:00
James Bradbury
a15aa9bd4d
include call stack + transforms in XLA metadata (#2073) 2020-01-26 23:27:56 -08:00