From 4d5ff2ee958d789b45535ba5dad0ae837b730d13 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 19 Jul 2017 14:34:03 -0700 Subject: [PATCH] Support placeholder for parameter k in tf.nn.in_top_k (#11197) * Support placeholder for parameter k in tf.nn.in_top_k This fix tries to address the issue raised in #9717 where it was not possible to have tensor for k in nn.in_top_k. This fix adds the implementation of InTopKV2Op, adds addition test cases, and following similiar workflow in #10840: 1. Register new kennel InTopKV2Op 2. Hide InTopK and InTopKV2 in python (tensorflow/python/ops/hidden_ops.txt) 3. Add a wrapper in_top_k (in tensorflow/python/ops/nn_ops.py) pointing to gen_nn_ops._in_top_k Another PR will be created after 3 weeks once this PR is merged: 1. Change the implementation of the wrapper in_top_k (in tensorflow/python/ops/nn_ops.py) pointing to gen_nn_ops._in_top_kv2 Signed-off-by: Yong Tang * Address review comments Signed-off-by: Yong Tang * Add HostMemory to InTopK kernel Add HostMemory to InTopK kernel based on the review feedback Signed-off-by: Yong Tang --- tensorflow/core/kernels/in_topk_op.cc | 52 +++++++++++++++++-- tensorflow/core/ops/nn_ops.cc | 43 +++++++++++++++ .../python/kernel_tests/in_topk_op_test.py | 14 +++++ tensorflow/python/ops/hidden_ops.txt | 2 + tensorflow/python/ops/nn_ops.py | 33 ++++++++++++ 5 files changed, 139 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc index 13890e5b7ff..e2861ae090c 100644 --- a/tensorflow/core/kernels/in_topk_op.cc +++ b/tensorflow/core/kernels/in_topk_op.cc @@ -17,11 +17,11 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -29,12 +29,29 @@ template class InTopK : public OpKernel { public: explicit InTopK(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + if (context->num_inputs() == 2) { + OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + } } void Compute(OpKernelContext* context) override { const auto& predictions_in = context->input(0); const auto& targets_in = context->input(1); + int64 k_val = k_; + if (context->num_inputs() == 3) { + const auto& k_in = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), + errors::InvalidArgument("k must be 0-D, got shape ", + k_in.shape().DebugString())); + + if (k_in.dtype() == DT_INT32) { + k_val = k_in.scalar()(); + } else { + k_val = k_in.scalar()(); + } + } + OP_REQUIRES(context, predictions_in.dims() == 2, errors::InvalidArgument("predictions must be 2-dimensional")); OP_REQUIRES(context, targets_in.dims() == 1, @@ -73,7 +90,7 @@ class InTopK : public OpKernel { } } } - out(b) = cannot_say ? false : (more_probable_classes < k_); + out(b) = cannot_say ? false : (more_probable_classes < k_val); } } @@ -82,10 +99,35 @@ class InTopK : public OpKernel { }; REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU).TypeConstraint("T"), + Name("InTopK").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint("T"), InTopK); REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU).TypeConstraint("T"), + Name("InTopK").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint("T"), + InTopK); + +REGISTER_KERNEL_BUILDER( + Name("InTopKV2").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint("T"), + InTopK); +REGISTER_KERNEL_BUILDER( + Name("InTopKV2").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint("T"), InTopK); } // namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 3a25fd15daa..555b97f53b3 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1979,6 +1979,49 @@ precision: Computed Precision at `k` as a `bool Tensor`. )doc"); +// This is the same as `InTopK`, but takes `k` as in input rather than an attr. +REGISTER_OP("InTopKV2") + .Input("predictions: float") + .Input("targets: T") + .Input("k: T") + .Output("precision: bool") + .Attr("T: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle predictions; + ShapeHandle targets; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets)); + DimensionHandle batch_size; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size)); + c->set_output(0, c->Vector(batch_size)); + return Status::OK(); + }) + .Doc(R"doc( +Says whether the targets are in the top `K` predictions. + +This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +prediction for the target class is among the top `k` predictions among +all predictions for example `i`. Note that the behavior of `InTopK` differs +from the `TopK` op in its handling of ties; if multiple classes have the +same prediction value and straddle the top-`k` boundary, all of those +classes are considered to be in the top `k`. + +More formally, let + + \\(predictions_i\\) be the predictions for all classes for example `i`, + \\(targets_i\\) be the target class for example `i`, + \\(out_i\\) be the output for example `i`, + +$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ + +predictions: A `batch_size` x `classes` tensor. +targets: A `batch_size` vector of class ids. +k: Number of top elements to look at for computing precision. +precision: Computed precision at `k` as a `bool Tensor`. + +)doc"); + namespace { Status TopKShapeFn(InferenceContext* c) { diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py index 4a4686d1b99..37e9a8e3d1b 100644 --- a/tensorflow/python/kernel_tests/in_topk_op_test.py +++ b/tensorflow/python/kernel_tests/in_topk_op_test.py @@ -20,7 +20,9 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test @@ -69,6 +71,18 @@ class InTopKTest(test.TestCase): "target.*out of range"): nn_ops.in_top_k(predictions, target, 2).eval() + def testTensorK(self): + predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] + target = [0, 2] + k = constant_op.constant(3) + np_ans = np.array([False, True]) + with self.test_session(): + # TODO (yongtang): The test will be switch to nn_ops.in_top + # once nn_ops.in_top points to _in_top_kv2 later + precision = gen_nn_ops._in_top_kv2(predictions, target, k) + out = precision.eval() + self.assertAllClose(np_ans, out) + self.assertShapeEqual(np_ans, precision) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 28817e99e8c..81e0852e62c 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -302,6 +302,8 @@ Softmax LogSoftmax FractionalAvgPoolGrad FractionalMaxPoolGrad +InTopK +InTopKV2 # parsing_ops ParseExample diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 44f98d64cde..1ce7ea179f1 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2085,3 +2085,36 @@ def erosion2d(value, kernel, strides, rates, padding, name=None): rates=rates, padding=padding, name=name)) + +def in_top_k(predictions, targets, k, name=None): + r"""Says whether the targets are in the top `K` predictions. + + This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the + prediction for the target class is among the top `k` predictions among + all predictions for example `i`. Note that the behavior of `InTopK` differs + from the `TopK` op in its handling of ties; if multiple classes have the + same prediction value and straddle the top-`k` boundary, all of those + classes are considered to be in the top `k`. + + More formally, let + + \\(predictions_i\\) be the predictions for all classes for example `i`, + \\(targets_i\\) be the target class for example `i`, + \\(out_i\\) be the output for example `i`, + + $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ + + Args: + predictions: A `Tensor` of type `float32`. + A `batch_size` x `classes` tensor. + targets: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A `batch_size` vector of class ids. + k: An `int`. Number of top elements to look at for computing precision. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`. + """ + with ops.name_scope(name, 'in_top_k'): + # TODO (yongtang): Need to switch to v2 after 3 weeks. + return gen_nn_ops._in_top_kv2(predictions, targets, k, name=name)