Add InTopKV2 kernel for GPU device

PiperOrigin-RevId: 248184131
This commit is contained in:
Eugene Zhulenev 2019-05-14 12:00:30 -07:00 committed by TensorFlower Gardener
parent 7be9519dd0
commit c75aa2295d
7 changed files with 438 additions and 47 deletions

View File

@ -1741,6 +1741,25 @@ tf_cc_test(
],
)
tf_cc_test(
name = "in_topk_op_test",
size = "small",
srcs = ["in_topk_op_test.cc"],
deps = [
":in_topk_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/stream_executor/cuda:cudnn_plugin",
],
)
tf_kernel_library(
name = "gather_functor",
prefix = "gather_functor",
@ -4148,7 +4167,7 @@ tf_kernel_library(
tf_kernel_library(
name = "in_topk_op",
prefix = "in_topk_op",
deps = NN_DEPS,
deps = NN_DEPS + [":reduction_ops"],
)
tf_kernel_library(
@ -6061,6 +6080,7 @@ filegroup(
"dynamic_stitch_op.cc",
"fft_ops.cc",
"in_topk_op.cc",
"in_topk_op.h",
"initializable_lookup_table.cc",
"logging_ops.cc",
"lookup_table_init_op.cc",

View File

@ -17,15 +17,18 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/kernels/in_topk_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
template <typename T, typename TARGET_T>
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T, typename TARGET_T>
class InTopK : public OpKernel {
public:
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
@ -37,7 +40,10 @@ class InTopK : public OpKernel {
void Compute(OpKernelContext* context) override {
const auto& predictions_in = context->input(0);
const auto& targets_in = context->input(1);
int64 k_val = k_;
int64 k_value = k_;
const Tensor* k_tensor = nullptr;
if (context->num_inputs() == 3) {
const auto& k_in = context->input(2);
@ -45,11 +51,7 @@ class InTopK : public OpKernel {
errors::InvalidArgument("k must be 0-D, got shape ",
k_in.shape().DebugString()));
if (k_in.dtype() == DT_INT32) {
k_val = k_in.scalar<int32>()();
} else {
k_val = k_in.scalar<int64>()();
}
k_tensor = &k_in;
}
OP_REQUIRES(context, predictions_in.dims() == 2,
@ -61,8 +63,9 @@ class InTopK : public OpKernel {
predictions_in.dim_size(0),
" must match length of targets ",
targets_in.dim_size(0)));
const auto& predictions = predictions_in.matrix<T>();
const auto& targets = targets_in.vec<TARGET_T>();
const auto predictions = predictions_in.matrix<T>();
const auto targets = targets_in.vec<TARGET_T>();
Tensor* t_out = nullptr;
OP_REQUIRES_OK(context,
@ -70,28 +73,11 @@ class InTopK : public OpKernel {
0, TensorShape({targets_in.dim_size(0)}), &t_out));
auto out = t_out->vec<bool>();
const auto size = targets.size();
const auto num_classes = predictions.dimension(1);
for (int b = 0; b < size; b++) {
auto target = internal::SubtleMustCopy(targets(b));
OP_REQUIRES(context, FastBoundsCheck(target, num_classes),
errors::InvalidArgument("targets[", b, "] is out of range"));
T target_prediction = predictions(b, target);
bool cannot_say = !std::isfinite(target_prediction);
int more_probable_classes = 0;
if (!cannot_say) {
for (int i = 0; i < num_classes; ++i) {
T pred = predictions(b, i);
if (!std::isfinite(pred)) {
cannot_say = true;
break;
} else if (pred > target_prediction) {
++more_probable_classes;
}
}
}
out(b) = cannot_say ? false : (more_probable_classes < k_val);
}
functor::InTopKFunctor<Device, T, TARGET_T> f;
functor::TopKArg arg;
arg.k_value = k_value;
arg.k_tensor = k_tensor;
f(context, predictions, targets, arg, out);
}
private:
@ -104,14 +90,14 @@ REGISTER_KERNEL_BUILDER(Name("InTopK")
.HostMemory("targets")
.HostMemory("precision")
.TypeConstraint<int32>("T"),
InTopK<float, int32>);
InTopK<CPUDevice, float, int32>);
REGISTER_KERNEL_BUILDER(Name("InTopK")
.Device(DEVICE_CPU)
.HostMemory("predictions")
.HostMemory("targets")
.HostMemory("precision")
.TypeConstraint<int64>("T"),
InTopK<float, int64>);
InTopK<CPUDevice, float, int64>);
REGISTER_KERNEL_BUILDER(Name("InTopKV2")
.Device(DEVICE_CPU)
@ -120,7 +106,7 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
.HostMemory("k")
.HostMemory("precision")
.TypeConstraint<int32>("T"),
InTopK<float, int32>);
InTopK<CPUDevice, float, int32>);
REGISTER_KERNEL_BUILDER(Name("InTopKV2")
.Device(DEVICE_CPU)
.HostMemory("predictions")
@ -128,6 +114,34 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
.HostMemory("k")
.HostMemory("precision")
.TypeConstraint<int64>("T"),
InTopK<float, int64>);
InTopK<CPUDevice, float, int64>);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, TARGET_T) \
template <> \
void InTopKFunctor<GPUDevice, T, TARGET_T>::operator()( \
OpKernelContext* context, \
typename TTypes<T, 2>::ConstTensor predictions, \
typename TTypes<TARGET_T>::ConstVec targets, const TopKArg k, \
typename TTypes<bool>::Vec output); \
extern template struct InTopKFunctor<GPUDevice, T, TARGET_T>;
DECLARE_GPU_SPEC(float, int32);
DECLARE_GPU_SPEC(float, int64);
#undef DECLARE_GPU_SPEC
} // namespace functor
REGISTER_KERNEL_BUILDER(
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
InTopK<GPUDevice, float, int32>);
REGISTER_KERNEL_BUILDER(
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
InTopK<GPUDevice, float, int64>);
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -0,0 +1,100 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
namespace functor {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
// InTopK argument can be passed either via mode attribute (InTopK op), or as an
// input tensor (InTopKV2 op).
struct TopKArg {
int64 k_value = -1;
const Tensor* k_tensor = nullptr;
};
template <typename Device, typename T, typename TargetT>
struct InTopKFunctor {
template <int ndims>
using Dims = Eigen::DSizes<Eigen::Index, ndims>;
void operator()(OpKernelContext* context,
typename TTypes<T, 2>::ConstTensor predictions,
typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
typename TTypes<bool>::Vec output) {}
};
template <typename T, typename TargetT>
struct InTopKFunctor<CPUDevice, T, TargetT> {
void operator()(OpKernelContext* context,
typename TTypes<T, 2>::ConstTensor predictions,
typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
typename TTypes<bool>::Vec output) {
const Eigen::Index num_targets = predictions.dimension(0);
const Eigen::Index num_classes = predictions.dimension(1);
int64 k_val = k.k_value;
if (k.k_tensor != nullptr) {
if (k.k_tensor->dtype() == DT_INT32) {
k_val = k.k_tensor->scalar<int32>()();
} else {
k_val = k.k_tensor->scalar<int64>()();
}
}
for (int batch_idx = 0; batch_idx < num_targets; batch_idx++) {
auto target = internal::SubtleMustCopy(targets(batch_idx));
bool cannot_say = !FastBoundsCheck(target, num_classes) ||
!std::isfinite(predictions(batch_idx, target));
int more_probable_classes = 0;
if (!cannot_say) {
const T target_prediction = predictions(batch_idx, target);
for (int class_idx = 0; class_idx < num_classes; ++class_idx) {
T pred = predictions(batch_idx, class_idx);
if (!std::isfinite(pred)) {
cannot_say = true;
break;
} else if (pred > target_prediction) {
++more_probable_classes;
if (more_probable_classes > k_val) break;
}
}
}
output(batch_idx) = cannot_say ? false : (more_probable_classes < k_val);
}
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_

View File

@ -0,0 +1,176 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/in_topk_op.h"
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// Compare each prediction in 'predictions' with a target prediction for the
// batch, and write result to the 'mask':
// -1: If the target class is out of range, or if the prediction value is not
// finite and can't be compared to target prediction (and vice versa).
// 0: If prediction is smaller than the target prediction for the batch.
// 1: If prediction is larger than the target prediction for the batch.
template <typename T, typename TargetT>
__global__ void ComputePredictionMaskKernel(
const T* predictions, // dims: [ num_targets x num_classes ]
const TargetT* targets, // dims: [ num_targets ]
int64* mask, // dims: [ num_targets x num_classes ]
int num_targets, int num_classes) {
CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) {
const int batch_index = i / num_classes;
TargetT target_idx = ldg(targets + batch_index);
if (!FastBoundsCheck(target_idx, num_classes)) {
mask[i] = -1;
return;
}
T prediction = ldg(predictions + i);
T target_prediction =
ldg(predictions + batch_index * num_classes + target_idx);
if (!Eigen::numext::isfinite(prediction) ||
!Eigen::numext::isfinite(target_prediction)) {
mask[i] = -1;
} else {
mask[i] = prediction > target_prediction ? 1 : 0;
}
}
}
// Reduce all prediction masks either to the sum of '1' for each prediction
// larger than the target, or to '-1' if target class in invalid of predictions
// in a batch have non-finite values.
struct MaskSum {
__host__ __device__ int64 operator()(const int64& a, const int64& b) const {
if (a < 0 || b < 0)
return -1;
else
return a + b;
}
};
namespace reduction_op_helper {
template <>
struct IdentityValue<int64, MaskSum> {
int64 operator()() { return 0; }
};
} // namespace reduction_op_helper
template <typename T, typename TargetT>
struct InTopKFunctor<GPUDevice, T, TargetT> {
template <int ndims>
using Dims = Eigen::DSizes<Eigen::Index, ndims>;
void operator()(OpKernelContext* context,
typename TTypes<T, 2>::ConstTensor predictions,
typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
typename TTypes<bool>::Vec output) {
const Eigen::Index num_targets = predictions.dimension(0);
const Eigen::Index num_classes = predictions.dimension(1);
OP_REQUIRES(
context, num_targets * num_classes < std::numeric_limits<int>::max(),
errors::InvalidArgument(
"Number of targets * number of classes must be less than INT_MAX"));
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
Tensor predictions_mask;
OP_REQUIRES_OK(
context, context->allocate_temp(DT_INT64,
TensorShape({num_targets, num_classes}),
&predictions_mask));
// Number of predictions for each target that are larger than the target
// prediction (or -1 if we can't compute this number, because not all
// predictions are finite or target class is out of range).
Tensor num_larger_prediction;
OP_REQUIRES_OK(context,
context->allocate_temp(DT_INT64, TensorShape({num_targets}),
&num_larger_prediction));
const auto& d = context->eigen_device<GPUDevice>();
// Compute a mask for all predictions.
CudaLaunchConfig config = GetCudaLaunchConfig(num_targets * num_classes, d);
OP_REQUIRES_OK(context, CudaLaunchKernel(
ComputePredictionMaskKernel<T, TargetT>,
config.block_count, config.thread_per_block, 0,
d.stream(), predictions.data(), targets.data(),
predictions_mask.flat<int64>().data(),
num_targets, num_classes));
// Reduce prediction masks to number of predictions larger than the target
// prediction, or to the negative value if we can't compute an answer.
{
auto in = predictions_mask.matrix<int64>();
auto out = num_larger_prediction.flat<int64>();
ReduceImpl<int64, MaskSum, int64*, int64*, Dims<1>>(
context, (int64*)out.data(), (int64*)in.data(), in.rank(),
in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1,
in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), Dims<1>(1),
MaskSum());
}
// Compute if target prediction is in top K predictions.
auto cnt = num_larger_prediction.flat<int64>();
if (k.k_tensor != nullptr) {
if (k.k_tensor->dtype() == DT_INT32) {
output.device(d) =
(cnt >= cnt.constant(0)) &&
(cnt < k.k_tensor->flat<int32>().template cast<int64>().broadcast(
Dims<1>(num_targets)));
} else {
output.device(d) =
(cnt >= cnt.constant(0)) &&
(cnt < k.k_tensor->flat<int64>().broadcast(Dims<1>(num_targets)));
}
} else {
output.device(d) =
(cnt >= cnt.constant(0)) && (cnt < targets.constant(k.k_value));
}
}
};
} // namespace functor
// Definition of the GPU implementations declared in in_topk_op.cc.
#define DEFINE_GPU_KERNELS(T, TARGET_T) \
template struct functor::InTopKFunctor<GPUDevice, T, TARGET_T>;
DEFINE_GPU_KERNELS(float, int32);
DEFINE_GPU_KERNELS(float, int64);
#undef DEFINE_GPU_KERNELS
} // end namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,84 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
template <typename T>
static Graph* InTopK(int num_targets, int num_classes, T top_k) {
Graph* g = new Graph(OpRegistry::Global());
DataType dtype = DataTypeToEnum<T>::value;
Tensor predictions_t(DT_FLOAT, TensorShape({num_targets, num_classes}));
predictions_t.flat<float>().setRandom();
Tensor targets_t(dtype, TensorShape({num_targets}));
targets_t.flat<T>().setRandom();
Tensor k_t(dtype, TensorShape({}));
k_t.scalar<T>() = k_t.scalar<T>().constant(top_k);
Node* predictions = test::graph::Constant(g, predictions_t, "predictions");
Node* targets = test::graph::Constant(g, targets_t, "targets");
Node* k = test::graph::Constant(g, k_t, "k");
Node* in_topk;
TF_CHECK_OK(NodeBuilder(g->NewName("in_topk"), "InTopKV2")
.Input(predictions)
.Input(targets)
.Input(k)
.Attr("T", dtype)
.Finalize(g, &in_topk));
return g;
}
#define BM_NAME(T, TARGETS, CLASSES, K, DEVICE) \
BM_InTopK##_##T##_##TARGETS##_##CLASSES##_##K##_##DEVICE
#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE) \
static void BM_NAME(T, TARGETS, CLASSES, K, DEVICE)(int iters) { \
testing::UseRealTime(); \
testing::ItemsProcessed(static_cast<int64>(iters) * TARGETS * CLASSES); \
test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K)).Run(iters); \
} \
BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE));
BM_InTopK(int64, 64, 1000, 10, cpu);
BM_InTopK(int64, 64, 10000, 10, cpu);
#ifdef GOOGLE_CUDA
BM_InTopK(int64, 64, 1000, 10, gpu);
BM_InTopK(int64, 64, 10000, 10, gpu);
#endif // GOOGLE_CUDA
} // namespace tensorflow

View File

@ -536,7 +536,7 @@ tf_py_test(
],
)
tf_py_test(
cuda_py_test(
name = "in_topk_op_test",
size = "small",
srcs = ["in_topk_op_test.py"],

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
@ -30,7 +29,7 @@ class InTopKTest(test.TestCase):
def _validateInTopK(self, predictions, target, k, expected):
np_ans = np.array(expected)
with self.cached_session():
with self.cached_session(use_gpu=True):
precision = nn_ops.in_top_k(predictions, target, k)
out = self.evaluate(precision)
self.assertAllClose(np_ans, out)
@ -63,12 +62,9 @@ class InTopKTest(test.TestCase):
self._validateInTopK(predictions, target, 2, [False, False])
def testBadTarget(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [0, 80000]
with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"target.*out of range"):
nn_ops.in_top_k(predictions, target, 2).eval()
predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]]
target = [2, 12345] # must return False for invalid target
self._validateInTopK(predictions, target, 2, [True, False])
def testTensorK(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
@ -81,5 +77,6 @@ class InTopKTest(test.TestCase):
self.assertAllClose(np_ans, out)
self.assertShapeEqual(np_ans, precision)
if __name__ == "__main__":
test.main()