Add InTopKV2 kernel for GPU device
PiperOrigin-RevId: 248184131
This commit is contained in:
parent
7be9519dd0
commit
c75aa2295d
tensorflow
core/kernels
python/kernel_tests
@ -1741,6 +1741,25 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "in_topk_op_test",
|
||||
size = "small",
|
||||
srcs = ["in_topk_op_test.cc"],
|
||||
deps = [
|
||||
":in_topk_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/stream_executor/cuda:cudnn_plugin",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "gather_functor",
|
||||
prefix = "gather_functor",
|
||||
@ -4148,7 +4167,7 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "in_topk_op",
|
||||
prefix = "in_topk_op",
|
||||
deps = NN_DEPS,
|
||||
deps = NN_DEPS + [":reduction_ops"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -6061,6 +6080,7 @@ filegroup(
|
||||
"dynamic_stitch_op.cc",
|
||||
"fft_ops.cc",
|
||||
"in_topk_op.cc",
|
||||
"in_topk_op.h",
|
||||
"initializable_lookup_table.cc",
|
||||
"logging_ops.cc",
|
||||
"lookup_table_init_op.cc",
|
||||
|
@ -17,15 +17,18 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/in_topk_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T, typename TARGET_T>
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T, typename TARGET_T>
|
||||
class InTopK : public OpKernel {
|
||||
public:
|
||||
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
|
||||
@ -37,7 +40,10 @@ class InTopK : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const auto& predictions_in = context->input(0);
|
||||
const auto& targets_in = context->input(1);
|
||||
int64 k_val = k_;
|
||||
|
||||
int64 k_value = k_;
|
||||
const Tensor* k_tensor = nullptr;
|
||||
|
||||
if (context->num_inputs() == 3) {
|
||||
const auto& k_in = context->input(2);
|
||||
|
||||
@ -45,11 +51,7 @@ class InTopK : public OpKernel {
|
||||
errors::InvalidArgument("k must be 0-D, got shape ",
|
||||
k_in.shape().DebugString()));
|
||||
|
||||
if (k_in.dtype() == DT_INT32) {
|
||||
k_val = k_in.scalar<int32>()();
|
||||
} else {
|
||||
k_val = k_in.scalar<int64>()();
|
||||
}
|
||||
k_tensor = &k_in;
|
||||
}
|
||||
|
||||
OP_REQUIRES(context, predictions_in.dims() == 2,
|
||||
@ -61,8 +63,9 @@ class InTopK : public OpKernel {
|
||||
predictions_in.dim_size(0),
|
||||
" must match length of targets ",
|
||||
targets_in.dim_size(0)));
|
||||
const auto& predictions = predictions_in.matrix<T>();
|
||||
const auto& targets = targets_in.vec<TARGET_T>();
|
||||
|
||||
const auto predictions = predictions_in.matrix<T>();
|
||||
const auto targets = targets_in.vec<TARGET_T>();
|
||||
|
||||
Tensor* t_out = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -70,28 +73,11 @@ class InTopK : public OpKernel {
|
||||
0, TensorShape({targets_in.dim_size(0)}), &t_out));
|
||||
auto out = t_out->vec<bool>();
|
||||
|
||||
const auto size = targets.size();
|
||||
const auto num_classes = predictions.dimension(1);
|
||||
for (int b = 0; b < size; b++) {
|
||||
auto target = internal::SubtleMustCopy(targets(b));
|
||||
OP_REQUIRES(context, FastBoundsCheck(target, num_classes),
|
||||
errors::InvalidArgument("targets[", b, "] is out of range"));
|
||||
T target_prediction = predictions(b, target);
|
||||
bool cannot_say = !std::isfinite(target_prediction);
|
||||
int more_probable_classes = 0;
|
||||
if (!cannot_say) {
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
T pred = predictions(b, i);
|
||||
if (!std::isfinite(pred)) {
|
||||
cannot_say = true;
|
||||
break;
|
||||
} else if (pred > target_prediction) {
|
||||
++more_probable_classes;
|
||||
}
|
||||
}
|
||||
}
|
||||
out(b) = cannot_say ? false : (more_probable_classes < k_val);
|
||||
}
|
||||
functor::InTopKFunctor<Device, T, TARGET_T> f;
|
||||
functor::TopKArg arg;
|
||||
arg.k_value = k_value;
|
||||
arg.k_tensor = k_tensor;
|
||||
f(context, predictions, targets, arg, out);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -104,14 +90,14 @@ REGISTER_KERNEL_BUILDER(Name("InTopK")
|
||||
.HostMemory("targets")
|
||||
.HostMemory("precision")
|
||||
.TypeConstraint<int32>("T"),
|
||||
InTopK<float, int32>);
|
||||
InTopK<CPUDevice, float, int32>);
|
||||
REGISTER_KERNEL_BUILDER(Name("InTopK")
|
||||
.Device(DEVICE_CPU)
|
||||
.HostMemory("predictions")
|
||||
.HostMemory("targets")
|
||||
.HostMemory("precision")
|
||||
.TypeConstraint<int64>("T"),
|
||||
InTopK<float, int64>);
|
||||
InTopK<CPUDevice, float, int64>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("InTopKV2")
|
||||
.Device(DEVICE_CPU)
|
||||
@ -120,7 +106,7 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
|
||||
.HostMemory("k")
|
||||
.HostMemory("precision")
|
||||
.TypeConstraint<int32>("T"),
|
||||
InTopK<float, int32>);
|
||||
InTopK<CPUDevice, float, int32>);
|
||||
REGISTER_KERNEL_BUILDER(Name("InTopKV2")
|
||||
.Device(DEVICE_CPU)
|
||||
.HostMemory("predictions")
|
||||
@ -128,6 +114,34 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
|
||||
.HostMemory("k")
|
||||
.HostMemory("precision")
|
||||
.TypeConstraint<int64>("T"),
|
||||
InTopK<float, int64>);
|
||||
InTopK<CPUDevice, float, int64>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T, TARGET_T) \
|
||||
template <> \
|
||||
void InTopKFunctor<GPUDevice, T, TARGET_T>::operator()( \
|
||||
OpKernelContext* context, \
|
||||
typename TTypes<T, 2>::ConstTensor predictions, \
|
||||
typename TTypes<TARGET_T>::ConstVec targets, const TopKArg k, \
|
||||
typename TTypes<bool>::Vec output); \
|
||||
extern template struct InTopKFunctor<GPUDevice, T, TARGET_T>;
|
||||
|
||||
DECLARE_GPU_SPEC(float, int32);
|
||||
DECLARE_GPU_SPEC(float, int64);
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
|
||||
InTopK<GPUDevice, float, int32>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
|
||||
InTopK<GPUDevice, float, int64>);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
100
tensorflow/core/kernels/in_topk_op.h
Normal file
100
tensorflow/core/kernels/in_topk_op.h
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// InTopK argument can be passed either via mode attribute (InTopK op), or as an
|
||||
// input tensor (InTopKV2 op).
|
||||
struct TopKArg {
|
||||
int64 k_value = -1;
|
||||
const Tensor* k_tensor = nullptr;
|
||||
};
|
||||
|
||||
template <typename Device, typename T, typename TargetT>
|
||||
struct InTopKFunctor {
|
||||
template <int ndims>
|
||||
using Dims = Eigen::DSizes<Eigen::Index, ndims>;
|
||||
|
||||
void operator()(OpKernelContext* context,
|
||||
typename TTypes<T, 2>::ConstTensor predictions,
|
||||
typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
|
||||
typename TTypes<bool>::Vec output) {}
|
||||
};
|
||||
|
||||
template <typename T, typename TargetT>
|
||||
struct InTopKFunctor<CPUDevice, T, TargetT> {
|
||||
void operator()(OpKernelContext* context,
|
||||
typename TTypes<T, 2>::ConstTensor predictions,
|
||||
typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
|
||||
typename TTypes<bool>::Vec output) {
|
||||
const Eigen::Index num_targets = predictions.dimension(0);
|
||||
const Eigen::Index num_classes = predictions.dimension(1);
|
||||
|
||||
int64 k_val = k.k_value;
|
||||
if (k.k_tensor != nullptr) {
|
||||
if (k.k_tensor->dtype() == DT_INT32) {
|
||||
k_val = k.k_tensor->scalar<int32>()();
|
||||
} else {
|
||||
k_val = k.k_tensor->scalar<int64>()();
|
||||
}
|
||||
}
|
||||
|
||||
for (int batch_idx = 0; batch_idx < num_targets; batch_idx++) {
|
||||
auto target = internal::SubtleMustCopy(targets(batch_idx));
|
||||
|
||||
bool cannot_say = !FastBoundsCheck(target, num_classes) ||
|
||||
!std::isfinite(predictions(batch_idx, target));
|
||||
|
||||
int more_probable_classes = 0;
|
||||
if (!cannot_say) {
|
||||
const T target_prediction = predictions(batch_idx, target);
|
||||
|
||||
for (int class_idx = 0; class_idx < num_classes; ++class_idx) {
|
||||
T pred = predictions(batch_idx, class_idx);
|
||||
if (!std::isfinite(pred)) {
|
||||
cannot_say = true;
|
||||
break;
|
||||
} else if (pred > target_prediction) {
|
||||
++more_probable_classes;
|
||||
if (more_probable_classes > k_val) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
output(batch_idx) = cannot_say ? false : (more_probable_classes < k_val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
176
tensorflow/core/kernels/in_topk_op_gpu.cu.cc
Normal file
176
tensorflow/core/kernels/in_topk_op_gpu.cu.cc
Normal file
@ -0,0 +1,176 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/in_topk_op.h"
|
||||
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Compare each prediction in 'predictions' with a target prediction for the
|
||||
// batch, and write result to the 'mask':
|
||||
// -1: If the target class is out of range, or if the prediction value is not
|
||||
// finite and can't be compared to target prediction (and vice versa).
|
||||
// 0: If prediction is smaller than the target prediction for the batch.
|
||||
// 1: If prediction is larger than the target prediction for the batch.
|
||||
template <typename T, typename TargetT>
|
||||
__global__ void ComputePredictionMaskKernel(
|
||||
const T* predictions, // dims: [ num_targets x num_classes ]
|
||||
const TargetT* targets, // dims: [ num_targets ]
|
||||
int64* mask, // dims: [ num_targets x num_classes ]
|
||||
int num_targets, int num_classes) {
|
||||
CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) {
|
||||
const int batch_index = i / num_classes;
|
||||
TargetT target_idx = ldg(targets + batch_index);
|
||||
|
||||
if (!FastBoundsCheck(target_idx, num_classes)) {
|
||||
mask[i] = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
T prediction = ldg(predictions + i);
|
||||
T target_prediction =
|
||||
ldg(predictions + batch_index * num_classes + target_idx);
|
||||
|
||||
if (!Eigen::numext::isfinite(prediction) ||
|
||||
!Eigen::numext::isfinite(target_prediction)) {
|
||||
mask[i] = -1;
|
||||
} else {
|
||||
mask[i] = prediction > target_prediction ? 1 : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce all prediction masks either to the sum of '1' for each prediction
|
||||
// larger than the target, or to '-1' if target class in invalid of predictions
|
||||
// in a batch have non-finite values.
|
||||
struct MaskSum {
|
||||
__host__ __device__ int64 operator()(const int64& a, const int64& b) const {
|
||||
if (a < 0 || b < 0)
|
||||
return -1;
|
||||
else
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
namespace reduction_op_helper {
|
||||
template <>
|
||||
struct IdentityValue<int64, MaskSum> {
|
||||
int64 operator()() { return 0; }
|
||||
};
|
||||
|
||||
} // namespace reduction_op_helper
|
||||
|
||||
template <typename T, typename TargetT>
|
||||
struct InTopKFunctor<GPUDevice, T, TargetT> {
|
||||
template <int ndims>
|
||||
using Dims = Eigen::DSizes<Eigen::Index, ndims>;
|
||||
|
||||
void operator()(OpKernelContext* context,
|
||||
typename TTypes<T, 2>::ConstTensor predictions,
|
||||
typename TTypes<TargetT>::ConstVec targets, const TopKArg k,
|
||||
typename TTypes<bool>::Vec output) {
|
||||
const Eigen::Index num_targets = predictions.dimension(0);
|
||||
const Eigen::Index num_classes = predictions.dimension(1);
|
||||
|
||||
OP_REQUIRES(
|
||||
context, num_targets * num_classes < std::numeric_limits<int>::max(),
|
||||
errors::InvalidArgument(
|
||||
"Number of targets * number of classes must be less than INT_MAX"));
|
||||
|
||||
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
|
||||
Tensor predictions_mask;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(DT_INT64,
|
||||
TensorShape({num_targets, num_classes}),
|
||||
&predictions_mask));
|
||||
|
||||
// Number of predictions for each target that are larger than the target
|
||||
// prediction (or -1 if we can't compute this number, because not all
|
||||
// predictions are finite or target class is out of range).
|
||||
Tensor num_larger_prediction;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DT_INT64, TensorShape({num_targets}),
|
||||
&num_larger_prediction));
|
||||
|
||||
const auto& d = context->eigen_device<GPUDevice>();
|
||||
|
||||
// Compute a mask for all predictions.
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(num_targets * num_classes, d);
|
||||
OP_REQUIRES_OK(context, CudaLaunchKernel(
|
||||
ComputePredictionMaskKernel<T, TargetT>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), predictions.data(), targets.data(),
|
||||
predictions_mask.flat<int64>().data(),
|
||||
num_targets, num_classes));
|
||||
|
||||
// Reduce prediction masks to number of predictions larger than the target
|
||||
// prediction, or to the negative value if we can't compute an answer.
|
||||
{
|
||||
auto in = predictions_mask.matrix<int64>();
|
||||
auto out = num_larger_prediction.flat<int64>();
|
||||
|
||||
ReduceImpl<int64, MaskSum, int64*, int64*, Dims<1>>(
|
||||
context, (int64*)out.data(), (int64*)in.data(), in.rank(),
|
||||
in.dimension(0), in.rank() >= 2 ? in.dimension(1) : 1,
|
||||
in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), Dims<1>(1),
|
||||
MaskSum());
|
||||
}
|
||||
|
||||
// Compute if target prediction is in top K predictions.
|
||||
auto cnt = num_larger_prediction.flat<int64>();
|
||||
|
||||
if (k.k_tensor != nullptr) {
|
||||
if (k.k_tensor->dtype() == DT_INT32) {
|
||||
output.device(d) =
|
||||
(cnt >= cnt.constant(0)) &&
|
||||
(cnt < k.k_tensor->flat<int32>().template cast<int64>().broadcast(
|
||||
Dims<1>(num_targets)));
|
||||
} else {
|
||||
output.device(d) =
|
||||
(cnt >= cnt.constant(0)) &&
|
||||
(cnt < k.k_tensor->flat<int64>().broadcast(Dims<1>(num_targets)));
|
||||
}
|
||||
} else {
|
||||
output.device(d) =
|
||||
(cnt >= cnt.constant(0)) && (cnt < targets.constant(k.k_value));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
// Definition of the GPU implementations declared in in_topk_op.cc.
|
||||
#define DEFINE_GPU_KERNELS(T, TARGET_T) \
|
||||
template struct functor::InTopKFunctor<GPUDevice, T, TARGET_T>;
|
||||
|
||||
DEFINE_GPU_KERNELS(float, int32);
|
||||
DEFINE_GPU_KERNELS(float, int64);
|
||||
|
||||
#undef DEFINE_GPU_KERNELS
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
84
tensorflow/core/kernels/in_topk_op_test.cc
Normal file
84
tensorflow/core/kernels/in_topk_op_test.cc
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
static Graph* InTopK(int num_targets, int num_classes, T top_k) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
|
||||
DataType dtype = DataTypeToEnum<T>::value;
|
||||
|
||||
Tensor predictions_t(DT_FLOAT, TensorShape({num_targets, num_classes}));
|
||||
predictions_t.flat<float>().setRandom();
|
||||
|
||||
Tensor targets_t(dtype, TensorShape({num_targets}));
|
||||
targets_t.flat<T>().setRandom();
|
||||
|
||||
Tensor k_t(dtype, TensorShape({}));
|
||||
k_t.scalar<T>() = k_t.scalar<T>().constant(top_k);
|
||||
|
||||
Node* predictions = test::graph::Constant(g, predictions_t, "predictions");
|
||||
Node* targets = test::graph::Constant(g, targets_t, "targets");
|
||||
Node* k = test::graph::Constant(g, k_t, "k");
|
||||
|
||||
Node* in_topk;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("in_topk"), "InTopKV2")
|
||||
.Input(predictions)
|
||||
.Input(targets)
|
||||
.Input(k)
|
||||
.Attr("T", dtype)
|
||||
.Finalize(g, &in_topk));
|
||||
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_NAME(T, TARGETS, CLASSES, K, DEVICE) \
|
||||
BM_InTopK##_##T##_##TARGETS##_##CLASSES##_##K##_##DEVICE
|
||||
|
||||
#define BM_InTopK(T, TARGETS, CLASSES, K, DEVICE) \
|
||||
static void BM_NAME(T, TARGETS, CLASSES, K, DEVICE)(int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * TARGETS * CLASSES); \
|
||||
test::Benchmark(#DEVICE, InTopK<T>(TARGETS, CLASSES, K)).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(T, TARGETS, CLASSES, K, DEVICE));
|
||||
|
||||
BM_InTopK(int64, 64, 1000, 10, cpu);
|
||||
BM_InTopK(int64, 64, 10000, 10, cpu);
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
BM_InTopK(int64, 64, 1000, 10, gpu);
|
||||
BM_InTopK(int64, 64, 10000, 10, gpu);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
@ -536,7 +536,7 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
cuda_py_test(
|
||||
name = "in_topk_op_test",
|
||||
size = "small",
|
||||
srcs = ["in_topk_op_test.py"],
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -30,7 +29,7 @@ class InTopKTest(test.TestCase):
|
||||
|
||||
def _validateInTopK(self, predictions, target, k, expected):
|
||||
np_ans = np.array(expected)
|
||||
with self.cached_session():
|
||||
with self.cached_session(use_gpu=True):
|
||||
precision = nn_ops.in_top_k(predictions, target, k)
|
||||
out = self.evaluate(precision)
|
||||
self.assertAllClose(np_ans, out)
|
||||
@ -63,12 +62,9 @@ class InTopKTest(test.TestCase):
|
||||
self._validateInTopK(predictions, target, 2, [False, False])
|
||||
|
||||
def testBadTarget(self):
|
||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||
target = [0, 80000]
|
||||
with self.cached_session():
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"target.*out of range"):
|
||||
nn_ops.in_top_k(predictions, target, 2).eval()
|
||||
predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]]
|
||||
target = [2, 12345] # must return False for invalid target
|
||||
self._validateInTopK(predictions, target, 2, [True, False])
|
||||
|
||||
def testTensorK(self):
|
||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||
@ -81,5 +77,6 @@ class InTopKTest(test.TestCase):
|
||||
self.assertAllClose(np_ans, out)
|
||||
self.assertShapeEqual(np_ans, precision)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user