mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[PJRT] Delete the old :cpu_device target that uses StreamExecutor.
The TFRT CPU client is better in every way and the SE CPU client is unmaintained and has not been used by JAX in many months. PiperOrigin-RevId: 489246256
This commit is contained in:
parent
ebee4f4bfd
commit
88379603e0
@ -29,8 +29,8 @@ tf_cc_binary(
|
||||
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:status",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:statusor",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:cpu_device",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader",
|
||||
"@org_tensorflow//tensorflow/core/platform:logging",
|
||||
|
@ -42,8 +42,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
|
||||
@ -67,7 +67,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
// Get a CPU client.
|
||||
std::unique_ptr<xla::PjRtClient> client =
|
||||
xla::GetCpuClient(/*asynchronous=*/true).value();
|
||||
xla::GetTfrtCpuClient(/*asynchronous=*/true).value();
|
||||
|
||||
// Compile XlaComputation to PjRtExecutable.
|
||||
xla::XlaComputation xla_computation(test_module_proto);
|
||||
|
Loading…
x
Reference in New Issue
Block a user