Merge pull request #43553 from Intel-tensorflow:fix-3898
PiperOrigin-RevId: 334050084 Change-Id: I3fc5be7c80e43812e628216f629c5ff3bfeaf907
This commit is contained in:
commit
652291e217
@ -216,6 +216,8 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
|
||||
switch (x.dtype()) {
|
||||
case DT_HALF:
|
||||
return ExpectClose<Eigen::half>(x, y, atol, rtol);
|
||||
case DT_BFLOAT16:
|
||||
return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
|
||||
case DT_FLOAT:
|
||||
return ExpectClose<float>(x, y, atol, rtol);
|
||||
case DT_DOUBLE:
|
||||
|
@ -451,8 +451,10 @@ class RemapperFuseMatMulWithBiasTest : public RemapperTest {
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
||||
if (DTYPE == DT_BFLOAT16)
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
|
||||
else
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
};
|
||||
|
||||
@ -704,8 +706,10 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
typedef typename EnumToDataType<DTYPE>::Type T;
|
||||
test::ExpectTensorNear<T>(tensors[0], tensors_expected[0], 1e-6);
|
||||
if (DTYPE == DT_BFLOAT16)
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
|
||||
else
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/relu_op.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -68,7 +69,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
|
||||
SeluGradOp<CPUDevice, type>)
|
||||
|
||||
// Elu and Selu only make sense with float or double.
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS);
|
||||
TF_CALL_FLOAT_TYPES(REGISTER_ELU_KERNELS);
|
||||
#undef REGISTER_ELU_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
@ -208,5 +209,4 @@ REGISTER_KERNEL_BUILDER(
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user