Support placeholder for parameter k in tf.nn.in_top_k (#11197)
* Support placeholder for parameter k in tf.nn.in_top_k This fix tries to address the issue raised in #9717 where it was not possible to have tensor for k in nn.in_top_k. This fix adds the implementation of InTopKV2Op, adds addition test cases, and following similiar workflow in #10840: 1. Register new kennel InTopKV2Op 2. Hide InTopK and InTopKV2 in python (tensorflow/python/ops/hidden_ops.txt) 3. Add a wrapper in_top_k (in tensorflow/python/ops/nn_ops.py) pointing to gen_nn_ops._in_top_k Another PR will be created after 3 weeks once this PR is merged: 1. Change the implementation of the wrapper in_top_k (in tensorflow/python/ops/nn_ops.py) pointing to gen_nn_ops._in_top_kv2 Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Address review comments Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add HostMemory to InTopK kernel Add HostMemory to InTopK kernel based on the review feedback Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
bad0e44c17
commit
4d5ff2ee95
@ -17,11 +17,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -29,12 +29,29 @@ template <typename T, typename TARGET_T>
|
|||||||
class InTopK : public OpKernel {
|
class InTopK : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
|
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
|
if (context->num_inputs() == 2) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const auto& predictions_in = context->input(0);
|
const auto& predictions_in = context->input(0);
|
||||||
const auto& targets_in = context->input(1);
|
const auto& targets_in = context->input(1);
|
||||||
|
int64 k_val = k_;
|
||||||
|
if (context->num_inputs() == 3) {
|
||||||
|
const auto& k_in = context->input(2);
|
||||||
|
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
|
||||||
|
errors::InvalidArgument("k must be 0-D, got shape ",
|
||||||
|
k_in.shape().DebugString()));
|
||||||
|
|
||||||
|
if (k_in.dtype() == DT_INT32) {
|
||||||
|
k_val = k_in.scalar<int32>()();
|
||||||
|
} else {
|
||||||
|
k_val = k_in.scalar<int64>()();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
OP_REQUIRES(context, predictions_in.dims() == 2,
|
OP_REQUIRES(context, predictions_in.dims() == 2,
|
||||||
errors::InvalidArgument("predictions must be 2-dimensional"));
|
errors::InvalidArgument("predictions must be 2-dimensional"));
|
||||||
OP_REQUIRES(context, targets_in.dims() == 1,
|
OP_REQUIRES(context, targets_in.dims() == 1,
|
||||||
@ -73,7 +90,7 @@ class InTopK : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out(b) = cannot_say ? false : (more_probable_classes < k_);
|
out(b) = cannot_say ? false : (more_probable_classes < k_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,10 +99,35 @@ class InTopK : public OpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int32>("T"),
|
Name("InTopK").Device(DEVICE_CPU)
|
||||||
|
.HostMemory("predictions")
|
||||||
|
.HostMemory("targets")
|
||||||
|
.HostMemory("precision")
|
||||||
|
.TypeConstraint<int32>("T"),
|
||||||
InTopK<float, int32>);
|
InTopK<float, int32>);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int64>("T"),
|
Name("InTopK").Device(DEVICE_CPU)
|
||||||
|
.HostMemory("predictions")
|
||||||
|
.HostMemory("targets")
|
||||||
|
.HostMemory("precision")
|
||||||
|
.TypeConstraint<int64>("T"),
|
||||||
|
InTopK<float, int64>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("InTopKV2").Device(DEVICE_CPU)
|
||||||
|
.HostMemory("predictions")
|
||||||
|
.HostMemory("targets")
|
||||||
|
.HostMemory("k")
|
||||||
|
.HostMemory("precision")
|
||||||
|
.TypeConstraint<int32>("T"),
|
||||||
|
InTopK<float, int32>);
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("InTopKV2").Device(DEVICE_CPU)
|
||||||
|
.HostMemory("predictions")
|
||||||
|
.HostMemory("targets")
|
||||||
|
.HostMemory("k")
|
||||||
|
.HostMemory("precision")
|
||||||
|
.TypeConstraint<int64>("T"),
|
||||||
InTopK<float, int64>);
|
InTopK<float, int64>);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1979,6 +1979,49 @@ precision: Computed Precision at `k` as a `bool Tensor`.
|
|||||||
|
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
// This is the same as `InTopK`, but takes `k` as in input rather than an attr.
|
||||||
|
REGISTER_OP("InTopKV2")
|
||||||
|
.Input("predictions: float")
|
||||||
|
.Input("targets: T")
|
||||||
|
.Input("k: T")
|
||||||
|
.Output("precision: bool")
|
||||||
|
.Attr("T: {int32, int64} = DT_INT32")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle predictions;
|
||||||
|
ShapeHandle targets;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
|
||||||
|
DimensionHandle batch_size;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
|
||||||
|
c->set_output(0, c->Vector(batch_size));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
Says whether the targets are in the top `K` predictions.
|
||||||
|
|
||||||
|
This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
|
||||||
|
prediction for the target class is among the top `k` predictions among
|
||||||
|
all predictions for example `i`. Note that the behavior of `InTopK` differs
|
||||||
|
from the `TopK` op in its handling of ties; if multiple classes have the
|
||||||
|
same prediction value and straddle the top-`k` boundary, all of those
|
||||||
|
classes are considered to be in the top `k`.
|
||||||
|
|
||||||
|
More formally, let
|
||||||
|
|
||||||
|
\\(predictions_i\\) be the predictions for all classes for example `i`,
|
||||||
|
\\(targets_i\\) be the target class for example `i`,
|
||||||
|
\\(out_i\\) be the output for example `i`,
|
||||||
|
|
||||||
|
$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
|
||||||
|
|
||||||
|
predictions: A `batch_size` x `classes` tensor.
|
||||||
|
targets: A `batch_size` vector of class ids.
|
||||||
|
k: Number of top elements to look at for computing precision.
|
||||||
|
precision: Computed precision at `k` as a `bool Tensor`.
|
||||||
|
|
||||||
|
)doc");
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Status TopKShapeFn(InferenceContext* c) {
|
Status TopKShapeFn(InferenceContext* c) {
|
||||||
|
@ -20,7 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -69,6 +71,18 @@ class InTopKTest(test.TestCase):
|
|||||||
"target.*out of range"):
|
"target.*out of range"):
|
||||||
nn_ops.in_top_k(predictions, target, 2).eval()
|
nn_ops.in_top_k(predictions, target, 2).eval()
|
||||||
|
|
||||||
|
def testTensorK(self):
|
||||||
|
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||||
|
target = [0, 2]
|
||||||
|
k = constant_op.constant(3)
|
||||||
|
np_ans = np.array([False, True])
|
||||||
|
with self.test_session():
|
||||||
|
# TODO (yongtang): The test will be switch to nn_ops.in_top
|
||||||
|
# once nn_ops.in_top points to _in_top_kv2 later
|
||||||
|
precision = gen_nn_ops._in_top_kv2(predictions, target, k)
|
||||||
|
out = precision.eval()
|
||||||
|
self.assertAllClose(np_ans, out)
|
||||||
|
self.assertShapeEqual(np_ans, precision)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -302,6 +302,8 @@ Softmax
|
|||||||
LogSoftmax
|
LogSoftmax
|
||||||
FractionalAvgPoolGrad
|
FractionalAvgPoolGrad
|
||||||
FractionalMaxPoolGrad
|
FractionalMaxPoolGrad
|
||||||
|
InTopK
|
||||||
|
InTopKV2
|
||||||
|
|
||||||
# parsing_ops
|
# parsing_ops
|
||||||
ParseExample
|
ParseExample
|
||||||
|
@ -2085,3 +2085,36 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
|
|||||||
rates=rates,
|
rates=rates,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
name=name))
|
name=name))
|
||||||
|
|
||||||
|
def in_top_k(predictions, targets, k, name=None):
|
||||||
|
r"""Says whether the targets are in the top `K` predictions.
|
||||||
|
|
||||||
|
This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
|
||||||
|
prediction for the target class is among the top `k` predictions among
|
||||||
|
all predictions for example `i`. Note that the behavior of `InTopK` differs
|
||||||
|
from the `TopK` op in its handling of ties; if multiple classes have the
|
||||||
|
same prediction value and straddle the top-`k` boundary, all of those
|
||||||
|
classes are considered to be in the top `k`.
|
||||||
|
|
||||||
|
More formally, let
|
||||||
|
|
||||||
|
\\(predictions_i\\) be the predictions for all classes for example `i`,
|
||||||
|
\\(targets_i\\) be the target class for example `i`,
|
||||||
|
\\(out_i\\) be the output for example `i`,
|
||||||
|
|
||||||
|
$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: A `Tensor` of type `float32`.
|
||||||
|
A `batch_size` x `classes` tensor.
|
||||||
|
targets: A `Tensor`. Must be one of the following types: `int32`, `int64`.
|
||||||
|
A `batch_size` vector of class ids.
|
||||||
|
k: An `int`. Number of top elements to look at for computing precision.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` of type `bool`. Computed Precision at `k` as a `bool Tensor`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, 'in_top_k'):
|
||||||
|
# TODO (yongtang): Need to switch to v2 after 3 weeks.
|
||||||
|
return gen_nn_ops._in_top_kv2(predictions, targets, k, name=name)
|
||||||
|
Loading…
Reference in New Issue
Block a user