From 0796a9598950d7a59d76f4d3d9063c19b68645ef Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 10 Dec 2020 14:53:59 -0800 Subject: [PATCH] Add a custom device tensor handle type This will replace TensorHandles holding custom device tensors, which will clean up the runtime a bit, provide flexibility for custom devices (especially for shape handling and copies), and morph custom devices a bit more toward op handlers. I'll do the migration itself in a followup. For now I'm leaving the naming as "custom device" for consistency. Once custom devices are migrated to this type it will be easier to bump the abstraction level a bit (AbstractTensorHandles instead of ImmediateExecutionTensorHandles) and re-name things. PiperOrigin-RevId: 346871040 Change-Id: Ic1d4791b27d37b1078bd985be344c54748ae49b8 --- tensorflow/c/eager/abstract_tensor_handle.h | 2 +- tensorflow/core/common_runtime/eager/BUILD | 35 ++++++ .../core/common_runtime/eager/context.h | 21 +--- .../common_runtime/eager/custom_device.cc | 82 ++++++++++++++ .../core/common_runtime/eager/custom_device.h | 107 ++++++++++++++++++ .../eager/custom_device_test.cc | 99 ++++++++++++++++ 6 files changed, 325 insertions(+), 21 deletions(-) create mode 100644 tensorflow/core/common_runtime/eager/custom_device.cc create mode 100644 tensorflow/core/common_runtime/eager/custom_device.h create mode 100644 tensorflow/core/common_runtime/eager/custom_device_test.cc diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h index 1ca4a9a8ecb..a560c0d58c9 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.h +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -27,7 +27,7 @@ namespace tensorflow { // execution mode. class AbstractTensorHandle : public core::RefCounted { protected: - enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt }; + enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice }; explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} virtual ~AbstractTensorHandle() {} diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 728aacb36e4..2ae079be53c 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -77,6 +77,7 @@ tf_cuda_library( deps = [ ":eager_executor", ":kernel_and_device", + ":custom_device", "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_tensor_internal", "//tensorflow/c/eager:immediate_execution_context", @@ -110,6 +111,39 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "custom_device", + srcs = ["custom_device.cc"], + hdrs = ["custom_device.h"], + visibility = ["//tensorflow:internal"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core/lib/core:status", + ], + }), +) + +tf_cc_test( + name = "custom_device_test", + srcs = ["custom_device_test.cc"], + deps = [ + ":context", + ":core", + ":custom_device", + ":tensor_handle", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cuda_library( name = "context_distributed_manager", srcs = [ @@ -665,6 +699,7 @@ filegroup( srcs = [ "attr_builder.h", "context.h", + "custom_device.h", "eager_executor.h", "eager_operation.h", "kernel_and_device.h", diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 886ed498c07..ef99dfac3a1 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/function.h" @@ -81,26 +82,6 @@ class RemoteMgr; class TensorHandle; class EagerOperation; -class CustomDevice { - public: - virtual ~CustomDevice() {} - virtual const string& name() = 0; - virtual Status CopyTensorToDevice(TensorHandle* tensor, - TensorHandle** result) = 0; - - virtual Status CopyTensorFromDevice(TensorHandle* tensor, - const string& target_device_name, - TensorHandle** result) = 0; - - virtual Status Execute(const EagerOperation* op, TensorHandle** retvals, - int* num_retvals) = 0; -}; - -// Custom devices do many of the same things as physical Devices, but have a -// much more restricted interface. We pass around ambiguous pointers since -// TensorHandles may be placed either on custom or physical devices. -using VariantDevice = absl::variant; - class EagerContext : public ImmediateExecutionContext, public core::RefCounted { public: static constexpr uint64 kInvalidContextId = 0; diff --git a/tensorflow/core/common_runtime/eager/custom_device.cc b/tensorflow/core/common_runtime/eager/custom_device.cc new file mode 100644 index 00000000000..c3cfe910e92 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/custom_device.cc @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/custom_device.h" + +namespace tensorflow { + +Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const { + int num_dims; + TF_RETURN_IF_ERROR(NumDims(&num_dims)); + std::vector dims(num_dims); + for (int i = 0; i < num_dims; ++i) { + TF_RETURN_IF_ERROR(Dim(i, &dims[i])); + } + return PartialTensorShape::MakePartialShape(dims.data(), num_dims, shape); +} + +Status CustomDeviceTensorHandle::NumElements(int64* num_elements) const { + *num_elements = 1; + int num_dims; + TF_RETURN_IF_ERROR(NumDims(&num_dims)); + for (int i = 0; i < num_dims; ++i) { + int64 dim; + TF_RETURN_IF_ERROR(Dim(i, &dim)); + *num_elements *= dim; + } + return Status::OK(); +} + +const char* CustomDeviceTensorHandle::DeviceType(Status* status) const { + const DeviceNameUtils::ParsedName* parsed = ParsedName(status); + if (!status->ok()) { + return ""; + } + return parsed->type.c_str(); +} + +int CustomDeviceTensorHandle::DeviceId(Status* status) const { + const DeviceNameUtils::ParsedName* parsed = ParsedName(status); + if (!status->ok()) { + return 0; + } + return parsed->id; +} + +AbstractTensorInterface* CustomDeviceTensorHandle::Resolve(Status* status) { + core::RefCountPtr copied_off( + context_->CopyTensorHandleToDevice( + this, "/job:localhost/replica:0/task:0/device:CPU:0", status)); + if (!status->ok()) { + return nullptr; + } + return copied_off->Resolve(status); +} + +const DeviceNameUtils::ParsedName* CustomDeviceTensorHandle::ParsedName( + Status* status) const { + if (!parsed_name_.has_value()) { + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullOrLocalName(device_->name(), &parsed_name)) { + *status = errors::InvalidArgument( + absl::StrCat("Invalid custom device name ", device_->name())); + return nullptr; + } + parsed_name_.emplace(std::move(parsed_name)); + } + return &*parsed_name_; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/custom_device.h b/tensorflow/core/common_runtime/eager/custom_device.h new file mode 100644 index 00000000000..e3168b6265b --- /dev/null +++ b/tensorflow/core/common_runtime/eager/custom_device.h @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class TensorHandle; +class EagerOperation; + +// Custom devices intercept the execution of operations (the `Execute` method), +// typically implemented with one or more of the custom device's own executions. +class CustomDevice { + public: + virtual ~CustomDevice() {} + virtual const string& name() = 0; + virtual Status CopyTensorToDevice(TensorHandle* tensor, + TensorHandle** result) = 0; + + virtual Status CopyTensorFromDevice(TensorHandle* tensor, + const string& target_device_name, + TensorHandle** result) = 0; + + virtual Status Execute(const EagerOperation* op, TensorHandle** retvals, + int* num_retvals) = 0; +}; + +// Custom devices do many of the same things as physical Devices, but have a +// much more restricted interface. We pass around ambiguous pointers since +// operations may be placed either on custom or physical devices. +using VariantDevice = absl::variant; + +// A tensor handle produced by a custom device. Generally they can only be +// consumed by executing an operation on the same custom device that produced it +// originally, or by attempting to copy the handle off the custom device. +// +// TODO(allenl): Currently custom devices are tied to the eager C API. They +// should be renamed op handlers and subclass AbstractTensorHandle instead so +// they are eager/graph agnostic. +class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle { + public: + CustomDeviceTensorHandle(ImmediateExecutionContext* context, + CustomDevice* device, tensorflow::DataType dtype) + : ImmediateExecutionTensorHandle(kCustomDevice), + context_(context), + device_(device), + dtype_(dtype) {} + + tensorflow::DataType DataType() const override { return dtype_; } + Status Shape(PartialTensorShape* shape) const override; + Status NumElements(int64* num_elements) const override; + + const char* DeviceName(Status* status) const override { + return device_->name().c_str(); + } + const char* BackingDeviceName(Status* status) const override { + return device_->name().c_str(); + } + CustomDevice* device() const { return device_; } + const char* DeviceType(Status* status) const override; + int DeviceId(Status* status) const override; + + AbstractTensorInterface* Resolve(Status* status) override; + + ImmediateExecutionTensorHandle* Copy() override { + Ref(); + return this; + } + void Release() override { Unref(); } + + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kCustomDevice; + } + + protected: + const DeviceNameUtils::ParsedName* ParsedName(Status* status) const; + + ImmediateExecutionContext* const context_; + CustomDevice* const device_; + const tensorflow::DataType dtype_; + + mutable absl::optional parsed_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc new file mode 100644 index 00000000000..32dbf2ce9b6 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/custom_device.h" + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class TestCustomDevice : public CustomDevice { + public: + explicit TestCustomDevice(std::string name) : name_(name) {} + const std::string& name() override { return name_; } + Status CopyTensorToDevice(TensorHandle* tensor, + TensorHandle** result) override { + tensor->Ref(); + *result = tensor; + return Status::OK(); + } + Status CopyTensorFromDevice(TensorHandle* tensor, + const std::string& target_device_name, + TensorHandle** result) override { + tensor->Ref(); + *result = tensor; + return Status::OK(); + } + Status Execute(const EagerOperation* op, TensorHandle** retvals, + int* num_retvals) override { + return errors::Unimplemented("Not implemented"); + } + + private: + std::string name_; +}; + +class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle { + public: + TestCustomDeviceTensorHandle(ImmediateExecutionContext* context, + TestCustomDevice* device, + tensorflow::DataType dtype) + : CustomDeviceTensorHandle(context, device, dtype) {} + + Status NumDims(int* num_dims) const override { + *num_dims = 1; + return Status::OK(); + } + Status Dim(int dim_index, int64* dim) const override { + if (dim_index == 0) { + *dim = 3; + return Status::OK(); + } else { + return errors::Internal("Dim out of bounds"); + } + } +}; + +TEST(CustomDevice, TestTensorHandle) { + StaticDeviceMgr device_mgr(DeviceFactory::NewDevice( + "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0")); + core::RefCountPtr ctx(new EagerContext( + SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false, + false, &device_mgr, false, nullptr, nullptr)); + std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15"; + TestCustomDevice device(device_name); + core::RefCountPtr tensor( + new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT)); + Status s; + std::string device_type = tensor->DeviceType(&s); + ASSERT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ("CUSTOM", device_type); + int device_index = tensor->DeviceId(&s); + ASSERT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(15, device_index); + int64 num_elements = 0; + s = tensor->NumElements(&num_elements); + ASSERT_TRUE(s.ok()) << s.error_message(); + EXPECT_EQ(3, num_elements); +} + +} // namespace +} // namespace tensorflow