[TF2XLA] Make tf.argmin stable on XLA:TPU

PiperOrigin-RevId: 324851314
Change-Id: Icdecbe87c545d4254bcdb508f76e31de30bc8f86
This commit is contained in:
George Karpenkov 2020-08-04 11:18:29 -07:00 committed by TensorFlower Gardener
parent 4e03f13e6a
commit 6acd86d539
2 changed files with 1 additions and 4 deletions

View File

@ -71,7 +71,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
if (is_gpu_) {
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMin(input, index_xla_type, axis);
output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true);
}
} else {
if (is_gpu_) {

View File

@ -293,9 +293,6 @@ class DefFunctionTest(xla_test.XLATestCase):
@test_util.disable_mlir_bridge('TODO(b/162271237): argmax gives different'
' results in MLIR-based bridge')
def testArgMinMax(self):
if 'tpu' in self.device.lower():
self.skipTest('b/162800904: Tie resolution is wrong on TPU for tf.func')
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)