Open source the XRTTpuDeviceAccessor and register XRTStateOps on TPU on the open source XRT side if it is a 1vm use case. 2vm code remain unchanged.
PiperOrigin-RevId: 351656566 Change-Id: I531e13073558bcd0d85d9fc76eb4358252bc0658
This commit is contained in:
parent
e513f24d41
commit
9d5cd9981b
@ -33,6 +33,25 @@ tf_proto_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xrt_tpu_utils",
|
||||
srcs = [
|
||||
"xrt_tpu_device.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xrt_tpu_device.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:xla_device",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xrt_utils",
|
||||
srcs = [
|
||||
|
@ -49,11 +49,14 @@ cc_library(
|
||||
srcs = [
|
||||
"tpu_compile_ops.cc",
|
||||
"tpu_execute_op.cc",
|
||||
"tpu_state_op.cc",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":xrt_state_ops",
|
||||
"//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -61,6 +64,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
@ -68,6 +72,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/compiler/xrt:xrt_proto_cc",
|
||||
"//tensorflow/compiler/xrt:xrt_tpu_utils",
|
||||
"//tensorflow/compiler/xrt:xrt_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
104
tensorflow/compiler/xrt/kernels/tpu_state_op.cc
Normal file
104
tensorflow/compiler/xrt/kernels/tpu_state_op.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Classes for allocating XLA literals in device memory and managing handles
|
||||
// that refer to them.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h"
|
||||
#include "tensorflow/compiler/xrt/xrt_tpu_device.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("allocation")
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateUninitializedOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("inputs")
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateFromTensorOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("base_handle")
|
||||
.HostMemory("shape_index")
|
||||
.HostMemory("output_handle"),
|
||||
XRTSubTupleOp<false, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("base_handle")
|
||||
.HostMemory("shape_index")
|
||||
.HostMemory("output_handle"),
|
||||
XRTSubTupleOp<true, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("tuple_description")
|
||||
.HostMemory("input_handles")
|
||||
.HostMemory("output_handle"),
|
||||
XRTMakeTupleOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal"),
|
||||
XRTReadLiteralOp<false, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal")
|
||||
.HostMemory("output_handle"),
|
||||
XRTWriteLiteralOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal"),
|
||||
XRTReadLiteralOp<true, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handles")
|
||||
.HostMemory("tensors"),
|
||||
XRTReadToTensorOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle"),
|
||||
XRTReleaseAllocationOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("XRTReleaseAllAllocations").Device(DEVICE_TPU_NODE),
|
||||
XRTReleaseAllAllocationsOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_TPU_NODE),
|
||||
XRTCompactAllocationsOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_TPU_NODE),
|
||||
XRTMemoryInfoOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
} // namespace tensorflow
|
61
tensorflow/compiler/xrt/xrt_tpu_device.cc
Normal file
61
tensorflow/compiler/xrt/xrt_tpu_device.cc
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xrt/xrt_tpu_device.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/*static*/ Status XRTTpuDeviceAccessor::GetResourceManager(OpKernelContext* ctx,
|
||||
ResourceMgr** rm) {
|
||||
// ctx is unused here, but maintained because XRTGenericDeviceAccessor uses
|
||||
// it in its GetResourceManager.
|
||||
*rm = GetTPUConfigResourceMgr();
|
||||
if (*rm == nullptr) {
|
||||
return errors::Internal("No Tpu resource manager.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XRTTpuDeviceAccessor::ScopedRef::Acquire(int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(node_context_,
|
||||
tpu::TpuNodeContext::Create(device_ordinal));
|
||||
ordinal_ = device_ordinal;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XRTTpuDeviceAccessor::ScopedRef::Acquire(OpKernelContext* ctx) {
|
||||
const XlaDevice::Metadata* metadata;
|
||||
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
|
||||
return Acquire(metadata->device_ordinal());
|
||||
}
|
||||
|
||||
/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef(
|
||||
OpKernelContext* /*unused ctx*/, int device_ordinal,
|
||||
ScopedRef* scoped_ref) {
|
||||
return scoped_ref->Acquire(device_ordinal);
|
||||
}
|
||||
|
||||
/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef(OpKernelContext* ctx,
|
||||
ScopedRef* scoped_ref) {
|
||||
return scoped_ref->Acquire(ctx);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
68
tensorflow/compiler/xrt/xrt_tpu_device.h
Normal file
68
tensorflow/compiler/xrt/xrt_tpu_device.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Classes for keeping track of on-device state for TPUs.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This accessor is used for XLA TPU. It uses the distributed TPU compilation
|
||||
// cache infrastructure which it accesses via the TPU_SYSTEM resource manager.
|
||||
class XRTTpuDeviceAccessor {
|
||||
public:
|
||||
static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
|
||||
|
||||
class ScopedRef {
|
||||
public:
|
||||
ScopedRef() {}
|
||||
~ScopedRef() {}
|
||||
|
||||
ScopedRef(const ScopedRef&) = delete;
|
||||
ScopedRef& operator=(const ScopedRef&) = delete;
|
||||
|
||||
// Returns the XLA device properties from the TpuNodeContext object
|
||||
// protected by this ScopedRef.
|
||||
xla::Backend* backend() { return node_context_->backend(); }
|
||||
int device_ordinal() { return ordinal_; }
|
||||
|
||||
private:
|
||||
// XRTTpuDeviceAccessor::InitScopedRef is the only way to initialize
|
||||
// ScopedRef.
|
||||
friend class XRTTpuDeviceAccessor;
|
||||
|
||||
Status Acquire(int device_ordinal);
|
||||
|
||||
Status Acquire(OpKernelContext* ctx);
|
||||
|
||||
std::unique_ptr<tpu::TpuNodeContext> node_context_;
|
||||
int ordinal_ = 0;
|
||||
};
|
||||
|
||||
static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal,
|
||||
ScopedRef* scoped_ref);
|
||||
|
||||
static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_
|
@ -10,7 +10,7 @@ package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/mlir/tensorflow:__subpackages__",
|
||||
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
|
||||
"//tensorflow/compiler/xrt/kernels:__subpackages__",
|
||||
"//tensorflow/compiler/xrt:__subpackages__",
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
"//tensorflow/stream_executor/tpu:__subpackages__",
|
||||
],
|
||||
|
@ -5,6 +5,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
package(
|
||||
default_visibility = [
|
||||
"//learning/brain/experimental/dtensor:__subpackages__",
|
||||
"//tensorflow/compiler/xrt:__subpackages__",
|
||||
"//tensorflow/core/profiler/internal/tpu:__subpackages__",
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
],
|
||||
|
Loading…
x
Reference in New Issue
Block a user