Add a function in DeviceBase to deep copy on the device.

This change adds support for a CPU and GPU deep copy function, accessible from
the DeviceBase class.  Other devices are not supported for now.

This new functionality is intended for TensorFlow internal runtime use only.

PiperOrigin-RevId: 243263305
This commit is contained in:
Ayush Dubey 2019-04-12 07:53:06 -07:00 committed by TensorFlower Gardener
parent 8fd880bca0
commit d1db9860a2
12 changed files with 227 additions and 11 deletions

View File

@ -3900,6 +3900,7 @@ tf_cc_tests(
"common_runtime/pending_counts_test.cc", "common_runtime/pending_counts_test.cc",
"common_runtime/placer_test.cc", "common_runtime/placer_test.cc",
"common_runtime/session_test.cc", "common_runtime/session_test.cc",
"common_runtime/threadpool_device_test.cc",
"example/feature_util_test.cc", "example/feature_util_test.cc",
"framework/allocator_test.cc", "framework/allocator_test.cc",
"framework/attr_value_util_test.cc", "framework/attr_value_util_test.cc",

View File

@ -745,6 +745,14 @@ Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
} }
} }
void BaseGPUDevice::CopyTensorInSameDevice(const Tensor* input_tensor,
Tensor* output_tensor,
const DeviceContext* device_context,
StatusCallback done) {
GPUUtil::CopyGPUTensorToSameGPU(static_cast<Device*>(this), device_context,
input_tensor, output_tensor, std::move(done));
}
namespace { namespace {
class ConcretePerOpGpuDevice : public PerOpGpuDevice { class ConcretePerOpGpuDevice : public PerOpGpuDevice {
public: public:

View File

@ -90,6 +90,10 @@ class BaseGPUDevice : public LocalDevice {
const AllocatorAttributes alloc_attrs, const AllocatorAttributes alloc_attrs,
Tensor* tensor) override; Tensor* tensor) override;
void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
const DeviceContext* device_context,
StatusCallback done) override;
// The caller owns the returned device. // The caller owns the returned device.
PerOpGpuDevice* MakeGpuDevice() override; PerOpGpuDevice* MakeGpuDevice() override;

View File

@ -85,6 +85,36 @@ class GPUDeviceTest : public ::testing::Test {
} }
return options; return options;
} }
void InitCPUTensor(Tensor* cpu_tensor, int num_elements, float value) {
auto tensor = cpu_tensor->tensor<float, 1>();
for (int i = 0; i < num_elements; ++i) {
tensor(i) = value;
}
}
void CopyCPUToGPU(Tensor* cpu_tensor, Tensor* gpu_tensor, Device* device,
DeviceContext* device_context) {
Notification note;
device_context->CopyCPUTensorToDevice(cpu_tensor, device, gpu_tensor,
[&note](const Status& s) {
TF_ASSERT_OK(s);
note.Notify();
});
note.WaitForNotification();
}
void CopyGPUToCPU(Tensor* gpu_tensor, Tensor* cpu_tensor, Device* device,
DeviceContext* device_context) {
Notification note;
device_context->CopyDeviceTensorToCPU(gpu_tensor, /*tensor_name=*/"",
device, cpu_tensor,
[&note](const Status& s) {
TF_ASSERT_OK(s);
note.Notify();
});
note.WaitForNotification();
}
}; };
TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) { TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
@ -277,6 +307,45 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
allocator->DeallocateRaw(ptr); allocator->DeallocateRaw(ptr);
} }
TEST_F(GPUDeviceTest, CopyTensorInSameDevice) {
SessionOptions opts = MakeSessionOptions("0");
std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
Device* device = devices[0].get();
auto* device_info = device->tensorflow_gpu_device_info();
CHECK(device_info);
DeviceContext* device_context = device_info->default_context;
Allocator* allocator = device->GetAllocator(AllocatorAttributes());
constexpr int kNumElements = 4;
Tensor input_tensor(allocator, DT_FLOAT, TensorShape({kNumElements}));
Tensor output_tensor(allocator, DT_FLOAT, TensorShape({kNumElements}));
Tensor cpu_tensor(cpu_allocator(), DT_FLOAT, TensorShape({kNumElements}));
// Initialize input as {1, 1, 1, 1} and output as {0, 0, 0, 0}. After copy,
// both should become {1, 1, 1, 1}.
InitCPUTensor(&cpu_tensor, kNumElements, 0);
CopyCPUToGPU(&cpu_tensor, &output_tensor, device, device_context);
InitCPUTensor(&cpu_tensor, kNumElements, 1);
CopyCPUToGPU(&cpu_tensor, &input_tensor, device, device_context);
Notification note;
device->CopyTensorInSameDevice(&input_tensor, &output_tensor, device_context,
[&note](const Status& s) {
TF_ASSERT_OK(s);
note.Notify();
});
note.WaitForNotification();
Tensor output_cpu_tensor(cpu_allocator(), DT_FLOAT,
TensorShape({kNumElements}));
CopyGPUToCPU(&output_tensor, &output_cpu_tensor, device, device_context);
auto input = cpu_tensor.tensor<float, 1>();
auto output = output_cpu_tensor.tensor<float, 1>();
for (int i = 0; i < kNumElements; ++i) {
EXPECT_EQ(input(i), output(i)) << " for index " << i;
}
}
class GPUKernelTrackerTest : public ::testing::Test { class GPUKernelTrackerTest : public ::testing::Test {
protected: protected:
void SetUp() { void SetUp() {

View File

@ -94,6 +94,13 @@ class RenamedDevice : public Device {
return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor); return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
} }
void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
const DeviceContext* device_context,
StatusCallback done) override {
underlying_->CopyTensorInSameDevice(input_tensor, output_tensor,
device_context, std::move(done));
}
// Below are virtual methods defined on Device // Below are virtual methods defined on Device
void Compute(OpKernel* op_kernel, OpKernelContext* context) override { void Compute(OpKernel* op_kernel, OpKernelContext* context) override {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow { namespace tensorflow {
@ -69,6 +70,19 @@ class SingleThreadedCpuDevice : public Device {
return Status::OK(); return Status::OK();
} }
void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
const DeviceContext*,
StatusCallback done) override {
if (input_tensor->NumElements() != output_tensor->NumElements()) {
done(errors::Internal(
"SingleThreadedCPU->SingleThreadedCPU copy shape mismatch: input=",
input_tensor->shape(), ", output=", output_tensor->shape()));
return;
}
tensor::DeepCopy(*input_tensor, output_tensor);
done(Status::OK());
}
Allocator* GetAllocator(AllocatorAttributes attr) override { Allocator* GetAllocator(AllocatorAttributes attr) override {
return cpu_allocator(); return cpu_allocator();
} }

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/tensor.pb_text.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/types.h" #include "tensorflow/core/graph/types.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
@ -102,6 +103,19 @@ Status ThreadPoolDevice::MakeTensorFromProto(
ProtoDebugString(tensor_proto)); ProtoDebugString(tensor_proto));
} }
void ThreadPoolDevice::CopyTensorInSameDevice(
const Tensor* input_tensor, Tensor* output_tensor,
const DeviceContext* device_context, StatusCallback done) {
if (input_tensor->NumElements() != output_tensor->NumElements()) {
done(errors::Internal(
"CPU->CPU copy shape mismatch: input=", input_tensor->shape(),
", output=", output_tensor->shape()));
return;
}
tensor::DeepCopy(*input_tensor, output_tensor);
done(Status::OK());
}
#ifdef INTEL_MKL #ifdef INTEL_MKL
namespace { namespace {
class MklCPUAllocatorFactory : public AllocatorFactory { class MklCPUAllocatorFactory : public AllocatorFactory {

View File

@ -38,6 +38,9 @@ class ThreadPoolDevice : public LocalDevice {
Status MakeTensorFromProto(const TensorProto& tensor_proto, Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs, const AllocatorAttributes alloc_attrs,
Tensor* tensor) override; Tensor* tensor) override;
void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
const DeviceContext* device_context,
StatusCallback done) override;
Status Sync() override { return Status::OK(); } Status Sync() override { return Status::OK(); }

View File

@ -0,0 +1,72 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
const int kDimSize = 2;
void InitTensor(Tensor* tensor, float value) {
auto eigen_tensor = tensor->tensor<float, kDimSize>();
for (int i = 0; i < kDimSize; ++i) {
for (int j = 0; j < kDimSize; ++j) {
eigen_tensor(i, j) = value;
}
}
}
bool Equal(const Tensor& tensor1, const Tensor& tensor2) {
auto eigen_tensor1 = tensor1.tensor<float, kDimSize>();
auto eigen_tensor2 = tensor2.tensor<float, kDimSize>();
for (int i = 0; i < kDimSize; ++i) {
for (int j = 0; j < kDimSize; ++j) {
if (eigen_tensor1(i, j) != eigen_tensor2(i, j)) {
return false;
}
}
}
return true;
}
TEST(ThreadPoolDeviceTest, CopyTensor) {
Tensor input(DT_FLOAT, TensorShape({kDimSize, kDimSize}));
Tensor output(DT_FLOAT, TensorShape({kDimSize, kDimSize}));
InitTensor(&input, 1);
InitTensor(&output, 0);
ASSERT_FALSE(Equal(input, output));
ThreadPoolDevice device(SessionOptions(), "/device:CPU:0", Bytes(256),
DeviceLocality(), cpu_allocator());
DeviceContext* device_context = new DeviceContext;
Notification note;
device.CopyTensorInSameDevice(&input, &output, device_context,
[&note](const Status& s) {
TF_ASSERT_OK(s);
note.Notify();
});
note.WaitForNotification();
ASSERT_TRUE(Equal(input, output));
device_context->Unref();
}
} // namespace
} // namespace tensorflow

View File

@ -255,6 +255,22 @@ class DeviceBase {
// be tagging deallocated memory chunks using the same counter. // be tagging deallocated memory chunks using the same counter.
virtual uint64 SafeAllocFrontier() { return 0; } virtual uint64 SafeAllocFrontier() { return 0; }
// Copies `input_tensor` to `output_tensor`, where both tensors are on this
// device. This function assumes that `output_tensor` has already been
// allocated with a buffer that is large enough to hold `input_tensor`'s data.
// Calls `done` from a device-specific thread after copy is finished, which
// may be the same as calling thread.
//
// NOTE(ayushd): This function is for TensorFlow internal use only. Deep copy
// is discouraged and should not be used in OpKernels.
virtual void CopyTensorInSameDevice(const Tensor* input_tensor,
Tensor* output_tensor,
const DeviceContext* device_context,
StatusCallback done) {
done(errors::Internal("Device ", name(), " does not implement ",
"CopyTensorInSameDevice"));
}
protected: protected:
// Does not take ownership. // Does not take ownership.
void set_tensorflow_device_thread_pool(thread::ThreadPool* thread_pool) { void set_tensorflow_device_thread_pool(thread::ThreadPool* thread_pool) {

View File

@ -31,24 +31,28 @@ namespace tensor {
Tensor DeepCopy(const Tensor& other) { Tensor DeepCopy(const Tensor& other) {
Tensor tmp = Tensor(other.dtype(), other.shape()); Tensor tmp = Tensor(other.dtype(), other.shape());
if (DataTypeCanUseMemcpy(other.dtype())) { DeepCopy(other, &tmp);
if (other.NumElements() > 0) { return tmp;
StringPiece other_data = other.tensor_data(); }
void DeepCopy(const Tensor& input, Tensor* output) {
if (DataTypeCanUseMemcpy(input.dtype())) {
if (input.NumElements() > 0) {
StringPiece input_data = input.tensor_data();
// We use StringPiece as a convenient map over the tensor buffer, // We use StringPiece as a convenient map over the tensor buffer,
// but we cast the type to get to the underlying buffer to do the // but we cast the type to get to the underlying buffer to do the
// copy. // copy.
StringPiece tmp_data = tmp.tensor_data(); StringPiece output_data = output->tensor_data();
memcpy(const_cast<char*>(tmp_data.data()), other_data.data(), memcpy(const_cast<char*>(output_data.data()), input_data.data(),
other_data.size()); input_data.size());
} }
} else if (other.dtype() == DT_STRING) { } else if (input.dtype() == DT_STRING) {
tmp.unaligned_flat<string>() = other.unaligned_flat<string>(); output->unaligned_flat<string>() = input.unaligned_flat<string>();
} else { } else {
CHECK_EQ(DT_VARIANT, other.dtype()); CHECK_EQ(DT_VARIANT, input.dtype());
tmp.unaligned_flat<Variant>() = other.unaligned_flat<Variant>(); output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
} }
return tmp;
} }
Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) { Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {

View File

@ -38,6 +38,10 @@ namespace tensor {
// 'other' is not appropriately memory-aligned. // 'other' is not appropriately memory-aligned.
Tensor DeepCopy(const Tensor& other); Tensor DeepCopy(const Tensor& other);
// Deep copies input to output. This function is similar to above, but assumes
// that the memory for the output has already been allocated.
void DeepCopy(const Tensor& input, Tensor* output);
// Concatenates 'tensors' into a single tensor, along their 0th dimension. // Concatenates 'tensors' into a single tensor, along their 0th dimension.
// //
// REQUIRES: All members of 'tensors' must have the same data type parameter. // REQUIRES: All members of 'tensors' must have the same data type parameter.