Remove dynamic dimension of strided slice grad if input to strided slice is static.
If we slice a dynamic shaped tensor from a static tensor, the output of the gradient should still be static. Unfortunately this cannot be deduced alone by xla, so extra information is needed from the tf2xla bridge. PiperOrigin-RevId: 317058146 Change-Id: I33e4895e169c238ad3d73a57ada11c4984d11dfb
This commit is contained in:
parent
e4af590df8
commit
ad1434444b
@ -350,28 +350,6 @@ class StridedSliceGradOp : public XlaOpKernel {
|
|||||||
grad = xla::Rev(grad, dimensions_to_reverse);
|
grad = xla::Rev(grad, dimensions_to_reverse);
|
||||||
}
|
}
|
||||||
grad = xla::Pad(grad, zero, padding_config);
|
grad = xla::Pad(grad, zero, padding_config);
|
||||||
|
|
||||||
xla::XlaOp dynamic_shape = ctx->Input(0);
|
|
||||||
xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie();
|
|
||||||
ctx->set_dynamic_dimension_is_minus_one(true);
|
|
||||||
std::vector<int64> dynamic_size;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &dynamic_size));
|
|
||||||
// Input of strided_slice_op has to have the same shape as output.
|
|
||||||
DCHECK_EQ(grad_shape.rank(), input_shape.dims());
|
|
||||||
for (int64 dim = 0; dim < input_shape.dims(); ++dim) {
|
|
||||||
DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim));
|
|
||||||
if (dynamic_size[dim] == -1) {
|
|
||||||
// Input is a dynamic dimension, set the same dynamic dimension size in
|
|
||||||
// the output.
|
|
||||||
auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1});
|
|
||||||
grad = xla::SetDimensionSize(grad, dim_size, dim);
|
|
||||||
} else if (grad_shape.is_dynamic_dimension(dim)) {
|
|
||||||
// Input is static but output is dynamic, respect input and remove any
|
|
||||||
// dynamic dim in the output.
|
|
||||||
grad = xla::RemoveDynamicDimension(grad, dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx->SetOutput(0, grad);
|
ctx->SetOutput(0, grad);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2727,25 +2727,6 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64 dimension) {
|
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
|
||||||
HloInstructionProto instr;
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
|
||||||
|
|
||||||
Shape shape = *operand_shape;
|
|
||||||
shape.set_dynamic_dimension(dimension, false);
|
|
||||||
// Setting an op's dynamic dimension to its static size removes the dynamic
|
|
||||||
// dimension.
|
|
||||||
XlaOp static_size =
|
|
||||||
ConstantR0<int32>(this, operand_shape->dimensions(dimension));
|
|
||||||
|
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
|
||||||
instr.add_dimensions(dimension);
|
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
|
|
||||||
{operand, static_size});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
|
XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
@ -3846,8 +3827,4 @@ XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) {
|
|||||||
return operand.builder()->SetDimensionSize(operand, val, dimension);
|
return operand.builder()->SetDimensionSize(operand, val, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp RemoveDynamicDimension(const XlaOp operand, int64 dimension) {
|
|
||||||
return operand.builder()->RemoveDynamicDimension(operand, dimension);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -704,8 +704,6 @@ class XlaBuilder {
|
|||||||
|
|
||||||
XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
|
XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
|
||||||
|
|
||||||
XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
|
|
||||||
|
|
||||||
StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
|
StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
|
||||||
absl::Span<const XlaOp> operands = {});
|
absl::Span<const XlaOp> operands = {});
|
||||||
|
|
||||||
@ -1153,7 +1151,6 @@ class XlaBuilder {
|
|||||||
|
|
||||||
friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
|
friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
|
||||||
friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
|
friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
|
||||||
friend XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Returns OK status if the given op was built using this builder. Otherwise,
|
// Returns OK status if the given op was built using this builder. Otherwise,
|
||||||
@ -2152,9 +2149,6 @@ XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
|
|||||||
|
|
||||||
XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
|
XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
|
||||||
|
|
||||||
// Returns the same op but with dynamic dimension removed.
|
|
||||||
XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
|
|
||||||
|
|
||||||
// Implementation details below this point.
|
// Implementation details below this point.
|
||||||
//
|
//
|
||||||
|
|
||||||
|
@ -556,32 +556,6 @@ TEST_F(XlaBuilderTest, DynamicParameter) {
|
|||||||
EXPECT_TRUE(param_shape.is_dynamic_dimension(0));
|
EXPECT_TRUE(param_shape.is_dynamic_dimension(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(XlaBuilderTest, SetDimensionSize) {
|
|
||||||
XlaBuilder b(TestName());
|
|
||||||
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
|
|
||||||
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
|
|
||||||
auto set_dim_size = SetDimensionSize(p0, p1, 0);
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
|
||||||
BuildHloModule(&b, /*root=*/set_dim_size));
|
|
||||||
const Shape& root_shape =
|
|
||||||
module->entry_computation()->root_instruction()->shape();
|
|
||||||
EXPECT_TRUE(root_shape.is_dynamic_dimension(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(XlaBuilderTest, RemoveDimensionSize) {
|
|
||||||
XlaBuilder b(TestName());
|
|
||||||
auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
|
|
||||||
auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
|
|
||||||
auto set_dim_size = SetDimensionSize(p0, p1, 0);
|
|
||||||
auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0);
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
|
||||||
BuildHloModule(&b, /*root=*/remove_dim_size));
|
|
||||||
const Shape& root_shape =
|
|
||||||
module->entry_computation()->root_instruction()->shape();
|
|
||||||
// Dynamic dimension has been removed.
|
|
||||||
EXPECT_FALSE(root_shape.is_dynamic_dimension(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(XlaBuilderTest, DynamicUnary) {
|
TEST_F(XlaBuilderTest, DynamicUnary) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
|
Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
|
||||||
|
Loading…
Reference in New Issue
Block a user