Add utility API to implement strided dynamic slice.
PiperOrigin-RevId: 301039665 Change-Id: Ib7bf50a7c27bb531561db482e33bab7b54b1769d
This commit is contained in:
parent
d08e6aeb49
commit
49dc8bb1d6
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
@ -24,6 +26,18 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
|
||||
absl::Span<const int64> window_sizes,
|
||||
absl::Span<const int64> strides) {
|
||||
XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
|
||||
if (std::any_of(strides.begin(), strides.end(),
|
||||
[](int64 stride) { return stride != 1; })) {
|
||||
sliced_input = Slice(sliced_input, std::vector<int64>(window_sizes.size()),
|
||||
window_sizes, strides);
|
||||
}
|
||||
return sliced_input;
|
||||
}
|
||||
|
||||
XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
|
||||
absl::Span<const int64> end) {
|
||||
XlaBuilder* builder = x.builder();
|
||||
|
||||
@ -22,6 +22,13 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Slices input starting from the base_indices and within the window_sizes,
|
||||
// using the supplied strides. This is the equivalent of the Python slicing op
|
||||
// [base_indices : base_indices+window_sizes : stride].
|
||||
XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
|
||||
absl::Span<const int64> window_sizes,
|
||||
absl::Span<const int64> strides);
|
||||
|
||||
// Updates a slice of 'x', i.e.,
|
||||
// x[start[0], ..., start[n]] = update
|
||||
XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user