Add a custom device tensor handle type

This will replace TensorHandles holding custom device tensors, which will clean up the runtime a bit, provide flexibility for custom devices (especially for shape handling and copies), and morph custom devices a bit more toward op handlers.

I'll do the migration itself in a followup.

For now I'm leaving the naming as "custom device" for consistency. Once custom devices are migrated to this type it will be easier to bump the abstraction level a bit (AbstractTensorHandles instead of ImmediateExecutionTensorHandles) and re-name things.

PiperOrigin-RevId: 346871040
Change-Id: Ic1d4791b27d37b1078bd985be344c54748ae49b8
This commit is contained in:
Allen Lavoie 2020-12-10 14:53:59 -08:00 committed by TensorFlower Gardener
parent bd9ffecdb6
commit 0796a95989
6 changed files with 325 additions and 21 deletions

View File

@ -27,7 +27,7 @@ namespace tensorflow {
// execution mode. // execution mode.
class AbstractTensorHandle : public core::RefCounted { class AbstractTensorHandle : public core::RefCounted {
protected: protected:
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt }; enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
virtual ~AbstractTensorHandle() {} virtual ~AbstractTensorHandle() {}

View File

@ -77,6 +77,7 @@ tf_cuda_library(
deps = [ deps = [
":eager_executor", ":eager_executor",
":kernel_and_device", ":kernel_and_device",
":custom_device",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
@ -110,6 +111,39 @@ tf_cuda_library(
}), }),
) )
tf_cuda_library(
name = "custom_device",
srcs = ["custom_device.cc"],
hdrs = ["custom_device.h"],
visibility = ["//tensorflow:internal"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core/lib/core:status",
],
}),
)
tf_cc_test(
name = "custom_device_test",
srcs = ["custom_device_test.cc"],
deps = [
":context",
":core",
":custom_device",
":tensor_handle",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_library( tf_cuda_library(
name = "context_distributed_manager", name = "context_distributed_manager",
srcs = [ srcs = [
@ -665,6 +699,7 @@ filegroup(
srcs = [ srcs = [
"attr_builder.h", "attr_builder.h",
"context.h", "context.h",
"custom_device.h",
"eager_executor.h", "eager_executor.h",
"eager_operation.h", "eager_operation.h",
"kernel_and_device.h", "kernel_and_device.h",

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -81,26 +82,6 @@ class RemoteMgr;
class TensorHandle; class TensorHandle;
class EagerOperation; class EagerOperation;
class CustomDevice {
public:
virtual ~CustomDevice() {}
virtual const string& name() = 0;
virtual Status CopyTensorToDevice(TensorHandle* tensor,
TensorHandle** result) = 0;
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
const string& target_device_name,
TensorHandle** result) = 0;
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
int* num_retvals) = 0;
};
// Custom devices do many of the same things as physical Devices, but have a
// much more restricted interface. We pass around ambiguous pointers since
// TensorHandles may be placed either on custom or physical devices.
using VariantDevice = absl::variant<Device*, CustomDevice*>;
class EagerContext : public ImmediateExecutionContext, public core::RefCounted { class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
public: public:
static constexpr uint64 kInvalidContextId = 0; static constexpr uint64 kInvalidContextId = 0;

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/custom_device.h"
namespace tensorflow {
Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const {
int num_dims;
TF_RETURN_IF_ERROR(NumDims(&num_dims));
std::vector<int64> dims(num_dims);
for (int i = 0; i < num_dims; ++i) {
TF_RETURN_IF_ERROR(Dim(i, &dims[i]));
}
return PartialTensorShape::MakePartialShape(dims.data(), num_dims, shape);
}
Status CustomDeviceTensorHandle::NumElements(int64* num_elements) const {
*num_elements = 1;
int num_dims;
TF_RETURN_IF_ERROR(NumDims(&num_dims));
for (int i = 0; i < num_dims; ++i) {
int64 dim;
TF_RETURN_IF_ERROR(Dim(i, &dim));
*num_elements *= dim;
}
return Status::OK();
}
const char* CustomDeviceTensorHandle::DeviceType(Status* status) const {
const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
if (!status->ok()) {
return "";
}
return parsed->type.c_str();
}
int CustomDeviceTensorHandle::DeviceId(Status* status) const {
const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
if (!status->ok()) {
return 0;
}
return parsed->id;
}
AbstractTensorInterface* CustomDeviceTensorHandle::Resolve(Status* status) {
core::RefCountPtr<ImmediateExecutionTensorHandle> copied_off(
context_->CopyTensorHandleToDevice(
this, "/job:localhost/replica:0/task:0/device:CPU:0", status));
if (!status->ok()) {
return nullptr;
}
return copied_off->Resolve(status);
}
const DeviceNameUtils::ParsedName* CustomDeviceTensorHandle::ParsedName(
Status* status) const {
if (!parsed_name_.has_value()) {
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullOrLocalName(device_->name(), &parsed_name)) {
*status = errors::InvalidArgument(
absl::StrCat("Invalid custom device name ", device_->name()));
return nullptr;
}
parsed_name_.emplace(std::move(parsed_name));
}
return &*parsed_name_;
}
} // namespace tensorflow

View File

@ -0,0 +1,107 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class TensorHandle;
class EagerOperation;
// Custom devices intercept the execution of operations (the `Execute` method),
// typically implemented with one or more of the custom device's own executions.
class CustomDevice {
public:
virtual ~CustomDevice() {}
virtual const string& name() = 0;
virtual Status CopyTensorToDevice(TensorHandle* tensor,
TensorHandle** result) = 0;
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
const string& target_device_name,
TensorHandle** result) = 0;
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
int* num_retvals) = 0;
};
// Custom devices do many of the same things as physical Devices, but have a
// much more restricted interface. We pass around ambiguous pointers since
// operations may be placed either on custom or physical devices.
using VariantDevice = absl::variant<Device*, CustomDevice*>;
// A tensor handle produced by a custom device. Generally they can only be
// consumed by executing an operation on the same custom device that produced it
// originally, or by attempting to copy the handle off the custom device.
//
// TODO(allenl): Currently custom devices are tied to the eager C API. They
// should be renamed op handlers and subclass AbstractTensorHandle instead so
// they are eager/graph agnostic.
class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle {
public:
CustomDeviceTensorHandle(ImmediateExecutionContext* context,
CustomDevice* device, tensorflow::DataType dtype)
: ImmediateExecutionTensorHandle(kCustomDevice),
context_(context),
device_(device),
dtype_(dtype) {}
tensorflow::DataType DataType() const override { return dtype_; }
Status Shape(PartialTensorShape* shape) const override;
Status NumElements(int64* num_elements) const override;
const char* DeviceName(Status* status) const override {
return device_->name().c_str();
}
const char* BackingDeviceName(Status* status) const override {
return device_->name().c_str();
}
CustomDevice* device() const { return device_; }
const char* DeviceType(Status* status) const override;
int DeviceId(Status* status) const override;
AbstractTensorInterface* Resolve(Status* status) override;
ImmediateExecutionTensorHandle* Copy() override {
Ref();
return this;
}
void Release() override { Unref(); }
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kCustomDevice;
}
protected:
const DeviceNameUtils::ParsedName* ParsedName(Status* status) const;
ImmediateExecutionContext* const context_;
CustomDevice* const device_;
const tensorflow::DataType dtype_;
mutable absl::optional<DeviceNameUtils::ParsedName> parsed_name_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/device_factory.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class TestCustomDevice : public CustomDevice {
public:
explicit TestCustomDevice(std::string name) : name_(name) {}
const std::string& name() override { return name_; }
Status CopyTensorToDevice(TensorHandle* tensor,
TensorHandle** result) override {
tensor->Ref();
*result = tensor;
return Status::OK();
}
Status CopyTensorFromDevice(TensorHandle* tensor,
const std::string& target_device_name,
TensorHandle** result) override {
tensor->Ref();
*result = tensor;
return Status::OK();
}
Status Execute(const EagerOperation* op, TensorHandle** retvals,
int* num_retvals) override {
return errors::Unimplemented("Not implemented");
}
private:
std::string name_;
};
class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
public:
TestCustomDeviceTensorHandle(ImmediateExecutionContext* context,
TestCustomDevice* device,
tensorflow::DataType dtype)
: CustomDeviceTensorHandle(context, device, dtype) {}
Status NumDims(int* num_dims) const override {
*num_dims = 1;
return Status::OK();
}
Status Dim(int dim_index, int64* dim) const override {
if (dim_index == 0) {
*dim = 3;
return Status::OK();
} else {
return errors::Internal("Dim out of bounds");
}
}
};
TEST(CustomDevice, TestTensorHandle) {
StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
core::RefCountPtr<EagerContext> ctx(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
false, &device_mgr, false, nullptr, nullptr));
std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
TestCustomDevice device(device_name);
core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT));
Status s;
std::string device_type = tensor->DeviceType(&s);
ASSERT_TRUE(s.ok()) << s.error_message();
EXPECT_EQ("CUSTOM", device_type);
int device_index = tensor->DeviceId(&s);
ASSERT_TRUE(s.ok()) << s.error_message();
EXPECT_EQ(15, device_index);
int64 num_elements = 0;
s = tensor->NumElements(&num_elements);
ASSERT_TRUE(s.ok()) << s.error_message();
EXPECT_EQ(3, num_elements);
}
} // namespace
} // namespace tensorflow