Remove dynamic dimension of strided slice grad if input to strided slice is static.

If we slice a dynamic shaped tensor from a static tensor, the output
of the gradient should still be static. Unfortunately this cannot be
deduced alone by xla, so extra information is needed from the tf2xla
bridge.

PiperOrigin-RevId: 317051543
Change-Id: I7a8113c47a4aed145dfba7f7d12992ca35a9cf19
This commit is contained in:
Yunxing Dai 2020-06-18 00:52:32 -07:00 committed by TensorFlower Gardener
parent fabcd8f89c
commit b38d5d1889
4 changed files with 77 additions and 0 deletions

View File

@ -350,6 +350,28 @@ class StridedSliceGradOp : public XlaOpKernel {
grad = xla::Rev(grad, dimensions_to_reverse); grad = xla::Rev(grad, dimensions_to_reverse);
} }
grad = xla::Pad(grad, zero, padding_config); grad = xla::Pad(grad, zero, padding_config);
xla::XlaOp dynamic_shape = ctx->Input(0);
xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie();
ctx->set_dynamic_dimension_is_minus_one(true);
std::vector<int64> dynamic_size;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &dynamic_size));
// Input of strided_slice_op has to have the same shape as output.
DCHECK_EQ(grad_shape.rank(), input_shape.dims());
for (int64 dim = 0; dim < input_shape.dims(); ++dim) {
DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim));
if (dynamic_size[dim] == -1) {
// Input is a dynamic dimension, set the same dynamic dimension size in
// the output.
auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1});
grad = xla::SetDimensionSize(grad, dim_size, dim);
} else if (grad_shape.is_dynamic_dimension(dim)) {
// Input is static but output is dynamic, respect input and remove any
// dynamic dim in the output.
grad = xla::RemoveDynamicDimension(grad, dim);
}
}
ctx->SetOutput(0, grad); ctx->SetOutput(0, grad);
} }

View File

@ -2727,6 +2727,25 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
}); });
} }
XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
Shape shape = *operand_shape;
shape.set_dynamic_dimension(dimension, false);
// Setting an op's dynamic dimension to its static size removes the dynamic
// dimension.
XlaOp static_size =
ConstantR0<int32>(this, operand_shape->dimensions(dimension));
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
{operand, static_size});
});
}
XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) { XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr; HloInstructionProto instr;
@ -3827,4 +3846,8 @@ XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) {
return operand.builder()->SetDimensionSize(operand, val, dimension); return operand.builder()->SetDimensionSize(operand, val, dimension);
} }
XlaOp RemoveDynamicDimension(const XlaOp operand, int64 dimension) {
return operand.builder()->RemoveDynamicDimension(operand, dimension);
}
} // namespace xla } // namespace xla

View File

@ -704,6 +704,8 @@ class XlaBuilder {
XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
absl::Span<const XlaOp> operands = {}); absl::Span<const XlaOp> operands = {});
@ -1151,6 +1153,7 @@ class XlaBuilder {
friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension); friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
friend XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
protected: protected:
// Returns OK status if the given op was built using this builder. Otherwise, // Returns OK status if the given op was built using this builder. Otherwise,
@ -2149,6 +2152,9 @@ XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
// Returns the same op but with dynamic dimension removed.
XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
// Implementation details below this point. // Implementation details below this point.
// //

View File

@ -556,6 +556,32 @@ TEST_F(XlaBuilderTest, DynamicParameter) {
EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); EXPECT_TRUE(param_shape.is_dynamic_dimension(0));
} }
TEST_F(XlaBuilderTest, SetDimensionSize) {
XlaBuilder b(TestName());
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
auto set_dim_size = SetDimensionSize(p0, p1, 0);
TF_ASSERT_OK_AND_ASSIGN(auto module,
BuildHloModule(&b, /*root=*/set_dim_size));
const Shape& root_shape =
module->entry_computation()->root_instruction()->shape();
EXPECT_TRUE(root_shape.is_dynamic_dimension(0));
}
TEST_F(XlaBuilderTest, RemoveDimensionSize) {
XlaBuilder b(TestName());
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
auto set_dim_size = SetDimensionSize(p0, p1, 0);
auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0);
TF_ASSERT_OK_AND_ASSIGN(auto module,
BuildHloModule(&b, /*root=*/remove_dim_size));
const Shape& root_shape =
module->entry_computation()->root_instruction()->shape();
// Dynamic dimension has been removed.
EXPECT_FALSE(root_shape.is_dynamic_dimension(0));
}
TEST_F(XlaBuilderTest, DynamicUnary) { TEST_F(XlaBuilderTest, DynamicUnary) {
XlaBuilder b(TestName()); XlaBuilder b(TestName());
Shape tuple_param_shape = ShapeUtil::MakeTupleShape( Shape tuple_param_shape = ShapeUtil::MakeTupleShape(