From ad1434444bda1e8321dcc09965a5dce8da847eed Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 18 Jun 2020 01:51:06 -0700 Subject: [PATCH] Remove dynamic dimension of strided slice grad if input to strided slice is static. If we slice a dynamic shaped tensor from a static tensor, the output of the gradient should still be static. Unfortunately this cannot be deduced alone by xla, so extra information is needed from the tf2xla bridge. PiperOrigin-RevId: 317058146 Change-Id: I33e4895e169c238ad3d73a57ada11c4984d11dfb --- .../tf2xla/kernels/strided_slice_op.cc | 22 ---------------- tensorflow/compiler/xla/client/xla_builder.cc | 23 ---------------- tensorflow/compiler/xla/client/xla_builder.h | 6 ----- .../compiler/xla/client/xla_builder_test.cc | 26 ------------------- 4 files changed, 77 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 51764018df1..2684c982600 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -350,28 +350,6 @@ class StridedSliceGradOp : public XlaOpKernel { grad = xla::Rev(grad, dimensions_to_reverse); } grad = xla::Pad(grad, zero, padding_config); - - xla::XlaOp dynamic_shape = ctx->Input(0); - xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie(); - ctx->set_dynamic_dimension_is_minus_one(true); - std::vector dynamic_size; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &dynamic_size)); - // Input of strided_slice_op has to have the same shape as output. - DCHECK_EQ(grad_shape.rank(), input_shape.dims()); - for (int64 dim = 0; dim < input_shape.dims(); ++dim) { - DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim)); - if (dynamic_size[dim] == -1) { - // Input is a dynamic dimension, set the same dynamic dimension size in - // the output. - auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1}); - grad = xla::SetDimensionSize(grad, dim_size, dim); - } else if (grad_shape.is_dynamic_dimension(dim)) { - // Input is static but output is dynamic, respect input and remove any - // dynamic dim in the output. - grad = xla::RemoveDynamicDimension(grad, dim); - } - } - ctx->SetOutput(0, grad); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c7b6a7f9491..bfba48862f6 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2727,25 +2727,6 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) { }); } -XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64 dimension) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - - Shape shape = *operand_shape; - shape.set_dynamic_dimension(dimension, false); - // Setting an op's dynamic dimension to its static size removes the dynamic - // dimension. - XlaOp static_size = - ConstantR0(this, operand_shape->dimensions(dimension)); - - *instr.mutable_shape() = shape.ToProto(); - instr.add_dimensions(dimension); - return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize, - {operand, static_size}); - }); -} - XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -3846,8 +3827,4 @@ XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) { return operand.builder()->SetDimensionSize(operand, val, dimension); } -XlaOp RemoveDynamicDimension(const XlaOp operand, int64 dimension) { - return operand.builder()->RemoveDynamicDimension(operand, dimension); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index b8af180b83e..ffa6a7c3439 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -704,8 +704,6 @@ class XlaBuilder { XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); - XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); - StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, absl::Span operands = {}); @@ -1153,7 +1151,6 @@ class XlaBuilder { friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension); friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); - friend XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); protected: // Returns OK status if the given op was built using this builder. Otherwise, @@ -2152,9 +2149,6 @@ XlaOp GetDimensionSize(XlaOp operand, int64 dimension); XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); -// Returns the same op but with dynamic dimension removed. -XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); - // Implementation details below this point. // diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 7011c946203..4fa47077fca 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -556,32 +556,6 @@ TEST_F(XlaBuilderTest, DynamicParameter) { EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); } -TEST_F(XlaBuilderTest, SetDimensionSize) { - XlaBuilder b(TestName()); - auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); - auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); - auto set_dim_size = SetDimensionSize(p0, p1, 0); - TF_ASSERT_OK_AND_ASSIGN(auto module, - BuildHloModule(&b, /*root=*/set_dim_size)); - const Shape& root_shape = - module->entry_computation()->root_instruction()->shape(); - EXPECT_TRUE(root_shape.is_dynamic_dimension(0)); -} - -TEST_F(XlaBuilderTest, RemoveDimensionSize) { - XlaBuilder b(TestName()); - auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); - auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); - auto set_dim_size = SetDimensionSize(p0, p1, 0); - auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0); - TF_ASSERT_OK_AND_ASSIGN(auto module, - BuildHloModule(&b, /*root=*/remove_dim_size)); - const Shape& root_shape = - module->entry_computation()->root_instruction()->shape(); - // Dynamic dimension has been removed. - EXPECT_FALSE(root_shape.is_dynamic_dimension(0)); -} - TEST_F(XlaBuilderTest, DynamicUnary) { XlaBuilder b(TestName()); Shape tuple_param_shape = ShapeUtil::MakeTupleShape(