Fix tests that fail with tf32
This commit is contained in:
parent
2638bb9920
commit
5445a05db7
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically unstable
|
||||
enable_tensor_float_32_execution(false);
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -43,6 +44,10 @@ class CppGradients
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically unstable.
|
||||
// Some forward pass tests also fail with TensorFloat-32 due to low tolerances
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1487,6 +1487,8 @@ class FunctionCaptureByValueTest(test.TestCase):
|
||||
self.assertAllEqual(y, [[12.0]])
|
||||
|
||||
|
||||
@test_util.run_all_without_tensor_float_32(
|
||||
"Calls matmul in custom LSTM function")
|
||||
class UnrollLSTMTest(test.TestCase):
|
||||
BATCH_SIZE = 16
|
||||
LSTM_DIMS = 32
|
||||
@ -1593,7 +1595,6 @@ class UnrollLSTMTest(test.TestCase):
|
||||
self.assertAllClose(mv0, mv2, rtol=1e-4)
|
||||
self.assertAllClose(mv0, mv3, rtol=1e-4)
|
||||
|
||||
@test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function")
|
||||
def testUnrollLSTMGrad(self):
|
||||
# Run one step of the unrolled lstm graph.
|
||||
def RunForwardBackward(mode, cfg=None):
|
||||
|
@ -1977,6 +1977,9 @@ def matmul_without_tf32(a, b, *args, **kwargs):
|
||||
If a matmul itself is being tested, or some other op which uses matmul, use
|
||||
`run_without_tensor_float_32` instead.
|
||||
|
||||
This also casts complex64 inputs to complex128, since TensorFloat-32 can also
|
||||
be used with complex64
|
||||
|
||||
Args:
|
||||
a: First input to tf.linalg.matmul
|
||||
b: Second input to tf.linalg.matmul
|
||||
@ -1991,6 +1994,11 @@ def matmul_without_tf32(a, b, *args, **kwargs):
|
||||
b = math_ops.cast(b, "float64")
|
||||
ret = math_ops.matmul(a, b, *args, **kwargs)
|
||||
return math_ops.cast(ret, a.dtype)
|
||||
elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
|
||||
a = math_ops.cast(a, "complex128")
|
||||
b = math_ops.cast(b, "complex128")
|
||||
ret = math_ops.matmul(a, b, *args, **kwargs)
|
||||
return math_ops.cast(ret, a.dtype)
|
||||
else:
|
||||
return math_ops.matmul(a, b, *args, **kwargs)
|
||||
|
||||
|
@ -112,12 +112,12 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
|
||||
else:
|
||||
tol = 1e-14
|
||||
# Tests that a ~= q*r.
|
||||
a_recon = math_ops.matmul(q, r)
|
||||
a_recon = test_util.matmul_without_tf32(q, r)
|
||||
self.assertAllClose(a_recon, a, rtol=tol, atol=tol)
|
||||
|
||||
def CheckUnitary(self, x):
|
||||
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
|
||||
xx = math_ops.matmul(x, x, adjoint_a=True)
|
||||
xx = test_util.matmul_without_tf32(x, x, adjoint_a=True)
|
||||
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
|
||||
if is_single:
|
||||
tol = 1e-5
|
||||
|
Loading…
Reference in New Issue
Block a user