Refactor ExpectEqual() function

1. Rename the argument name from `expect_items_equal` to `compare_order`
2. Enable the function to return tensorflow::errors::Internal if the two input tensor (vectors) are not equal.
This commit is contained in:
Fei Hu 2019-04-22 13:57:59 -07:00
parent 1dfaeb0028
commit 1fa2e34d11
4 changed files with 40 additions and 13 deletions

View File

@ -18,12 +18,39 @@ limitations under the License.
namespace tensorflow {
namespace data {
template <typename T>
Status IsEqual(const Tensor& t1, const Tensor& t2) {
if (t1.dtype() != t2.dtype()) {
return tensorflow::errors::Internal(
"Two tensors have different dtypes: ", DataTypeString(t1.dtype()),
" vs. ", DataTypeString(t2.dtype()));
}
if (!t1.IsSameSize(t2)) {
return tensorflow::errors::Internal(
"Two tensors have different shapes: ", t1.shape().DebugString(),
" vs. ", t2.shape().DebugString());
}
auto flat_t1 = t1.flat<T>();
auto flat_t2 = t2.flat<T>();
auto length = flat_t1.size();
for (int i = 0; i < length; ++i) {
if (flat_t1(i) != flat_t2(i)) {
return tensorflow::errors::Internal(
"Two tensors have different values "
"at [",
i, "]: ", flat_t1(i), " vs. ", flat_t2(i));
}
}
return Status::OK();
}
Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
EXPECT_EQ(a.dtype(), b.dtype());
switch (a.dtype()) {
#define CASE(type) \
case DataTypeToEnum<type>::value: \
test::ExpectTensorEqual<type>(a, b); \
#define CASE(DT) \
case DataTypeToEnum<DT>::value: \
TF_RETURN_IF_ERROR(IsEqual<DT>(a, b)); \
break;
TF_CALL_NUMBER_TYPES(CASE);
TF_CALL_string(CASE);
@ -36,7 +63,7 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
}
template <typename T>
bool compare(Tensor t1, Tensor t2) {
bool compare(const Tensor& t1, const Tensor& t2) {
auto flat_t1 = t1.flat<T>();
auto flat_t2 = t2.flat<T>();
auto length = std::min(flat_t1.size(), flat_t2.size());
@ -49,7 +76,7 @@ bool compare(Tensor t1, Tensor t2) {
Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
std::vector<Tensor> expected_tensors,
bool expect_items_equal) {
bool compare_order) {
if (produced_tensors.size() != expected_tensors.size()) {
return Status(tensorflow::errors::Internal(
"The two tensor vectors have different size (", produced_tensors.size(),
@ -64,7 +91,7 @@ Status DatasetOpsTestBase::ExpectEqual(std::vector<Tensor> produced_tensors,
")"));
}
if (expect_items_equal) {
if (!compare_order) {
const DataType& dtype = produced_tensors[0].dtype();
switch (dtype) {
#define CASE(DT) \

View File

@ -52,11 +52,11 @@ class DatasetOpsTestBase : public ::testing::Test {
static Status ExpectEqual(const Tensor& a, const Tensor& b);
// The method validates whether the two tensor vectors have the same tensors.
// If `expect_items_equal` is true, the method will only evaluate the two
// If `compare_order` is false, the method will only evaluate the two
// vectors have the same elements regardless of order.
static Status ExpectEqual(std::vector<Tensor> produced_tensors,
std::vector<Tensor> expected_tensors,
bool expect_items_equal);
bool compare_order);
// Creates a tensor with the specified dtype, shape, and value.
template <typename T>

View File

@ -494,7 +494,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, GetNext) {
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy));
/*compare_order*/ !test_case.sloppy));
}
TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
@ -949,7 +949,7 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Roundtrip) {
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy));
/*compare_order*/ !test_case.sloppy));
}
INSTANTIATE_TEST_SUITE_P(

View File

@ -334,7 +334,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, GetNext) {
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy));
/*compare_order*/ !test_case.sloppy));
}
TEST_F(ParallelMapDatasetOpTest, DatasetNodeName) {
@ -769,7 +769,7 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, Roundtrip) {
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*expect_items_equal*/ test_case.sloppy));
/*compare_order*/ !test_case.sloppy));
}
TEST_F(ParallelMapDatasetOpTest, InvalidNumParallelCalls) {