[TF2XLA] Make tf.argmin stable on XLA:TPU
PiperOrigin-RevId: 324851314 Change-Id: Icdecbe87c545d4254bcdb508f76e31de30bc8f86
This commit is contained in:
parent
4e03f13e6a
commit
6acd86d539
@ -71,7 +71,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
|||||||
if (is_gpu_) {
|
if (is_gpu_) {
|
||||||
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
|
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
|
||||||
} else {
|
} else {
|
||||||
output = xla::ArgMin(input, index_xla_type, axis);
|
output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (is_gpu_) {
|
if (is_gpu_) {
|
||||||
|
|||||||
@ -293,9 +293,6 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
@test_util.disable_mlir_bridge('TODO(b/162271237): argmax gives different'
|
@test_util.disable_mlir_bridge('TODO(b/162271237): argmax gives different'
|
||||||
' results in MLIR-based bridge')
|
' results in MLIR-based bridge')
|
||||||
def testArgMinMax(self):
|
def testArgMinMax(self):
|
||||||
if 'tpu' in self.device.lower():
|
|
||||||
self.skipTest('b/162800904: Tie resolution is wrong on TPU for tf.func')
|
|
||||||
|
|
||||||
with ops.device('device:{}:0'.format(self.device)):
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
@def_function.function(experimental_compile=True)
|
@def_function.function(experimental_compile=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user