Merge code from PR #11940 with internal changes from cl/164796436, and update Python tests to also run on GPU.
PiperOrigin-RevId: 164929133
This commit is contained in:
parent
9fba8c1851
commit
e2a163a905
tensorflow
core/kernels
python/kernel_tests
@ -1573,6 +1573,10 @@ tf_kernel_library(
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dynamic_stitch_op",
|
||||
gpu_srcs = [
|
||||
"cuda_device_array.h",
|
||||
"cuda_device_array_gpu.h",
|
||||
],
|
||||
prefix = "dynamic_stitch_op",
|
||||
deps = DYNAMIC_DEPS,
|
||||
)
|
||||
|
@ -21,8 +21,17 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/cuda_device_array.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
#ifdef GOOGLE_CUDA
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
template <class T>
|
||||
class DynamicStitchOpImplBase : public OpKernel {
|
||||
public:
|
||||
@ -66,17 +75,24 @@ class DynamicStitchOpImplBase : public OpKernel {
|
||||
void CheckArgsAndAllocateResult(OpKernelContext* c,
|
||||
OpInputList* indices_inputs,
|
||||
OpInputList* data_inputs, int* first_dim_size,
|
||||
int* data_elements_size,
|
||||
Tensor** result_ptr) {
|
||||
// Find maximum index in the indices vectors
|
||||
OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
|
||||
|
||||
int32 max_index = -1;
|
||||
if (data_elements_size) {
|
||||
*data_elements_size = 0;
|
||||
}
|
||||
for (const Tensor& indices : *indices_inputs) {
|
||||
if (indices.NumElements() > 0) {
|
||||
Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
|
||||
indices.flat<int32>().maximum();
|
||||
max_index = std::max(m(), max_index);
|
||||
}
|
||||
if (data_elements_size) {
|
||||
*data_elements_size += indices.NumElements();
|
||||
}
|
||||
}
|
||||
|
||||
*first_dim_size = max_index + 1;
|
||||
@ -90,18 +106,19 @@ class DynamicStitchOpImplBase : public OpKernel {
|
||||
const Tensor& data = (*data_inputs)[input_num];
|
||||
OP_REQUIRES(
|
||||
c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
|
||||
errors::InvalidArgument("data[", input_num, "].shape = ",
|
||||
data.shape().DebugString(),
|
||||
errors::InvalidArgument("data[", input_num,
|
||||
"].shape = ", data.shape().DebugString(),
|
||||
" does not start with indices[", input_num,
|
||||
"].shape = ", indices.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
|
||||
errors::InvalidArgument(
|
||||
"Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
|
||||
"].shape[", indices.dims(), ":], got data[0].shape = ",
|
||||
data0.shape().DebugString(), ", data[", input_num, "].shape = ",
|
||||
data.shape().DebugString(), ", indices[0].shape = ",
|
||||
indices0.shape().DebugString(), ", indices[", input_num,
|
||||
"].shape[", indices.dims(),
|
||||
":], got data[0].shape = ", data0.shape().DebugString(),
|
||||
", data[", input_num, "].shape = ", data.shape().DebugString(),
|
||||
", indices[0].shape = ", indices0.shape().DebugString(),
|
||||
", indices[", input_num,
|
||||
"].shape = ", indices.shape().DebugString()));
|
||||
}
|
||||
|
||||
@ -116,10 +133,90 @@ class DynamicStitchOpImplBase : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, bool Parallel>
|
||||
class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
template <typename T>
|
||||
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
const int32 slice_size, const int32 first_dim_size,
|
||||
const CudaDeviceArrayStruct<int>& input_indices,
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||
T* output);
|
||||
|
||||
template <class T>
|
||||
class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
|
||||
public:
|
||||
explicit DynamicStitchOpImpl(OpKernelConstruction* c)
|
||||
explicit DynamicStitchOpGPU(OpKernelConstruction* c)
|
||||
: DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
OpInputList indices_inputs;
|
||||
OpInputList data_inputs;
|
||||
int first_dim_size;
|
||||
int data_elements_size;
|
||||
Tensor* merged = nullptr;
|
||||
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
|
||||
&first_dim_size, &data_elements_size,
|
||||
&merged);
|
||||
if (!c->status().ok()) {
|
||||
// Avoid segmentation faults if merged cannot be allocated and an error is
|
||||
// passed back in the context.
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(jeff): Currently we leave uninitialized any portions of
|
||||
// merged that aren't covered by an index in indices. What should we do?
|
||||
if (first_dim_size > 0) {
|
||||
// because the collision requirements, we have to deal with
|
||||
// collion first before send data to gpu kernel.
|
||||
// TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
|
||||
// last of duplicated indices, it could instead be done of the GPU
|
||||
// implicitly using atomics to make sure the last index is the final
|
||||
// write.
|
||||
const int slice_size = merged->flat_outer_dims<T>().dimension(1);
|
||||
CudaDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
|
||||
CudaDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
|
||||
OP_REQUIRES_OK(c, indices_flat.Init());
|
||||
OP_REQUIRES_OK(c, data_flat.Init());
|
||||
// initialize the indices_flat (-1 represents missing indices)
|
||||
for (int i = 0; i < first_dim_size; ++i) {
|
||||
indices_flat.Set(i, -1);
|
||||
}
|
||||
|
||||
// data_flat index
|
||||
int32 idx = 0;
|
||||
// sum of indices_inputs[i].NumElements() for compute indicies_flat value.
|
||||
int32 base_size = 0;
|
||||
for (int i = 0; i < indices_inputs.size(); ++i) {
|
||||
auto indices_vec = indices_inputs[i].flat<int32>();
|
||||
auto data_ptr_base = data_inputs[i].template flat<T>().data();
|
||||
for (int j = 0; j < indices_vec.size(); ++j) {
|
||||
// indices_flat's indices represent the indices of output.
|
||||
// indices_flat's values represent the indices of input_data where the
|
||||
// data located.
|
||||
indices_flat.Set(indices_vec(j), base_size + j);
|
||||
data_flat.Set(
|
||||
idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
|
||||
j * slice_size));
|
||||
++idx;
|
||||
}
|
||||
base_size += indices_vec.size();
|
||||
}
|
||||
OP_REQUIRES_OK(c, indices_flat.Finalize());
|
||||
OP_REQUIRES_OK(c, data_flat.Finalize());
|
||||
|
||||
auto output = merged->template flat<T>().data();
|
||||
DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
|
||||
indices_flat.data(), data_flat.data(), output);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
template <class T, bool Parallel>
|
||||
class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
|
||||
public:
|
||||
explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
|
||||
: DynamicStitchOpImplBase<T>(
|
||||
c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
|
||||
|
||||
@ -129,7 +226,7 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
|
||||
int first_dim_size;
|
||||
Tensor* merged = nullptr;
|
||||
this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
|
||||
&first_dim_size, &merged);
|
||||
&first_dim_size, nullptr, &merged);
|
||||
if (!c->status().ok()) {
|
||||
// Avoid segmentation faults if merged cannot be allocated and an error is
|
||||
// passed back in the context.
|
||||
@ -207,13 +304,13 @@ class DynamicStitchOpImpl : public DynamicStitchOpImplBase<T> {
|
||||
// functionality later.
|
||||
|
||||
template <typename T>
|
||||
struct DynamicStitchOp : DynamicStitchOpImpl<T, false> {
|
||||
using DynamicStitchOpImpl<T, false>::DynamicStitchOpImpl;
|
||||
struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
|
||||
using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
|
||||
using DynamicStitchOpImpl<T, true>::DynamicStitchOpImpl;
|
||||
struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
|
||||
using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
|
||||
};
|
||||
|
||||
#define REGISTER_DYNAMIC_STITCH(type) \
|
||||
@ -221,12 +318,12 @@ struct ParallelDynamicStitchOp : DynamicStitchOpImpl<T, true> {
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("indices"), \
|
||||
DynamicStitchOp<type>) \
|
||||
DynamicStitchOpCPU<type>) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("indices"), \
|
||||
ParallelDynamicStitchOp<type>)
|
||||
ParallelDynamicStitchOpCPU<type>)
|
||||
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
|
||||
#undef REGISTER_DYNAMIC_STITCH
|
||||
@ -236,19 +333,21 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
|
||||
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("indices") \
|
||||
.HostMemory("data") \
|
||||
.HostMemory("merged"), \
|
||||
DynamicStitchOp<type>) \
|
||||
.HostMemory("indices"), \
|
||||
DynamicStitchOpGPU<type>) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("indices") \
|
||||
.HostMemory("data") \
|
||||
.HostMemory("merged"), \
|
||||
ParallelDynamicStitchOp<type>)
|
||||
ParallelDynamicStitchOpCPU<type>)
|
||||
|
||||
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_complex64(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_complex128(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
|
||||
#undef REGISTER_DYNAMIC_STITCH_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
81
tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc
Normal file
81
tensorflow/core/kernels/dynamic_stitch_op_gpu.cu.cc
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cuda_device_array_gpu.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
__global__ void DynamicStitchKernel(const int32 slice_size,
|
||||
const int32 output_size,
|
||||
CudaDeviceArrayStruct<int32> input_indices,
|
||||
CudaDeviceArrayStruct<const T*> input_ptrs,
|
||||
T* output) {
|
||||
int32* data_indices = GetCudaDeviceArrayOnDevice(&input_indices);
|
||||
const T** data_ptrs = GetCudaDeviceArrayOnDevice(&input_ptrs);
|
||||
CUDA_1D_KERNEL_LOOP(output_index, output_size) {
|
||||
const int32 slice_id = output_index / slice_size;
|
||||
const int32 slice_offset = output_index % slice_size;
|
||||
const int32 input_index = data_indices[slice_id];
|
||||
if (input_index != -1) {
|
||||
output[output_index] = ldg(data_ptrs[input_index] + slice_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
|
||||
const int32 slice_size, const int32 first_dim_size,
|
||||
const CudaDeviceArrayStruct<int>& input_indices,
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs,
|
||||
T* output) {
|
||||
const int32 output_size = first_dim_size * slice_size;
|
||||
auto config = GetCudaLaunchConfig(output_size, gpu_device);
|
||||
|
||||
DynamicStitchKernel<T>
|
||||
<<<config.block_count, config.thread_per_block, 0, gpu_device.stream()>>>(
|
||||
slice_size, output_size, input_indices, input_ptrs, output);
|
||||
}
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
template void DynamicStitchGPUImpl( \
|
||||
const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
|
||||
const int32 first_dim_size, \
|
||||
const CudaDeviceArrayStruct<int32>& input_indices, \
|
||||
const CudaDeviceArrayStruct<const T*>& input_ptrs, T* output);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
|
||||
TF_CALL_complex64(REGISTER_GPU);
|
||||
TF_CALL_complex128(REGISTER_GPU);
|
||||
TF_CALL_int64(REGISTER_GPU);
|
||||
TF_CALL_int32(REGISTER_GPU)
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
|
||||
@ -33,7 +34,7 @@ class DynamicStitchTestBase(object):
|
||||
self.stitch_op = stitch_op
|
||||
|
||||
def testScalar(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
indices = [constant_op.constant(0), constant_op.constant(1)]
|
||||
data = [constant_op.constant(40), constant_op.constant(60)]
|
||||
for step in -1, 1:
|
||||
@ -46,7 +47,7 @@ class DynamicStitchTestBase(object):
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
|
||||
def testSimpleOneDimensional(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
indices = [
|
||||
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
|
||||
]
|
||||
@ -63,7 +64,7 @@ class DynamicStitchTestBase(object):
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
|
||||
def testOneListOneDimensional(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
|
||||
data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
|
||||
stitched_t = self.stitch_op(indices, data)
|
||||
@ -75,7 +76,7 @@ class DynamicStitchTestBase(object):
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
|
||||
def testSimpleTwoDimensional(self):
|
||||
with self.test_session():
|
||||
with self.test_session(use_gpu=True):
|
||||
indices = [
|
||||
constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
|
||||
constant_op.constant([2, 3, 5])
|
||||
@ -95,7 +96,7 @@ class DynamicStitchTestBase(object):
|
||||
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||
|
||||
def testHigherRank(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
indices = [
|
||||
constant_op.constant(6), constant_op.constant([4, 1]),
|
||||
constant_op.constant([[5, 2], [0, 3]])
|
||||
@ -176,6 +177,45 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
|
||||
test.TestCase.__init__(self, *test_case_args)
|
||||
DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch)
|
||||
|
||||
def testScalar(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
indices = [constant_op.constant(0), constant_op.constant(1)]
|
||||
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
|
||||
for step in -1, 1:
|
||||
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
|
||||
stitched_val = stitched_t.eval()
|
||||
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
|
||||
# Dimension 0 is determined by the max index in indices, so we
|
||||
# can only infer that the output is a vector of some unknown
|
||||
# length.
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
|
||||
def testHigherRank(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
indices = [
|
||||
constant_op.constant(6),
|
||||
constant_op.constant([4, 1]),
|
||||
constant_op.constant([[5, 2], [0, 3]])
|
||||
]
|
||||
data = [
|
||||
constant_op.constant([61, 62], dtype=dtypes.float32),
|
||||
constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
|
||||
constant_op.constant(
|
||||
[[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
|
||||
]
|
||||
stitched_t = data_flow_ops.dynamic_stitch(indices, data)
|
||||
stitched_val = stitched_t.eval()
|
||||
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
|
||||
self.assertAllEqual(correct, stitched_val)
|
||||
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||
# Test gradients
|
||||
stitched_grad = 7 * stitched_val
|
||||
grads = gradients_impl.gradients(stitched_t, indices + data,
|
||||
stitched_grad)
|
||||
self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients
|
||||
for datum, grad in zip(data, sess.run(grads[3:])):
|
||||
self.assertAllEqual(7.0 * datum.eval(), grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user