2023-04-20 12:21:41 +03:00
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2023-10-06 21:32:28 +02:00
from __future__ import annotations
2023-04-26 08:46:52 +02:00
import contextlib
2023-07-24 14:38:38 +03:00
import functools
import logging
2023-09-05 22:15:22 -07:00
import math
2023-06-06 13:26:35 -07:00
import re
2023-10-06 21:32:28 +02:00
from typing import Sequence
2023-04-20 12:21:41 +03:00
import unittest
2023-05-27 06:15:50 +02:00
from absl . testing import absltest
2023-04-20 12:21:41 +03:00
import jax
from jax import numpy as jnp
2023-06-02 21:48:45 -07:00
from jax import tree_util
2023-09-05 22:15:22 -07:00
from jax . experimental . export import export
2023-10-06 21:32:28 +02:00
from jax . experimental import pjit
from jax . sharding import Mesh
from jax . sharding import PartitionSpec as P
2023-05-27 06:15:50 +02:00
2023-10-12 13:15:22 +01:00
from jax . _src import config
2023-04-20 12:21:41 +03:00
from jax . _src import core
from jax . _src import test_util as jtu
from jax . _src import xla_bridge as xb
2023-05-27 06:15:50 +02:00
from jax . _src . interpreters import mlir
2023-08-17 11:19:04 +02:00
from jax . _src . lib import version as jaxlib_version
2023-04-20 12:21:41 +03:00
2023-05-27 06:15:50 +02:00
from jax . _src . lib . mlir . dialects import hlo
2023-04-20 12:21:41 +03:00
2023-05-27 06:15:50 +02:00
import numpy as np
2023-04-20 12:21:41 +03:00
config . parse_flags_with_absl ( )
2023-10-06 21:32:28 +02:00
prev_xla_flags = None
def setUpModule ( ) :
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu . set_host_platform_device_count ( 2 )
# Reset to previous configuration in case other test modules will be run.
def tearDownModule ( ) :
prev_xla_flags ( )
2023-09-04 11:03:55 +03:00
# A primitive for testing multi-platform lowering. Takes one argument and
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
_testing_multi_platform_p = core . Primitive ( " testing_multi_platform " )
_testing_multi_platform_to_add = dict ( cpu = 2. , tpu = 3. , cuda = 4. , rocm = 5. )
@_testing_multi_platform_p.def_abstract_eval
def _testing_multi_platform_abstract_eval ( xaval : core . AbstractValue ) :
assert xaval . dtype == np . float32 # type: ignore
return xaval
def _testing_multi_platform_lowering ( ctx : mlir . LoweringRuleContext ,
x : mlir . Value ,
* ,
platform : str ) - > Sequence [ mlir . Value ] :
to_add = _testing_multi_platform_to_add [ platform ]
to_add_value = mlir . broadcast_in_dim ( ctx ,
mlir . ir_constant ( np . float32 ( to_add ) ) ,
ctx . avals_in [ 0 ] ,
broadcast_dimensions = ( ) )
return mlir . hlo . AddOp ( x , to_add_value ) . results
# Register a default rule for cuda, to test the default-platform rule selection.
mlir . register_lowering ( _testing_multi_platform_p ,
functools . partial ( _testing_multi_platform_lowering ,
platform = " cuda " ) )
for platform in [ " cpu " , " tpu " , " rocm " ] :
mlir . register_lowering ( _testing_multi_platform_p ,
functools . partial ( _testing_multi_platform_lowering ,
platform = platform ) ,
platform = platform )
def _testing_multi_platform_func ( x ) :
return _testing_multi_platform_p . bind ( x )
2023-10-06 21:32:28 +02:00
def _testing_multi_platform_fun_expected ( x ,
platform : str | None = None ) :
return x + _testing_multi_platform_to_add [
xb . canonicalize_platform ( platform or jtu . device_under_test ( ) )
]
2023-04-20 12:21:41 +03:00
class JaxExportTest ( jtu . JaxTestCase ) :
2023-07-24 14:38:38 +03:00
def override_serialization_version ( self , version_override : int ) :
2023-10-12 13:15:22 +01:00
version = config . jax_serialization_version . value
2023-07-24 14:38:38 +03:00
if version != version_override :
2023-10-12 13:15:22 +01:00
self . enter_context ( config . jax_serialization_version ( version_override ) )
2023-07-24 14:38:38 +03:00
logging . info (
" Using JAX serialization version %s " ,
2023-10-12 13:15:22 +01:00
config . jax_serialization_version . value )
2023-07-24 14:38:38 +03:00
2023-10-06 21:32:28 +02:00
@classmethod
def setUpClass ( cls ) :
# Find the available platforms
cls . platforms = [ ]
for backend in [ " cpu " , " gpu " , " tpu " ] :
try :
jax . devices ( backend )
except RuntimeError :
continue
cls . platforms . append ( backend )
super ( JaxExportTest , cls ) . setUpClass ( )
2023-07-24 14:38:38 +03:00
def setUp ( self ) :
super ( ) . setUp ( )
# Run tests with the maximum supported version by default
self . override_serialization_version (
2023-10-12 13:15:22 +01:00
export . maximum_supported_serialization_version )
2023-07-24 14:38:38 +03:00
2023-04-20 12:21:41 +03:00
def test_basic_export_only ( self ) :
def my_fun ( x ) :
return jnp . sin ( x )
2023-09-05 22:15:22 -07:00
exp = export . export ( my_fun ) ( jax . ShapeDtypeStruct ( ( 4 , ) , dtype = np . float32 ) )
2023-04-20 12:21:41 +03:00
self . assertEqual ( " my_fun " , exp . fun_name )
2023-10-13 10:30:11 -07:00
self . assertEqual ( ( export . default_lowering_platform ( ) , ) ,
exp . lowering_platforms )
2023-04-20 12:21:41 +03:00
self . assertEqual ( tree_util . tree_flatten ( ( ( 1 , ) , { } ) ) [ 1 ] , exp . in_tree )
self . assertEqual ( ( core . ShapedArray ( ( 4 , ) , dtype = np . float32 ) , ) , exp . in_avals )
self . assertEqual ( ( core . ShapedArray ( ( 4 , ) , dtype = np . float32 ) , ) , exp . out_avals )
def test_pytree_export_only ( self ) :
a = np . arange ( 4 , dtype = np . float32 )
b = np . arange ( 6 , dtype = np . float32 )
def f ( a_b_pair , * , a , b ) :
return ( dict ( res = a_b_pair , a = a , b = b ) , jnp . sin ( a ) , jnp . cos ( b ) )
2023-10-13 10:30:11 -07:00
exp = export . export ( f , lowering_platforms = ( " cpu " , ) ) ( ( a , b ) , a = a , b = b )
2023-04-20 12:21:41 +03:00
a_aval = core . ShapedArray ( a . shape , a . dtype )
b_aval = core . ShapedArray ( b . shape , b . dtype )
2023-10-13 10:30:11 -07:00
self . assertEqual ( exp . lowering_platforms , ( " cpu " , ) )
2023-04-20 12:21:41 +03:00
args = ( ( a , b ) , )
kwargs = dict ( a = a , b = b )
self . assertEqual ( exp . in_tree , tree_util . tree_flatten ( ( args , kwargs ) ) [ 1 ] )
self . assertEqual ( exp . in_avals , ( a_aval , b_aval , a_aval , b_aval ) )
self . assertEqual ( exp . out_tree , tree_util . tree_flatten ( f ( * args , * * kwargs ) ) [ 1 ] )
self . assertEqual ( exp . out_avals , ( a_aval , b_aval , a_aval , b_aval , a_aval , b_aval ) )
2023-04-26 08:46:52 +02:00
def test_poly_export_only ( self ) :
a = np . arange ( 12 , dtype = np . float32 ) . reshape ( ( 3 , 4 ) )
2023-05-13 16:57:27 +02:00
def f ( a , b ) : # a: f32[2w,h] b: f32[w,h]
return jnp . concatenate ( [ a , b ] , axis = 0 )
2023-04-26 08:46:52 +02:00
2023-09-05 22:15:22 -07:00
exp = export . export ( f ) (
export . poly_spec ( a . shape , a . dtype , " (2*w, h) " ) ,
export . poly_spec ( a . shape , a . dtype , " (w, h) " ) )
2023-05-13 16:57:27 +02:00
self . assertEqual ( " (2*w, h) " , str ( exp . in_avals [ 0 ] . shape ) )
self . assertEqual ( " (w, h) " , str ( exp . in_avals [ 1 ] . shape ) )
self . assertEqual ( " (3*w, h) " , str ( exp . out_avals [ 0 ] . shape ) )
def test_poly_pytree_export_only ( self ) :
a = np . arange ( 12 , dtype = np . float32 ) . reshape ( ( 3 , 4 ) )
def f ( a0 , a1 , * , ak ) :
return jnp . concatenate ( [ a0 , a1 , ak ] , axis = 0 )
2023-09-05 22:15:22 -07:00
a_poly_spec = export . poly_spec ( a . shape , a . dtype , " (w, h) " )
exp = export . export ( f ) ( a_poly_spec , a_poly_spec , ak = a_poly_spec )
2023-04-26 08:46:52 +02:00
self . assertEqual ( " (w, h) " , str ( exp . in_avals [ 0 ] . shape ) )
2023-05-13 16:57:27 +02:00
self . assertEqual ( " (3*w, h) " , str ( exp . out_avals [ 0 ] . shape ) )
2023-04-26 08:46:52 +02:00
2023-04-20 12:21:41 +03:00
def test_basic ( self ) :
f = jnp . sin
x = np . arange ( 4 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( x )
2023-04-20 12:21:41 +03:00
2023-09-05 22:15:22 -07:00
f1 = export . call_exported ( exp_f )
2023-04-20 12:21:41 +03:00
self . assertAllClose ( f ( x ) , f1 ( x ) )
def test_call_exported_lambda ( self ) :
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
f = lambda x : jnp . sin ( x )
x = np . arange ( 4 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( x )
f1 = export . call_exported ( exp_f )
2023-04-20 12:21:41 +03:00
self . assertAllClose ( f ( x ) , f1 ( x ) )
def test_call_twice_exported ( self ) :
def f ( x ) : return jnp . sin ( x )
x = np . arange ( 4 , dtype = np . float32 )
@jax.jit
def f1 ( x ) :
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( x )
return export . call_exported ( exp_f ) ( x ) + export . call_exported ( exp_f ) ( x )
2023-04-20 12:21:41 +03:00
self . assertAllClose ( 2. * f ( x ) , f1 ( x ) )
def test_unused_args ( self ) :
f = lambda x , y : jnp . sin ( x )
x = np . arange ( 4 , dtype = np . float32 )
y = np . arange ( 6 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( x , y )
2023-04-20 12:21:41 +03:00
2023-09-05 22:15:22 -07:00
f1 = export . call_exported ( exp_f )
2023-04-20 12:21:41 +03:00
self . assertAllClose ( f ( x , y ) , f1 ( x , y ) )
def test_pytree ( self ) :
a = np . arange ( 4 , dtype = np . float32 )
b = np . arange ( 6 , dtype = np . float32 )
def f ( a_b_pair , a , b ) :
return ( dict ( res = a_b_pair , a = a , b = b ) , jnp . sin ( a ) , jnp . cos ( b ) )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( ( a , b ) , a = a , b = b )
f1 = export . call_exported ( exp_f )
2023-04-20 12:21:41 +03:00
self . assertAllClose ( f ( ( a , b ) , a = a , b = b ) ,
f1 ( ( a , b ) , a = a , b = b ) )
def test_error_wrong_intree ( self ) :
def f ( a_b_pair , * , c ) :
return jnp . sin ( a_b_pair [ 0 ] ) + jnp . cos ( a_b_pair [ 1 ] ) + c
a = b = c = np . arange ( 4 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( ( a , b ) , c = c )
2023-04-20 12:21:41 +03:00
with self . assertRaisesRegex (
ValueError ,
" The invocation args and kwargs must have the same pytree structure " ) :
2023-09-05 22:15:22 -07:00
export . call_exported ( exp_f ) ( a , b , c = ( a , b ) )
2023-04-20 12:21:41 +03:00
def test_error_wrong_avals ( self ) :
2023-04-26 08:46:52 +02:00
def f ( a , * , b ) : # a: f32[4] and b: f32[4]
2023-04-20 12:21:41 +03:00
return jnp . sin ( a ) + jnp . cos ( b )
2023-04-26 08:46:52 +02:00
f32_4 = np . arange ( 4 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( f32_4 , b = f32_4 )
2023-04-20 12:21:41 +03:00
2023-04-26 08:46:52 +02:00
with self . assertRaisesRegex ( ValueError ,
2023-05-13 16:57:27 +02:00
r " Shape mismatch for args \ [0 \ ].shape \ [0 \ ] " ) :
2023-09-05 22:15:22 -07:00
export . call_exported ( exp_f ) ( np . arange ( 6 , dtype = np . float32 ) , b = f32_4 )
2023-04-26 08:46:52 +02:00
with self . assertRaisesRegex ( ValueError ,
2023-05-13 16:57:27 +02:00
r " Shape mismatch for kwargs \ [ ' b ' \ ].shape \ [0 \ ] " ) :
2023-09-05 22:15:22 -07:00
export . call_exported ( exp_f ) ( f32_4 , b = np . arange ( 6 , dtype = np . float32 ) )
2023-04-26 08:46:52 +02:00
with self . assertRaisesRegex ( ValueError ,
r " Rank mismatch for args \ [0 \ ] " ) :
2023-09-05 22:15:22 -07:00
export . call_exported ( exp_f ) ( f32_4 . reshape ( ( 1 , 4 ) ) , b = f32_4 )
2023-04-26 08:46:52 +02:00
with self . assertRaisesRegex ( ValueError ,
r " Dtype mismatch for args \ [0 \ ] " ) :
2023-09-05 22:15:22 -07:00
export . call_exported ( exp_f ) ( f32_4 . astype ( np . float16 ) , b = f32_4 )
2023-04-20 12:21:41 +03:00
2023-05-27 06:15:50 +02:00
@jtu.parameterized_filterable (
testcase_name = lambda kw : kw [ " platform " ] ,
kwargs = [ dict ( platform = p )
for p in ( " cpu " , " cuda " , " rocm " , " tpu " ) ] )
2023-04-20 12:21:41 +03:00
def test_error_wrong_platform ( self , platform ) :
a = np . arange ( 4 , dtype = np . float32 )
2023-10-13 10:30:11 -07:00
exp_f = export . export ( jnp . sin , lowering_platforms = ( platform , ) ) ( a )
2023-04-20 12:21:41 +03:00
if xb . canonicalize_platform ( jtu . device_under_test ( ) ) == platform :
2023-08-27 13:27:34 +02:00
raise unittest . SkipTest ( " Uninteresting scenario " )
2023-04-20 12:21:41 +03:00
with self . assertRaisesRegex (
ValueError , " The exported function .* was lowered for platform " ) :
2023-09-05 22:15:22 -07:00
export . call_exported ( exp_f ) ( a )
2023-04-20 12:21:41 +03:00
2023-06-10 09:27:42 +03:00
# Now try with the platform check disabled
2023-09-05 22:15:22 -07:00
exp_f_no_platform_check = export . export (
2023-10-13 10:30:11 -07:00
jnp . sin , lowering_platforms = ( platform , ) ,
2023-09-05 22:15:22 -07:00
disabled_checks = [ export . DisabledSafetyCheck . platform ( ) ] ) ( a )
res = export . call_exported ( exp_f_no_platform_check ) ( a )
2023-06-10 09:27:42 +03:00
self . assertAllClose ( res , jnp . sin ( a ) )
2023-06-18 13:17:57 +03:00
@jtu.parameterized_filterable (
testcase_name = lambda kw : kw [ " dialect " ] ,
kwargs = [ dict ( dialect = dialect )
for dialect in ( " mhlo " , " stablehlo " ) ]
)
def test_error_disallowed_custom_call ( self , dialect ) :
# If we use hlo.custom_call or mhlo.custom_call we detect
# invalid custom call targets.
# Set up a primitive with custom lowering rules
test_primitive = core . Primitive ( " _test_primitive_disallowed_custom_call " )
test_primitive . def_abstract_eval ( lambda in_aval : in_aval )
def test_primitive_lowering ( ctx , arg ) :
from jax . _src . lib . mlir . dialects import mhlo
op = dict ( stablehlo = hlo . CustomCallOp , mhlo = mhlo . CustomCallOp ) [ dialect ]
return op ( [ arg . type ] , [ arg ] , " disallowed_call_target " ) . results
mlir . register_lowering ( test_primitive , test_primitive_lowering )
self . addCleanup ( lambda : mlir . register_lowering ( test_primitive , None ) )
a = np . arange ( 3 , dtype = np . float32 )
2023-06-10 09:27:42 +03:00
with self . assertRaisesRegex ( ValueError ,
" Cannot serialize code with custom calls whose targets .* " ) :
2023-09-05 22:15:22 -07:00
export . export (
2023-06-18 13:17:57 +03:00
lambda a : a + test_primitive . bind ( a )
) ( a )
2023-06-10 09:27:42 +03:00
# Now try again with the safety check disabled
2023-09-05 22:15:22 -07:00
exp = export . export (
2023-06-18 13:17:57 +03:00
lambda a : a + test_primitive . bind ( a ) ,
2023-09-05 22:15:22 -07:00
disabled_checks = [ export . DisabledSafetyCheck . custom_call ( " disallowed_call_target " ) ]
2023-06-18 13:17:57 +03:00
) ( a )
2023-07-24 15:11:54 +03:00
self . assertIn ( " disallowed_call_target " , exp . mlir_module ( ) )
2023-06-10 09:27:42 +03:00
2023-04-20 12:21:41 +03:00
def test_grad ( self ) :
f = lambda x : jnp . sum ( jnp . sin ( x ) )
x = np . arange ( 4 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( x )
2023-04-20 12:21:41 +03:00
2023-09-05 22:15:22 -07:00
f1 = export . call_exported ( exp_f )
2023-04-20 12:21:41 +03:00
self . assertAllClose ( jax . grad ( f ) ( x ) , jax . grad ( f1 ) ( x ) )
2023-09-19 11:45:38 +02:00
def test_higher_order_grad ( self ) :
f = lambda x : x * * 3
x = np . float32 ( 4. )
exp_f = export . export ( f ) ( x )
f1 = export . call_exported ( exp_f )
self . assertAllClose ( jax . grad ( f ) ( x ) ,
jax . grad ( f1 ) ( x ) )
self . assertAllClose ( jax . grad ( jax . grad ( f ) ) ( x ) ,
jax . grad ( jax . grad ( f1 ) ) ( x ) )
self . assertAllClose ( jax . grad ( jax . grad ( jax . grad ( f ) ) ) ( x ) ,
jax . grad ( jax . grad ( jax . grad ( f1 ) ) ) ( x ) )
2023-04-20 12:21:41 +03:00
def test_pytree_vjp ( self ) :
def f ( a_b_pair , * , a , b ) :
return ( dict ( res = a_b_pair , a = 2. * a , b = 3. * b ) ,
jnp . sin ( 4. * a ) )
a = np . arange ( 4 , dtype = np . float32 )
b = np . arange ( 6 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f ) ( ( a , b ) , a = a , b = b )
2023-04-20 12:21:41 +03:00
out_ct = f ( ( a , b ) , a = a , b = b ) # The output has the right structure as the cotangent
def f1_jax ( a , b ) : # For VJP, make a function without kwargs
res = f ( ( a , b ) , a = a , b = b )
return res
def f1_exp ( a , b ) : # For VJP, make a function without kwargs
2023-09-05 22:15:22 -07:00
res = export . call_exported ( exp_f ) ( ( a , b ) , a = a , b = b )
2023-04-20 12:21:41 +03:00
return res
jax_vjp = jax . vjp ( f1_jax , a , b ) [ 1 ] ( out_ct )
exp_vjp = jax . vjp ( f1_exp , a , b ) [ 1 ] ( out_ct )
self . assertAllClose ( jax_vjp , exp_vjp )
def test_roundtrip ( self ) :
def f1 ( x ) :
return jnp . sin ( x )
a = np . arange ( 4 , dtype = np . float32 )
2023-09-05 22:15:22 -07:00
exp_f1 = export . export ( f1 ) ( a )
2023-04-20 12:21:41 +03:00
def f2 ( x ) :
2023-09-05 22:15:22 -07:00
res1 = export . call_exported ( exp_f1 ) ( x )
res2 = export . call_exported ( exp_f1 ) ( res1 )
2023-04-20 12:21:41 +03:00
return jnp . cos ( res2 )
2023-09-05 22:15:22 -07:00
exp_f2 = export . export ( f2 ) ( a )
2023-04-20 12:21:41 +03:00
self . assertAllClose ( jnp . cos ( jnp . sin ( jnp . sin ( a ) ) ) ,
2023-09-05 22:15:22 -07:00
export . call_exported ( exp_f2 ) ( a ) )
2023-04-20 12:21:41 +03:00
2023-07-24 14:38:38 +03:00
@jtu.parameterized_filterable (
kwargs = [
dict ( v = v )
2023-09-05 22:15:22 -07:00
for v in range ( export . minimum_supported_serialization_version - 1 ,
export . maximum_supported_serialization_version + 2 ) ] )
2023-07-24 14:38:38 +03:00
def test_shape_poly_basic_versions ( self , v : int ) :
self . override_serialization_version ( v )
with contextlib . ExitStack ( ) as e :
2023-09-05 22:15:22 -07:00
if not ( export . minimum_supported_serialization_version < = v
< = export . maximum_supported_serialization_version ) :
2023-07-24 14:38:38 +03:00
e . enter_context ( self . assertRaisesRegex (
ValueError ,
f " The requested jax_serialization version { v } is outside the range of supported versions " ) )
2023-09-05 22:15:22 -07:00
exp = export . export ( jnp . sin ) (
export . poly_spec ( ( 3 , 4 ) , np . float32 , " w, h " ) )
2023-07-24 15:11:54 +03:00
# Peek at the module
module_str = exp . mlir_module ( )
2023-10-12 13:15:22 +01:00
self . assertEqual ( config . jax_serialization_version . value > = 7 ,
2023-07-24 15:11:54 +03:00
" shape_assertion " in module_str )
self . assertIn ( " jax.uses_shape_polymorphism = true " ,
module_str )
2023-09-25 10:13:30 +02:00
dim_vars = re . findall (
r " ( %a rg \ d): \ s*tensor<i..> \ s* { jax.dimension_variable = true} " ,
module_str )
self . assertEqual ( [ " %a rg0 " , " %a rg1 " ] , dim_vars ,
f " Found { dim_vars } in { module_str } " )
2023-07-24 14:38:38 +03:00
x = np . arange ( 30 , dtype = np . float32 ) . reshape ( ( 5 , 6 ) )
2023-09-05 22:15:22 -07:00
res = export . call_exported ( exp ) ( x )
2023-07-24 14:38:38 +03:00
self . assertAllClose ( res , np . sin ( x ) )
2023-05-27 06:15:50 +02:00
# A function is exported with f32[poly_spec] and is called with different arg
2023-09-05 22:15:22 -07:00
# shapes. We use export.call_exported and we also run the shape check
2023-05-27 06:15:50 +02:00
# module.
@jtu.parameterized_filterable (
testcase_name = lambda kw : f " poly_spec= { kw [ ' poly_spec ' ] } _arg_shape= { kw [ ' arg_shape ' ] } " , # type: ignore
kwargs = [
dict ( poly_spec = " 3,4,12 " , arg_shape = ( 3 , 4 , 12 ) ) ,
dict ( poly_spec = " 3,4,12 " , arg_shape = ( 3 , 4 , 13 ) ,
# The shape check module does not test constant dimensions
2023-07-03 17:31:31 +03:00
expect_error = re . escape (
2023-05-27 06:15:50 +02:00
r " Shape mismatch for args[0].shape[2] (expected same constant) " ) ) ,
dict ( poly_spec = " 3,4,6*a " , arg_shape = ( 3 , 4 , 12 ) ) ,
dict ( poly_spec = " 3,a,a+8 " , arg_shape = ( 3 , 4 , 12 ) ) ,
dict ( poly_spec = " 3,4,a+1 " , arg_shape = ( 3 , 4 , 1 ) ,
expect_error = re . escape (
2023-07-21 14:46:30 +03:00
" Expected value >= 1 for dimension variable ' a ' . "
" Using the following polymorphic shapes specifications: args[0].shape = (3, 4, a + 1). "
" Obtained dimension variables: ' a ' = 0 "
) ) ,
2023-05-27 06:15:50 +02:00
dict ( poly_spec = " 3,4,6*a " , arg_shape = ( 3 , 4 , 13 ) ,
expect_error = re . escape (
2023-07-21 14:46:30 +03:00
" Division had remainder 1 when computing the value of ' a ' "
) ) ,
2023-05-27 06:15:50 +02:00
dict ( poly_spec = " 3,a,a+8 " , arg_shape = ( 3 , 4 , 13 ) ,
expect_error = re . escape (
2023-07-21 14:46:30 +03:00
" Found inconsistency between dimension size "
" args[0].shape[2] (= 13) and the specification ' a + 8 ' (= 12) "
) ) ,
2023-05-27 06:15:50 +02:00
] )
def test_poly_shape_checks (
self , poly_spec = " 3,a,a+8 " ,
arg_shape = ( 3 , 4 , 12 ) , arg_dtype = np . float32 ,
2023-07-03 17:31:31 +03:00
expect_error = None ) : # If given, error from running the exported module
2023-05-27 06:15:50 +02:00
def f ( x ) : # x: f32[poly_spec]
return jnp . reshape ( x , ( - 1 , x . shape [ 1 ] ) )
2023-08-01 08:52:54 -07:00
disabled_checks = ( )
2023-09-05 22:15:22 -07:00
exp_f = export . export ( f , disabled_checks = disabled_checks ) (
export . poly_spec ( ( 3 , 4 , 12 ) , np . float32 , poly_spec ) )
2023-07-24 14:38:38 +03:00
self . assertEqual ( exp_f . uses_shape_polymorphism , poly_spec != " 3,4,12 " )
2023-05-27 06:15:50 +02:00
arg = np . arange ( np . prod ( arg_shape ) ,
dtype = arg_dtype ) . reshape ( arg_shape ) # arg : f32[3,4,12]
with contextlib . ExitStack ( ) as stack :
2023-07-03 17:31:31 +03:00
if expect_error is not None :
stack . push ( self . assertRaisesRegex ( Exception , expect_error ) )
2023-05-27 06:15:50 +02:00
2023-07-03 17:31:31 +03:00
assert core . is_constant_shape ( arg . shape )
2023-09-05 22:15:22 -07:00
res = export . call_exported ( exp_f ) ( arg )
2023-05-27 06:15:50 +02:00
2023-07-03 17:31:31 +03:00
if not expect_error :
2023-05-27 06:15:50 +02:00
self . assertAllClose ( res , f ( arg ) )
2023-04-26 09:11:04 +02:00
# An inner function is exported with polymorphic shapes inner_poly_spec, and
2023-05-13 16:57:27 +02:00
# is called from an outer function, which is exported with outer_poly_spec.
2023-05-27 06:15:50 +02:00
@jtu.parameterized_filterable (
testcase_name = lambda kw : f " inner= { kw [ ' inner_poly_spec ' ] } _outer= { kw [ ' outer_poly_spec ' ] } " , # type: ignore
#one_containing="",
# By default arg_shape = (3, 4, 12) for both the outer function and the inner
# The inner function is exported for f32.
kwargs = [
# Both inner and outer are static shapes
dict ( inner_poly_spec = " 3,4,12 " , outer_poly_spec = " 3,4,12 " ) ,
2023-07-21 14:46:30 +03:00
# Inner has poly shapes but outer has static shapes. When we call inner
# we do the shape constraint checking
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,a,a+b " , outer_poly_spec = " 3,4,12 " ) ,
dict ( inner_poly_spec = " 3,4,3*a " , outer_poly_spec = " 3,4,12 " ) ,
dict ( inner_poly_spec = " 3,a,a " , outer_poly_spec = " 3,4,12 " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Found inconsistency between dimension size "
" args[0].shape[2] (= 12) and the specification ' a ' (= 4) " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,4,5*a " , outer_poly_spec = " 3,4,12 " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Division had remainder 2 when computing the value of ' a ' " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,4,12+a " , outer_poly_spec = " 3,4,12 " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Expected value >= 1 for dimension variable ' a ' . "
" Using the following polymorphic shapes specifications: args[0].shape = (3, 4, a + 12). "
" Obtained dimension variables: ' a ' = 0 from specification "
" ' a + 12 ' for dimension args[0].shape[2] (= 12) " ) ) ,
2023-05-27 06:15:50 +02:00
# Both inner and outer have poly shapes.
dict ( inner_poly_spec = " 3,a,b " , outer_poly_spec = " 3,4,c " ) ,
dict ( inner_poly_spec = " 3,4,3*a " , outer_poly_spec = " 3,4,6*c " ) ,
dict ( inner_poly_spec = " 3,a,a+8 " , outer_poly_spec = " 3,c+2,c+10 " ) ,
dict ( inner_poly_spec = " 3,a,a+b " , outer_poly_spec = " 3,4,c " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Expected value >= 1 for dimension variable ' b ' . "
" Using the following polymorphic shapes specifications: args[0].shape = (3, a, a + b). "
" Obtained dimension variables: ' a ' = 4 from specification "
" ' a ' for dimension args[0].shape[1] (= 4), "
" ' b ' = c + -4 from specification ' a + b ' for dimension args[0].shape[2] (= c), " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,a,a " , outer_poly_spec = " 3,4,c " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Found inconsistency between dimension size "
" args[0].shape[2] (= c) and the specification ' a ' (= 4) " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,a,a " , arg_shape = ( 3 , 4 ) ,
outer_poly_spec = " 3,c " ,
expect_error_outer_exp = r " Rank mismatch for args \ [0 \ ] " ) ,
dict ( inner_poly_spec = " 3,a,a+b " , arg_dtype = np . int32 ,
outer_poly_spec = " 3,c,d " ,
expect_error_outer_exp = r " Dtype mismatch for args \ [0 \ ] " ) ,
dict ( inner_poly_spec = " 3,4,5*a " , outer_poly_spec = " 3,4,c " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Division had remainder mod(c, 5) when computing the value of ' a ' " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,a,a+b " , outer_poly_spec = " 3,c,c " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Expected value >= 1 for dimension variable ' b ' . "
" Using the following polymorphic shapes specifications: args[0].shape = (3, a, a + b). "
" Obtained dimension variables: ' a ' = c from "
" specification ' a ' for dimension args[0].shape[1] (= c), "
" ' b ' = 0 from specification ' a + b ' for dimension args[0].shape[2] (= c) " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,a,a+b " , outer_poly_spec = " c,4,12 " ,
expect_error_outer_exp = re . escape (
2023-07-21 14:46:30 +03:00
" Shape mismatch for args[0].shape[0] (expected same constant) " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,4,5*a " , outer_poly_spec = " 3,4,25*c " ,
expect_error_run = re . escape (
2023-07-21 14:46:30 +03:00
" Division had remainder 12 when computing the value of ' c ' " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,a,b " , outer_poly_spec = " 3,c+4,12 " ,
expect_error_run = re . escape (
2023-07-21 14:46:30 +03:00
" Expected value >= 1 for dimension variable ' c ' . "
" Using the following polymorphic shapes specifications: args[0].shape = (3, c + 4, 12). "
" Obtained dimension variables: ' c ' = 0 " ) ) ,
2023-05-27 06:15:50 +02:00
dict ( inner_poly_spec = " 3,a,a " , outer_poly_spec = " 3,a,a " ,
expect_error_run = re . escape (
2023-07-21 14:46:30 +03:00
" Found inconsistency between dimension size "
" args[0].shape[2] (= 12) and the specification ' a ' (= 4) " ) ) ,
2023-05-27 06:15:50 +02:00
] )
def test_poly_shape_checks_nested (
self , inner_poly_spec = " 3,4,5*a " ,
arg_shape = ( 3 , 4 , 12 ) , arg_dtype = np . float32 ,
outer_poly_spec = " 3,4,25*c " ,
expect_error_outer_exp = None ,
expect_error_run = None ) :
2023-04-26 09:11:04 +02:00
# Polymorphic export called with static or polymorphic shapes
2023-05-13 16:57:27 +02:00
def inner ( x ) : # x: inner_poly_spec
return jnp . reshape ( x , ( - 1 , x . shape [ 1 ] ) )
2023-04-26 09:11:04 +02:00
2023-05-27 06:15:50 +02:00
arg = np . arange ( np . prod ( arg_shape ) ,
dtype = arg_dtype ) . reshape ( arg_shape ) # x : f32[3,4,12]
2023-09-05 22:15:22 -07:00
inner_exp = export . export ( inner ) (
export . poly_spec ( ( 3 , 4 , 12 ) , np . float32 , inner_poly_spec ) )
2023-04-26 09:11:04 +02:00
2023-07-24 14:38:38 +03:00
self . assertEqual ( inner_exp . uses_shape_polymorphism ,
2023-05-31 11:00:08 +03:00
( inner_poly_spec != " 3,4,12 " ) )
2023-05-13 16:57:27 +02:00
def outer ( x ) : # x: outer_poly_spec
2023-04-26 09:11:04 +02:00
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
2023-09-05 22:15:22 -07:00
return export . call_exported ( inner_exp ) ( x ) + inner ( x )
2023-04-26 09:11:04 +02:00
with contextlib . ExitStack ( ) as stack :
2023-05-27 06:15:50 +02:00
if expect_error_outer_exp is not None :
stack . push ( self . assertRaisesRegex ( ValueError , expect_error_outer_exp ) )
2023-04-26 09:11:04 +02:00
# Call it after exporting again, with polymorphic shapes
2023-09-05 22:15:22 -07:00
outer_exp = export . export ( outer ) (
export . poly_spec ( arg . shape , arg . dtype , outer_poly_spec ) )
2023-04-20 12:21:41 +03:00
2023-05-27 06:15:50 +02:00
if expect_error_outer_exp is not None :
return
2023-06-02 21:48:45 -07:00
2023-07-24 14:38:38 +03:00
self . assertEqual ( outer_exp . uses_shape_polymorphism ,
2023-05-27 06:15:50 +02:00
( inner_poly_spec != " 3,4,12 " or outer_poly_spec != " 3,4,12 " ) )
2023-06-02 21:48:45 -07:00
2023-05-27 06:15:50 +02:00
with contextlib . ExitStack ( ) as stack :
if expect_error_run is not None :
stack . push ( self . assertRaisesRegex ( Exception , expect_error_run ) )
2023-06-02 21:48:45 -07:00
2023-09-05 22:15:22 -07:00
res = export . call_exported ( outer_exp ) ( arg )
2023-05-27 06:15:50 +02:00
if expect_error_run is not None :
return
self . assertAllClose ( 2. * inner ( arg ) , res )
2023-06-02 21:48:45 -07:00
2023-07-21 14:46:30 +03:00
# Tests details of the shape constraints errors
# This test exists also in shape_poly_test.py. Here we test the
# call_exported error reporting.
@jtu.parameterized_filterable (
2023-08-17 11:19:04 +02:00
testcase_name = lambda kw : kw [ " shape " ] , # assume "shape" is unique
2023-07-21 14:46:30 +03:00
kwargs = [
dict ( shape = ( 8 , 2 , 9 ) , # a = 2, b = 3, c = 4
poly_spec = " (a + 2*b, a, a + b + c) " ) ,
dict ( shape = ( 2 , 2 , 6 ) , # a = 2, b = 0, c = 4
poly_spec = " (a + 2*b, a, a + b + c) " ,
expect_error = (
" Input shapes do not match the polymorphic shapes specification. "
" Expected value >= 1 for dimension variable ' b ' . "
" Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
" Obtained dimension variables: ' a ' = 2 from specification ' a ' for dimension args[0].shape[1] (= 2), "
" ' b ' = 0 from specification ' a + 2*b ' for dimension args[0].shape[0] (= 2), . "
" Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. "
) ) ,
dict ( shape = ( 3 , 2 , 6 ) , # a = 2, b = 0.5, c = 4 - b is not integer
poly_spec = " (a + 2*b, a, a + b + c) " ,
expect_error = (
" Input shapes do not match the polymorphic shapes specification. "
" Division had remainder 1 when computing the value of ' b ' . "
" Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b + c). "
" Obtained dimension variables: ' a ' = 2 from specification ' a ' for dimension args[0].shape[1] (= 2), . "
" Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. "
) ) ,
dict ( shape = ( 8 , 2 , 6 ) , # a = 2, b = 3 - inconsistency
poly_spec = " (a + 2*b, a, a + b) " ,
expect_error = (
" Input shapes do not match the polymorphic shapes specification. "
" Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification ' a + 2*b ' (= 10). "
" Using the following polymorphic shapes specifications: args[0].shape = (a + 2*b, a, a + b). "
" Obtained dimension variables: ' a ' = 2 from specification ' a ' for dimension args[0].shape[1] (= 2), "
2023-08-17 11:19:04 +02:00
" ' b ' = 4 from specification ' a + b ' for dimension args[0].shape[2] (= 6), . "
2023-07-21 14:46:30 +03:00
" Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. "
) ) ,
dict ( shape = ( 7 , 2 , 36 ) , # a = 2, b = 3, c = 6 - cannot solve c
poly_spec = " (2 * a + b, a, c * c) " ,
expect_error = (
" Cannot solve for values of dimension variables { ' c ' }. "
" We can only solve linear uni-variate constraints. "
" Using the following polymorphic shapes specifications: args[0].shape = (2*a + b, a, c^2). "
" Unprocessed specifications: ' c^2 ' for dimension size args[0].shape[2]. "
" Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. "
) ) ,
] )
def test_shape_constraints_errors ( self , * ,
2023-10-06 21:32:28 +02:00
shape , poly_spec : str , expect_error : str | None = None ) :
2023-07-21 14:46:30 +03:00
def f_jax ( x ) : # x: f32[a + 2*b, a, a + b + c]
return 0.
2023-08-17 11:19:04 +02:00
if shape == ( 8 , 2 , 6 ) and jaxlib_version < = ( 0 , 4 , 14 ) :
raise unittest . SkipTest ( " Test requires jaxlib >= 0.4.14 " )
2023-07-21 14:46:30 +03:00
x = np . arange ( math . prod ( shape ) , dtype = np . float32 ) . reshape ( shape )
with contextlib . ExitStack ( ) as stack :
if expect_error is not None :
stack . push ( self . assertRaisesRegex ( Exception , re . escape ( expect_error ) ) )
2023-09-05 22:15:22 -07:00
exp = export . export ( f_jax ) (
export . poly_spec ( x . shape , x . dtype , poly_spec ) )
export . call_exported ( exp ) ( x )
2023-07-21 14:46:30 +03:00
2023-10-06 21:32:28 +02:00
def test_with_sharding ( self ) :
nr_devices = 2
if len ( jax . devices ( ) ) < nr_devices :
self . skipTest ( " Need at least 2 devices " )
export_devices = jax . devices ( ) [ 0 : nr_devices ]
export_mesh = Mesh ( export_devices , axis_names = ( " x " , ) )
a = np . arange ( 16 * 4 , dtype = np . float32 ) . reshape ( ( 16 , 4 ) )
@functools.partial (
jax . jit ,
in_shardings = ( jax . sharding . NamedSharding ( export_mesh , P ( " x " , None ) , ) , ) ,
out_shardings = jax . sharding . NamedSharding ( export_mesh , P ( None , " x " ) ) )
def f_jax ( b ) : # b: f32[16 // DEVICES, 4]
return b * 2.
res_native = f_jax ( a )
exp = export . export ( f_jax ) ( a )
run_devices = export_devices [ : : - 1 ] # We can use other devices
run_mesh = Mesh ( run_devices , " y " )
a_device = jax . device_put ( a , jax . sharding . NamedSharding ( run_mesh , P ( ) ) )
expected_re = re . compile (
# The top-level input it replicated
r " func.func .* @main \ ( %a rg0: tensor<16x4xf32> { mhlo.sharding = \" {replicated} \" } \ ).* "
# We apply the in_shardings for f_jax
r " .*custom_call @Sharding \ ( %a rg0 \ ) { mhlo.sharding = \" { devices= \ [2,1 \ ]<= \ [2 \ ]} \" }.* "
r " % 1 = .*call @call_exported_f_jax.* "
# We apply the out_shardings for f_jax
r " .*custom_call @Sharding \ ( % 1 \ ) { mhlo.sharding = \" { devices= \ [1,2 \ ]<= \ [2 \ ]} \" }.* " ,
re . DOTALL )
hlo = jax . jit ( export . call_exported ( exp ) ) . lower ( a_device ) . as_text ( )
self . assertRegex ( hlo , expected_re )
res_exported = export . call_exported ( exp ) ( a_device )
self . assertAllClose ( res_native , res_exported )
# Test error reporting
with self . assertRaisesRegex (
NotImplementedError ,
" Exported module .* was lowered for 2 devices and is called in a context with 1 device " ) :
_ = export . call_exported ( exp ) ( a )
with self . assertRaisesRegex (
NotImplementedError ,
" Exported module .* was lowered for 2 devices and is called in a context with 1 device " ) :
mesh1 = Mesh ( jax . devices ( ) [ 0 : 1 ] , axis_names = ( " x " , ) )
_ = jax . jit (
export . call_exported ( exp ) ,
in_shardings = ( jax . sharding . NamedSharding ( mesh1 , P ( " x " , None ) ) , )
) ( a )
@jtu.parameterized_filterable (
kwargs = [
dict ( testcase_name = f " _in_shardings= { in_shardings } _out_shardings= { out_shardings } " ,
in_shardings = in_shardings , out_shardings = out_shardings )
for in_shardings in ( " missing " , None , " P " )
for out_shardings in ( " missing " , None , " P " )
] )
def test_grad_with_sharding ( self , in_shardings = " P " , out_shardings = None ) :
if len ( jax . devices ( ) ) < 2 :
self . skipTest ( " Test requires at least 2 devices " )
x_shape = ( 10 , 20 )
x = np . arange ( np . prod ( x_shape ) , dtype = np . float32 ) . reshape ( x_shape )
def f_jax ( x ) : # x: f32[10,20] -> f32[20,10]
return jnp . sin ( x . T )
pjit_kwargs = { }
if in_shardings != " missing " :
pjit_kwargs [ " in_shardings " ] = ( P ( None , " x " ) if in_shardings == " P " else None )
if out_shardings != " missing " :
pjit_kwargs [ " out_shardings " ] = ( P ( " x " , None ) if out_shardings == " P " else None )
f_jax = pjit . pjit ( f_jax , * * pjit_kwargs )
with Mesh ( jax . devices ( ) [ : 2 ] , " x " ) :
exp = export . export ( f_jax ) ( x )
exp_vjp = exp . vjp ( )
vjp_module_str = str ( exp_vjp . mlir_module ( ) )
if in_shardings == " P " :
primal_in_sharding = " { devices=[1,2]<=[2]} "
else :
primal_in_sharding = " {replicated} "
if out_shardings == " P " :
primal_out_sharding = " { devices=[2,1]<=[2]} "
else :
primal_out_sharding = " {replicated} "
main = re . search (
r " func.func public @main \ ( %a rg0: tensor<10x20xf32> { mhlo.sharding = \" ([^ \" ]+) \" "
r " .* %a rg1: tensor<20x10xf32> { mhlo.sharding = \" ([^ \" ]+) \" "
# result
r " .* -> \ (tensor<10x20xf32>.*mhlo.sharding = \" ([^ \" ]+) \" " ,
vjp_module_str )
self . assertEqual (
main . groups ( ) ,
( primal_in_sharding , primal_out_sharding , primal_in_sharding ) )
# Custom calls for the primal input shape
primal_in_calls = re . findall (
r " custom_call @Sharding.* { mhlo.sharding = \" (.+) \" } : .*tensor<10x20xf32> " ,
vjp_module_str )
self . assertTrue (
all ( s == primal_in_sharding for s in primal_in_calls ) ,
primal_in_calls
)
# Custom calls for the primal output shape
primal_out_calls = re . findall (
r " custom_call @Sharding.* { mhlo.sharding = \" (.+) \" } : .*tensor<20x10xf32> " ,
vjp_module_str )
self . assertTrue (
all ( s == primal_out_sharding for s in primal_out_calls ) ,
primal_in_calls
)
2023-08-27 13:27:34 +02:00
def test_multi_platform ( self ) :
2023-10-06 21:32:28 +02:00
x = np . arange ( 8 , dtype = np . float32 )
2023-09-04 11:03:55 +03:00
exp = export . export ( _testing_multi_platform_func ,
2023-10-06 21:32:28 +02:00
lowering_platforms = ( " cpu " , " tpu " , " cuda " ) ) ( x )
self . assertEqual ( exp . lowering_platforms , ( " cpu " , " tpu " , " cuda " ) )
2023-09-25 10:13:30 +02:00
module_str = str ( exp . mlir_module ( ) )
2023-09-04 11:03:55 +03:00
expected_main_re = (
r " @main \ ( "
r " %a rg0: tensor<i..> { jax.platform_index = true}.*, "
2023-10-06 21:32:28 +02:00
r " %a rg1: tensor<8xf32>.* -> " )
2023-09-04 11:03:55 +03:00
self . assertRegex ( module_str , expected_main_re )
[export] Ensure that we run shape refinement for modules that use multi-platform lowering
For multi-platform lowering we use a constant platform index argument
threaded through all function calls, and we use conditionals
for the lowering of primitives that have multiple lowerings.
In many cases, but not all, these conditionals are removed
by constant folding prior to conversion to HLO, and the XLA
compiler will only see the code for the compilation platform.
However, in some cases these conditionals are not constant-folded
and the XLA compiler will either see code for other platforms
that is does not expect (the TPU tests failing before),
or will simply generate slightly different code
(e.g., the conv_general_dilated tests on CPU,
where we saw numerical differences before).
To address this, we ensure that we run shape refinement
for modules that use multi-platform lowering. The shape refinement
pass already handles inter-procedural constant folding for dimension
value arguments.
At the moment, the platform index argument is modelled as a dimension
value during lowering, so it makes some sense to use the same
shape refinement pass to clean it up before compilation. But
a cleaner solution would be to separate the shape refinement
pass into an interprocedural constant folding, followed by
proper shape refinement. Then we'd introduce a separate
attribute `jax.needs_constant_folding` in addition to
`jax.uses_shape_polymorphism`.
This change fixes the remaining failures in the
multi_platform_export_test for TPU, and the
conv_general_dilated test for CPU.
PiperOrigin-RevId: 571254037
2023-10-06 00:41:27 -07:00
self . assertIn ( " jax.uses_shape_polymorphism = true " ,
module_str )
2023-10-06 21:32:28 +02:00
# Call with argument placed on different plaforms
for platform in self . __class__ . platforms :
x_device = jax . device_put ( x , jax . devices ( platform ) [ 0 ] )
res_exp = export . call_exported ( exp ) ( x_device )
self . assertAllClose (
res_exp ,
_testing_multi_platform_fun_expected ( x , platform = platform ) )
2023-08-27 13:27:34 +02:00
def test_multi_platform_nested ( self ) :
x = np . arange ( 5 , dtype = np . float32 )
2023-09-04 11:03:55 +03:00
exp = export . export ( _testing_multi_platform_func ,
2023-10-06 21:32:28 +02:00
lowering_platforms = ( " cpu " , " tpu " , " cuda " ) ) ( x )
self . assertEqual ( exp . lowering_platforms , ( " cpu " , " tpu " , " cuda " ) )
2023-08-27 13:27:34 +02:00
# Now serialize the call to the exported using a different sequence of
# lowering platforms, but included in the lowering platforms for the
# nested exported.
2023-09-05 22:15:22 -07:00
exp2 = export . export ( export . call_exported ( exp ) ,
2023-10-06 21:32:28 +02:00
lowering_platforms = ( " cpu " , " cuda " ) ) ( x )
# Call with argument placed on different plaforms
for platform in self . __class__ . platforms :
if platform == " tpu " : continue
x_device = jax . device_put ( x , jax . devices ( platform ) [ 0 ] )
res_exp = export . call_exported ( exp2 ) ( x_device )
self . assertAllClose (
res_exp ,
_testing_multi_platform_fun_expected ( x , platform = platform ) )
2023-08-27 13:27:34 +02:00
[export] Ensure that we run shape refinement for modules that use multi-platform lowering
For multi-platform lowering we use a constant platform index argument
threaded through all function calls, and we use conditionals
for the lowering of primitives that have multiple lowerings.
In many cases, but not all, these conditionals are removed
by constant folding prior to conversion to HLO, and the XLA
compiler will only see the code for the compilation platform.
However, in some cases these conditionals are not constant-folded
and the XLA compiler will either see code for other platforms
that is does not expect (the TPU tests failing before),
or will simply generate slightly different code
(e.g., the conv_general_dilated tests on CPU,
where we saw numerical differences before).
To address this, we ensure that we run shape refinement
for modules that use multi-platform lowering. The shape refinement
pass already handles inter-procedural constant folding for dimension
value arguments.
At the moment, the platform index argument is modelled as a dimension
value during lowering, so it makes some sense to use the same
shape refinement pass to clean it up before compilation. But
a cleaner solution would be to separate the shape refinement
pass into an interprocedural constant folding, followed by
proper shape refinement. Then we'd introduce a separate
attribute `jax.needs_constant_folding` in addition to
`jax.uses_shape_polymorphism`.
This change fixes the remaining failures in the
multi_platform_export_test for TPU, and the
conv_general_dilated test for CPU.
PiperOrigin-RevId: 571254037
2023-10-06 00:41:27 -07:00
def test_multi_platform_nested_inside_single_platform_export ( self ) :
x = np . arange ( 5 , dtype = np . float32 )
exp = export . export ( _testing_multi_platform_func ,
2023-10-06 21:32:28 +02:00
lowering_platforms = ( " cpu " , " tpu " , " cuda " ) ) ( x )
self . assertEqual ( exp . lowering_platforms , ( " cpu " , " tpu " , " cuda " ) )
[export] Ensure that we run shape refinement for modules that use multi-platform lowering
For multi-platform lowering we use a constant platform index argument
threaded through all function calls, and we use conditionals
for the lowering of primitives that have multiple lowerings.
In many cases, but not all, these conditionals are removed
by constant folding prior to conversion to HLO, and the XLA
compiler will only see the code for the compilation platform.
However, in some cases these conditionals are not constant-folded
and the XLA compiler will either see code for other platforms
that is does not expect (the TPU tests failing before),
or will simply generate slightly different code
(e.g., the conv_general_dilated tests on CPU,
where we saw numerical differences before).
To address this, we ensure that we run shape refinement
for modules that use multi-platform lowering. The shape refinement
pass already handles inter-procedural constant folding for dimension
value arguments.
At the moment, the platform index argument is modelled as a dimension
value during lowering, so it makes some sense to use the same
shape refinement pass to clean it up before compilation. But
a cleaner solution would be to separate the shape refinement
pass into an interprocedural constant folding, followed by
proper shape refinement. Then we'd introduce a separate
attribute `jax.needs_constant_folding` in addition to
`jax.uses_shape_polymorphism`.
This change fixes the remaining failures in the
multi_platform_export_test for TPU, and the
conv_general_dilated test for CPU.
PiperOrigin-RevId: 571254037
2023-10-06 00:41:27 -07:00
# Now serialize the call for the current platform.
exp2 = export . export ( export . call_exported ( exp ) ) ( x )
module_str = str ( exp2 . mlir_module ( ) )
self . assertIn ( " jax.uses_shape_polymorphism = true " ,
module_str )
res2 = export . call_exported ( exp2 ) ( x )
self . assertAllClose ( res2 , _testing_multi_platform_fun_expected ( x ) )
2023-08-27 13:27:34 +02:00
def test_multi_platform_and_poly ( self ) :
2023-09-27 12:10:06 -07:00
if jtu . test_device_matches ( [ " gpu " ] ) :
2023-08-27 13:27:34 +02:00
# The export is not applicable to GPU
raise unittest . SkipTest ( " Not intended for running on GPU " )
2023-09-04 11:03:55 +03:00
exp = export . export ( lambda x : jnp . reshape ( _testing_multi_platform_func ( x ) , ( - 1 , ) ) ,
2023-10-06 21:32:28 +02:00
lowering_platforms = ( " cpu " , " tpu " ) ) (
2023-09-05 22:15:22 -07:00
export . poly_spec ( ( 5 , 6 ) , np . float32 , " b1, b2 " )
2023-08-27 13:27:34 +02:00
)
x = np . arange ( 12 , dtype = np . float32 ) . reshape ( ( 3 , 4 ) )
2023-09-05 22:15:22 -07:00
res = export . call_exported ( exp ) ( x )
2023-09-04 11:03:55 +03:00
self . assertAllClose ( res , _testing_multi_platform_fun_expected ( x ) . reshape ( ( - 1 , ) ) )
2023-08-27 13:27:34 +02:00
# Now serialize the call to the exported
2023-09-05 22:15:22 -07:00
exp2 = export . export ( export . call_exported ( exp ) ) ( x )
res2 = export . call_exported ( exp2 ) ( x )
2023-09-04 11:03:55 +03:00
self . assertAllClose ( res2 , _testing_multi_platform_fun_expected ( x ) . reshape ( ( - 1 , ) ) )
2023-08-27 13:27:34 +02:00
2023-10-06 21:32:28 +02:00
def test_multi_platform_and_sharding ( self ) :
export_devices = jax . devices ( ) [ 0 : 2 ]
export_mesh = Mesh ( export_devices , axis_names = ( " x " , ) )
a = np . arange ( 16 * 4 , dtype = np . float32 ) . reshape ( ( 16 , 4 ) )
@functools.partial (
jax . jit ,
in_shardings = ( jax . sharding . NamedSharding ( export_mesh , P ( " x " , None ) , ) , ) ,
out_shardings = jax . sharding . NamedSharding ( export_mesh , P ( None , " x " ) ) )
def f_jax ( b ) : # b: f32[16 // DEVICES, 4]
return b * 2.
res_native = f_jax ( a )
exp = export . export ( f_jax ,
lowering_platforms = ( " cpu " , " tpu " , " cuda " ) ) ( a )
# Call with argument placed on different plaforms
for platform in self . __class__ . platforms :
run_devices = jax . devices ( platform ) [ 0 : len ( export_devices ) ]
if len ( run_devices ) != len ( export_devices ) :
continue
run_mesh = Mesh ( run_devices , ( " x " , ) )
a_device = jax . device_put ( a , jax . sharding . NamedSharding ( run_mesh , None ) )
res_exp = export . call_exported ( exp ) ( a_device )
self . assertArraysAllClose ( res_native , res_exp )
2023-04-20 12:21:41 +03:00
if __name__ == " __main__ " :
absltest . main ( testLoader = jtu . JaxTestLoader ( ) )