Merge code from PR with internal changes from cl/164796436, and update Python tests to also run on GPU.

PiperOrigin-RevId: 164929133
This commit is contained in:
A. Unique TensorFlower 2017-08-10 17:55:10 -07:00 committed by TensorFlower Gardener
parent 9fba8c1851
commit e2a163a905
4 changed files with 251 additions and 27 deletions

View File

@ -1573,6 +1573,10 @@ tf_kernel_library(
tf_kernel_library(
name = "dynamic_stitch_op",
gpu_srcs = [
"cuda_device_array.h",
"cuda_device_array_gpu.h",
],
prefix = "dynamic_stitch_op",
deps = DYNAMIC_DEPS,
)

View File

@ -21,8 +21,17 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/threadpool.h"
#ifdef GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_device_array.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
#ifdef GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
template <class T>
class DynamicStitchOpImplBase : public OpKernel {
public:
@ -66,17 +75,24 @@ class DynamicStitchOpImplBase : public OpKernel {
void CheckArgsAndAllocateResult(OpKernelContext* c,
OpInputList* indices_inputs,
OpInputList* data_inputs, int* first_dim_size,
int* data_elements_size,
Tensor** result_ptr) {
// Find maximum index in the indices vectors
OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
int32 max_index = -1;
if (data_elements_size) {
*data_elements_size = 0;
}
for (const Tensor& indices : *indices_inputs) {
if (indices.NumElements() > 0) {
Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
indices.flat<int32>().maximum();
max_index = std::max(m(), max_index);
}
if (data_elements_size) {
*data_elements_size += indices.NumElements();
}
}
*first_dim_size = max_index + 1;
@ -90,18 +106,19 @@ class DynamicStitchOpImplBase : public OpKernel {
const Tensor& data = (*data_inputs)[input_num];
OP_REQUIRES(
c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
errors::InvalidArgument("data[", input_num, "].shape = ",
data.shape().DebugString(),
errors::InvalidArgument("data[", input_num,
"].shape = ", data.shape().DebugString(),
" does not start with indices[", input_num,
"].shape = ", indices.shape().DebugString()));
OP_REQUIRES(
c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
errors::InvalidArgument(
"Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
"].shape[", indices.dims(), ":], got data[0].shape = ",
data0.shape().DebugString(), ", data[", input_num, "].shape = ",
data.shape().DebugString(), ", indices[0].shape = ",
indices0.shape().DebugString(), ", indices[", input_num,
"].shape[", indices.dims(),
":], got data[0].shape = ", data0.shape().DebugString(),
", data[", input_num, "].shape = ", data.shape().DebugString(),
", indices[0].shape = ", indices0.shape().DebugString(),
", indices[", input_num,
"].shape = ", indices.shape().DebugString()));
}
@ -116,10 +133,90 @@ class DynamicStitchOpImplBase : public OpKernel {
}
};
template <class T, bool Parallel>
class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
#if GOOGLE_CUDA
template <typename T>
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
const int32 slice_size, const int32 first_dim_size,
const CudaDeviceArrayStruct<int>& input_indices,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
T* output);
template <class T>
class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
public:
explicit DynamicStitchOpImpl(OpKernelConstruction* c)
explicit DynamicStitchOpGPU(OpKernelConstruction* c)
: DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
void Compute(OpKernelContext* c) override {
OpInputList indices_inputs;
OpInputList data_inputs;
int first_dim_size;
int data_elements_size;
Tensor* merged = nullptr;
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
&first_dim_size, &data_elements_size,
&merged);
if (!c->status().ok()) {
// Avoid segmentation faults if merged cannot be allocated and an error is
// passed back in the context.
return;
}
// TODO(jeff): Currently we leave uninitialized any portions of
// merged that aren't covered by an index in indices. What should we do?
if (first_dim_size > 0) {
// because the collision requirements, we have to deal with
// collion first before send data to gpu kernel.
// TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
// last of duplicated indices, it could instead be done of the GPU
// implicitly using atomics to make sure the last index is the final
// write.
const int slice_size = merged->flat_outer_dims<T>().dimension(1);
CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
OP_REQUIRES_OK(c, indices_flat.Init());
OP_REQUIRES_OK(c, data_flat.Init());
// initialize the indices_flat (-1 represents missing indices)
for (int i = 0; i < first_dim_size; ++i) {
indices_flat.Set(i, -1);
}
// data_flat index
int32 idx = 0;
// sum of indices_inputs[i].NumElements() for compute indicies_flat value.
int32 base_size = 0;
for (int i = 0; i < indices_inputs.size(); ++i) {
auto indices_vec = indices_inputs[i].flat<int32>();
auto data_ptr_base = data_inputs[i].template flat<T>().data();
for (int j = 0; j < indices_vec.size(); ++j) {
// indices_flat's indices represent the indices of output.
// indices_flat's values represent the indices of input_data where the
// data located.
indices_flat.Set(indices_vec(j), base_size + j);
data_flat.Set(
idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
j * slice_size));
++idx;
}
base_size += indices_vec.size();
}
OP_REQUIRES_OK(c, indices_flat.Finalize());
OP_REQUIRES_OK(c, data_flat.Finalize());
auto output = merged->template flat<T>().data();
DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
indices_flat.data(), data_flat.data(), output);
}
}
};
#endif // GOOGLE_CUDA
template <class T, bool Parallel>
class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
public:
explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
: DynamicStitchOpImplBase<T>(
c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
@ -129,7 +226,7 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
int first_dim_size;
Tensor* merged = nullptr;
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
&first_dim_size, &merged);
&first_dim_size, nullptr, &merged);
if (!c->status().ok()) {
// Avoid segmentation faults if merged cannot be allocated and an error is
// passed back in the context.
@ -207,13 +304,13 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
// functionality later.
template <typename T>
struct DynamicStitchOp : DynamicStitchOpImpl<T, false> {
using DynamicStitchOpImpl<T, false>::DynamicStitchOpImpl;
struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
};
template <typename T>
struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
using DynamicStitchOpImpl<T, true>::DynamicStitchOpImpl;
struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
};
#define REGISTER_DYNAMIC_STITCH(type) \
@ -221,12 +318,12 @@ struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("indices"), \
DynamicStitchOp<type>) \
DynamicStitchOpCPU<type>) \
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("indices"), \
ParallelDynamicStitchOp<type>)
ParallelDynamicStitchOpCPU<type>)
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
#undef REGISTER_DYNAMIC_STITCH
@ -236,19 +333,21 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("indices") \
.HostMemory("data") \
.HostMemory("merged"), \
DynamicStitchOp<type>) \
.HostMemory("indices"), \
DynamicStitchOpGPU<type>) \
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("indices") \
.HostMemory("data") \
.HostMemory("merged"), \
ParallelDynamicStitchOp<type>)
ParallelDynamicStitchOpCPU<type>)
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
#undef REGISTER_DYNAMIC_STITCH_GPU
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
using GPUDevice = Eigen::GpuDevice;
namespace {
template <typename T>
__global__ void DynamicStitchKernel(const int32 slice_size,
const int32 output_size,
CudaDeviceArrayStruct<int32> input_indices,
CudaDeviceArrayStruct<const T*> input_ptrs,
T* output) {
int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices);
const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs);
CUDA_1D_KERNEL_LOOP(output_index, output_size) {
const int32 slice_id = output_index / slice_size;
const int32 slice_offset = output_index % slice_size;
const int32 input_index = data_indices[slice_id];
if (input_index != -1) {
output[output_index] = ldg(data_ptrs[input_index] + slice_offset);
}
}
}
} // namespace
template <typename T>
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
const int32 slice_size, const int32 first_dim_size,
const CudaDeviceArrayStruct<int>& input_indices,
const CudaDeviceArrayStruct<const T*>& input_ptrs,
T* output) {
const int32 output_size = first_dim_size * slice_size;
auto config = GetCudaLaunchConfig(output_size, gpu_device);
DynamicStitchKernel<T>
<<<config.block_count, config.thread_per_block, 0, gpu_device.stream()>>>(
slice_size, output_size, input_indices, input_ptrs, output);
}
#define REGISTER_GPU(T) \
template void DynamicStitchGPUImpl( \
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
const int32 first_dim_size, \
const CudaDeviceArrayStruct<int32>& input_indices, \
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
TF_CALL_int64(REGISTER_GPU);
TF_CALL_int32(REGISTER_GPU)
#undef REGISTER_GPU
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
@ -33,7 +34,7 @@ class DynamicStitchTestBase(object):
self.stitch_op = stitch_op
def testScalar(self):
with self.test_session():
with self.test_session(use_gpu=True):
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40), constant_op.constant(60)]
for step in -1, 1:
@ -46,7 +47,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None], stitched_t.get_shape().as_list())
def testSimpleOneDimensional(self):
with self.test_session():
with self.test_session(use_gpu=True):
indices = [
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
]
@ -63,7 +64,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None], stitched_t.get_shape().as_list())
def testOneListOneDimensional(self):
with self.test_session():
with self.test_session(use_gpu=True):
indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
stitched_t = self.stitch_op(indices, data)
@ -75,7 +76,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None], stitched_t.get_shape().as_list())
def testSimpleTwoDimensional(self):
with self.test_session():
with self.test_session(use_gpu=True):
indices = [
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
constant_op.constant([2, 3, 5])
@ -95,7 +96,7 @@ class DynamicStitchTestBase(object):
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
def testHigherRank(self):
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
indices = [
constant_op.constant(6), constant_op.constant([4, 1]),
constant_op.constant([[5, 2], [0, 3]])
@ -176,6 +177,45 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
test.TestCase.__init__(self, *test_case_args)
DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch)
def testScalar(self):
with self.test_session(use_gpu=True):
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
for step in -1, 1:
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
stitched_val = stitched_t.eval()
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
# Dimension 0 is determined by the max index in indices, so we
# can only infer that the output is a vector of some unknown
# length.
self.assertEqual([None], stitched_t.get_shape().as_list())
def testHigherRank(self):
with self.test_session(use_gpu=True) as sess:
indices = [
constant_op.constant(6),
constant_op.constant([4, 1]),
constant_op.constant([[5, 2], [0, 3]])
]
data = [
constant_op.constant([61, 62], dtype=dtypes.float32),
constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
constant_op.constant(
[[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
]
stitched_t = data_flow_ops.dynamic_stitch(indices, data)
stitched_val = stitched_t.eval()
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
self.assertAllEqual(correct, stitched_val)
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
stitched_grad)
self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
for datum, grad in zip(data, sess.run(grads[3:])):
self.assertAllEqual(7.0 * datum.eval(), grad)
if __name__ == "__main__":
test.main()