[TF:XLA] Fix tf2xla's in_topk lowering and update tests.

- Fix an off by 1 comparison error.
- Model Tensorflow's in_topk behavior for non-finite values.
- Update Tensorflow's documentation to indicate behavior on non-finite values.
- Enable TF:Classic's in_topk tests to run on XLA.
- Correct test cases to check for off-by-1 errors and non-finite values.

PiperOrigin-RevId: 248763768
This commit is contained in:
A. Unique TensorFlower 2019-05-17 12:05:52 -07:00 committed by TensorFlower Gardener
parent 63eaabaeeb
commit 335915673e
4 changed files with 16 additions and 14 deletions

View File

@ -81,20 +81,21 @@ class InTopKOp : public XlaOpKernel {
xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
// Calculate in each row of `predictions`, how many values are larger than
// the value of target class. Then return the result whether the count <= k,
// the value of target class. Then return the result whether the count < k,
// which indicates the target is in topk.
xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0});
xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0});
xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2);
xla::XlaOp num_ge_r1 = xla::Reduce(
xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2);
xla::XlaOp num_gt_r1 = xla::Reduce(
one_hot_r2, zero_r0,
xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
xla::XlaOp result =
xla::Le(num_ge_r1, xla::ConstantR0<int32>(xla_builder, k));
xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)),
xla::IsFinite(targets_values_r1));
context->SetOutput(0, result);
}

View File

@ -546,6 +546,7 @@ cuda_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:nn_ops",
],
xla_enable_strict_auto_jit = True,
)
tf_py_test(

View File

@ -37,12 +37,12 @@ class InTopKTest(test.TestCase):
def testInTop1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [3, 1]
target = [3, 2]
self._validateInTopK(predictions, target, 1, [True, False])
def testInTop2(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [0, 2]
target = [2, 2]
self._validateInTopK(predictions, target, 2, [False, True])
def testInTop2Tie(self):
@ -58,12 +58,12 @@ class InTopKTest(test.TestCase):
def testInTopNan(self):
predictions = [[0.1, float("nan"), 0.2, 0.4], [0.1, 0.2, 0.3, float("inf")]]
target = [0, 2]
target = [1, 3]
self._validateInTopK(predictions, target, 2, [False, False])
def testBadTarget(self):
predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]]
target = [2, 12345] # must return False for invalid target
target = [2, 4] # must return False for invalid target
self._validateInTopK(predictions, target, 2, [True, False])
def testTensorK(self):

View File

@ -4755,11 +4755,11 @@ def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
prediction for the target class is among the top `k` predictions among
all predictions for example `i`. Note that the behavior of `InTopK` differs
from the `TopK` op in its handling of ties; if multiple classes have the
same prediction value and straddle the top-`k` boundary, all of those
classes are considered to be in the top `k`.
prediction for the target class is finite (not inf, -inf, or nan) and among
the top `k` predictions among all predictions for example `i`. Note that the
behavior of `InTopK` differs from the `TopK` op in its handling of ties; if
multiple classes have the same prediction value and straddle the top-`k`
boundary, all of those classes are considered to be in the top `k`.
More formally, let