diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 2684c982600..784b790767c 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -350,6 +350,30 @@ class StridedSliceGradOp : public XlaOpKernel { grad = xla::Rev(grad, dimensions_to_reverse); } grad = xla::Pad(grad, zero, padding_config); + + xla::XlaOp dynamic_shape = ctx->Input(0); + xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie(); + ctx->set_dynamic_dimension_is_minus_one(true); + std::vector dynamic_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &dynamic_size)); + // Input of strided_slice_op has to have the same shape as output. + DCHECK_EQ(grad_shape.rank(), input_shape.dims()); + for (int64 dim = 0; dim < input_shape.dims(); ++dim) { + DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim)); + if (dynamic_size[dim] == -1) { + // Input is a dynamic dimension, set the same dynamic dimension size in + // the output. + auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1}); + auto dim_size_scalar = + xla::Reshape(xla::ShapeUtil::MakeScalarShape(xla::S32), dim_size); + grad = xla::SetDimensionSize(grad, dim_size_scalar, dim); + } else if (grad_shape.is_dynamic_dimension(dim)) { + // Input is static but output is dynamic, respect input and remove any + // dynamic dim in the output. + grad = xla::RemoveDynamicDimension(grad, dim); + } + } + ctx->SetOutput(0, grad); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index bfba48862f6..c7b6a7f9491 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2727,6 +2727,25 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) { }); } +XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + + Shape shape = *operand_shape; + shape.set_dynamic_dimension(dimension, false); + // Setting an op's dynamic dimension to its static size removes the dynamic + // dimension. + XlaOp static_size = + ConstantR0(this, operand_shape->dimensions(dimension)); + + *instr.mutable_shape() = shape.ToProto(); + instr.add_dimensions(dimension); + return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize, + {operand, static_size}); + }); +} + XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -3827,4 +3846,8 @@ XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) { return operand.builder()->SetDimensionSize(operand, val, dimension); } +XlaOp RemoveDynamicDimension(const XlaOp operand, int64 dimension) { + return operand.builder()->RemoveDynamicDimension(operand, dimension); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index ffa6a7c3439..b8af180b83e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -704,6 +704,8 @@ class XlaBuilder { XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); + XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, absl::Span operands = {}); @@ -1151,6 +1153,7 @@ class XlaBuilder { friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension); friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); + friend XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); protected: // Returns OK status if the given op was built using this builder. Otherwise, @@ -2149,6 +2152,9 @@ XlaOp GetDimensionSize(XlaOp operand, int64 dimension); XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); +// Returns the same op but with dynamic dimension removed. +XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); + // Implementation details below this point. // diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 4fa47077fca..7011c946203 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -556,6 +556,32 @@ TEST_F(XlaBuilderTest, DynamicParameter) { EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); } +TEST_F(XlaBuilderTest, SetDimensionSize) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); + auto set_dim_size = SetDimensionSize(p0, p1, 0); + TF_ASSERT_OK_AND_ASSIGN(auto module, + BuildHloModule(&b, /*root=*/set_dim_size)); + const Shape& root_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(root_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, RemoveDimensionSize) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); + auto set_dim_size = SetDimensionSize(p0, p1, 0); + auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0); + TF_ASSERT_OK_AND_ASSIGN(auto module, + BuildHloModule(&b, /*root=*/remove_dim_size)); + const Shape& root_shape = + module->entry_computation()->root_instruction()->shape(); + // Dynamic dimension has been removed. + EXPECT_FALSE(root_shape.is_dynamic_dimension(0)); +} + TEST_F(XlaBuilderTest, DynamicUnary) { XlaBuilder b(TestName()); Shape tuple_param_shape = ShapeUtil::MakeTupleShape(