Stops getting address of zero-sided matrix, which indeed dereferences null.
PiperOrigin-RevId: 329709701 Change-Id: I62648769c0066d0fa904edfbef82bdcfc414904c
This commit is contained in:
parent
ea462c5eac
commit
bed49a80b4
@ -476,12 +476,12 @@ inline SparseTensor SparseTensor::Concat(
|
||||
// Fill in indices & values.
|
||||
if (st_num_entries > 0) {
|
||||
std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
|
||||
}
|
||||
|
||||
const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
|
||||
auto* ix_out = &ix_t(offset, 0);
|
||||
for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
|
||||
*ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
|
||||
const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
|
||||
auto* ix_out = &ix_t(offset, 0);
|
||||
for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
|
||||
*ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
|
||||
}
|
||||
}
|
||||
|
||||
offset += st_num_entries;
|
||||
|
@ -592,6 +592,20 @@ TEST(SparseTensorTest, Concat) {
|
||||
EXPECT_EQ(conc_ooo.num_entries(), 4 * N);
|
||||
}
|
||||
|
||||
TEST(SparseTensorTest, ConcatEmptyN) {
|
||||
constexpr int N = 0;
|
||||
constexpr int NDIM = 2;
|
||||
Tensor ix(DT_INT64, TensorShape({N, NDIM}));
|
||||
Tensor vals(DT_STRING, TensorShape({N}));
|
||||
TensorShape shape({10, 10});
|
||||
SparseTensor st;
|
||||
TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, {0, 1}, &st));
|
||||
|
||||
SparseTensor concatted = SparseTensor::Concat<tstring>({st, st, st});
|
||||
|
||||
EXPECT_EQ(concatted.num_entries(), 0);
|
||||
}
|
||||
|
||||
// TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output)
|
||||
// reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and
|
||||
// slices of resorted indices on generator.
|
||||
|
Loading…
Reference in New Issue
Block a user