Add a custom device tensor handle type
This will replace TensorHandles holding custom device tensors, which will clean up the runtime a bit, provide flexibility for custom devices (especially for shape handling and copies), and morph custom devices a bit more toward op handlers. I'll do the migration itself in a followup. For now I'm leaving the naming as "custom device" for consistency. Once custom devices are migrated to this type it will be easier to bump the abstraction level a bit (AbstractTensorHandles instead of ImmediateExecutionTensorHandles) and re-name things. PiperOrigin-RevId: 346871040 Change-Id: Ic1d4791b27d37b1078bd985be344c54748ae49b8
This commit is contained in:
parent
bd9ffecdb6
commit
0796a95989
@ -27,7 +27,7 @@ namespace tensorflow {
|
||||
// execution mode.
|
||||
class AbstractTensorHandle : public core::RefCounted {
|
||||
protected:
|
||||
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
|
||||
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice };
|
||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractTensorHandle() {}
|
||||
|
||||
|
@ -77,6 +77,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":eager_executor",
|
||||
":kernel_and_device",
|
||||
":custom_device",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
@ -110,6 +111,39 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "custom_device",
|
||||
srcs = ["custom_device.cc"],
|
||||
hdrs = ["custom_device.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
srcs = ["custom_device_test.cc"],
|
||||
deps = [
|
||||
":context",
|
||||
":core",
|
||||
":custom_device",
|
||||
":tensor_handle",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "context_distributed_manager",
|
||||
srcs = [
|
||||
@ -665,6 +699,7 @@ filegroup(
|
||||
srcs = [
|
||||
"attr_builder.h",
|
||||
"context.h",
|
||||
"custom_device.h",
|
||||
"eager_executor.h",
|
||||
"eager_operation.h",
|
||||
"kernel_and_device.h",
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -81,26 +82,6 @@ class RemoteMgr;
|
||||
class TensorHandle;
|
||||
class EagerOperation;
|
||||
|
||||
class CustomDevice {
|
||||
public:
|
||||
virtual ~CustomDevice() {}
|
||||
virtual const string& name() = 0;
|
||||
virtual Status CopyTensorToDevice(TensorHandle* tensor,
|
||||
TensorHandle** result) = 0;
|
||||
|
||||
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||
const string& target_device_name,
|
||||
TensorHandle** result) = 0;
|
||||
|
||||
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals) = 0;
|
||||
};
|
||||
|
||||
// Custom devices do many of the same things as physical Devices, but have a
|
||||
// much more restricted interface. We pass around ambiguous pointers since
|
||||
// TensorHandles may be placed either on custom or physical devices.
|
||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||
|
||||
class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
public:
|
||||
static constexpr uint64 kInvalidContextId = 0;
|
||||
|
82
tensorflow/core/common_runtime/eager/custom_device.cc
Normal file
82
tensorflow/core/common_runtime/eager/custom_device.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const {
|
||||
int num_dims;
|
||||
TF_RETURN_IF_ERROR(NumDims(&num_dims));
|
||||
std::vector<int64> dims(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
TF_RETURN_IF_ERROR(Dim(i, &dims[i]));
|
||||
}
|
||||
return PartialTensorShape::MakePartialShape(dims.data(), num_dims, shape);
|
||||
}
|
||||
|
||||
Status CustomDeviceTensorHandle::NumElements(int64* num_elements) const {
|
||||
*num_elements = 1;
|
||||
int num_dims;
|
||||
TF_RETURN_IF_ERROR(NumDims(&num_dims));
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
int64 dim;
|
||||
TF_RETURN_IF_ERROR(Dim(i, &dim));
|
||||
*num_elements *= dim;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const char* CustomDeviceTensorHandle::DeviceType(Status* status) const {
|
||||
const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
|
||||
if (!status->ok()) {
|
||||
return "";
|
||||
}
|
||||
return parsed->type.c_str();
|
||||
}
|
||||
|
||||
int CustomDeviceTensorHandle::DeviceId(Status* status) const {
|
||||
const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
|
||||
if (!status->ok()) {
|
||||
return 0;
|
||||
}
|
||||
return parsed->id;
|
||||
}
|
||||
|
||||
AbstractTensorInterface* CustomDeviceTensorHandle::Resolve(Status* status) {
|
||||
core::RefCountPtr<ImmediateExecutionTensorHandle> copied_off(
|
||||
context_->CopyTensorHandleToDevice(
|
||||
this, "/job:localhost/replica:0/task:0/device:CPU:0", status));
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return copied_off->Resolve(status);
|
||||
}
|
||||
|
||||
const DeviceNameUtils::ParsedName* CustomDeviceTensorHandle::ParsedName(
|
||||
Status* status) const {
|
||||
if (!parsed_name_.has_value()) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!DeviceNameUtils::ParseFullOrLocalName(device_->name(), &parsed_name)) {
|
||||
*status = errors::InvalidArgument(
|
||||
absl::StrCat("Invalid custom device name ", device_->name()));
|
||||
return nullptr;
|
||||
}
|
||||
parsed_name_.emplace(std::move(parsed_name));
|
||||
}
|
||||
return &*parsed_name_;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
107
tensorflow/core/common_runtime/eager/custom_device.h
Normal file
107
tensorflow/core/common_runtime/eager/custom_device.h
Normal file
@ -0,0 +1,107 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorHandle;
|
||||
class EagerOperation;
|
||||
|
||||
// Custom devices intercept the execution of operations (the `Execute` method),
|
||||
// typically implemented with one or more of the custom device's own executions.
|
||||
class CustomDevice {
|
||||
public:
|
||||
virtual ~CustomDevice() {}
|
||||
virtual const string& name() = 0;
|
||||
virtual Status CopyTensorToDevice(TensorHandle* tensor,
|
||||
TensorHandle** result) = 0;
|
||||
|
||||
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||
const string& target_device_name,
|
||||
TensorHandle** result) = 0;
|
||||
|
||||
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals) = 0;
|
||||
};
|
||||
|
||||
// Custom devices do many of the same things as physical Devices, but have a
|
||||
// much more restricted interface. We pass around ambiguous pointers since
|
||||
// operations may be placed either on custom or physical devices.
|
||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||
|
||||
// A tensor handle produced by a custom device. Generally they can only be
|
||||
// consumed by executing an operation on the same custom device that produced it
|
||||
// originally, or by attempting to copy the handle off the custom device.
|
||||
//
|
||||
// TODO(allenl): Currently custom devices are tied to the eager C API. They
|
||||
// should be renamed op handlers and subclass AbstractTensorHandle instead so
|
||||
// they are eager/graph agnostic.
|
||||
class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle {
|
||||
public:
|
||||
CustomDeviceTensorHandle(ImmediateExecutionContext* context,
|
||||
CustomDevice* device, tensorflow::DataType dtype)
|
||||
: ImmediateExecutionTensorHandle(kCustomDevice),
|
||||
context_(context),
|
||||
device_(device),
|
||||
dtype_(dtype) {}
|
||||
|
||||
tensorflow::DataType DataType() const override { return dtype_; }
|
||||
Status Shape(PartialTensorShape* shape) const override;
|
||||
Status NumElements(int64* num_elements) const override;
|
||||
|
||||
const char* DeviceName(Status* status) const override {
|
||||
return device_->name().c_str();
|
||||
}
|
||||
const char* BackingDeviceName(Status* status) const override {
|
||||
return device_->name().c_str();
|
||||
}
|
||||
CustomDevice* device() const { return device_; }
|
||||
const char* DeviceType(Status* status) const override;
|
||||
int DeviceId(Status* status) const override;
|
||||
|
||||
AbstractTensorInterface* Resolve(Status* status) override;
|
||||
|
||||
ImmediateExecutionTensorHandle* Copy() override {
|
||||
Ref();
|
||||
return this;
|
||||
}
|
||||
void Release() override { Unref(); }
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kCustomDevice;
|
||||
}
|
||||
|
||||
protected:
|
||||
const DeviceNameUtils::ParsedName* ParsedName(Status* status) const;
|
||||
|
||||
ImmediateExecutionContext* const context_;
|
||||
CustomDevice* const device_;
|
||||
const tensorflow::DataType dtype_;
|
||||
|
||||
mutable absl::optional<DeviceNameUtils::ParsedName> parsed_name_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
|
99
tensorflow/core/common_runtime/eager/custom_device_test.cc
Normal file
99
tensorflow/core/common_runtime/eager/custom_device_test.cc
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/device_factory.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class TestCustomDevice : public CustomDevice {
|
||||
public:
|
||||
explicit TestCustomDevice(std::string name) : name_(name) {}
|
||||
const std::string& name() override { return name_; }
|
||||
Status CopyTensorToDevice(TensorHandle* tensor,
|
||||
TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
*result = tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||
const std::string& target_device_name,
|
||||
TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
*result = tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
return errors::Unimplemented("Not implemented");
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
|
||||
public:
|
||||
TestCustomDeviceTensorHandle(ImmediateExecutionContext* context,
|
||||
TestCustomDevice* device,
|
||||
tensorflow::DataType dtype)
|
||||
: CustomDeviceTensorHandle(context, device, dtype) {}
|
||||
|
||||
Status NumDims(int* num_dims) const override {
|
||||
*num_dims = 1;
|
||||
return Status::OK();
|
||||
}
|
||||
Status Dim(int dim_index, int64* dim) const override {
|
||||
if (dim_index == 0) {
|
||||
*dim = 3;
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal("Dim out of bounds");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST(CustomDevice, TestTensorHandle) {
|
||||
StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
core::RefCountPtr<EagerContext> ctx(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||
false, &device_mgr, false, nullptr, nullptr));
|
||||
std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
|
||||
TestCustomDevice device(device_name);
|
||||
core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
|
||||
new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT));
|
||||
Status s;
|
||||
std::string device_type = tensor->DeviceType(&s);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
EXPECT_EQ("CUSTOM", device_type);
|
||||
int device_index = tensor->DeviceId(&s);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
EXPECT_EQ(15, device_index);
|
||||
int64 num_elements = 0;
|
||||
s = tensor->NumElements(&num_elements);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
EXPECT_EQ(3, num_elements);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user