Open source the XRTTpuDeviceAccessor and register XRTStateOps on TPU on the open source XRT side.
PiperOrigin-RevId: 345582775 Change-Id: Ie97bae4e37b16ba8bce411c016e80b50f7a26ff1
This commit is contained in:
parent
1943e58d29
commit
293bd7502b
@ -41,7 +41,6 @@ cc_library(
|
||||
"xrt_memory_manager.cc",
|
||||
"xrt_metrics.cc",
|
||||
"xrt_state.cc",
|
||||
"xrt_tpu_device.cc",
|
||||
"xrt_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -51,7 +50,6 @@ cc_library(
|
||||
"xrt_metrics.h",
|
||||
"xrt_refptr.h",
|
||||
"xrt_state.h",
|
||||
"xrt_tpu_device.h",
|
||||
"xrt_util.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
@ -78,10 +76,8 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:regexp",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
|
@ -73,7 +73,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xrt/xrt_metrics.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -68,11 +67,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
|
||||
.HostMemory("allocation")
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("allocation")
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -82,10 +76,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||
.Device(DEVICE_XLA_CPU)
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateUninitializedOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -97,11 +87,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
|
||||
.HostMemory("inputs")
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateFromTensorOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("inputs")
|
||||
.HostMemory("handle"),
|
||||
XRTAllocateFromTensorOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -115,12 +100,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
|
||||
.HostMemory("shape_index")
|
||||
.HostMemory("output_handle"),
|
||||
XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("base_handle")
|
||||
.HostMemory("shape_index")
|
||||
.HostMemory("output_handle"),
|
||||
XRTSubTupleOp<false, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -134,12 +113,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
|
||||
.HostMemory("shape_index")
|
||||
.HostMemory("output_handle"),
|
||||
XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("base_handle")
|
||||
.HostMemory("shape_index")
|
||||
.HostMemory("output_handle"),
|
||||
XRTSubTupleOp<true, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -153,12 +126,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
|
||||
.HostMemory("input_handles")
|
||||
.HostMemory("output_handle"),
|
||||
XRTMakeTupleOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("tuple_description")
|
||||
.HostMemory("input_handles")
|
||||
.HostMemory("output_handle"),
|
||||
XRTMakeTupleOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -170,11 +137,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal"),
|
||||
XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal"),
|
||||
XRTReadLiteralOp<false, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -188,12 +150,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
|
||||
.HostMemory("literal")
|
||||
.HostMemory("output_handle"),
|
||||
XRTWriteLiteralOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal")
|
||||
.HostMemory("output_handle"),
|
||||
XRTWriteLiteralOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -205,11 +161,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal"),
|
||||
XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle")
|
||||
.HostMemory("literal"),
|
||||
XRTReadLiteralOp<true, XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -221,11 +172,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
|
||||
.HostMemory("handles")
|
||||
.HostMemory("tensors"),
|
||||
XRTReadToTensorOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handles")
|
||||
.HostMemory("tensors"),
|
||||
XRTReadToTensorOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
|
||||
.Device(DEVICE_XLA_GPU)
|
||||
@ -235,25 +181,16 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
|
||||
.Device(DEVICE_XLA_CPU)
|
||||
.HostMemory("handle"),
|
||||
XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
|
||||
.Device(DEVICE_TPU_NODE)
|
||||
.HostMemory("handle"),
|
||||
XRTReleaseAllocationOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU),
|
||||
XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU),
|
||||
XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("XRTReleaseAllAllocations").Device(DEVICE_TPU_NODE),
|
||||
XRTReleaseAllAllocationsOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_GPU),
|
||||
XRTCompactAllocationsOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_CPU),
|
||||
XRTCompactAllocationsOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_TPU_NODE),
|
||||
XRTCompactAllocationsOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTMetricsCollect").Device(DEVICE_CPU),
|
||||
XRTMetricsCollectOp);
|
||||
@ -262,7 +199,5 @@ REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_GPU),
|
||||
XRTMemoryInfoOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_CPU),
|
||||
XRTMemoryInfoOp<XRTGenericDeviceAccessor>);
|
||||
REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_TPU_NODE),
|
||||
XRTMemoryInfoOp<XRTTpuDeviceAccessor>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -37,7 +37,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xrt/xrt_memory_manager.h"
|
||||
#include "tensorflow/compiler/xrt/xrt_metrics.h"
|
||||
#include "tensorflow/compiler/xrt/xrt_state.h"
|
||||
#include "tensorflow/compiler/xrt/xrt_tpu_device.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
@ -1,61 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xrt/xrt_tpu_device.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/*static*/ Status XRTTpuDeviceAccessor::GetResourceManager(OpKernelContext* ctx,
|
||||
ResourceMgr** rm) {
|
||||
// ctx is unused here, but maintained because XRTGenericDeviceAccessor uses
|
||||
// it in its GetResourceManager.
|
||||
*rm = GetTPUConfigResourceMgr();
|
||||
if (*rm == nullptr) {
|
||||
return errors::Internal("No Tpu resource manager.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XRTTpuDeviceAccessor::ScopedRef::Acquire(int device_ordinal) {
|
||||
TF_ASSIGN_OR_RETURN(node_context_,
|
||||
tpu::TpuNodeContext::Create(device_ordinal));
|
||||
ordinal_ = device_ordinal;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XRTTpuDeviceAccessor::ScopedRef::Acquire(OpKernelContext* ctx) {
|
||||
const XlaDevice::Metadata* metadata;
|
||||
TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
|
||||
return Acquire(metadata->device_ordinal());
|
||||
}
|
||||
|
||||
/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef(
|
||||
OpKernelContext* /*unused ctx*/, int device_ordinal,
|
||||
ScopedRef* scoped_ref) {
|
||||
return scoped_ref->Acquire(device_ordinal);
|
||||
}
|
||||
|
||||
/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef(OpKernelContext* ctx,
|
||||
ScopedRef* scoped_ref) {
|
||||
return scoped_ref->Acquire(ctx);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,68 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Classes for keeping track of on-device state for TPUs.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This accessor is used for XLA TPU. It uses the distributed TPU compilation
|
||||
// cache infrastructure which it accesses via the TPU_SYSTEM resource manager.
|
||||
class XRTTpuDeviceAccessor {
|
||||
public:
|
||||
static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
|
||||
|
||||
class ScopedRef {
|
||||
public:
|
||||
ScopedRef() {}
|
||||
~ScopedRef() {}
|
||||
|
||||
ScopedRef(const ScopedRef&) = delete;
|
||||
ScopedRef& operator=(const ScopedRef&) = delete;
|
||||
|
||||
// Returns the XLA device properties from the TpuNodeContext object
|
||||
// protected by this ScopedRef.
|
||||
xla::Backend* backend() { return node_context_->backend(); }
|
||||
int device_ordinal() { return ordinal_; }
|
||||
|
||||
private:
|
||||
// XRTTpuDeviceAccessor::InitScopedRef is the only way to initialize
|
||||
// ScopedRef.
|
||||
friend class XRTTpuDeviceAccessor;
|
||||
|
||||
Status Acquire(int device_ordinal);
|
||||
|
||||
Status Acquire(OpKernelContext* ctx);
|
||||
|
||||
std::unique_ptr<tpu::TpuNodeContext> node_context_;
|
||||
int ordinal_ = 0;
|
||||
};
|
||||
|
||||
static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal,
|
||||
ScopedRef* scoped_ref);
|
||||
|
||||
static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_
|
@ -9,7 +9,6 @@ load(
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/compiler/tf2xla/kernels:__subpackages__",
|
||||
"//tensorflow/compiler/xrt:__subpackages__",
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
"//tensorflow/stream_executor/tpu:__subpackages__",
|
||||
],
|
||||
|
@ -5,7 +5,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
package(
|
||||
default_visibility = [
|
||||
"//learning/brain/experimental/dtensor:__subpackages__",
|
||||
"//tensorflow/compiler/xrt:__subpackages__",
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
|
Loading…
Reference in New Issue
Block a user