Merge pull request #43553 from Intel-tensorflow:fix-3898

PiperOrigin-RevId: 334050084
Change-Id: I3fc5be7c80e43812e628216f629c5ff3bfeaf907
This commit is contained in:
TensorFlower Gardener 2020-09-27 15:59:12 -07:00
commit 652291e217
3 changed files with 12 additions and 6 deletions

View File

@ -216,6 +216,8 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
switch (x.dtype()) {
case DT_HALF:
return ExpectClose<Eigen::half>(x, y, atol, rtol);
case DT_BFLOAT16:
return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
case DT_FLOAT:
return ExpectClose<float>(x, y, atol, rtol);
case DT_DOUBLE:

View File

@ -451,8 +451,10 @@ class RemapperFuseMatMulWithBiasTest : public RemapperTest {
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
typedef typename EnumToDataType<DTYPE>::Type T;
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
if (DTYPE == DT_BFLOAT16)
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
else
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
}
};
@ -704,8 +706,10 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
typedef typename EnumToDataType<DTYPE>::Type T;
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
if (DTYPE == DT_BFLOAT16)
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
else
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
}
}
};

View File

@ -18,6 +18,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/relu_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -68,7 +69,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
SeluGradOp<CPUDevice, type>)
// Elu and Selu only make sense with float or double.
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
TF_CALL_FLOAT_TYPES(REGISTER_ELU_KERNELS);
#undef REGISTER_ELU_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -208,5 +209,4 @@ REGISTER_KERNEL_BUILDER(
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow