mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 09:06:06 +00:00

But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**. PiperOrigin-RevId: 744888557
22 lines
754 B
Python
22 lines
754 B
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the ific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from jax._src.layout import (
|
|
DeviceLocalLayout as DeviceLocalLayout,
|
|
Layout as Layout,
|
|
)
|
|
from jax._src.pjit import (
|
|
with_dll_constraint as with_dll_constraint,
|
|
)
|