Fix tests that fail with tf32
This commit is contained in:
parent
2638bb9920
commit
5445a05db7
tensorflow
c/eager
python
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -56,6 +57,9 @@ Status RegisterGradients(GradientRegistry* registry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||||
|
// Computing numerical gradients with TensorFloat-32 is numerically unstable
|
||||||
|
enable_tensor_float_32_execution(false);
|
||||||
|
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
AbstractContextPtr ctx;
|
AbstractContextPtr ctx;
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -43,6 +44,10 @@ class CppGradients
|
|||||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||||
Status s = StatusFromTF_Status(status.get());
|
Status s = StatusFromTF_Status(status.get());
|
||||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
// Computing numerical gradients with TensorFloat-32 is numerically unstable.
|
||||||
|
// Some forward pass tests also fail with TensorFloat-32 due to low tolerances
|
||||||
|
enable_tensor_float_32_execution(false);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1487,6 +1487,8 @@ class FunctionCaptureByValueTest(test.TestCase):
|
|||||||
self.assertAllEqual(y, [[12.0]])
|
self.assertAllEqual(y, [[12.0]])
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_without_tensor_float_32(
|
||||||
|
"Calls matmul in custom LSTM function")
|
||||||
class UnrollLSTMTest(test.TestCase):
|
class UnrollLSTMTest(test.TestCase):
|
||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
LSTM_DIMS = 32
|
LSTM_DIMS = 32
|
||||||
@ -1593,7 +1595,6 @@ class UnrollLSTMTest(test.TestCase):
|
|||||||
self.assertAllClose(mv0, mv2, rtol=1e-4)
|
self.assertAllClose(mv0, mv2, rtol=1e-4)
|
||||||
self.assertAllClose(mv0, mv3, rtol=1e-4)
|
self.assertAllClose(mv0, mv3, rtol=1e-4)
|
||||||
|
|
||||||
@test_util.run_without_tensor_float_32("Calls matmul in custom LSTM function")
|
|
||||||
def testUnrollLSTMGrad(self):
|
def testUnrollLSTMGrad(self):
|
||||||
# Run one step of the unrolled lstm graph.
|
# Run one step of the unrolled lstm graph.
|
||||||
def RunForwardBackward(mode, cfg=None):
|
def RunForwardBackward(mode, cfg=None):
|
||||||
|
@ -1977,6 +1977,9 @@ def matmul_without_tf32(a, b, *args, **kwargs):
|
|||||||
If a matmul itself is being tested, or some other op which uses matmul, use
|
If a matmul itself is being tested, or some other op which uses matmul, use
|
||||||
`run_without_tensor_float_32` instead.
|
`run_without_tensor_float_32` instead.
|
||||||
|
|
||||||
|
This also casts complex64 inputs to complex128, since TensorFloat-32 can also
|
||||||
|
be used with complex64
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: First input to tf.linalg.matmul
|
a: First input to tf.linalg.matmul
|
||||||
b: Second input to tf.linalg.matmul
|
b: Second input to tf.linalg.matmul
|
||||||
@ -1991,6 +1994,11 @@ def matmul_without_tf32(a, b, *args, **kwargs):
|
|||||||
b = math_ops.cast(b, "float64")
|
b = math_ops.cast(b, "float64")
|
||||||
ret = math_ops.matmul(a, b, *args, **kwargs)
|
ret = math_ops.matmul(a, b, *args, **kwargs)
|
||||||
return math_ops.cast(ret, a.dtype)
|
return math_ops.cast(ret, a.dtype)
|
||||||
|
elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
|
||||||
|
a = math_ops.cast(a, "complex128")
|
||||||
|
b = math_ops.cast(b, "complex128")
|
||||||
|
ret = math_ops.matmul(a, b, *args, **kwargs)
|
||||||
|
return math_ops.cast(ret, a.dtype)
|
||||||
else:
|
else:
|
||||||
return math_ops.matmul(a, b, *args, **kwargs)
|
return math_ops.matmul(a, b, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -112,12 +112,12 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
|
|||||||
else:
|
else:
|
||||||
tol = 1e-14
|
tol = 1e-14
|
||||||
# Tests that a ~= q*r.
|
# Tests that a ~= q*r.
|
||||||
a_recon = math_ops.matmul(q, r)
|
a_recon = test_util.matmul_without_tf32(q, r)
|
||||||
self.assertAllClose(a_recon, a, rtol=tol, atol=tol)
|
self.assertAllClose(a_recon, a, rtol=tol, atol=tol)
|
||||||
|
|
||||||
def CheckUnitary(self, x):
|
def CheckUnitary(self, x):
|
||||||
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
|
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
|
||||||
xx = math_ops.matmul(x, x, adjoint_a=True)
|
xx = test_util.matmul_without_tf32(x, x, adjoint_a=True)
|
||||||
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
|
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
|
||||||
if is_single:
|
if is_single:
|
||||||
tol = 1e-5
|
tol = 1e-5
|
||||||
|
Loading…
Reference in New Issue
Block a user