Allow subslicing Tensors with a single dimension.
PiperOrigin-RevId: 214553359
This commit is contained in:
parent
6666516f39
commit
f2b17b22e1
@ -813,7 +813,7 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor Tensor::SubSlice(int64 index) const {
|
Tensor Tensor::SubSlice(int64 index) const {
|
||||||
CHECK_GE(dims(), 2); // Crash ok.
|
CHECK_GE(dims(), 1); // Crash ok.
|
||||||
CHECK_LE(0, index); // Crash ok.
|
CHECK_LE(0, index); // Crash ok.
|
||||||
int64 dim0_size = shape_.dim_size(0);
|
int64 dim0_size = shape_.dim_size(0);
|
||||||
CHECK_LE(index, dim0_size); // Crash ok.
|
CHECK_LE(index, dim0_size); // Crash ok.
|
||||||
|
@ -219,7 +219,7 @@ class Tensor {
|
|||||||
/// must check the returned tensor's alignment before calling certain
|
/// must check the returned tensor's alignment before calling certain
|
||||||
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
|
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
|
||||||
///
|
///
|
||||||
/// REQUIRES: `dims()` >= 2
|
/// REQUIRES: `dims()` >= 1
|
||||||
/// REQUIRES: `0 <= dim0_start < dim_size(0)`
|
/// REQUIRES: `0 <= dim0_start < dim_size(0)`
|
||||||
Tensor SubSlice(int64 index) const;
|
Tensor SubSlice(int64 index) const;
|
||||||
|
|
||||||
|
@ -1246,6 +1246,9 @@ TEST(Tensor, SubSlice_Basic) {
|
|||||||
EXPECT_EQ(&tx(5, j, k), &ty(j, k));
|
EXPECT_EQ(&tx(5, j, k), &ty(j, k));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Tensor z = y.SubSlice(3).SubSlice(31);
|
||||||
|
auto tz = z.unaligned_flat<float>();
|
||||||
|
EXPECT_EQ(*tz.data(), 5.0);
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Test unaligned access via a SubSlice.
|
// Test unaligned access via a SubSlice.
|
||||||
|
Loading…
Reference in New Issue
Block a user