Introduce dynamic reshape op.
- Exposes a dynamic reshape op through xla builder. - More like tf and other frameworks, users now have the ability to specify the dimension sizes when building the op. - Dynamic reshape is rewritten to static reshape in dynamic padder. - This can, and will, remove a lot of complexity in handling dynamic reshapes in xla today, where we are forced to derive the dynamic reshape sizes ourselves. PiperOrigin-RevId: 326744137 Change-Id: I57b5b40abab2972e0e1e3df1577bf89146ebd7cc
This commit is contained in:
parent
7b56de0366
commit
73b40908a4
tensorflow
compiler
tf2xla/kernels
xla
client
service
dfs_hlo_visitor.hdfs_hlo_visitor_with_default.hdynamic_dimension_inference.ccdynamic_dimension_inference_test.ccdynamic_padder.ccdynamic_padder_test.cchlo_cost_analysis.cchlo_cost_analysis.hhlo_graph_dumper.cchlo_instruction.cchlo_instruction.hhlo_instructions.cchlo_instructions.hhlo_opcode.hhlo_opcode_test.cchlo_parser.cchlo_verifier.cchlo_verifier.hinstruction_fusion.cclayout_assignment.ccshape_inference.ccshape_inference.hsharding_propagation.cc
shape_layout.hpython/distribute
@ -19,8 +19,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -108,38 +110,73 @@ class ReshapeOp : public XlaOpKernel {
|
||||
|
||||
VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
|
||||
<< shape.DebugString() << ", unknown_index=" << unknown_index;
|
||||
auto input_xla_shape = ctx->InputXlaShape(0);
|
||||
if (input_xla_shape->is_static()) {
|
||||
ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes()));
|
||||
return;
|
||||
}
|
||||
// Handing dynamic reshapes if input contains a dynamic dimension.
|
||||
std::vector<xla::XlaOp> output_dim_sizes;
|
||||
std::vector<bool> dims_are_dynamic;
|
||||
for (int64 i = 0; i < shape.dims(); ++i) {
|
||||
output_dim_sizes.push_back(
|
||||
xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {}));
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic));
|
||||
if (unknown_index == -1) {
|
||||
// No unknown index.
|
||||
ctx->SetOutput(0,
|
||||
xla::DynamicReshape(ctx->Input(0), output_dim_sizes,
|
||||
shape.dim_sizes(), dims_are_dynamic));
|
||||
return;
|
||||
}
|
||||
auto common_factors =
|
||||
xla::CommonFactors(input_shape.dim_sizes(), shape.dim_sizes());
|
||||
|
||||
int dynamic_dimension = -1;
|
||||
if (ctx->InputXlaShape(0)->is_dynamic()) {
|
||||
std::vector<bool> dynamic_dims;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims));
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const bool dim_is_dynamic = dynamic_dims[d];
|
||||
if (dim_is_dynamic) {
|
||||
dynamic_dimension = d;
|
||||
// Find common_factors that the input belongs to.
|
||||
for (int64 i = 0; i < common_factors.size() - 1; ++i) {
|
||||
auto start = common_factors[i];
|
||||
auto end = common_factors[i + 1];
|
||||
bool input_is_dynamic = false;
|
||||
// product of all input dims in this group. E.g., in
|
||||
// reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group
|
||||
// containing -1 will be 6.
|
||||
xla::XlaOp product = xla::One(ctx->builder(), xla::S32);
|
||||
for (int64 dim = start.first; dim < end.first; ++dim) {
|
||||
if (input_xla_shape->is_dynamic_dimension(dim)) {
|
||||
input_is_dynamic = true;
|
||||
}
|
||||
product = xla::Mul(product, xla::GetDimensionSize(ctx->Input(0), dim));
|
||||
}
|
||||
bool unknown_dim_in_group = false;
|
||||
// The real size for the -1 dimension in a reshape. E.g., in
|
||||
// reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2.
|
||||
xla::XlaOp unknown_dim_size = product;
|
||||
for (int64 dim = start.second; dim < end.second; ++dim) {
|
||||
if (dim == unknown_index) {
|
||||
unknown_dim_in_group = true;
|
||||
} else {
|
||||
unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]);
|
||||
}
|
||||
}
|
||||
|
||||
// When reshaping from dynamic dimension, unkwown index is considered
|
||||
// dynamic. E.g.,
|
||||
// [<=10]
|
||||
// |
|
||||
// Reshape
|
||||
// |
|
||||
// [2, -1]
|
||||
// The second dimension is dynamic.
|
||||
if (dynamic_dimension == -1) {
|
||||
dynamic_dimension = unknown_index;
|
||||
if (unknown_dim_in_group) {
|
||||
// If input dim is dynamic, output dim at the -1 position must be
|
||||
// dynamic. Similarly, if input dim is static, output dim has to be
|
||||
// static at the -1 dimension.
|
||||
dims_are_dynamic[unknown_index] = input_is_dynamic;
|
||||
output_dim_sizes[unknown_index] = unknown_dim_size;
|
||||
|
||||
ctx->SetOutput(
|
||||
0, xla::DynamicReshape(ctx->Input(0), output_dim_sizes,
|
||||
shape.dim_sizes(), dims_are_dynamic));
|
||||
VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString()
|
||||
<< " to " << xla::VectorString(shape.dim_sizes())
|
||||
<< ", dynamic_dims=" << xla::VectorString(dims_are_dynamic);
|
||||
return;
|
||||
}
|
||||
VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to "
|
||||
<< xla::VectorString(shape.dim_sizes())
|
||||
<< ", dynamic_dim=" << dynamic_dimension;
|
||||
}
|
||||
// Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
|
||||
// in XLA to know which output dimension is dynamic.
|
||||
ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
|
||||
ctx->Input(0), shape.dim_sizes(), dynamic_dimension));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1083,6 +1083,36 @@ XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
|
||||
absl::Span<const XlaOp> dim_sizes,
|
||||
absl::Span<const int64> new_size_bounds,
|
||||
const std::vector<bool>& dims_are_dynamic) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
std::vector<const Shape*> dim_size_shape_ptrs;
|
||||
TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
|
||||
GetOperandShapes(dim_sizes));
|
||||
|
||||
absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(const Shape shape,
|
||||
ShapeInference::InferDynamicReshapeShape(
|
||||
*operand_shape, dim_size_shape_ptrs,
|
||||
new_size_bounds, dims_are_dynamic));
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
std::vector<XlaOp> operands;
|
||||
operands.reserve(1 + dim_sizes.size());
|
||||
operands.push_back(operand);
|
||||
for (const XlaOp& dim_size : dim_sizes) {
|
||||
operands.push_back(dim_size);
|
||||
}
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
|
||||
operands);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (dimensions.size() <= 1) {
|
||||
@ -3466,6 +3496,13 @@ XlaOp Reshape(const Shape& shape, XlaOp operand) {
|
||||
return operand.builder()->Reshape(shape, operand);
|
||||
}
|
||||
|
||||
XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
|
||||
absl::Span<const int64> new_size_bounds,
|
||||
const std::vector<bool>& dims_are_dynamic) {
|
||||
return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
|
||||
dims_are_dynamic);
|
||||
}
|
||||
|
||||
XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
||||
absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension) {
|
||||
|
@ -454,6 +454,10 @@ class XlaBuilder {
|
||||
XlaOp Reshape(const Shape& shape, XlaOp operand,
|
||||
int64 inferred_dimension = -1);
|
||||
|
||||
XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
|
||||
absl::Span<const int64> new_size_bounds,
|
||||
const std::vector<bool>& dims_are_dynamic);
|
||||
|
||||
XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
|
||||
|
||||
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
|
||||
@ -940,6 +944,10 @@ class XlaBuilder {
|
||||
|
||||
friend XlaOp Reshape(const Shape& shape, XlaOp operand);
|
||||
|
||||
friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
|
||||
absl::Span<const int64> new_size_bounds,
|
||||
const std::vector<bool>& dims_are_dynamic);
|
||||
|
||||
friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
||||
absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension);
|
||||
@ -1453,9 +1461,16 @@ XlaOp Pad(XlaOp operand, XlaOp padding_value,
|
||||
XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes);
|
||||
|
||||
// Enqueues an operation onto the computation that collapses the operand, from
|
||||
// first to last dimension (C order), then reshapes it to the given dimension
|
||||
// sizes. Conceptually, this is a limited form of "shape casting".
|
||||
// Enqueues a dynamic reshape operation. The dynamic reshape takes additional
|
||||
// XlaOps as sizes for the result dimension. The result dim i is a dynamic
|
||||
// dimension dimension if dims_are_dynamic[i] is true.
|
||||
XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
|
||||
absl::Span<const int64> new_size_bounds,
|
||||
const std::vector<bool>& dims_are_dynamic);
|
||||
|
||||
// Enqueues an operation onto the computation that collapses the operand,
|
||||
// from first to last dimension (C order), then reshapes it to the given
|
||||
// dimension sizes. Conceptually, this is a limited form of "shape casting".
|
||||
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
|
||||
|
||||
// Enqueues a Reshape op that uses an explicit target shape.
|
||||
|
@ -245,6 +245,7 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleFusion(HloInstructionPtr hlo) = 0;
|
||||
|
@ -198,6 +198,9 @@ class DfsHloVisitorWithDefaultBase
|
||||
Status HandlePad(HloInstructionPtr pad) override {
|
||||
return DefaultAction(pad);
|
||||
}
|
||||
Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override {
|
||||
return DefaultAction(dynamic_reshape);
|
||||
}
|
||||
Status HandleReshape(HloInstructionPtr reshape) override {
|
||||
return DefaultAction(reshape);
|
||||
}
|
||||
|
@ -97,6 +97,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleTranspose(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleDynamicReshape(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleReshape(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleSort(HloInstruction* hlo) override;
|
||||
@ -621,6 +623,18 @@ Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
|
||||
return PassThroughDynamicDimension(hlo);
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
|
||||
HloInstruction* hlo) {
|
||||
HloDynamicReshapeInstruction* dynamic_reshape =
|
||||
Cast<HloDynamicReshapeInstruction>(hlo);
|
||||
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
|
||||
if (hlo->shape().is_dynamic_dimension(i)) {
|
||||
parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo,
|
||||
|
@ -1248,5 +1248,34 @@ TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) {
|
||||
EXPECT_TRUE(handler_called);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto input = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {9}), "data_input"));
|
||||
auto six = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
|
||||
// Creates an input of shape [<=9], dynamic size is 6.
|
||||
auto dynamic_input =
|
||||
builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
|
||||
ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0));
|
||||
auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(S32, {}), "size_param"));
|
||||
auto three = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
|
||||
|
||||
// Reshape [<=9] into [3, <=3]
|
||||
|
||||
auto dynamic_reshape =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicReshape(
|
||||
ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input,
|
||||
{three, dynamic_size}));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(RunInference());
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr);
|
||||
EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1290,6 +1290,18 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->opcode() == HloOpcode::kDynamicReshape) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
|
||||
auto* static_reshape =
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
inst->shape(), inst->mutable_operand(0)));
|
||||
TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_reshape));
|
||||
TF_RETURN_IF_ERROR(dynamic_dimension_inference.ForwardDynamicSize(
|
||||
inst, static_reshape, {}));
|
||||
continue;
|
||||
}
|
||||
for (int64 operand_num = 0; operand_num < inst->operand_count();
|
||||
++operand_num) {
|
||||
HloInstruction* original_operand = inst->mutable_operand(operand_num);
|
||||
|
@ -379,6 +379,13 @@ class ExecutionTest : public HloTestBase {
|
||||
Literal PadAndExecute(std::unique_ptr<HloModule> module,
|
||||
absl::Span<Literal* const> arguments,
|
||||
bool slice_dynamic_output = true) {
|
||||
if (!slice_dynamic_output) {
|
||||
auto new_config = module->config();
|
||||
new_config.mutable_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->ClearDynamicShape();
|
||||
module->set_config(new_config);
|
||||
}
|
||||
DynamicPadder padder(slice_dynamic_output);
|
||||
TF_CHECK_OK(padder.Run(module.get()).status());
|
||||
HloDCE dce;
|
||||
@ -1176,6 +1183,84 @@ ENTRY main {
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicReshapeDoubleDynamicDimensions) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
||||
ENTRY main {
|
||||
param = s32[2, 3, 3] parameter(0)
|
||||
size = s32[] constant(2)
|
||||
param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, size),
|
||||
dimensions={1}
|
||||
param_padded = s32[2, <=3, <=3] set-dimension-size(param_padded_partial, size),
|
||||
dimensions={2}
|
||||
result_size = s32[] constant(8)
|
||||
ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, result_size)
|
||||
}
|
||||
)";
|
||||
|
||||
// First dimension (1) is dynamic. Since dynamic size is 0, result is also 0.
|
||||
Literal operand = LiteralUtil::CreateR3<int32>(
|
||||
{{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}});
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand}, false);
|
||||
result.SetDynamicSize(0, 8);
|
||||
// Padded data looks like this (P is padding which is ignored).
|
||||
// [[0, 1, P]
|
||||
// [3, 4, P]
|
||||
// [P, P, P]]
|
||||
//
|
||||
// [[0, 1, P]
|
||||
// [3, 4, P]
|
||||
// [P, P, P]]
|
||||
//
|
||||
// Reshaping (with correct reshape rewriting) produces:
|
||||
// [0, 1, 3, 4, 0, 1, 3, 4]
|
||||
Literal expected = LiteralUtil::CreateR1<int32>({0, 1, 3, 4, 0, 1, 3, 4});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, DynamicReshapeOutputDoubleDynamicDimensions) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
||||
ENTRY main {
|
||||
param = s32[18] parameter(0)
|
||||
eight = s32[] constant(8)
|
||||
param_dynamic = s32[<=18] set-dimension-size(param, eight), dimensions={0}
|
||||
two = s32[] constant(2)
|
||||
// every dimension has dynamic size two.
|
||||
ROOT reshaped = s32[2, <=3, <=3] dynamic-reshape(param_dynamic, two, two, two)
|
||||
}
|
||||
)";
|
||||
Literal operand = LiteralUtil::CreateR1<int32>(
|
||||
{0, 1, 3, 4, 0, 1, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1});
|
||||
|
||||
auto module = GetHloModule(hlo_text);
|
||||
|
||||
Literal result = PadAndExecute(std::move(module), {&operand}, false);
|
||||
|
||||
result.SetDynamicSize(1, 2);
|
||||
result.SetDynamicSize(2, 2);
|
||||
// Padded operand is:
|
||||
// [0, 1, 3, 4, 0, 1, 3, 4, P, P ....]
|
||||
//
|
||||
// Reshaping it should produce:
|
||||
// [[0, 1, P]
|
||||
// [3, 4, P]
|
||||
// [P, P, P]]
|
||||
//
|
||||
// [[0, 1, P]
|
||||
// [3, 4, P]
|
||||
// [P, P, P]]
|
||||
Literal expected =
|
||||
LiteralUtil::CreateR3<int32>({{{0, 1}, {3, 4}}, {{0, 1}, {3, 4}}});
|
||||
|
||||
EXPECT_EQ(result, expected);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ExecutionTest, SetGetDimensionSize) {
|
||||
const string hlo_text = R"(
|
||||
HloModule TensorFlowScatterV1
|
||||
|
@ -486,6 +486,10 @@ Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) {
|
||||
// TODO(b/62294698): Implement cost analysis for batch-norm-training.
|
||||
return Status::OK();
|
||||
|
@ -113,6 +113,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
Status HandleBroadcast(const HloInstruction* broadcast) override;
|
||||
Status HandlePad(const HloInstruction* pad) override;
|
||||
Status HandleReshape(const HloInstruction* reshape) override;
|
||||
Status HandleDynamicReshape(const HloInstruction* reshape) override;
|
||||
Status HandleAddDependency(const HloInstruction* add_dependency) override;
|
||||
Status HandleAfterAll(const HloInstruction* token) override;
|
||||
Status HandleTranspose(const HloInstruction* transpose) override;
|
||||
|
@ -1012,6 +1012,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kGather:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kReverse:
|
||||
case HloOpcode::kTupleSelect:
|
||||
case HloOpcode::kTranspose:
|
||||
|
@ -700,6 +700,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
instruction = CreateReshape(shape, operands(0), inferred_dimension);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDynamicReshape: {
|
||||
TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
|
||||
ShapeUtil::ElementsIn(shape) ==
|
||||
ShapeUtil::ElementsIn(operands(0)->shape()))
|
||||
<< "shape: " << ShapeUtil::HumanString(shape)
|
||||
<< " operand: " << ShapeUtil::HumanString(operands(0)->shape());
|
||||
const auto& operand_vector = all_operands();
|
||||
instruction = CreateDynamicReshape(
|
||||
shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
|
||||
for (const int64 operand_id : proto.operand_ids()) {
|
||||
@ -1373,6 +1384,19 @@ HloInstruction::CreateBroadcastSequence(
|
||||
inferred_dimension);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateDynamicReshape(
|
||||
const Shape& shape, HloInstruction* data_operand,
|
||||
absl::Span<HloInstruction* const> dim_sizes) {
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(shape),
|
||||
ShapeUtil::ElementsIn(data_operand[0].shape()))
|
||||
<< "shape: " << ShapeUtil::HumanString(shape)
|
||||
<< " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
|
||||
CHECK_EQ(shape.rank(), dim_sizes.size());
|
||||
return absl::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
|
||||
dim_sizes);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
absl::Span<const int64> dimensions) {
|
||||
@ -1569,6 +1593,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kConstant:
|
||||
@ -2007,6 +2032,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kReplicaId:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kRsqrt:
|
||||
@ -2812,7 +2838,8 @@ HloInstructionProto HloInstruction::ToProto() const {
|
||||
|
||||
string HloInstruction::ToCategory() const {
|
||||
if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
|
||||
opcode() == HloOpcode::kReshape) {
|
||||
opcode() == HloOpcode::kReshape ||
|
||||
opcode() == HloOpcode::kDynamicReshape) {
|
||||
return "data formatting";
|
||||
}
|
||||
|
||||
@ -3033,6 +3060,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandlePad(this);
|
||||
case HloOpcode::kReshape:
|
||||
return visitor->HandleReshape(this);
|
||||
case HloOpcode::kDynamicReshape:
|
||||
return visitor->HandleDynamicReshape(this);
|
||||
case HloOpcode::kTranspose:
|
||||
return visitor->HandleTranspose(this);
|
||||
case HloOpcode::kReverse:
|
||||
|
@ -879,6 +879,14 @@ class HloInstruction {
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
int64 inferred_dimension = -1);
|
||||
|
||||
// Creates a dynamic reshape instruction. Similar to reshape but dynamic
|
||||
// dimensions sizes are provided as additional variadic arguments.
|
||||
//
|
||||
// Precondition: dim_sizes.size() == shape.rank()
|
||||
static std::unique_ptr<HloInstruction> CreateDynamicReshape(
|
||||
const Shape& shape, HloInstruction* data_operand,
|
||||
absl::Span<HloInstruction* const> dim_sizes);
|
||||
|
||||
// Creates a transpose instruction which permutes the operand dimensions.
|
||||
static std::unique_ptr<HloInstruction> CreateTranspose(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
|
@ -1027,6 +1027,16 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl(
|
||||
dimensions());
|
||||
}
|
||||
|
||||
HloDynamicReshapeInstruction::HloDynamicReshapeInstruction(
|
||||
const Shape& shape, HloInstruction* data_operand,
|
||||
absl::Span<HloInstruction* const> dim_sizes)
|
||||
: HloInstruction(HloOpcode::kDynamicReshape, shape) {
|
||||
AppendOperand(data_operand);
|
||||
for (auto operand : dim_sizes) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
}
|
||||
|
||||
HloReshapeInstruction::HloReshapeInstruction(const Shape& shape,
|
||||
HloInstruction* operand,
|
||||
int64 inferred_dimension)
|
||||
|
@ -679,6 +679,21 @@ class HloBroadcastInstruction : public HloInstruction {
|
||||
std::vector<int64> dimensions_;
|
||||
};
|
||||
|
||||
class HloDynamicReshapeInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloDynamicReshapeInstruction(
|
||||
const Shape& shape, HloInstruction* data_operand,
|
||||
absl::Span<HloInstruction* const> dim_sizes);
|
||||
|
||||
// Returns the input dim sizes dimensions, which is operands[1:]
|
||||
absl::Span<HloInstruction* const> dim_sizes() const {
|
||||
return absl::MakeSpan(operands()).subspan(1, operand_count());
|
||||
}
|
||||
|
||||
// Returns the input dim size dimension, which is operands[1+i]
|
||||
HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; }
|
||||
};
|
||||
|
||||
class HloReshapeInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
|
||||
|
@ -123,6 +123,7 @@ namespace xla {
|
||||
V(kRemainder, "remainder", 2) \
|
||||
V(kReplicaId, "replica-id", 0) \
|
||||
V(kReshape, "reshape", 1) \
|
||||
V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \
|
||||
V(kReverse, "reverse", 1) \
|
||||
V(kRng, "rng", kHloOpcodeIsVariadic) \
|
||||
V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \
|
||||
|
@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kReduce:
|
||||
|
@ -1108,6 +1108,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
builder->AddInstruction(HloInstruction::CreatePartitionId());
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDynamicReshape: {
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateDynamicReshape(
|
||||
shape, operands[0],
|
||||
absl::Span<HloInstruction* const>(operands).subspan(1)));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReshape: {
|
||||
optional<int64> inferred_dimension;
|
||||
attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
|
||||
|
@ -703,6 +703,20 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) {
|
||||
// Check for mixed precision.
|
||||
const Shape& operand_shape = dynamic_reshape->operand(0)->shape();
|
||||
TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape));
|
||||
TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) ==
|
||||
ShapeUtil::ElementsIn(operand_shape));
|
||||
TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 ==
|
||||
dynamic_reshape->operand_count());
|
||||
for (int64 i = 1; i < dynamic_reshape->operand_count(); ++i) {
|
||||
TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
|
||||
// Check for mixed precision.
|
||||
const Shape& operand_shape = reshape->operand(0)->shape();
|
||||
|
@ -78,6 +78,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status HandleBroadcast(HloInstruction* broadcast) override;
|
||||
Status HandleReshape(HloInstruction* reshape) override;
|
||||
Status HandleDynamicReshape(HloInstruction* dynamic_reshape) override;
|
||||
Status HandleTranspose(HloInstruction* transpose) override;
|
||||
Status HandleParameter(HloInstruction*) override;
|
||||
Status HandleFusion(HloInstruction*) override;
|
||||
|
@ -102,6 +102,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kReplicaId:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kReverse:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kSelect:
|
||||
|
@ -2278,6 +2278,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReplicaId:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
|
@ -2825,6 +2825,38 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferDynamicReshapeShape(
|
||||
const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
|
||||
absl::Span<const int64> new_size_bounds,
|
||||
const std::vector<bool>& dims_are_dynamic) {
|
||||
if (new_size_bounds.size() != dims_are_dynamic.size()) {
|
||||
return InvalidArgument(
|
||||
"DynamicReshape has to have the same number of elements in new_sizes "
|
||||
"(%d) and dims_are_dynamic (%d)",
|
||||
new_size_bounds.size(), dims_are_dynamic.size());
|
||||
}
|
||||
|
||||
for (const Shape* dim_size_shape : dim_size_shapes) {
|
||||
if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) {
|
||||
return InvalidArgument(
|
||||
"DynamicReshape's dim size has to be scalar S32, got (%s): ",
|
||||
dim_size_shape->ToString());
|
||||
}
|
||||
}
|
||||
|
||||
Shape inferred_shape = ShapeUtil::MakeShape(
|
||||
operand.element_type(), new_size_bounds, dims_are_dynamic);
|
||||
if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
|
||||
return InvalidArgument(
|
||||
"Reshape operation has mismatched element counts: from=%d (%s) "
|
||||
"to=%d (%s).",
|
||||
ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand),
|
||||
ShapeUtil::ElementsIn(inferred_shape),
|
||||
ShapeUtil::HumanString(inferred_shape));
|
||||
}
|
||||
return inferred_shape;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
|
||||
const Shape& operand, absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes, int64 inferred_dimension) {
|
||||
|
@ -241,6 +241,15 @@ class ShapeInference {
|
||||
absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension);
|
||||
|
||||
// Infers the shape produced by a dynamic reshape operation from the element
|
||||
// type of its operand and the new dimension sizes specified. The result shape
|
||||
// will have dynamic dimensions as specific in `dim_is_dynamic` and bound
|
||||
// `new_size_bounds`.
|
||||
static StatusOr<Shape> InferDynamicReshapeShape(
|
||||
const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
|
||||
absl::Span<const int64> new_size_bounds,
|
||||
const std::vector<bool>& dims_are_dynamic);
|
||||
|
||||
// Infers the shape produced by a transpose operation from the element type of
|
||||
// its operand and its dimensions field.
|
||||
static StatusOr<Shape> InferTransposeShape(
|
||||
|
@ -387,6 +387,7 @@ const HloInstruction* PickRepresentativeOperand(
|
||||
case HloOpcode::kDot:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
case HloOpcode::kFft:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kGather:
|
||||
|
@ -61,6 +61,10 @@ class ShapeLayout {
|
||||
// Returns the shape (with layouts).
|
||||
const Shape& shape() const { return shape_; }
|
||||
|
||||
// Clear dynamic dimensions of this module. Pretending the module creates
|
||||
// static results. Useful in inspecting full outputs when testing.
|
||||
void ClearDynamicShape() { shape_.clear_dynamic_dimensions(); }
|
||||
|
||||
// Checks that a layout is set for the shape, and returns a reference to the
|
||||
// layout directly on the shape. Shape must not be a tuple.
|
||||
const Layout& layout() const;
|
||||
|
@ -635,8 +635,57 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
mode=["eager"]))
|
||||
def testReshapeWithDynamicInputs(self, distribution):
|
||||
|
||||
def dataset_fn(_):
|
||||
data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32)
|
||||
dataset = get_dataset_from_tensor_slices(data)
|
||||
dataset = dataset.batch(3)
|
||||
return dataset
|
||||
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def step_fn(example):
|
||||
# example: [<=3, 1, 2]
|
||||
# tile: [<=3, <=3, 2]
|
||||
tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1])
|
||||
# reshape1: [<=(3*3 = 9), 2]
|
||||
reshape1 = array_ops.reshape(tile, [-1, 2])
|
||||
|
||||
# reshape2: [<=3, <=3, 2]
|
||||
reshape2 = array_ops.reshape(
|
||||
reshape1,
|
||||
[array_ops.shape(example)[0],
|
||||
array_ops.shape(example)[0], 2])
|
||||
|
||||
# reshape3: [<=3, -1, 2]
|
||||
reshape3 = array_ops.reshape(reshape1,
|
||||
[array_ops.shape(example)[0], -1, 2])
|
||||
# reshape4: [-1, <=3, 2]
|
||||
reshape4 = array_ops.reshape(reshape1,
|
||||
[-1, array_ops.shape(example)[0], 2])
|
||||
return [reshape1, reshape2, reshape3, reshape4]
|
||||
|
||||
# This assumes that there are exactly 2 replicas
|
||||
outputs = distribution.experimental_local_results(
|
||||
distribution.run(step_fn, args=(next(input_iterator),)))
|
||||
self.assertAllEqual((9, 2), outputs[0][0].values[0].shape)
|
||||
self.assertAllEqual((3, 3, 2), outputs[0][1].values[0].shape)
|
||||
self.assertAllEqual((3, 3, 2), outputs[0][2].values[0].shape)
|
||||
self.assertAllEqual((3, 3, 2), outputs[0][3].values[0].shape)
|
||||
|
||||
self.assertAllEqual((4, 2), outputs[0][0].values[1].shape)
|
||||
self.assertAllEqual((2, 2, 2), outputs[0][1].values[1].shape)
|
||||
self.assertAllEqual((2, 2, 2), outputs[0][2].values[1].shape)
|
||||
self.assertAllEqual((2, 2, 2), outputs[0][3].values[1].shape)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
mode=["eager"]))
|
||||
def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution):
|
||||
def dataset_fn(_):
|
||||
dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]])
|
||||
|
Loading…
Reference in New Issue
Block a user