Support placeholder for parameter k in tf.nn.in_top_k (#11197)

* Support placeholder for parameter k in tf.nn.in_top_k

This fix tries to address the issue raised in #9717 where
it was not possible to have tensor for k in nn.in_top_k.

This fix adds the implementation of InTopKV2Op, adds addition test
cases, and following similiar workflow in #10840:
1. Register new kennel InTopKV2Op
2. Hide InTopK and InTopKV2 in python (tensorflow/python/ops/hidden_ops.txt)
3. Add a wrapper in_top_k (in tensorflow/python/ops/nn_ops.py) pointing to gen_nn_ops._in_top_k

Another PR will be created after 3 weeks once this PR is merged:
1. Change the implementation of the wrapper in_top_k (in tensorflow/python/ops/nn_ops.py)
pointing to gen_nn_ops._in_top_kv2

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Address review comments

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>

* Add HostMemory to InTopK kernel

Add HostMemory to InTopK kernel based on the review feedback

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2017-07-19 14:34:03 -07:00 committed by drpngx
parent bad0e44c17
commit 4d5ff2ee95
5 changed files with 139 additions and 5 deletions

View File

@ -17,11 +17,11 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow { namespace tensorflow {
@ -29,12 +29,29 @@ template <typename T, typename TARGET_T>
class InTopK : public OpKernel { class InTopK : public OpKernel {
public: public:
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) { explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); if (context->num_inputs() == 2) {
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
}
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const auto& predictions_in = context->input(0); const auto& predictions_in = context->input(0);
const auto& targets_in = context->input(1); const auto& targets_in = context->input(1);
int64 k_val = k_;
if (context->num_inputs() == 3) {
const auto& k_in = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
errors::InvalidArgument("k must be 0-D, got shape ",
k_in.shape().DebugString()));
if (k_in.dtype() == DT_INT32) {
k_val = k_in.scalar<int32>()();
} else {
k_val = k_in.scalar<int64>()();
}
}
OP_REQUIRES(context, predictions_in.dims() == 2, OP_REQUIRES(context, predictions_in.dims() == 2,
errors::InvalidArgument("predictions must be 2-dimensional")); errors::InvalidArgument("predictions must be 2-dimensional"));
OP_REQUIRES(context, targets_in.dims() == 1, OP_REQUIRES(context, targets_in.dims() == 1,
@ -73,7 +90,7 @@ class InTopK : public OpKernel {
} }
} }
} }
out(b) = cannot_say ? false : (more_probable_classes < k_); out(b) = cannot_say ? false : (more_probable_classes < k_val);
} }
} }
@ -82,10 +99,35 @@ class InTopK : public OpKernel {
}; };
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int32>("T"), Name("InTopK").Device(DEVICE_CPU)
.HostMemory("predictions")
.HostMemory("targets")
.HostMemory("precision")
.TypeConstraint<int32>("T"),
InTopK<float, int32>); InTopK<float, int32>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int64>("T"), Name("InTopK").Device(DEVICE_CPU)
.HostMemory("predictions")
.HostMemory("targets")
.HostMemory("precision")
.TypeConstraint<int64>("T"),
InTopK<float, int64>);
REGISTER_KERNEL_BUILDER(
Name("InTopKV2").Device(DEVICE_CPU)
.HostMemory("predictions")
.HostMemory("targets")
.HostMemory("k")
.HostMemory("precision")
.TypeConstraint<int32>("T"),
InTopK<float, int32>);
REGISTER_KERNEL_BUILDER(
Name("InTopKV2").Device(DEVICE_CPU)
.HostMemory("predictions")
.HostMemory("targets")
.HostMemory("k")
.HostMemory("precision")
.TypeConstraint<int64>("T"),
InTopK<float, int64>); InTopK<float, int64>);
} // namespace tensorflow } // namespace tensorflow

View File

@ -1979,6 +1979,49 @@ precision: Computed Precision at `k` as a `bool Tensor`.
)doc"); )doc");
// This is the same as `InTopK`, but takes `k` as in input rather than an attr.
REGISTER_OP("InTopKV2")
.Input("predictions: float")
.Input("targets: T")
.Input("k: T")
.Output("precision: bool")
.Attr("T: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle predictions;
ShapeHandle targets;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
DimensionHandle batch_size;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
c->set_output(0, c->Vector(batch_size));
return Status::OK();
})
.Doc(R"doc(
Says whether the targets are in the top `K` predictions.
This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
prediction for the target class is among the top `k` predictions among
all predictions for example `i`. Note that the behavior of `InTopK` differs
from the `TopK` op in its handling of ties; if multiple classes have the
same prediction value and straddle the top-`k` boundary, all of those
classes are considered to be in the top `k`.
More formally, let
\\(predictions_i\\) be the predictions for all classes for example `i`,
\\(targets_i\\) be the target class for example `i`,
\\(out_i\\) be the output for example `i`,
$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
predictions: A `batch_size` x `classes` tensor.
targets: A `batch_size` vector of class ids.
k: Number of top elements to look at for computing precision.
precision: Computed precision at `k` as a `bool Tensor`.
)doc");
namespace { namespace {
Status TopKShapeFn(InferenceContext* c) { Status TopKShapeFn(InferenceContext* c) {

View File

@ -20,7 +20,9 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -69,6 +71,18 @@ class InTopKTest(test.TestCase):
"target.*out of range"): "target.*out of range"):
nn_ops.in_top_k(predictions, target, 2).eval() nn_ops.in_top_k(predictions, target, 2).eval()
def testTensorK(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [0, 2]
k = constant_op.constant(3)
np_ans = np.array([False, True])
with self.test_session():
# TODO (yongtang): The test will be switch to nn_ops.in_top
# once nn_ops.in_top points to _in_top_kv2 later
precision = gen_nn_ops._in_top_kv2(predictions, target, k)
out = precision.eval()
self.assertAllClose(np_ans, out)
self.assertShapeEqual(np_ans, precision)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -302,6 +302,8 @@ Softmax
LogSoftmax LogSoftmax
FractionalAvgPoolGrad FractionalAvgPoolGrad
FractionalMaxPoolGrad FractionalMaxPoolGrad
InTopK
InTopKV2
# parsing_ops # parsing_ops
ParseExample ParseExample

View File

@ -2085,3 +2085,36 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
rates=rates, rates=rates,
padding=padding, padding=padding,
name=name)) name=name))
def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
prediction for the target class is among the top `k` predictions among
all predictions for example `i`. Note that the behavior of `InTopK` differs
from the `TopK` op in its handling of ties; if multiple classes have the
same prediction value and straddle the top-`k` boundary, all of those
classes are considered to be in the top `k`.
More formally, let
\\(predictions_i\\) be the predictions for all classes for example `i`,
\\(targets_i\\) be the target class for example `i`,
\\(out_i\\) be the output for example `i`,
$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
Args:
predictions: A `Tensor` of type `float32`.
A `batch_size` x `classes` tensor.
targets: A `Tensor`. Must be one of the following types: `int32`, `int64`.
A `batch_size` vector of class ids.
k: An `int`. Number of top elements to look at for computing precision.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`.
"""
with ops.name_scope(name, 'in_top_k'):
# TODO (yongtang): Need to switch to v2 after 3 weeks.
return gen_nn_ops._in_top_kv2(predictions, targets, k, name=name)