Bug fix for remapper test
This commit is contained in:
parent
7893e4bcc1
commit
e77349d7e2
@ -216,6 +216,8 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
|
||||
switch (x.dtype()) {
|
||||
case DT_HALF:
|
||||
return ExpectClose<Eigen::half>(x, y, atol, rtol);
|
||||
case DT_BFLOAT16:
|
||||
return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
|
||||
case DT_FLOAT:
|
||||
return ExpectClose<float>(x, y, atol, rtol);
|
||||
case DT_DOUBLE:
|
||||
|
@ -451,8 +451,10 @@ class RemapperFuseMatMulWithBiasTest : public RemapperTest {
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
||||
if (DTYPE == DT_BFLOAT16)
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
|
||||
else
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
@ -704,8 +706,10 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
||||
if (DTYPE == DT_BFLOAT16)
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
|
||||
else
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/relu_op.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -68,7 +68,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
|
||||
SeluGradOp<CPUDevice, type>)
|
||||
|
||||
// Elu and Selu only make sense with float or double.
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_ELU_KERNELS);
|
||||
#undef REGISTER_ELU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
@ -208,5 +208,4 @@ REGISTER_KERNEL_BUILDER(
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user