mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

In some environments this appears to import the config module rather than the config object.
55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import inspect
|
|
|
|
from absl.testing import absltest
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax._src import source_info_util
|
|
from jax._src import test_util as jtu
|
|
|
|
jax.config.parse_flags_with_absl()
|
|
|
|
|
|
class SourceInfoTest(jtu.JaxTestCase):
|
|
|
|
def test_inline_jit_location_uses_callee_location(self):
|
|
|
|
# 'f' should be inlined into both 'g' and 'h', using the source line
|
|
# information of the call site. In particular, the source line information
|
|
# of 'h' should not refer to the source information of 'g'.
|
|
@partial(jax.jit, inline=True)
|
|
def f(x): return lax.add(x, 3)
|
|
|
|
def g(x): return lax.add(f(x), 4)
|
|
|
|
def h(x): return lax.add(f(x), 5)
|
|
|
|
for fn in (g, h):
|
|
lines, fn_startline = inspect.getsourcelines(fn)
|
|
fn_endline = fn_startline + len(lines)
|
|
jaxpr = jax.make_jaxpr(fn)(2)
|
|
for eqn in jaxpr.eqns:
|
|
frame = source_info_util.user_frame(eqn.source_info)
|
|
assert frame is not None, eqn
|
|
self.assertLessEqual(fn_startline, frame.start_line)
|
|
self.assertLessEqual(frame.end_line, fn_endline)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|