Using unaligned_flat() instead of flat() for Expectors.

Otherwise, the Expectors may crash on unaligned Tensors such as what's returned by Tensor::Slice().

PiperOrigin-RevId: 299956868
Change-Id: I2fd7f7db2f03c4d536b6923eb2f876df931e9b09
This commit is contained in:
A. Unique TensorFlower 2020-03-09 15:58:02 -07:00 committed by TensorFlower Gardener
parent 8f6e6c541f
commit 97e1e8091a
2 changed files with 16 additions and 6 deletions

View File

@ -201,8 +201,8 @@ struct Expector<T, false> {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
AssertSameTypeDims(x, y);
const auto size = x.NumElements();
const T* a = x.flat<T>().data();
const T* b = y.flat<T>().data();
const T* a = x.unaligned_flat<T>().data();
const T* b = y.unaligned_flat<T>().data();
for (int i = 0; i < size; ++i) {
ExpectEqual(a[i], b[i]);
}
@ -218,8 +218,8 @@ struct Expector<T, true> {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
AssertSameTypeDims(x, y);
const auto size = x.NumElements();
const T* a = x.flat<T>().data();
const T* b = y.flat<T>().data();
const T* a = x.unaligned_flat<T>().data();
const T* b = y.unaligned_flat<T>().data();
for (int i = 0; i < size; ++i) {
ExpectEqual(a[i], b[i]);
}
@ -235,8 +235,8 @@ struct Expector<T, true> {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
AssertSameTypeDims(x, y);
const auto size = x.NumElements();
const T* a = x.flat<T>().data();
const T* b = y.flat<T>().data();
const T* a = x.unaligned_flat<T>().data();
const T* b = y.unaligned_flat<T>().data();
for (int i = 0; i < size; ++i) {
EXPECT_TRUE(Near(a[i], b[i], abs_err))
<< "a = " << a[i] << " b = " << b[i] << " index = " << i;

View File

@ -184,6 +184,16 @@ TEST(TensorTestUtilTest, ExpectTensorNearDouble) {
TestEdgeCasesNear<T>();
}
// Tensor::Slice() and Tensor::SubSlice() may return unaligned Tensor.
TEST(TensorTestUtilTest, ExpectTensorNearSlice) {
Tensor x(DT_FLOAT, TensorShape({7, 3}));
test::FillFn<float>(&x, [](int i) -> float { return 1.0; });
test::ExpectTensorNear<float>(
x.SubSlice(3), test::AsTensor<float>({1.0, 1.0, 1.0}, TensorShape({3})),
1e-10);
}
static const double kSlackFactor = 5.0;
template <typename T>