mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU
* Exposes LoadedExecutable.cost_analysis via pybind * Updates XlaExecutable.cost_analysis to try LoadedExecutable.cost_analysis, then fallback to the client method. PiperOrigin-RevId: 542671990
This commit is contained in:
parent
9f4080ae2b
commit
10424c5972
@ -25,6 +25,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
determine the output shardings.
|
||||
* If the mesh context manager is provided, None will imply that the value
|
||||
will be replicated on all devices of the mesh.
|
||||
* Executable.cost_analysis() works on Cloud TPU
|
||||
|
||||
* Bug fixes
|
||||
* Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
|
||||
|
@ -232,11 +232,21 @@ class XlaExecutable(Executable):
|
||||
else:
|
||||
raise
|
||||
|
||||
# TODO(skyewm): this should return a single Dict (I think returning a list
|
||||
# was to support MPMD executables, which never fully landed)
|
||||
def cost_analysis(self) -> List[Dict[str, float]]:
|
||||
xla_ext_exe = self.xla_extension_executable()
|
||||
err_msg = ("cost analysis unsupported on current XLA backend: "
|
||||
f"{type(xla_ext_exe)}")
|
||||
|
||||
# TODO(b/259255524): Unify/merge the two cost_analysis calls below.
|
||||
if hasattr(xla_ext_exe, "cost_analysis"):
|
||||
try:
|
||||
return [xla_ext_exe.cost_analysis()]
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
|
||||
raise
|
||||
|
||||
# Try client method if executable cost_analysis method is unimplemented
|
||||
if hasattr(xla_ext_exe, "client"):
|
||||
try:
|
||||
return [
|
||||
@ -245,21 +255,12 @@ class XlaExecutable(Executable):
|
||||
]
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
raise NotImplementedError(err_msg) from e
|
||||
else:
|
||||
if not (type(msg) is str and msg.startswith("UNIMPLEMENTED")):
|
||||
raise
|
||||
elif hasattr(xla_ext_exe, "cost_analysis"):
|
||||
try:
|
||||
return xla_ext_exe.cost_analysis()
|
||||
except xla_extension.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
|
||||
raise NotImplementedError(err_msg) from e
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
raise NotImplementedError(err_msg)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"cost analysis unsupported on current XLA backend: {type(xla_ext_exe)}"
|
||||
)
|
||||
|
||||
def memory_analysis(self) -> Any:
|
||||
xla_ext_exe = self.xla_extension_executable()
|
||||
|
@ -15,67 +15,66 @@
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import concurrent.futures
|
||||
from contextlib import contextmanager
|
||||
import copy
|
||||
import enum
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import gc
|
||||
import importlib
|
||||
import inspect
|
||||
import itertools as it
|
||||
import operator
|
||||
import operator as op
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
from typing import Callable, List, Optional, NamedTuple
|
||||
from typing import Callable, List, NamedTuple, Optional
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
import functools
|
||||
import itertools as it
|
||||
import operator as op
|
||||
import gc
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest, parameterized
|
||||
import numpy as np
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import custom_derivatives as custom_derivatives_public
|
||||
from jax import device_put, float0, grad, hessian, jacfwd, jacrev, jit
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import api, api_util, dtypes, lib
|
||||
from jax._src import array
|
||||
from jax._src import config as config_internal
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.custom_batching
|
||||
import jax.custom_derivatives
|
||||
import jax.custom_transpose
|
||||
import jax.numpy as jnp
|
||||
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
|
||||
from jax._src import core
|
||||
from jax._src import config as config_internal
|
||||
from jax import lax
|
||||
from jax._src import api, dtypes, lib, api_util
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import array
|
||||
from jax.experimental import pjit
|
||||
from jax._src import custom_derivatives
|
||||
from jax import custom_derivatives as custom_derivatives_public
|
||||
from jax._src import prng
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
import jax._src.util as jax_util
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_name
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import xla
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
@ -305,7 +304,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
def test_complex_support(self):
|
||||
self.assertEqual(self.jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
|
||||
|
||||
|
||||
@parameterized.parameters("static_argnums", "donate_argnums")
|
||||
def test_jit_argnums_overflow_error(self, argnum_type: str):
|
||||
def f(a, b, c):
|
||||
@ -346,7 +344,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.jit(h, **{argnum_type: (0, 999)})
|
||||
self.jit(h, **{argnum_type: (0, -999)})
|
||||
|
||||
|
||||
# No positional arguments
|
||||
self.jit(i, static_argnums=())
|
||||
self.jit(i)
|
||||
@ -385,7 +382,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
with self.assertWarns(SyntaxWarning):
|
||||
self.jit(h, static_argnames=("args", "c"))
|
||||
|
||||
|
||||
def test_jit_with_many_args_works(self):
|
||||
|
||||
@self.jit
|
||||
@ -468,7 +464,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
def test_jit_cache_clear(self):
|
||||
@self.jit
|
||||
def f(x, y): return x + y
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
client = jax.devices()[0].client
|
||||
gc.collect()
|
||||
@ -1106,8 +1103,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
def test_jit_lower_compile_cost_analysis(self):
|
||||
f = self.jit(lambda x: x).lower(1.).compile()
|
||||
g = self.jit(lambda x: x + 4).lower(1.).compile()
|
||||
f.cost_analysis() # doesn't raise
|
||||
g.cost_analysis() # doesn't raise
|
||||
if xla_extension_version >= 164:
|
||||
self.assertIsNotNone(f.cost_analysis())
|
||||
self.assertIsNotNone(g.cost_analysis())
|
||||
else:
|
||||
f.cost_analysis() # doesn't raise
|
||||
g.cost_analysis() # doesn't raise
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def test_jit_lower_compile_memory_analysis(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user