Introduce dynamic reshape op.

- Exposes a dynamic reshape op through xla builder.
- More like tf and other frameworks, users now have the ability to specify the dimension sizes when building the op.
- Dynamic reshape is rewritten to static reshape in dynamic padder.
- This can, and will, remove a lot of complexity in handling dynamic reshapes in xla today, where we are forced to derive the dynamic reshape sizes ourselves.

PiperOrigin-RevId: 326744137
Change-Id: I57b5b40abab2972e0e1e3df1577bf89146ebd7cc
This commit is contained in:
Yunxing Dai 2020-08-14 15:43:57 -07:00 committed by TensorFlower Gardener
parent 7b56de0366
commit 73b40908a4
28 changed files with 457 additions and 32 deletions

View File

@ -19,8 +19,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -108,38 +110,73 @@ class ReshapeOp : public XlaOpKernel {
VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
<< shape.DebugString() << ", unknown_index=" << unknown_index;
auto input_xla_shape = ctx->InputXlaShape(0);
if (input_xla_shape->is_static()) {
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes()));
return;
}
// Handing dynamic reshapes if input contains a dynamic dimension.
std::vector<xla::XlaOp> output_dim_sizes;
std::vector<bool> dims_are_dynamic;
for (int64 i = 0; i < shape.dims(); ++i) {
output_dim_sizes.push_back(
xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {}));
}
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic));
if (unknown_index == -1) {
// No unknown index.
ctx->SetOutput(0,
xla::DynamicReshape(ctx->Input(0), output_dim_sizes,
shape.dim_sizes(), dims_are_dynamic));
return;
}
auto common_factors =
xla::CommonFactors(input_shape.dim_sizes(), shape.dim_sizes());
int dynamic_dimension = -1;
if (ctx->InputXlaShape(0)->is_dynamic()) {
std::vector<bool> dynamic_dims;
OP_REQUIRES_OK(ctx,
ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
for (int d = 0; d < num_dims; ++d) {
const bool dim_is_dynamic = dynamic_dims[d];
if (dim_is_dynamic) {
dynamic_dimension = d;
// Find common_factors that the input belongs to.
for (int64 i = 0; i < common_factors.size() - 1; ++i) {
auto start = common_factors[i];
auto end = common_factors[i + 1];
bool input_is_dynamic = false;
// product of all input dims in this group. E.g., in
// reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group
// containing -1 will be 6.
xla::XlaOp product = xla::One(ctx->builder(), xla::S32);
for (int64 dim = start.first; dim < end.first; ++dim) {
if (input_xla_shape->is_dynamic_dimension(dim)) {
input_is_dynamic = true;
}
product = xla::Mul(product, xla::GetDimensionSize(ctx->Input(0), dim));
}
bool unknown_dim_in_group = false;
// The real size for the -1 dimension in a reshape. E.g., in
// reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2.
xla::XlaOp unknown_dim_size = product;
for (int64 dim = start.second; dim < end.second; ++dim) {
if (dim == unknown_index) {
unknown_dim_in_group = true;
} else {
unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]);
}
}
// When reshaping from dynamic dimension, unkwown index is considered
// dynamic. E.g.,
// [<=10]
// |
// Reshape
// |
// [2, -1]
// The second dimension is dynamic.
if (dynamic_dimension == -1) {
dynamic_dimension = unknown_index;
if (unknown_dim_in_group) {
// If input dim is dynamic, output dim at the -1 position must be
// dynamic. Similarly, if input dim is static, output dim has to be
// static at the -1 dimension.
dims_are_dynamic[unknown_index] = input_is_dynamic;
output_dim_sizes[unknown_index] = unknown_dim_size;
ctx->SetOutput(
0, xla::DynamicReshape(ctx->Input(0), output_dim_sizes,
shape.dim_sizes(), dims_are_dynamic));
VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString()
<< " to " << xla::VectorString(shape.dim_sizes())
<< ", dynamic_dims=" << xla::VectorString(dims_are_dynamic);
return;
}
VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to "
<< xla::VectorString(shape.dim_sizes())
<< ", dynamic_dim=" << dynamic_dimension;
}
// Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
// in XLA to know which output dimension is dynamic.
ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
ctx->Input(0), shape.dim_sizes(), dynamic_dimension));
}
};

View File

@ -1083,6 +1083,36 @@ XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
});
}
XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64> new_size_bounds,
const std::vector<bool>& dims_are_dynamic) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<const Shape*> dim_size_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
GetOperandShapes(dim_sizes));
absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferDynamicReshapeShape(
*operand_shape, dim_size_shape_ptrs,
new_size_bounds, dims_are_dynamic));
TF_RETURN_IF_ERROR(first_error_);
std::vector<XlaOp> operands;
operands.reserve(1 + dim_sizes.size());
operands.push_back(operand);
for (const XlaOp& dim_size : dim_sizes) {
operands.push_back(dim_size);
}
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
operands);
});
}
XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
@ -3466,6 +3496,13 @@ XlaOp Reshape(const Shape& shape, XlaOp operand) {
return operand.builder()->Reshape(shape, operand);
}
XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64> new_size_bounds,
const std::vector<bool>& dims_are_dynamic) {
return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
dims_are_dynamic);
}
XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64> new_sizes,
int64 inferred_dimension) {

View File

@ -454,6 +454,10 @@ class XlaBuilder {
XlaOp Reshape(const Shape& shape, XlaOp operand,
int64 inferred_dimension = -1);
XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64> new_size_bounds,
const std::vector<bool>& dims_are_dynamic);
XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
@ -940,6 +944,10 @@ class XlaBuilder {
friend XlaOp Reshape(const Shape& shape, XlaOp operand);
friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64> new_size_bounds,
const std::vector<bool>& dims_are_dynamic);
friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64> new_sizes,
int64 inferred_dimension);
@ -1453,9 +1461,16 @@ XlaOp Pad(XlaOp operand, XlaOp padding_value,
XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes);
// Enqueues an operation onto the computation that collapses the operand, from
// first to last dimension (C order), then reshapes it to the given dimension
// sizes. Conceptually, this is a limited form of "shape casting".
// Enqueues a dynamic reshape operation. The dynamic reshape takes additional
// XlaOps as sizes for the result dimension. The result dim i is a dynamic
// dimension dimension if dims_are_dynamic[i] is true.
XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64> new_size_bounds,
const std::vector<bool>& dims_are_dynamic);
// Enqueues an operation onto the computation that collapses the operand,
// from first to last dimension (C order), then reshapes it to the given
// dimension sizes. Conceptually, this is a limited form of "shape casting".
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
// Enqueues a Reshape op that uses an explicit target shape.

View File

@ -245,6 +245,7 @@ class DfsHloVisitorBase {
virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0;
virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
virtual Status HandleFusion(HloInstructionPtr hlo) = 0;

View File

@ -198,6 +198,9 @@ class DfsHloVisitorWithDefaultBase
Status HandlePad(HloInstructionPtr pad) override {
return DefaultAction(pad);
}
Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override {
return DefaultAction(dynamic_reshape);
}
Status HandleReshape(HloInstructionPtr reshape) override {
return DefaultAction(reshape);
}

View File

@ -97,6 +97,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
Status HandleTranspose(HloInstruction* hlo) override;
Status HandleDynamicReshape(HloInstruction* hlo) override;
Status HandleReshape(HloInstruction* hlo) override;
Status HandleSort(HloInstruction* hlo) override;
@ -621,6 +623,18 @@ Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
return PassThroughDynamicDimension(hlo);
}
Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
HloInstruction* hlo) {
HloDynamicReshapeInstruction* dynamic_reshape =
Cast<HloDynamicReshapeInstruction>(hlo);
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (hlo->shape().is_dynamic_dimension(i)) {
parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
}
}
return Status::OK();
}
Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo,

View File

@ -1248,5 +1248,34 @@ TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) {
EXPECT_TRUE(handler_called);
}
TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {9}), "data_input"));
auto six = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
// Creates an input of shape [<=9], dynamic size is 6.
auto dynamic_input =
builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0));
auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(S32, {}), "size_param"));
auto three = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
// Reshape [<=9] into [3, <=3]
auto dynamic_reshape =
builder.AddInstruction(HloInstruction::CreateDynamicReshape(
ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input,
{three, dynamic_size}));
module_->AddEntryComputation(builder.Build());
TF_ASSERT_OK(RunInference());
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr);
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
}
} // namespace
} // namespace xla

View File

@ -1290,6 +1290,18 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
continue;
}
if (inst->opcode() == HloOpcode::kDynamicReshape) {
TF_ASSIGN_OR_RETURN(
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
auto* static_reshape =
computation->AddInstruction(HloInstruction::CreateReshape(
inst->shape(), inst->mutable_operand(0)));
TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_reshape));
TF_RETURN_IF_ERROR(dynamic_dimension_inference.ForwardDynamicSize(
inst, static_reshape, {}));
continue;
}
for (int64 operand_num = 0; operand_num < inst->operand_count();
++operand_num) {
HloInstruction* original_operand = inst->mutable_operand(operand_num);

View File

@ -379,6 +379,13 @@ class ExecutionTest : public HloTestBase {
Literal PadAndExecute(std::unique_ptr<HloModule> module,
absl::Span<Literal* const> arguments,
bool slice_dynamic_output = true) {
if (!slice_dynamic_output) {
auto new_config = module->config();
new_config.mutable_entry_computation_layout()
->mutable_result_layout()
->ClearDynamicShape();
module->set_config(new_config);
}
DynamicPadder padder(slice_dynamic_output);
TF_CHECK_OK(padder.Run(module.get()).status());
HloDCE dce;
@ -1176,6 +1183,84 @@ ENTRY main {
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicReshapeDoubleDynamicDimensions) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
ENTRY main {
param = s32[2, 3, 3] parameter(0)
size = s32[] constant(2)
param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, size),
dimensions={1}
param_padded = s32[2, <=3, <=3] set-dimension-size(param_padded_partial, size),
dimensions={2}
result_size = s32[] constant(8)
ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, result_size)
}
)";
// First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
Literal operand = LiteralUtil::CreateR3<int32>(
{{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}});
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {&operand}, false);
result.SetDynamicSize(0, 8);
// Padded data looks like this (P is padding which is ignored).
// [[0, 1, P]
// [3, 4, P]
// [P, P, P]]
//
// [[0, 1, P]
// [3, 4, P]
// [P, P, P]]
//
// Reshaping (with correct reshape rewriting) produces:
// [0, 1, 3, 4, 0, 1, 3, 4]
Literal expected = LiteralUtil::CreateR1<int32>({0, 1, 3, 4, 0, 1, 3, 4});
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, DynamicReshapeOutputDoubleDynamicDimensions) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1
ENTRY main {
param = s32[18] parameter(0)
eight = s32[] constant(8)
param_dynamic = s32[<=18] set-dimension-size(param, eight), dimensions={0}
two = s32[] constant(2)
// every dimension has dynamic size two.
ROOT reshaped = s32[2, <=3, <=3] dynamic-reshape(param_dynamic, two, two, two)
}
)";
Literal operand = LiteralUtil::CreateR1<int32>(
{0, 1, 3, 4, 0, 1, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1});
auto module = GetHloModule(hlo_text);
Literal result = PadAndExecute(std::move(module), {&operand}, false);
result.SetDynamicSize(1, 2);
result.SetDynamicSize(2, 2);
// Padded operand is:
// [0, 1, 3, 4, 0, 1, 3, 4, P, P ....]
//
// Reshaping it should produce:
// [[0, 1, P]
// [3, 4, P]
// [P, P, P]]
//
// [[0, 1, P]
// [3, 4, P]
// [P, P, P]]
Literal expected =
LiteralUtil::CreateR3<int32>({{{0, 1}, {3, 4}}, {{0, 1}, {3, 4}}});
EXPECT_EQ(result, expected);
}
XLA_TEST_F(ExecutionTest, SetGetDimensionSize) {
const string hlo_text = R"(
HloModule TensorFlowScatterV1

View File

@ -486,6 +486,10 @@ Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
return Status::OK();
}
Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
// TODO(b/62294698): Implement cost analysis for batch-norm-training.
return Status::OK();

View File

@ -113,6 +113,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleBroadcast(const HloInstruction* broadcast) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
Status HandleDynamicReshape(const HloInstruction* reshape) override;
Status HandleAddDependency(const HloInstruction* add_dependency) override;
Status HandleAfterAll(const HloInstruction* token) override;
Status HandleTranspose(const HloInstruction* transpose) override;

View File

@ -1012,6 +1012,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kGather:
case HloOpcode::kPad:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kReverse:
case HloOpcode::kTupleSelect:
case HloOpcode::kTranspose:

View File

@ -700,6 +700,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction = CreateReshape(shape, operands(0), inferred_dimension);
break;
}
case HloOpcode::kDynamicReshape: {
TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
ShapeUtil::ElementsIn(shape) ==
ShapeUtil::ElementsIn(operands(0)->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operands(0)->shape());
const auto& operand_vector = all_operands();
instruction = CreateDynamicReshape(
shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
break;
}
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (const int64 operand_id : proto.operand_ids()) {
@ -1373,6 +1384,19 @@ HloInstruction::CreateBroadcastSequence(
inferred_dimension);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateDynamicReshape(
const Shape& shape, HloInstruction* data_operand,
absl::Span<HloInstruction* const> dim_sizes) {
CHECK_EQ(ShapeUtil::ElementsIn(shape),
ShapeUtil::ElementsIn(data_operand[0].shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
CHECK_EQ(shape.rank(), dim_sizes.size());
return absl::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
dim_sizes);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
absl::Span<const int64> dimensions) {
@ -1569,6 +1593,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
@ -2007,6 +2032,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kReplicaId:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRsqrt:
@ -2812,7 +2838,8 @@ HloInstructionProto HloInstruction::ToProto() const {
string HloInstruction::ToCategory() const {
if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
opcode() == HloOpcode::kReshape) {
opcode() == HloOpcode::kReshape ||
opcode() == HloOpcode::kDynamicReshape) {
return "data formatting";
}
@ -3033,6 +3060,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandlePad(this);
case HloOpcode::kReshape:
return visitor->HandleReshape(this);
case HloOpcode::kDynamicReshape:
return visitor->HandleDynamicReshape(this);
case HloOpcode::kTranspose:
return visitor->HandleTranspose(this);
case HloOpcode::kReverse:

View File

@ -879,6 +879,14 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
int64 inferred_dimension = -1);
// Creates a dynamic reshape instruction. Similar to reshape but dynamic
// dimensions sizes are provided as additional variadic arguments.
//
// Precondition: dim_sizes.size() == shape.rank()
static std::unique_ptr<HloInstruction> CreateDynamicReshape(
const Shape& shape, HloInstruction* data_operand,
absl::Span<HloInstruction* const> dim_sizes);
// Creates a transpose instruction which permutes the operand dimensions.
static std::unique_ptr<HloInstruction> CreateTranspose(
const Shape& shape, HloInstruction* operand,

View File

@ -1027,6 +1027,16 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl(
dimensions());
}
HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
const Shape& shape, HloInstruction* data_operand,
absl::Span<HloInstruction* const> dim_sizes)
: HloInstruction(HloOpcode::kDynamicReshape, shape) {
AppendOperand(data_operand);
for (auto operand : dim_sizes) {
AppendOperand(operand);
}
}
HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
HloInstruction* operand,
int64 inferred_dimension)

View File

@ -679,6 +679,21 @@ class HloBroadcastInstruction : public HloInstruction {
std::vector<int64> dimensions_;
};
class HloDynamicReshapeInstruction : public HloInstruction {
public:
explicit HloDynamicReshapeInstruction(
const Shape& shape, HloInstruction* data_operand,
absl::Span<HloInstruction* const> dim_sizes);
// Returns the input dim sizes dimensions, which is operands[1:]
absl::Span<HloInstruction* const> dim_sizes() const {
return absl::MakeSpan(operands()).subspan(1, operand_count());
}
// Returns the input dim size dimension, which is operands[1+i]
HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; }
};
class HloReshapeInstruction : public HloInstruction {
public:
explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,

View File

@ -123,6 +123,7 @@ namespace xla {
V(kRemainder, "remainder", 2) \
V(kReplicaId, "replica-id", 0) \
V(kReshape, "reshape", 1) \
V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \
V(kReverse, "reverse", 1) \
V(kRng, "rng", kHloOpcodeIsVariadic) \
V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \

View File

@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
case HloOpcode::kCustomCall:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kDynamicReshape:
case HloOpcode::kFusion:
case HloOpcode::kMap:
case HloOpcode::kReduce:

View File

@ -1108,6 +1108,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
builder->AddInstruction(HloInstruction::CreatePartitionId());
break;
}
case HloOpcode::kDynamicReshape: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateDynamicReshape(
shape, operands[0],
absl::Span<HloInstruction* const>(operands).subspan(1)));
break;
}
case HloOpcode::kReshape: {
optional<int64> inferred_dimension;
attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,

View File

@ -703,6 +703,20 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) {
// Check for mixed precision.
const Shape& operand_shape = dynamic_reshape->operand(0)->shape();
TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape));
TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) ==
ShapeUtil::ElementsIn(operand_shape));
TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 ==
dynamic_reshape->operand_count());
for (int64 i = 1; i < dynamic_reshape->operand_count(); ++i) {
TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32);
}
return Status::OK();
}
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
// Check for mixed precision.
const Shape& operand_shape = reshape->operand(0)->shape();

View File

@ -78,6 +78,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleDynamicReshape(HloInstruction* dynamic_reshape) override;
Status HandleTranspose(HloInstruction* transpose) override;
Status HandleParameter(HloInstruction*) override;
Status HandleFusion(HloInstruction*) override;

View File

@ -102,6 +102,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kReducePrecision:
case HloOpcode::kReplicaId:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSelect:

View File

@ -2278,6 +2278,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kReduce:
case HloOpcode::kReplicaId:
case HloOpcode::kReshape:
case HloOpcode::kDynamicReshape:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:

View File

@ -2825,6 +2825,38 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return output_shape;
}
/* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
absl::Span<const int64> new_size_bounds,
const std::vector<bool>& dims_are_dynamic) {
if (new_size_bounds.size() != dims_are_dynamic.size()) {
return InvalidArgument(
"DynamicReshape has to have the same number of elements in new_sizes "
"(%d) and dims_are_dynamic (%d)",
new_size_bounds.size(), dims_are_dynamic.size());
}
for (const Shape* dim_size_shape : dim_size_shapes) {
if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
return InvalidArgument(
"DynamicReshape's dim size has to be scalar S32, got (%s): ",
dim_size_shape->ToString());
}
}
Shape inferred_shape = ShapeUtil::MakeShape(
operand.element_type(), new_size_bounds, dims_are_dynamic);
if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
return InvalidArgument(
"Reshape operation has mismatched element counts: from=%d (%s) "
"to=%d (%s).",
ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
ShapeUtil::ElementsIn(inferred_shape),
ShapeUtil::HumanString(inferred_shape));
}
return inferred_shape;
}
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
const Shape& operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes, int64 inferred_dimension) {

View File

@ -241,6 +241,15 @@ class ShapeInference {
absl::Span<const int64> new_sizes,
int64 inferred_dimension);
// Infers the shape produced by a dynamic reshape operation from the element
// type of its operand and the new dimension sizes specified. The result shape
// will have dynamic dimensions as specific in `dim_is_dynamic` and bound
// `new_size_bounds`.
static StatusOr<Shape> InferDynamicReshapeShape(
const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
absl::Span<const int64> new_size_bounds,
const std::vector<bool>& dims_are_dynamic);
// Infers the shape produced by a transpose operation from the element type of
// its operand and its dimensions field.
static StatusOr<Shape> InferTransposeShape(

View File

@ -387,6 +387,7 @@ const HloInstruction* PickRepresentativeOperand(
case HloOpcode::kDot:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kDynamicReshape:
case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kGather:

View File

@ -61,6 +61,10 @@ class ShapeLayout {
// Returns the shape (with layouts).
const Shape& shape() const { return shape_; }
// Clear dynamic dimensions of this module. Pretending the module creates
// static results. Useful in inspecting full outputs when testing.
void ClearDynamicShape() { shape_.clear_dynamic_dimensions(); }
// Checks that a layout is set for the shape, and returns a reference to the
// layout directly on the shape. Shape must not be a tuple.
const Layout& layout() const;

View File

@ -635,8 +635,57 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
mode=["eager"]
))
mode=["eager"]))
def testReshapeWithDynamicInputs(self, distribution):
def dataset_fn(_):
data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32)
dataset = get_dataset_from_tensor_slices(data)
dataset = dataset.batch(3)
return dataset
input_iterator = iter(
distribution.experimental_distribute_datasets_from_function(dataset_fn))
@def_function.function
def step_fn(example):
# example: [<=3, 1, 2]
# tile: [<=3, <=3, 2]
tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1])
# reshape1: [<=(3*3 = 9), 2]
reshape1 = array_ops.reshape(tile, [-1, 2])
# reshape2: [<=3, <=3, 2]
reshape2 = array_ops.reshape(
reshape1,
[array_ops.shape(example)[0],
array_ops.shape(example)[0], 2])
# reshape3: [<=3, -1, 2]
reshape3 = array_ops.reshape(reshape1,
[array_ops.shape(example)[0], -1, 2])
# reshape4: [-1, <=3, 2]
reshape4 = array_ops.reshape(reshape1,
[-1, array_ops.shape(example)[0], 2])
return [reshape1, reshape2, reshape3, reshape4]
# This assumes that there are exactly 2 replicas
outputs = distribution.experimental_local_results(
distribution.run(step_fn, args=(next(input_iterator),)))
self.assertAllEqual((9, 2), outputs[0][0].values[0].shape)
self.assertAllEqual((3, 3, 2), outputs[0][1].values[0].shape)
self.assertAllEqual((3, 3, 2), outputs[0][2].values[0].shape)
self.assertAllEqual((3, 3, 2), outputs[0][3].values[0].shape)
self.assertAllEqual((4, 2), outputs[0][0].values[1].shape)
self.assertAllEqual((2, 2, 2), outputs[0][1].values[1].shape)
self.assertAllEqual((2, 2, 2), outputs[0][2].values[1].shape)
self.assertAllEqual((2, 2, 2), outputs[0][3].values[1].shape)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
mode=["eager"]))
def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution):
def dataset_fn(_):
dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]])