[TF:XLA] Fix tf2xla's in_topk lowering and update tests.
- Fix an off by 1 comparison error. - Model Tensorflow's in_topk behavior for non-finite values. - Update Tensorflow's documentation to indicate behavior on non-finite values. - Enable TF:Classic's in_topk tests to run on XLA. - Correct test cases to check for off-by-1 errors and non-finite values. PiperOrigin-RevId: 248763768
This commit is contained in:
parent
63eaabaeeb
commit
335915673e
@ -81,20 +81,21 @@ class InTopKOp : public XlaOpKernel {
|
||||
xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
|
||||
|
||||
// Calculate in each row of `predictions`, how many values are larger than
|
||||
// the value of target class. Then return the result whether the count <= k,
|
||||
// the value of target class. Then return the result whether the count < k,
|
||||
// which indicates the target is in topk.
|
||||
xla::XlaOp ge_r2 = xla::Ge(predictions_r2, targets_values_r1, {0});
|
||||
xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0});
|
||||
xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
|
||||
xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
|
||||
xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
|
||||
xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
|
||||
xla::XlaOp one_hot_r2 = xla::Select(ge_r2, one_r2, zero_r2);
|
||||
xla::XlaOp num_ge_r1 = xla::Reduce(
|
||||
xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2);
|
||||
xla::XlaOp num_gt_r1 = xla::Reduce(
|
||||
one_hot_r2, zero_r0,
|
||||
xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
|
||||
|
||||
xla::XlaOp result =
|
||||
xla::Le(num_ge_r1, xla::ConstantR0<int32>(xla_builder, k));
|
||||
xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)),
|
||||
xla::IsFinite(targets_values_r1));
|
||||
|
||||
context->SetOutput(0, result);
|
||||
}
|
||||
|
@ -546,6 +546,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:nn_ops",
|
||||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
|
@ -37,12 +37,12 @@ class InTopKTest(test.TestCase):
|
||||
|
||||
def testInTop1(self):
|
||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||
target = [3, 1]
|
||||
target = [3, 2]
|
||||
self._validateInTopK(predictions, target, 1, [True, False])
|
||||
|
||||
def testInTop2(self):
|
||||
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
|
||||
target = [0, 2]
|
||||
target = [2, 2]
|
||||
self._validateInTopK(predictions, target, 2, [False, True])
|
||||
|
||||
def testInTop2Tie(self):
|
||||
@ -58,12 +58,12 @@ class InTopKTest(test.TestCase):
|
||||
|
||||
def testInTopNan(self):
|
||||
predictions = [[0.1, float("nan"), 0.2, 0.4], [0.1, 0.2, 0.3, float("inf")]]
|
||||
target = [0, 2]
|
||||
target = [1, 3]
|
||||
self._validateInTopK(predictions, target, 2, [False, False])
|
||||
|
||||
def testBadTarget(self):
|
||||
predictions = [[0.1, 0.3, 0.2, 0.2], [0.1, 0.3, 0.2, 0.2]]
|
||||
target = [2, 12345] # must return False for invalid target
|
||||
target = [2, 4] # must return False for invalid target
|
||||
self._validateInTopK(predictions, target, 2, [True, False])
|
||||
|
||||
def testTensorK(self):
|
||||
|
@ -4755,11 +4755,11 @@ def in_top_k(predictions, targets, k, name=None):
|
||||
r"""Says whether the targets are in the top `K` predictions.
|
||||
|
||||
This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the
|
||||
prediction for the target class is among the top `k` predictions among
|
||||
all predictions for example `i`. Note that the behavior of `InTopK` differs
|
||||
from the `TopK` op in its handling of ties; if multiple classes have the
|
||||
same prediction value and straddle the top-`k` boundary, all of those
|
||||
classes are considered to be in the top `k`.
|
||||
prediction for the target class is finite (not inf, -inf, or nan) and among
|
||||
the top `k` predictions among all predictions for example `i`. Note that the
|
||||
behavior of `InTopK` differs from the `TopK` op in its handling of ties; if
|
||||
multiple classes have the same prediction value and straddle the top-`k`
|
||||
boundary, all of those classes are considered to be in the top `k`.
|
||||
|
||||
More formally, let
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user