From 9d5cd9981b6965d5c9b6bd8694897a8d93faad0a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 13 Jan 2021 13:42:39 -0800 Subject: [PATCH] Open source the XRTTpuDeviceAccessor and register XRTStateOps on TPU on the open source XRT side if it is a 1vm use case. 2vm code remain unchanged. PiperOrigin-RevId: 351656566 Change-Id: I531e13073558bcd0d85d9fc76eb4358252bc0658 --- tensorflow/compiler/xrt/BUILD | 19 ++++ tensorflow/compiler/xrt/kernels/BUILD | 5 + .../compiler/xrt/kernels/tpu_state_op.cc | 104 ++++++++++++++++++ tensorflow/compiler/xrt/xrt_tpu_device.cc | 61 ++++++++++ tensorflow/compiler/xrt/xrt_tpu_device.h | 68 ++++++++++++ tensorflow/core/tpu/BUILD | 2 +- tensorflow/stream_executor/tpu/BUILD | 1 + 7 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/xrt/kernels/tpu_state_op.cc create mode 100644 tensorflow/compiler/xrt/xrt_tpu_device.cc create mode 100644 tensorflow/compiler/xrt/xrt_tpu_device.h diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 1b699e7d8df..66b53e20681 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -33,6 +33,25 @@ tf_proto_library( visibility = ["//visibility:public"], ) +cc_library( + name = "xrt_tpu_utils", + srcs = [ + "xrt_tpu_device.cc", + ], + hdrs = [ + "xrt_tpu_device.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/tpu:tpu_configuration", + "//tensorflow/stream_executor/tpu:tpu_node_context", + ], +) + cc_library( name = "xrt_utils", srcs = [ diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 0e73ba7ecf3..8dcde2645c9 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -49,11 +49,14 @@ cc_library( srcs = [ "tpu_compile_ops.cc", "tpu_execute_op.cc", + "tpu_state_op.cc", ], visibility = [":friends"], deps = [ + ":xrt_state_ops", "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -61,6 +64,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", @@ -68,6 +72,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xrt:xrt_proto_cc", + "//tensorflow/compiler/xrt:xrt_tpu_utils", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xrt/kernels/tpu_state_op.cc b/tensorflow/compiler/xrt/kernels/tpu_state_op.cc new file mode 100644 index 00000000000..eeedddc2bdf --- /dev/null +++ b/tensorflow/compiler/xrt/kernels/tpu_state_op.cc @@ -0,0 +1,104 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Classes for allocating XLA literals in device memory and managing handles +// that refer to them. + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h" +#include "tensorflow/compiler/xrt/xrt_tpu_device.h" +#include "tensorflow/core/tpu/tpu_defs.h" + +namespace tensorflow { +REGISTER_KERNEL_BUILDER(Name("XRTAllocate") + .Device(DEVICE_TPU_NODE) + .HostMemory("allocation") + .HostMemory("handle"), + XRTAllocateOp); + +REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") + .Device(DEVICE_TPU_NODE) + .HostMemory("handle"), + XRTAllocateUninitializedOp); + +REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") + .Device(DEVICE_TPU_NODE) + .HostMemory("inputs") + .HostMemory("handle"), + XRTAllocateFromTensorOp); + +REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") + .Device(DEVICE_TPU_NODE) + .HostMemory("base_handle") + .HostMemory("shape_index") + .HostMemory("output_handle"), + XRTSubTupleOp); + +REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") + .Device(DEVICE_TPU_NODE) + .HostMemory("base_handle") + .HostMemory("shape_index") + .HostMemory("output_handle"), + XRTSubTupleOp); + +REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") + .Device(DEVICE_TPU_NODE) + .HostMemory("tuple_description") + .HostMemory("input_handles") + .HostMemory("output_handle"), + XRTMakeTupleOp); + +REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") + .Device(DEVICE_TPU_NODE) + .HostMemory("handle") + .HostMemory("literal"), + XRTReadLiteralOp); + +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_TPU_NODE) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + +REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") + .Device(DEVICE_TPU_NODE) + .HostMemory("handle") + .HostMemory("literal"), + XRTReadLiteralOp); + +REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") + .Device(DEVICE_TPU_NODE) + .HostMemory("handles") + .HostMemory("tensors"), + XRTReadToTensorOp); + +REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") + .Device(DEVICE_TPU_NODE) + .HostMemory("handle"), + XRTReleaseAllocationOp); + +REGISTER_KERNEL_BUILDER( + Name("XRTReleaseAllAllocations").Device(DEVICE_TPU_NODE), + XRTReleaseAllAllocationsOp); + +REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_TPU_NODE), + XRTCompactAllocationsOp); + +REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_TPU_NODE), + XRTMemoryInfoOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_tpu_device.cc b/tensorflow/compiler/xrt/xrt_tpu_device.cc new file mode 100644 index 00000000000..65f68d977d3 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_tpu_device.cc @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/xrt_tpu_device.h" + +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/tpu/tpu_configuration.h" + +namespace tensorflow { + +/*static*/ Status XRTTpuDeviceAccessor::GetResourceManager(OpKernelContext* ctx, + ResourceMgr** rm) { + // ctx is unused here, but maintained because XRTGenericDeviceAccessor uses + // it in its GetResourceManager. + *rm = GetTPUConfigResourceMgr(); + if (*rm == nullptr) { + return errors::Internal("No Tpu resource manager."); + } + return Status::OK(); +} + +Status XRTTpuDeviceAccessor::ScopedRef::Acquire(int device_ordinal) { + TF_ASSIGN_OR_RETURN(node_context_, + tpu::TpuNodeContext::Create(device_ordinal)); + ordinal_ = device_ordinal; + return Status::OK(); +} + +Status XRTTpuDeviceAccessor::ScopedRef::Acquire(OpKernelContext* ctx) { + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); + return Acquire(metadata->device_ordinal()); +} + +/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef( + OpKernelContext* /*unused ctx*/, int device_ordinal, + ScopedRef* scoped_ref) { + return scoped_ref->Acquire(device_ordinal); +} + +/*static*/ Status XRTTpuDeviceAccessor::InitScopedRef(OpKernelContext* ctx, + ScopedRef* scoped_ref) { + return scoped_ref->Acquire(ctx); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_tpu_device.h b/tensorflow/compiler/xrt/xrt_tpu_device.h new file mode 100644 index 00000000000..611d17b6ca1 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_tpu_device.h @@ -0,0 +1,68 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Classes for keeping track of on-device state for TPUs. + +#ifndef TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/stream_executor/tpu/tpu_node_context.h" + +namespace tensorflow { + +// This accessor is used for XLA TPU. It uses the distributed TPU compilation +// cache infrastructure which it accesses via the TPU_SYSTEM resource manager. +class XRTTpuDeviceAccessor { + public: + static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); + + class ScopedRef { + public: + ScopedRef() {} + ~ScopedRef() {} + + ScopedRef(const ScopedRef&) = delete; + ScopedRef& operator=(const ScopedRef&) = delete; + + // Returns the XLA device properties from the TpuNodeContext object + // protected by this ScopedRef. + xla::Backend* backend() { return node_context_->backend(); } + int device_ordinal() { return ordinal_; } + + private: + // XRTTpuDeviceAccessor::InitScopedRef is the only way to initialize + // ScopedRef. + friend class XRTTpuDeviceAccessor; + + Status Acquire(int device_ordinal); + + Status Acquire(OpKernelContext* ctx); + + std::unique_ptr node_context_; + int ordinal_ = 0; + }; + + static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal, + ScopedRef* scoped_ref); + + static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 2a5a1f5dcd0..ded1ba20bd0 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -10,7 +10,7 @@ package( default_visibility = [ "//tensorflow/compiler/mlir/tensorflow:__subpackages__", "//tensorflow/compiler/tf2xla/kernels:__subpackages__", - "//tensorflow/compiler/xrt/kernels:__subpackages__", + "//tensorflow/compiler/xrt:__subpackages__", "//tensorflow/core/tpu:__subpackages__", "//tensorflow/stream_executor/tpu:__subpackages__", ], diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD index 1a7e130f9ac..67b6b3c19dc 100644 --- a/tensorflow/stream_executor/tpu/BUILD +++ b/tensorflow/stream_executor/tpu/BUILD @@ -5,6 +5,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( default_visibility = [ "//learning/brain/experimental/dtensor:__subpackages__", + "//tensorflow/compiler/xrt:__subpackages__", "//tensorflow/core/profiler/internal/tpu:__subpackages__", "//tensorflow/core/tpu:__subpackages__", ],