Bug fix for remapper test

This commit is contained in:
ShengYang1 2020-09-25 14:36:52 +08:00
parent 7893e4bcc1
commit e77349d7e2
3 changed files with 12 additions and 7 deletions

View File

@ -216,6 +216,8 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
switch (x.dtype()) { switch (x.dtype()) {
case DT_HALF: case DT_HALF:
return ExpectClose<Eigen::half>(x, y, atol, rtol); return ExpectClose<Eigen::half>(x, y, atol, rtol);
case DT_BFLOAT16:
return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
case DT_FLOAT: case DT_FLOAT:
return ExpectClose<float>(x, y, atol, rtol); return ExpectClose<float>(x, y, atol, rtol);
case DT_DOUBLE: case DT_DOUBLE:

View File

@ -451,8 +451,10 @@ class RemapperFuseMatMulWithBiasTest : public RemapperTest {
ASSERT_EQ(tensors_expected.size(), 1); ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed); auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1); ASSERT_EQ(tensors.size(), 1);
typedef typename EnumToDataType<DTYPE>::Type T; if (DTYPE == DT_BFLOAT16)
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6); test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
else
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
} }
}; };
@ -704,8 +706,10 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
ASSERT_EQ(tensors_expected.size(), 1); ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed); auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1); ASSERT_EQ(tensors.size(), 1);
typedef typename EnumToDataType<DTYPE>::Type T; if (DTYPE == DT_BFLOAT16)
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6); test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
else
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
} }
} }
}; };

View File

@ -18,12 +18,12 @@ limitations under the License.
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/relu_op.h" #include "tensorflow/core/kernels/relu_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow { namespace tensorflow {
@ -68,7 +68,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
SeluGradOp<CPUDevice, type>) SeluGradOp<CPUDevice, type>)
// Elu and Selu only make sense with float or double. // Elu and Selu only make sense with float or double.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS); TF_CALL_FLOAT_TYPES(REGISTER_ELU_KERNELS);
#undef REGISTER_ELU_KERNELS #undef REGISTER_ELU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -208,5 +208,4 @@ REGISTER_KERNEL_BUILDER(
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow