[TF2XLA] Make tf.argmin stable on XLA:TPU
PiperOrigin-RevId: 324851314 Change-Id: Icdecbe87c545d4254bcdb508f76e31de30bc8f86
This commit is contained in:
parent
4e03f13e6a
commit
6acd86d539
@ -71,7 +71,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
if (is_gpu_) {
|
||||
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
|
||||
} else {
|
||||
output = xla::ArgMin(input, index_xla_type, axis);
|
||||
output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true);
|
||||
}
|
||||
} else {
|
||||
if (is_gpu_) {
|
||||
|
||||
@ -293,9 +293,6 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
@test_util.disable_mlir_bridge('TODO(b/162271237): argmax gives different'
|
||||
' results in MLIR-based bridge')
|
||||
def testArgMinMax(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('b/162800904: Tie resolution is wrong on TPU for tf.func')
|
||||
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user