[TF:XLA] Increase tolerance testing XLA's AtrousConv2d gradient.
Auto tuning on the GPU caused this error tolerance to be exceeded very often with XLA. PiperOrigin-RevId: 231248158
This commit is contained in:
parent
54dc5baa26
commit
367f164359
@ -139,7 +139,6 @@ class AtrousConv2DTest(test.TestCase):
|
|||||||
y1.eval(), self.evaluate(y2), rtol=1e-2, atol=1e-2)
|
y1.eval(), self.evaluate(y2), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla("This test never passed for XLA") # larger error range
|
|
||||||
def testGradient(self):
|
def testGradient(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
# Input: [batch, height, width, input_depth]
|
# Input: [batch, height, width, input_depth]
|
||||||
@ -161,7 +160,7 @@ class AtrousConv2DTest(test.TestCase):
|
|||||||
[x_shape, f_shape],
|
[x_shape, f_shape],
|
||||||
output, y_shape)
|
output, y_shape)
|
||||||
print("atrous_conv2d gradient err = %g " % err)
|
print("atrous_conv2d gradient err = %g " % err)
|
||||||
err_tolerance = 1e-3
|
err_tolerance = 4e-3 if test_util.is_xla_enabled() else 1e-3
|
||||||
self.assertLess(err, err_tolerance)
|
self.assertLess(err, err_tolerance)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user