Add a custom device tensor handle type
This will replace TensorHandles holding custom device tensors, which will clean up the runtime a bit, provide flexibility for custom devices (especially for shape handling and copies), and morph custom devices a bit more toward op handlers. I'll do the migration itself in a followup. For now I'm leaving the naming as "custom device" for consistency. Once custom devices are migrated to this type it will be easier to bump the abstraction level a bit (AbstractTensorHandles instead of ImmediateExecutionTensorHandles) and re-name things. PiperOrigin-RevId: 346871040 Change-Id: Ic1d4791b27d37b1078bd985be344c54748ae49b8
This commit is contained in:
parent
bd9ffecdb6
commit
0796a95989
@ -27,7 +27,7 @@ namespace tensorflow {
|
|||||||
// execution mode.
|
// execution mode.
|
||||||
class AbstractTensorHandle : public core::RefCounted {
|
class AbstractTensorHandle : public core::RefCounted {
|
||||||
protected:
|
protected:
|
||||||
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
|
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice };
|
||||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||||
virtual ~AbstractTensorHandle() {}
|
virtual ~AbstractTensorHandle() {}
|
||||||
|
|
||||||
|
@ -77,6 +77,7 @@ tf_cuda_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":eager_executor",
|
":eager_executor",
|
||||||
":kernel_and_device",
|
":kernel_and_device",
|
||||||
|
":custom_device",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"//tensorflow/c:tf_tensor_internal",
|
"//tensorflow/c:tf_tensor_internal",
|
||||||
"//tensorflow/c/eager:immediate_execution_context",
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
@ -110,6 +111,39 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "custom_device",
|
||||||
|
srcs = ["custom_device.cc"],
|
||||||
|
hdrs = ["custom_device.h"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/core/lib/core:status",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "custom_device_test",
|
||||||
|
srcs = ["custom_device_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":context",
|
||||||
|
":core",
|
||||||
|
":custom_device",
|
||||||
|
":tensor_handle",
|
||||||
|
"//tensorflow/core:core_cpu_base",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "context_distributed_manager",
|
name = "context_distributed_manager",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -665,6 +699,7 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"attr_builder.h",
|
"attr_builder.h",
|
||||||
"context.h",
|
"context.h",
|
||||||
|
"custom_device.h",
|
||||||
"eager_executor.h",
|
"eager_executor.h",
|
||||||
"eager_operation.h",
|
"eager_operation.h",
|
||||||
"kernel_and_device.h",
|
"kernel_and_device.h",
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -81,26 +82,6 @@ class RemoteMgr;
|
|||||||
class TensorHandle;
|
class TensorHandle;
|
||||||
class EagerOperation;
|
class EagerOperation;
|
||||||
|
|
||||||
class CustomDevice {
|
|
||||||
public:
|
|
||||||
virtual ~CustomDevice() {}
|
|
||||||
virtual const string& name() = 0;
|
|
||||||
virtual Status CopyTensorToDevice(TensorHandle* tensor,
|
|
||||||
TensorHandle** result) = 0;
|
|
||||||
|
|
||||||
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
|
|
||||||
const string& target_device_name,
|
|
||||||
TensorHandle** result) = 0;
|
|
||||||
|
|
||||||
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
|
||||||
int* num_retvals) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Custom devices do many of the same things as physical Devices, but have a
|
|
||||||
// much more restricted interface. We pass around ambiguous pointers since
|
|
||||||
// TensorHandles may be placed either on custom or physical devices.
|
|
||||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
|
||||||
|
|
||||||
class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||||
public:
|
public:
|
||||||
static constexpr uint64 kInvalidContextId = 0;
|
static constexpr uint64 kInvalidContextId = 0;
|
||||||
|
82
tensorflow/core/common_runtime/eager/custom_device.cc
Normal file
82
tensorflow/core/common_runtime/eager/custom_device.cc
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const {
|
||||||
|
int num_dims;
|
||||||
|
TF_RETURN_IF_ERROR(NumDims(&num_dims));
|
||||||
|
std::vector<int64> dims(num_dims);
|
||||||
|
for (int i = 0; i < num_dims; ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(Dim(i, &dims[i]));
|
||||||
|
}
|
||||||
|
return PartialTensorShape::MakePartialShape(dims.data(), num_dims, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CustomDeviceTensorHandle::NumElements(int64* num_elements) const {
|
||||||
|
*num_elements = 1;
|
||||||
|
int num_dims;
|
||||||
|
TF_RETURN_IF_ERROR(NumDims(&num_dims));
|
||||||
|
for (int i = 0; i < num_dims; ++i) {
|
||||||
|
int64 dim;
|
||||||
|
TF_RETURN_IF_ERROR(Dim(i, &dim));
|
||||||
|
*num_elements *= dim;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* CustomDeviceTensorHandle::DeviceType(Status* status) const {
|
||||||
|
const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
|
||||||
|
if (!status->ok()) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return parsed->type.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
int CustomDeviceTensorHandle::DeviceId(Status* status) const {
|
||||||
|
const DeviceNameUtils::ParsedName* parsed = ParsedName(status);
|
||||||
|
if (!status->ok()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return parsed->id;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorInterface* CustomDeviceTensorHandle::Resolve(Status* status) {
|
||||||
|
core::RefCountPtr<ImmediateExecutionTensorHandle> copied_off(
|
||||||
|
context_->CopyTensorHandleToDevice(
|
||||||
|
this, "/job:localhost/replica:0/task:0/device:CPU:0", status));
|
||||||
|
if (!status->ok()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return copied_off->Resolve(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
const DeviceNameUtils::ParsedName* CustomDeviceTensorHandle::ParsedName(
|
||||||
|
Status* status) const {
|
||||||
|
if (!parsed_name_.has_value()) {
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
if (!DeviceNameUtils::ParseFullOrLocalName(device_->name(), &parsed_name)) {
|
||||||
|
*status = errors::InvalidArgument(
|
||||||
|
absl::StrCat("Invalid custom device name ", device_->name()));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
parsed_name_.emplace(std::move(parsed_name));
|
||||||
|
}
|
||||||
|
return &*parsed_name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
107
tensorflow/core/common_runtime/eager/custom_device.h
Normal file
107
tensorflow/core/common_runtime/eager/custom_device.h
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class TensorHandle;
|
||||||
|
class EagerOperation;
|
||||||
|
|
||||||
|
// Custom devices intercept the execution of operations (the `Execute` method),
|
||||||
|
// typically implemented with one or more of the custom device's own executions.
|
||||||
|
class CustomDevice {
|
||||||
|
public:
|
||||||
|
virtual ~CustomDevice() {}
|
||||||
|
virtual const string& name() = 0;
|
||||||
|
virtual Status CopyTensorToDevice(TensorHandle* tensor,
|
||||||
|
TensorHandle** result) = 0;
|
||||||
|
|
||||||
|
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||||
|
const string& target_device_name,
|
||||||
|
TensorHandle** result) = 0;
|
||||||
|
|
||||||
|
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||||
|
int* num_retvals) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Custom devices do many of the same things as physical Devices, but have a
|
||||||
|
// much more restricted interface. We pass around ambiguous pointers since
|
||||||
|
// operations may be placed either on custom or physical devices.
|
||||||
|
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||||
|
|
||||||
|
// A tensor handle produced by a custom device. Generally they can only be
|
||||||
|
// consumed by executing an operation on the same custom device that produced it
|
||||||
|
// originally, or by attempting to copy the handle off the custom device.
|
||||||
|
//
|
||||||
|
// TODO(allenl): Currently custom devices are tied to the eager C API. They
|
||||||
|
// should be renamed op handlers and subclass AbstractTensorHandle instead so
|
||||||
|
// they are eager/graph agnostic.
|
||||||
|
class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle {
|
||||||
|
public:
|
||||||
|
CustomDeviceTensorHandle(ImmediateExecutionContext* context,
|
||||||
|
CustomDevice* device, tensorflow::DataType dtype)
|
||||||
|
: ImmediateExecutionTensorHandle(kCustomDevice),
|
||||||
|
context_(context),
|
||||||
|
device_(device),
|
||||||
|
dtype_(dtype) {}
|
||||||
|
|
||||||
|
tensorflow::DataType DataType() const override { return dtype_; }
|
||||||
|
Status Shape(PartialTensorShape* shape) const override;
|
||||||
|
Status NumElements(int64* num_elements) const override;
|
||||||
|
|
||||||
|
const char* DeviceName(Status* status) const override {
|
||||||
|
return device_->name().c_str();
|
||||||
|
}
|
||||||
|
const char* BackingDeviceName(Status* status) const override {
|
||||||
|
return device_->name().c_str();
|
||||||
|
}
|
||||||
|
CustomDevice* device() const { return device_; }
|
||||||
|
const char* DeviceType(Status* status) const override;
|
||||||
|
int DeviceId(Status* status) const override;
|
||||||
|
|
||||||
|
AbstractTensorInterface* Resolve(Status* status) override;
|
||||||
|
|
||||||
|
ImmediateExecutionTensorHandle* Copy() override {
|
||||||
|
Ref();
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
void Release() override { Unref(); }
|
||||||
|
|
||||||
|
// For LLVM style RTTI.
|
||||||
|
static bool classof(const AbstractTensorHandle* ptr) {
|
||||||
|
return ptr->getKind() == kCustomDevice;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const DeviceNameUtils::ParsedName* ParsedName(Status* status) const;
|
||||||
|
|
||||||
|
ImmediateExecutionContext* const context_;
|
||||||
|
CustomDevice* const device_;
|
||||||
|
const tensorflow::DataType dtype_;
|
||||||
|
|
||||||
|
mutable absl::optional<DeviceNameUtils::ParsedName> parsed_name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_
|
99
tensorflow/core/common_runtime/eager/custom_device_test.cc
Normal file
99
tensorflow/core/common_runtime/eager/custom_device_test.cc
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
|
#include "tensorflow/core/framework/device_factory.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class TestCustomDevice : public CustomDevice {
|
||||||
|
public:
|
||||||
|
explicit TestCustomDevice(std::string name) : name_(name) {}
|
||||||
|
const std::string& name() override { return name_; }
|
||||||
|
Status CopyTensorToDevice(TensorHandle* tensor,
|
||||||
|
TensorHandle** result) override {
|
||||||
|
tensor->Ref();
|
||||||
|
*result = tensor;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||||
|
const std::string& target_device_name,
|
||||||
|
TensorHandle** result) override {
|
||||||
|
tensor->Ref();
|
||||||
|
*result = tensor;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||||
|
int* num_retvals) override {
|
||||||
|
return errors::Unimplemented("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
|
||||||
|
public:
|
||||||
|
TestCustomDeviceTensorHandle(ImmediateExecutionContext* context,
|
||||||
|
TestCustomDevice* device,
|
||||||
|
tensorflow::DataType dtype)
|
||||||
|
: CustomDeviceTensorHandle(context, device, dtype) {}
|
||||||
|
|
||||||
|
Status NumDims(int* num_dims) const override {
|
||||||
|
*num_dims = 1;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status Dim(int dim_index, int64* dim) const override {
|
||||||
|
if (dim_index == 0) {
|
||||||
|
*dim = 3;
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::Internal("Dim out of bounds");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(CustomDevice, TestTensorHandle) {
|
||||||
|
StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
|
||||||
|
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
|
core::RefCountPtr<EagerContext> ctx(new EagerContext(
|
||||||
|
SessionOptions(),
|
||||||
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
|
||||||
|
false, &device_mgr, false, nullptr, nullptr));
|
||||||
|
std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
|
||||||
|
TestCustomDevice device(device_name);
|
||||||
|
core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
|
||||||
|
new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT));
|
||||||
|
Status s;
|
||||||
|
std::string device_type = tensor->DeviceType(&s);
|
||||||
|
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||||
|
EXPECT_EQ("CUSTOM", device_type);
|
||||||
|
int device_index = tensor->DeviceId(&s);
|
||||||
|
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||||
|
EXPECT_EQ(15, device_index);
|
||||||
|
int64 num_elements = 0;
|
||||||
|
s = tensor->NumElements(&num_elements);
|
||||||
|
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||||
|
EXPECT_EQ(3, num_elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user