Fix tests that fail with tf32

This commit is contained in:
Reed 2020-10-21 13:22:05 -07:00
parent 2638bb9920
commit 5445a05db7
5 changed files with 21 additions and 3 deletions

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
}
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
// Computing numerical gradients with TensorFloat-32 is numerically unstable
enable_tensor_float_32_execution(false);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -43,6 +44,10 @@ class CppGradients
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
// Computing numerical gradients with TensorFloat-32 is numerically unstable.
// Some forward pass tests also fail with TensorFloat-32 due to low tolerances
enable_tensor_float_32_execution(false);
}
};

View File

@ -1487,6 +1487,8 @@ class FunctionCaptureByValueTest(test.TestCase):
self.assertAllEqual(y, [[12.0]])
@test_util.run_all_without_tensor_float_32(
"Calls matmul in custom LSTM function")
class UnrollLSTMTest(test.TestCase):
BATCH_SIZE = 16
LSTM_DIMS = 32
@ -1593,7 +1595,6 @@ class UnrollLSTMTest(test.TestCase):
self.assertAllClose(mv0, mv2, rtol=1e-4)
self.assertAllClose(mv0, mv3, rtol=1e-4)
@test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function")
def testUnrollLSTMGrad(self):
# Run one step of the unrolled lstm graph.
def RunForwardBackward(mode, cfg=None):

View File

@ -1977,6 +1977,9 @@ def matmul_without_tf32(a, b, *args, **kwargs):
If a matmul itself is being tested, or some other op which uses matmul, use
`run_without_tensor_float_32` instead.
This also casts complex64 inputs to complex128, since TensorFloat-32 can also
be used with complex64
Args:
a: First input to tf.linalg.matmul
b: Second input to tf.linalg.matmul
@ -1991,6 +1994,11 @@ def matmul_without_tf32(a, b, *args, **kwargs):
b = math_ops.cast(b, "float64")
ret = math_ops.matmul(a, b, *args, **kwargs)
return math_ops.cast(ret, a.dtype)
elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
a = math_ops.cast(a, "complex128")
b = math_ops.cast(b, "complex128")
ret = math_ops.matmul(a, b, *args, **kwargs)
return math_ops.cast(ret, a.dtype)
else:
return math_ops.matmul(a, b, *args, **kwargs)

View File

@ -112,12 +112,12 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
else:
tol = 1e-14
# Tests that a ~= q*r.
a_recon = math_ops.matmul(q, r)
a_recon = test_util.matmul_without_tf32(q, r)
self.assertAllClose(a_recon, a, rtol=tol, atol=tol)
def CheckUnitary(self, x):
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
xx = math_ops.matmul(x, x, adjoint_a=True)
xx = test_util.matmul_without_tf32(x, x, adjoint_a=True)
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
if is_single:
tol = 1e-5