[XLA] Support strided slice grad in a while loop.
- The inputs to strided slice are stored in a stack for backward pass. - Popping items from the stack makes the input unknown in backward pass. - In the case where the begins and ends are unknown, lower the strided slice grads into dynamic update slice instead. PiperOrigin-RevId: 330987081 Change-Id: I0116a02f2fd7d660b49757622afc9934bb4b37e6
This commit is contained in:
parent
35d9474383
commit
6e71a34542
@ -26,11 +26,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/ops_util.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
@ -290,6 +292,83 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
||||
}
|
||||
|
||||
// When the begin / end is unknown, compile the gradient into dynamic update
|
||||
// slice into a broadcasted 0s.
|
||||
//
|
||||
// Broadcasted 0
|
||||
// +----------------------+
|
||||
// | +----+ |
|
||||
// |<-begin->|grad|<-end->| <== Dynamic update grad into 0s.
|
||||
// | +----+ |
|
||||
// +----------------------+
|
||||
void CompileAsDynamicUpdateSlice(XlaOpKernelContext* ctx,
|
||||
const TensorShape& input_shape,
|
||||
const xla::Literal& strides_literal) {
|
||||
bool dummy = false;
|
||||
Tensor strides_tensor;
|
||||
PartialTensorShape processing_shape, final_shape;
|
||||
absl::InlinedVector<int64, 4> begin;
|
||||
absl::InlinedVector<int64, 4> end;
|
||||
absl::InlinedVector<int64, 4> strides;
|
||||
|
||||
absl::InlinedVector<int64, 4> output_to_sparse_mapping;
|
||||
absl::InlinedVector<int64, 4> output_to_processing_mapping;
|
||||
|
||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||
&strides_tensor));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ValidateStridedSliceOp(
|
||||
nullptr, nullptr, strides_tensor, input_shape, begin_mask_,
|
||||
end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
|
||||
&processing_shape, &final_shape, &dummy, &dummy, &dummy,
|
||||
&begin, &end, &strides, &output_to_sparse_mapping,
|
||||
&output_to_processing_mapping));
|
||||
for (int64 i = 0; i < processing_shape.dims(); ++i) {
|
||||
OP_REQUIRES(
|
||||
ctx, strides[i] == 1,
|
||||
errors::InvalidArgument("Strides in strided slice grad have to be "
|
||||
"one when inputs are not constant."));
|
||||
}
|
||||
|
||||
auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
|
||||
zero = xla::Broadcast(zero, input_shape.dim_sizes());
|
||||
xla::XlaOp grad = ctx->Input(4);
|
||||
xla::Shape grad_shape = ctx->InputXlaShape(4).ValueOrDie();
|
||||
// Undo any new/shrink axes.
|
||||
VLOG(1) << "xla grad shape" << grad_shape;
|
||||
VLOG(1) << "input_shape" << input_shape.DebugString();
|
||||
std::vector<xla::XlaOp> begins(processing_shape.dims(),
|
||||
xla::Zero(ctx->builder(), xla::S32));
|
||||
for (int64 i = 0; i < grad_shape.rank(); ++i) {
|
||||
// Use grad shape, which is known, to update unknown processing shape.
|
||||
// Grad shape is the output of the ValidateStridedSliceOp function in
|
||||
// forward pass, thus we use output_to_processing_mapping.
|
||||
if (output_to_processing_mapping[i] != -1) {
|
||||
processing_shape.set_dim(output_to_processing_mapping[i],
|
||||
grad_shape.dimensions(i));
|
||||
}
|
||||
|
||||
// Similarly, use output_to_sparse_mapping to find out corresponding
|
||||
// begin dim of the output, as indices for dynamic update slice.
|
||||
int64 begin_dim = output_to_sparse_mapping[i];
|
||||
if (begin_dim != -1) {
|
||||
auto begin_index =
|
||||
xla::Slice(ctx->Input(1), {begin_dim}, {begin_dim + 1}, {1});
|
||||
auto begin_index_scalar = xla::Reshape(
|
||||
xla::ShapeUtil::MakeScalarShape(xla::S32), begin_index);
|
||||
begins[output_to_sparse_mapping[i]] = begin_index_scalar;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "processing_shape" << processing_shape.DebugString();
|
||||
TensorShape full_processing_shape;
|
||||
OP_REQUIRES(ctx, processing_shape.AsTensorShape(&full_processing_shape),
|
||||
errors::InvalidArgument(
|
||||
"Processing shape ", processing_shape.DebugString(),
|
||||
" can't be fully inferred from grad shape"));
|
||||
grad = xla::Reshape(grad, full_processing_shape.dim_sizes());
|
||||
grad = xla::DynamicUpdateSlice(zero, grad, begins);
|
||||
ctx->SetOutput(0, grad);
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape processing_shape, final_shape;
|
||||
absl::InlinedVector<int64, 4> begin;
|
||||
@ -298,12 +377,15 @@ class StridedSliceGradOp : public XlaOpKernel {
|
||||
|
||||
TensorShape input_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
|
||||
|
||||
xla::Literal begin_literal, end_literal, strides_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
||||
|
||||
bool begin_is_constant = ctx->ConstantInput(1, &begin_literal).ok();
|
||||
bool end_is_constant = ctx->ConstantInput(2, &end_literal).ok();
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
|
||||
if (!(begin_is_constant && end_is_constant)) {
|
||||
CompileAsDynamicUpdateSlice(ctx, input_shape, strides_literal);
|
||||
return;
|
||||
}
|
||||
Tensor begin_tensor, end_tensor, strides_tensor;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
|
||||
|
Loading…
Reference in New Issue
Block a user