Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU

* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
This commit is contained in:
Skye Wanderman-Milne 2023-06-22 14:42:14 -07:00 committed by jax authors
parent 9f4080ae2b
commit 10424c5972
3 changed files with 61 additions and 58 deletions

View File

@ -25,6 +25,7 @@ Remember to align the itemized text with the first line of an item within a list
determine the output shardings.
* If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
* Executable.cost_analysis() works on Cloud TPU
* Bug fixes
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel

View File

@ -232,11 +232,21 @@ class XlaExecutable(Executable):
else:
raise
# TODO(skyewm): this should return a single Dict (I think returning a list
# was to support MPMD executables, which never fully landed)
def cost_analysis(self) -> List[Dict[str, float]]:
xla_ext_exe = self.xla_extension_executable()
err_msg = ("cost analysis unsupported on current XLA backend: "
f"{type(xla_ext_exe)}")
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
if hasattr(xla_ext_exe, "cost_analysis"):
try:
return [xla_ext_exe.cost_analysis()]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
raise
# Try client method if executable cost_analysis method is unimplemented
if hasattr(xla_ext_exe, "client"):
try:
return [
@ -245,21 +255,12 @@ class XlaExecutable(Executable):
]
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
raise
elif hasattr(xla_ext_exe, "cost_analysis"):
try:
return xla_ext_exe.cost_analysis()
except xla_extension.XlaRuntimeError as e:
msg, *_ = e.args
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
raise NotImplementedError(err_msg) from e
else:
raise
else:
raise NotImplementedError(err_msg)
raise NotImplementedError(
f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}"
)
def memory_analysis(self) -> Any:
xla_ext_exe = self.xla_extension_executable()

View File

@ -15,67 +15,66 @@
import collections
import collections.abc
import concurrent.futures
from contextlib import contextmanager
import copy
import enum
import functools
from functools import partial
import inspect
import gc
import importlib
import inspect
import itertools as it
import operator
import operator as op
import os
import platform
import re
import subprocess
import sys
import types
from typing import Callable, List, Optional, NamedTuple
from typing import Callable, List, NamedTuple, Optional
import unittest
import warnings
import weakref
import functools
import itertools as it
import operator as op
import gc
from absl import logging
from absl.testing import absltest, parameterized
import numpy as np
import concurrent.futures
import jax
from jax import config
from jax import custom_derivatives as custom_derivatives_public
from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit
from jax import lax
from jax import tree_util
from jax._src import api, api_util, dtypes, lib
from jax._src import array
from jax._src import config as config_internal
from jax._src import core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import prng
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.ad_checkpoint import saved_residuals
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax._src import core
from jax._src import config as config_internal
from jax import lax
from jax._src import api, dtypes, lib, api_util
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax.sharding import PartitionSpec as P
from jax._src import array
from jax.experimental import pjit
from jax._src import custom_derivatives
from jax import custom_derivatives as custom_derivatives_public
from jax._src import prng
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src import test_util as jtu
from jax import tree_util
from jax._src import linear_util as lu
import jax._src.util as jax_util
from jax._src.ad_checkpoint import saved_residuals
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import xla
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
from jax import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
@ -305,7 +304,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_complex_support(self):
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
@parameterized.parameters("static_argnums", "donate_argnums")
def test_jit_argnums_overflow_error(self, argnum_type: str):
def f(a, b, c):
@ -346,7 +344,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.jit(h, **{argnum_type: (0, 999)})
self.jit(h, **{argnum_type: (0, -999)})
# No positional arguments
self.jit(i, static_argnums=())
self.jit(i)
@ -385,7 +382,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
with self.assertWarns(SyntaxWarning):
self.jit(h, static_argnames=("args", "c"))
def test_jit_with_many_args_works(self):
@self.jit
@ -468,7 +464,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_jit_cache_clear(self):
@self.jit
def f(x, y): return x + y
def f(x, y):
return x + y
client = jax.devices()[0].client
gc.collect()
@ -1106,8 +1103,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_jit_lower_compile_cost_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()
g = self.jit(lambda x: x + 4).lower(1.).compile()
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise
if xla_extension_version >= 164:
self.assertIsNotNone(f.cost_analysis())
self.assertIsNotNone(g.cost_analysis())
else:
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise
@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_memory_analysis(self):