diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc index 54de6888508..6765a5af74d 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.cc +++ b/tensorflow/core/kernels/data/dataset_test_base.cc @@ -18,12 +18,39 @@ limitations under the License. namespace tensorflow { namespace data { +template +Status IsEqual(const Tensor& t1, const Tensor& t2) { + if (t1.dtype() != t2.dtype()) { + return tensorflow::errors::Internal( + "Two tensors have different dtypes: ", DataTypeString(t1.dtype()), + " vs. ", DataTypeString(t2.dtype())); + } + if (!t1.IsSameSize(t2)) { + return tensorflow::errors::Internal( + "Two tensors have different shapes: ", t1.shape().DebugString(), + " vs. ", t2.shape().DebugString()); + } + + auto flat_t1 = t1.flat(); + auto flat_t2 = t2.flat(); + auto length = flat_t1.size(); + + for (int i = 0; i < length; ++i) { + if (flat_t1(i) != flat_t2(i)) { + return tensorflow::errors::Internal( + "Two tensors have different values " + "at [", + i, "]: ", flat_t1(i), " vs. ", flat_t2(i)); + } + } + return Status::OK(); +} + Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) { - EXPECT_EQ(a.dtype(), b.dtype()); switch (a.dtype()) { -#define CASE(type) \ - case DataTypeToEnum::value: \ - test::ExpectTensorEqual(a, b); \ +#define CASE(DT) \ + case DataTypeToEnum
::value: \ + TF_RETURN_IF_ERROR(IsEqual
(a, b)); \ break; TF_CALL_NUMBER_TYPES(CASE); TF_CALL_string(CASE); @@ -36,7 +63,7 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) { } template -bool compare(Tensor t1, Tensor t2) { +bool compare(const Tensor& t1, const Tensor& t2) { auto flat_t1 = t1.flat(); auto flat_t2 = t2.flat(); auto length = std::min(flat_t1.size(), flat_t2.size()); @@ -49,7 +76,7 @@ bool compare(Tensor t1, Tensor t2) { Status DatasetOpsTestBase::ExpectEqual(std::vector produced_tensors, std::vector expected_tensors, - bool expect_items_equal) { + bool compare_order) { if (produced_tensors.size() != expected_tensors.size()) { return Status(tensorflow::errors::Internal( "The two tensor vectors have different size (", produced_tensors.size(), @@ -64,7 +91,7 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector produced_tensors, ")")); } - if (expect_items_equal) { + if (!compare_order) { const DataType& dtype = produced_tensors[0].dtype(); switch (dtype) { #define CASE(DT) \ diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h index ca2be6b9258..803ae9055a1 100644 --- a/tensorflow/core/kernels/data/dataset_test_base.h +++ b/tensorflow/core/kernels/data/dataset_test_base.h @@ -52,11 +52,11 @@ class DatasetOpsTestBase : public ::testing::Test { static Status ExpectEqual(const Tensor& a, const Tensor& b); // The method validates whether the two tensor vectors have the same tensors. - // If `expect_items_equal` is true, the method will only evaluate the two + // If `compare_order` is false, the method will only evaluate the two // vectors have the same elements regardless of order. static Status ExpectEqual(std::vector produced_tensors, std::vector expected_tensors, - bool expect_items_equal); + bool compare_order); // Creates a tensor with the specified dtype, shape, and value. template diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc index 3c3d1dec2b0..6f30cce3fe1 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc @@ -494,7 +494,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, GetNext) { } TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, - /*expect_items_equal*/ test_case.sloppy)); + /*compare_order*/ !test_case.sloppy)); } TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) { @@ -949,7 +949,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Roundtrip) { } TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, - /*expect_items_equal*/ test_case.sloppy)); + /*compare_order*/ !test_case.sloppy)); } INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc index dc1ff9f5094..abb6e81aff6 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc @@ -334,7 +334,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, GetNext) { } TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, - /*expect_items_equal*/ test_case.sloppy)); + /*compare_order*/ !test_case.sloppy)); } TEST_F(ParallelMapDatasetOpTest, DatasetNodeName) { @@ -769,7 +769,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, Roundtrip) { } TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs, - /*expect_items_equal*/ test_case.sloppy)); + /*compare_order*/ !test_case.sloppy)); } TEST_F(ParallelMapDatasetOpTest, InvalidNumParallelCalls) {