Stops getting address of zero-sided matrix, which indeed dereferences null.

PiperOrigin-RevId: 329709701
Change-Id: I62648769c0066d0fa904edfbef82bdcfc414904c
This commit is contained in:
A. Unique TensorFlower 2020-09-02 07:34:57 -07:00 committed by TensorFlower Gardener
parent ea462c5eac
commit bed49a80b4
2 changed files with 19 additions and 5 deletions

View File

@ -476,12 +476,12 @@ inline SparseTensor SparseTensor::Concat(
// Fill in indices & values. // Fill in indices & values.
if (st_num_entries > 0) { if (st_num_entries > 0) {
std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset)); std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
}
const auto* st_ix = &st.ix_.matrix<int64>()(0, 0); const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
auto* ix_out = &ix_t(offset, 0); auto* ix_out = &ix_t(offset, 0);
for (std::size_t i = 0; i < st_num_entries * dims; ++i) { for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
*ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0); *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
}
} }
offset += st_num_entries; offset += st_num_entries;

View File

@ -592,6 +592,20 @@ TEST(SparseTensorTest, Concat) {
EXPECT_EQ(conc_ooo.num_entries(), 4 * N); EXPECT_EQ(conc_ooo.num_entries(), 4 * N);
} }
TEST(SparseTensorTest, ConcatEmptyN) {
constexpr int N = 0;
constexpr int NDIM = 2;
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
Tensor vals(DT_STRING, TensorShape({N}));
TensorShape shape({10, 10});
SparseTensor st;
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 1}, &st));
SparseTensor concatted = SparseTensor::Concat<tstring>({st, st, st});
EXPECT_EQ(concatted.num_entries(), 0);
}
// TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output) // TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output)
// reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and // reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and
// slices of resorted indices on generator. // slices of resorted indices on generator.