Add utility API to implement strided dynamic slice.

PiperOrigin-RevId: 301039665
Change-Id: Ib7bf50a7c27bb531561db482e33bab7b54b1769d
This commit is contained in:
Davide Libenzi 2020-03-15 11:02:14 -07:00 committed by TensorFlower Gardener
parent d08e6aeb49
commit 49dc8bb1d6
2 changed files with 21 additions and 0 deletions

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include <algorithm>
#include <limits>
#include <vector>
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
@ -24,6 +26,18 @@ limitations under the License.
namespace xla {
XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
absl::Span<const int64> window_sizes,
absl::Span<const int64> strides) {
XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
if (std::any_of(strides.begin(), strides.end(),
[](int64 stride) { return stride != 1; })) {
sliced_input = Slice(sliced_input, std::vector<int64>(window_sizes.size()),
window_sizes, strides);
}
return sliced_input;
}
XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
absl::Span<const int64> end) {
XlaBuilder* builder = x.builder();

View File

@ -22,6 +22,13 @@ limitations under the License.
namespace xla {
// Slices input starting from the base_indices and within the window_sizes,
// using the supplied strides. This is the equivalent of the Python slicing op
// [base_indices : base_indices+window_sizes : stride].
XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
absl::Span<const int64> window_sizes,
absl::Span<const int64> strides);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start);