Refactor ExpectEqual() function
1. Rename the argument name from `expect_items_equal` to `compare_order` 2. Enable the function to return tensorflow::errors::Internal if the two input tensor (vectors) are not equal.
This commit is contained in:
parent
1dfaeb0028
commit
1fa2e34d11
@ -18,12 +18,39 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
template <typename T>
|
||||
Status IsEqual(const Tensor& t1, const Tensor& t2) {
|
||||
if (t1.dtype() != t2.dtype()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
|
||||
" vs. ", DataTypeString(t2.dtype()));
|
||||
}
|
||||
if (!t1.IsSameSize(t2)) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Two tensors have different shapes: ", t1.shape().DebugString(),
|
||||
" vs. ", t2.shape().DebugString());
|
||||
}
|
||||
|
||||
auto flat_t1 = t1.flat<T>();
|
||||
auto flat_t2 = t2.flat<T>();
|
||||
auto length = flat_t1.size();
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
if (flat_t1(i) != flat_t2(i)) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Two tensors have different values "
|
||||
"at [",
|
||||
i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
|
||||
EXPECT_EQ(a.dtype(), b.dtype());
|
||||
switch (a.dtype()) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: \
|
||||
test::ExpectTensorEqual<type>(a, b); \
|
||||
#define CASE(DT) \
|
||||
case DataTypeToEnum<DT>::value: \
|
||||
TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
|
||||
break;
|
||||
TF_CALL_NUMBER_TYPES(CASE);
|
||||
TF_CALL_string(CASE);
|
||||
@ -36,7 +63,7 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool compare(Tensor t1, Tensor t2) {
|
||||
bool compare(const Tensor& t1, const Tensor& t2) {
|
||||
auto flat_t1 = t1.flat<T>();
|
||||
auto flat_t2 = t2.flat<T>();
|
||||
auto length = std::min(flat_t1.size(), flat_t2.size());
|
||||
@ -49,7 +76,7 @@ bool compare(Tensor t1, Tensor t2) {
|
||||
|
||||
Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
|
||||
std::vector<Tensor> expected_tensors,
|
||||
bool expect_items_equal) {
|
||||
bool compare_order) {
|
||||
if (produced_tensors.size() != expected_tensors.size()) {
|
||||
return Status(tensorflow::errors::Internal(
|
||||
"The two tensor vectors have different size (", produced_tensors.size(),
|
||||
@ -64,7 +91,7 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
|
||||
")"));
|
||||
}
|
||||
|
||||
if (expect_items_equal) {
|
||||
if (!compare_order) {
|
||||
const DataType& dtype = produced_tensors[0].dtype();
|
||||
switch (dtype) {
|
||||
#define CASE(DT) \
|
||||
|
@ -52,11 +52,11 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
static Status ExpectEqual(const Tensor& a, const Tensor& b);
|
||||
|
||||
// The method validates whether the two tensor vectors have the same tensors.
|
||||
// If `expect_items_equal` is true, the method will only evaluate the two
|
||||
// If `compare_order` is false, the method will only evaluate the two
|
||||
// vectors have the same elements regardless of order.
|
||||
static Status ExpectEqual(std::vector<Tensor> produced_tensors,
|
||||
std::vector<Tensor> expected_tensors,
|
||||
bool expect_items_equal);
|
||||
bool compare_order);
|
||||
|
||||
// Creates a tensor with the specified dtype, shape, and value.
|
||||
template <typename T>
|
||||
|
@ -494,7 +494,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, GetNext) {
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*expect_items_equal*/ test_case.sloppy));
|
||||
/*compare_order*/ !test_case.sloppy));
|
||||
}
|
||||
|
||||
TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
|
||||
@ -949,7 +949,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Roundtrip) {
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*expect_items_equal*/ test_case.sloppy));
|
||||
/*compare_order*/ !test_case.sloppy));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
@ -334,7 +334,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, GetNext) {
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*expect_items_equal*/ test_case.sloppy));
|
||||
/*compare_order*/ !test_case.sloppy));
|
||||
}
|
||||
|
||||
TEST_F(ParallelMapDatasetOpTest, DatasetNodeName) {
|
||||
@ -769,7 +769,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, Roundtrip) {
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
|
||||
/*expect_items_equal*/ test_case.sloppy));
|
||||
/*compare_order*/ !test_case.sloppy));
|
||||
}
|
||||
|
||||
TEST_F(ParallelMapDatasetOpTest, InvalidNumParallelCalls) {
|
||||
|
Loading…
Reference in New Issue
Block a user