[XLA] Support strided slice grad in a while loop.

- The inputs to strided slice are stored in a stack for backward pass.
- Popping items from the stack makes the input unknown in backward pass.
- In the case where the begins and ends are unknown, lower the strided slice grads into dynamic update slice instead.

PiperOrigin-RevId: 330987081
Change-Id: I0116a02f2fd7d660b49757622afc9934bb4b37e6
This commit is contained in:
Yunxing Dai 2020-09-10 12:04:41 -07:00 committed by TensorFlower Gardener
parent 35d9474383
commit 6e71a34542

View File

@ -26,11 +26,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/ops_util.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mem.h"
@ -290,6 +292,83 @@ class StridedSliceGradOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
}
// When the begin / end is unknown, compile the gradient into dynamic update
// slice into a broadcasted 0s.
//
// Broadcasted 0
// +----------------------+
// | +----+ |
// |<-begin->|grad|<-end->| <== Dynamic update grad into 0s.
// | +----+ |
// +----------------------+
void CompileAsDynamicUpdateSlice(XlaOpKernelContext* ctx,
const TensorShape& input_shape,
const xla::Literal& strides_literal) {
bool dummy = false;
Tensor strides_tensor;
PartialTensorShape processing_shape, final_shape;
absl::InlinedVector<int64, 4> begin;
absl::InlinedVector<int64, 4> end;
absl::InlinedVector<int64, 4> strides;
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
absl::InlinedVector<int64, 4> output_to_processing_mapping;
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
OP_REQUIRES_OK(
ctx, ValidateStridedSliceOp(
nullptr, nullptr, strides_tensor, input_shape, begin_mask_,
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
&processing_shape, &final_shape, &dummy, &dummy, &dummy,
&begin, &end, &strides, &output_to_sparse_mapping,
&output_to_processing_mapping));
for (int64 i = 0; i < processing_shape.dims(); ++i) {
OP_REQUIRES(
ctx, strides[i] == 1,
errors::InvalidArgument("Strides in strided slice grad have to be "
"one when inputs are not constant."));
}
auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
zero = xla::Broadcast(zero, input_shape.dim_sizes());
xla::XlaOp grad = ctx->Input(4);
xla::Shape grad_shape = ctx->InputXlaShape(4).ValueOrDie();
// Undo any new/shrink axes.
VLOG(1) << "xla grad shape" << grad_shape;
VLOG(1) << "input_shape" << input_shape.DebugString();
std::vector<xla::XlaOp> begins(processing_shape.dims(),
xla::Zero(ctx->builder(), xla::S32));
for (int64 i = 0; i < grad_shape.rank(); ++i) {
// Use grad shape, which is known, to update unknown processing shape.
// Grad shape is the output of the ValidateStridedSliceOp function in
// forward pass, thus we use output_to_processing_mapping.
if (output_to_processing_mapping[i] != -1) {
processing_shape.set_dim(output_to_processing_mapping[i],
grad_shape.dimensions(i));
}
// Similarly, use output_to_sparse_mapping to find out corresponding
// begin dim of the output, as indices for dynamic update slice.
int64 begin_dim = output_to_sparse_mapping[i];
if (begin_dim != -1) {
auto begin_index =
xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
auto begin_index_scalar = xla::Reshape(
xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
begins[output_to_sparse_mapping[i]] = begin_index_scalar;
}
}
VLOG(1) << "processing_shape" << processing_shape.DebugString();
TensorShape full_processing_shape;
OP_REQUIRES(ctx, processing_shape.AsTensorShape(&full_processing_shape),
errors::InvalidArgument(
"Processing shape ", processing_shape.DebugString(),
" can't be fully inferred from grad shape"));
grad = xla::Reshape(grad, full_processing_shape.dim_sizes());
grad = xla::DynamicUpdateSlice(zero, grad, begins);
ctx->SetOutput(0, grad);
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape processing_shape, final_shape;
absl::InlinedVector<int64, 4> begin;
@ -298,12 +377,15 @@ class StridedSliceGradOp : public XlaOpKernel {
TensorShape input_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
xla::Literal begin_literal, end_literal, strides_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok();
bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok();
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
if (!(begin_is_constant && end_is_constant)) {
CompileAsDynamicUpdateSlice(ctx, input_shape, strides_literal);
return;
}
Tensor begin_tensor, end_tensor, strides_tensor;
OP_REQUIRES_OK(
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));