Bug fix for remapper test
This commit is contained in:
parent
7893e4bcc1
commit
e77349d7e2
@ -216,6 +216,8 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
|
|||||||
switch (x.dtype()) {
|
switch (x.dtype()) {
|
||||||
case DT_HALF:
|
case DT_HALF:
|
||||||
return ExpectClose<Eigen::half>(x, y, atol, rtol);
|
return ExpectClose<Eigen::half>(x, y, atol, rtol);
|
||||||
|
case DT_BFLOAT16:
|
||||||
|
return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
|
||||||
case DT_FLOAT:
|
case DT_FLOAT:
|
||||||
return ExpectClose<float>(x, y, atol, rtol);
|
return ExpectClose<float>(x, y, atol, rtol);
|
||||||
case DT_DOUBLE:
|
case DT_DOUBLE:
|
||||||
|
@ -451,8 +451,10 @@ class RemapperFuseMatMulWithBiasTest : public RemapperTest {
|
|||||||
ASSERT_EQ(tensors_expected.size(), 1);
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
ASSERT_EQ(tensors.size(), 1);
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
if (DTYPE == DT_BFLOAT16)
|
||||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
|
||||||
|
else
|
||||||
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -704,8 +706,10 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
|||||||
ASSERT_EQ(tensors_expected.size(), 1);
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
ASSERT_EQ(tensors.size(), 1);
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
if (DTYPE == DT_BFLOAT16)
|
||||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
|
||||||
|
else
|
||||||
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/relu_op.h"
|
#include "tensorflow/core/kernels/relu_op.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
|
|||||||
SeluGradOp<CPUDevice, type>)
|
SeluGradOp<CPUDevice, type>)
|
||||||
|
|
||||||
// Elu and Selu only make sense with float or double.
|
// Elu and Selu only make sense with float or double.
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
|
TF_CALL_FLOAT_TYPES(REGISTER_ELU_KERNELS);
|
||||||
#undef REGISTER_ELU_KERNELS
|
#undef REGISTER_ELU_KERNELS
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
@ -208,5 +208,4 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user