[XLA]Add HLO infrastructure to support dynamic op lowering.
This CL gives additional configuration to dynamic padder, which tells it what ops can have accept dynamic tensors. If an op requires dynamic tensor as input, and a static tensor is presented, a "SliceToDynamic" will be inserted. If an op requires static tensor as input and a dynamic tensor is presented, a "PadToStatic" op will be inserted. If an op requires static tensor and the tensor is already static, dynamic padder will rewrite the op to make it produce the same result as if the tensor is dynamic (this is what we already have today). PiperOrigin-RevId: 309326119 Change-Id: I5376674d6acf9905af1b7e09b127811b57517e97
This commit is contained in:
parent
9c3f0435ad
commit
f6bc68ba4c
@ -1620,6 +1620,24 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool DynamicDimensionInference::HasDynamicDimension(
|
||||
HloInstruction* inst) const {
|
||||
bool has_dynamic_dim = false;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (subshape.IsTuple()) {
|
||||
return;
|
||||
}
|
||||
for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
|
||||
HloInstruction* operand_dynamic_size = GetDynamicSize(inst, index, i);
|
||||
if (operand_dynamic_size != nullptr) {
|
||||
has_dynamic_dim = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
return has_dynamic_dim;
|
||||
}
|
||||
|
||||
HloInstruction* DynamicDimensionInference::GetDynamicSize(
|
||||
HloInstruction* inst, const ShapeIndex& index, int64 dim) const {
|
||||
auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
|
||||
|
@ -51,6 +51,10 @@ class DynamicDimensionInference {
|
||||
HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index,
|
||||
int64 dim) const;
|
||||
|
||||
// Returns if current instruction contains any dynamic dimension. Recursively
|
||||
// go into tuples.
|
||||
bool HasDynamicDimension(HloInstruction* inst) const;
|
||||
|
||||
// Forward dynamic dimension size at `dim` and its constraint from `inst` to
|
||||
// `new_inst`.
|
||||
Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst,
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -943,106 +944,6 @@ Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For all dynamic outputs that live out of the computation, add
|
||||
// slice-to-dynamic operations.
|
||||
Status InsertSliceToDynamicBeforeModuleOutputs(
|
||||
const DynamicDimensionInference& dynamic_dimension_inference,
|
||||
HloModule* module) {
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
absl::flat_hash_set<ShapeIndex> dynamic_outputs;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (subshape.IsArray()) {
|
||||
bool has_dynamic_output = false;
|
||||
for (int64 dim = 0; dim < subshape.rank(); ++dim) {
|
||||
if (dynamic_dimension_inference.GetDynamicSize(root, index, dim) !=
|
||||
nullptr) {
|
||||
CHECK_LE(index.size(), 1) << "XLA doesn't support nested output "
|
||||
"dimension that has dynamic size";
|
||||
has_dynamic_output = true;
|
||||
}
|
||||
}
|
||||
if (has_dynamic_output) {
|
||||
dynamic_outputs.insert(index);
|
||||
}
|
||||
}
|
||||
});
|
||||
if (!dynamic_outputs.empty()) {
|
||||
if (root->shape().IsTuple()) {
|
||||
std::vector<HloInstruction*> new_root_operands;
|
||||
ShapeUtil::ForEachSubshape(root->shape(), [&](const Shape& subshape,
|
||||
const ShapeIndex& index) {
|
||||
if (!subshape.IsArray()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto gte = module->entry_computation()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::MakeShapeWithStaticDimensions(subshape), root,
|
||||
index[0]));
|
||||
|
||||
if (dynamic_outputs.contains(index)) {
|
||||
CHECK_EQ(index.size(), 1)
|
||||
<< "XLA only support 1 layer nested output tuple";
|
||||
// For dynamic outputs, creates an slice operation.
|
||||
std::vector<HloInstruction*> slice_operands;
|
||||
// First operand is the original input. Rest are dimension values.
|
||||
slice_operands.push_back(gte);
|
||||
// Keep a dynamic version of the subshape as we are removing the
|
||||
// dynamic dimension in the original root and gte.
|
||||
Shape dynamic_subshape = subshape;
|
||||
for (int64 dim = 0; dim < subshape.rank(); ++dim) {
|
||||
HloInstruction* dynamic_size =
|
||||
dynamic_dimension_inference.GetDynamicSize(root, index, dim);
|
||||
if (dynamic_size != nullptr) {
|
||||
slice_operands.push_back(dynamic_size);
|
||||
} else {
|
||||
auto const_size = HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(subshape.dimensions(dim)));
|
||||
slice_operands.push_back(
|
||||
module->entry_computation()->AddInstruction(
|
||||
std::move(const_size)));
|
||||
}
|
||||
}
|
||||
// This is a dynamic output, add slice operation.
|
||||
auto slice = HloInstruction::CreateCustomCall(
|
||||
dynamic_subshape, slice_operands, "SliceToDynamic");
|
||||
new_root_operands.push_back(
|
||||
module->entry_computation()->AddInstruction(std::move(slice)));
|
||||
} else {
|
||||
new_root_operands.push_back(gte);
|
||||
}
|
||||
});
|
||||
|
||||
auto new_root = module->entry_computation()->AddInstruction(
|
||||
HloInstruction::CreateTuple(new_root_operands));
|
||||
module->entry_computation()->set_root_instruction(new_root);
|
||||
} else {
|
||||
std::vector<HloInstruction*> slice_operands;
|
||||
// First operand is the original input. Rest are dimension values.
|
||||
slice_operands.push_back(root);
|
||||
for (int64 dim = 0; dim < root->shape().rank(); ++dim) {
|
||||
HloInstruction* dynamic_size =
|
||||
dynamic_dimension_inference.GetDynamicSize(root, {}, dim);
|
||||
if (dynamic_size != nullptr) {
|
||||
slice_operands.push_back(dynamic_size);
|
||||
} else {
|
||||
auto const_size = HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(root->shape().dimensions(dim)));
|
||||
slice_operands.push_back(module->entry_computation()->AddInstruction(
|
||||
std::move(const_size)));
|
||||
}
|
||||
// This is a dynamic output, add slice operation.
|
||||
auto slice = module->entry_computation()->AddInstruction(
|
||||
HloInstruction::CreateCustomCall(root->shape(), slice_operands,
|
||||
"SliceToDynamic", "0-0"));
|
||||
module->entry_computation()->set_root_instruction(slice);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
|
||||
//
|
||||
// After this visitor the entry computation then looks like:
|
||||
@ -1059,46 +960,217 @@ Status InsertSliceToDynamicBeforeModuleOutputs(
|
||||
// ROOT tuple (dynamic)
|
||||
class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit DynamicShapeRemovingVisitor(
|
||||
const DynamicPadder::OpSupportsDynamismHandler&
|
||||
op_supports_dynamism_handler,
|
||||
const DynamicDimensionInference& dynamic_dimension_inference)
|
||||
: op_supports_dynamism_handler_(op_supports_dynamism_handler),
|
||||
dynamic_dimension_inference_(dynamic_dimension_inference) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleCustomCall(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleTuple(HloInstruction* hlo) override;
|
||||
Status HandleGetTupleElement(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleParameter(HloInstruction* hlo) override;
|
||||
|
||||
static Status Run(HloComputation* computation) {
|
||||
DynamicShapeRemovingVisitor visitor;
|
||||
return computation->Accept(&visitor);
|
||||
static Status Run(HloComputation* computation,
|
||||
const DynamicPadder::OpSupportsDynamismHandler&
|
||||
op_supports_dynamism_handler,
|
||||
const DynamicDimensionInference& dynamic_shape_inference,
|
||||
bool require_dynamic_output) {
|
||||
DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler,
|
||||
dynamic_shape_inference);
|
||||
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
|
||||
// If the outputs is required to be dynamic form, insert static to dynamic
|
||||
// conversion as root.
|
||||
if (require_dynamic_output) {
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
if (dynamic_shape_inference.HasDynamicDimension(root)) {
|
||||
HloInstruction* new_root = visitor.ConvertToDynamic(root);
|
||||
computation->set_root_instruction(new_root);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// If a tensor produced by `inst` is in dynamic form, convert it to static and
|
||||
// returns the new instruction.
|
||||
HloInstruction* ConvertToStatic(HloInstruction* inst);
|
||||
|
||||
// If a tensor produced by `inst` is in static form, convert it to dynamic and
|
||||
// returns the new instruction.
|
||||
HloInstruction* ConvertToDynamic(HloInstruction* inst);
|
||||
|
||||
const DynamicPadder::OpSupportsDynamismHandler& op_supports_dynamism_handler_;
|
||||
|
||||
const DynamicDimensionInference& dynamic_dimension_inference_;
|
||||
};
|
||||
|
||||
HloInstruction* DynamicShapeRemovingVisitor::ConvertToDynamic(
|
||||
HloInstruction* inst) {
|
||||
auto* comp = inst->parent();
|
||||
const Shape& shape = inst->shape();
|
||||
if (shape.IsTuple()) {
|
||||
std::vector<HloInstruction*> dynamic_operands;
|
||||
for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
|
||||
auto operand = inst->mutable_operand(i);
|
||||
if (dynamic_dimension_inference_.HasDynamicDimension(operand)) {
|
||||
// Recurse.
|
||||
dynamic_operands.push_back(ConvertToDynamic(operand));
|
||||
} else {
|
||||
dynamic_operands.push_back(operand);
|
||||
}
|
||||
}
|
||||
return comp->AddInstruction(HloInstruction::CreateTuple(dynamic_operands));
|
||||
} else {
|
||||
// Collect the data input, as well as dimension sizes, and feed them to
|
||||
// slice to dynamic to create a dynamic tensor.
|
||||
Shape output_shape = shape; // 0th element.
|
||||
CHECK(output_shape.is_static());
|
||||
std::vector<HloInstruction*> slice_operand;
|
||||
slice_operand.push_back(inst);
|
||||
for (int64 i = 0; i < output_shape.dimensions_size(); ++i) {
|
||||
auto dimension_size =
|
||||
dynamic_dimension_inference_.GetDynamicSize(inst, {}, i);
|
||||
if (dimension_size == nullptr) {
|
||||
dimension_size = comp->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(output_shape.dimensions(i))));
|
||||
} else {
|
||||
output_shape.set_dynamic_dimension(i, true);
|
||||
}
|
||||
slice_operand.push_back(dimension_size);
|
||||
}
|
||||
return comp->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
output_shape, slice_operand, "SliceToDynamic"));
|
||||
}
|
||||
}
|
||||
|
||||
HloInstruction* DynamicShapeRemovingVisitor::ConvertToStatic(
|
||||
HloInstruction* inst) {
|
||||
auto* comp = inst->parent();
|
||||
const Shape& shape = inst->shape();
|
||||
CHECK(shape.is_dynamic());
|
||||
if (shape.IsTuple()) {
|
||||
std::vector<HloInstruction*> static_operands;
|
||||
for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
|
||||
auto operand = inst->mutable_operand(i);
|
||||
if (shape.tuple_shapes(i).is_dynamic()) {
|
||||
static_operands.push_back(ConvertToStatic(operand));
|
||||
} else {
|
||||
static_operands.push_back(operand);
|
||||
}
|
||||
}
|
||||
return comp->AddInstruction(HloInstruction::CreateTuple(static_operands));
|
||||
} else {
|
||||
// The output shape of pad static is a tuple. The 0th element is the data
|
||||
// output, which is the same as input shape, but without dynamic dimensions.
|
||||
// i-th element is the dynamic dimension size for i-1th input dimension.
|
||||
Shape data_output_shape = shape; // 0th element.
|
||||
data_output_shape.clear_dynamic_dimensions();
|
||||
Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
|
||||
&output_shape);
|
||||
}
|
||||
HloInstruction* pad_to_static =
|
||||
comp->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
output_shape, {inst}, "PadToStatic", ""));
|
||||
HloInstruction* data_output =
|
||||
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
data_output_shape, pad_to_static, 0));
|
||||
return data_output;
|
||||
}
|
||||
}
|
||||
|
||||
Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
// Default rule: If input to an op is static, remove dynamism in output.
|
||||
bool input_is_dynamic = false;
|
||||
// Default rule:
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (!hlo->operand(i)->shape().is_static()) {
|
||||
input_is_dynamic = true;
|
||||
const bool input_is_dynamic = absl::c_any_of(
|
||||
hlo->operands(),
|
||||
[](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); });
|
||||
|
||||
// By default, ops don't support dynamic lowering.
|
||||
OpDynamismSupport op_support = OpDynamismSupport::kNoSupport;
|
||||
if (op_supports_dynamism_handler_) {
|
||||
op_support = op_supports_dynamism_handler_(hlo);
|
||||
}
|
||||
if (op_support == OpDynamismSupport::kNoSupport) {
|
||||
for (auto* sub_computation : hlo->called_computations()) {
|
||||
for (auto* param : sub_computation->parameter_instructions()) {
|
||||
param->mutable_shape()->clear_dynamic_dimensions();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!input_is_dynamic) {
|
||||
// If the input to an op is static and the op doesn't support
|
||||
// dynamic output, remove dynamism in output -- dynamic_padder should have
|
||||
// rewritten it to support static shapes.
|
||||
if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
|
||||
hlo->mutable_shape()->clear_dynamic_dimensions();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Op doesn't support dynamic tensor: For each operand rewrite dynamic input
|
||||
// into static input using pad_to_static.
|
||||
if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
|
||||
VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString();
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (hlo->operand(i)->shape().is_dynamic()) {
|
||||
auto static_operand = ConvertToStatic(hlo->mutable_operand(i));
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand));
|
||||
}
|
||||
}
|
||||
// This op doesn't support dynamic lowering so the op has to be static.
|
||||
hlo->mutable_shape()->clear_dynamic_dimensions();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If the op requires dynamic tensor and input is static -- construct a
|
||||
// dynamic tensor from the static tensor to feed it.
|
||||
if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) {
|
||||
VLOG(1) << "op doesn't support static tensor: " << hlo->ToString();
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
auto operand = hlo->mutable_operand(i);
|
||||
if (dynamic_dimension_inference_.HasDynamicDimension(operand)) {
|
||||
auto dynamic_operand = ConvertToDynamic(hlo->mutable_operand(i));
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
if (hlo->custom_call_target() == "SliceToDynamic") {
|
||||
// Don't remove slice-to-dynamic instruction.
|
||||
return Status::OK();
|
||||
Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) {
|
||||
*hlo->mutable_shape() =
|
||||
hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) {
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
*hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape();
|
||||
}
|
||||
return DefaultAction(hlo);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
if (hlo->custom_call_target() == "SliceToDynamic" ||
|
||||
hlo->custom_call_target() == "PadToStatic") {
|
||||
// Those ops support are created to handle dynamic tensors so by their
|
||||
// nature they support dynamic lowering.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
@ -1137,11 +1209,20 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
}));
|
||||
|
||||
TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
|
||||
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
|
||||
DynamicDimensionInference::Run(module));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DynamicDimensionInference dynamic_dimension_inference,
|
||||
DynamicDimensionInference::Run(module, custom_call_handler_));
|
||||
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
|
||||
OpDynamismSupport has_dynamism_support = OpDynamismSupport::kNoSupport;
|
||||
if (op_supports_dynamism_handler_ != nullptr) {
|
||||
has_dynamism_support = op_supports_dynamism_handler_(inst);
|
||||
}
|
||||
// This op support dynamic lowering, no padding is required.
|
||||
if (has_dynamism_support != OpDynamismSupport::kNoSupport) {
|
||||
continue;
|
||||
}
|
||||
if (inst->opcode() == HloOpcode::kConcatenate) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference));
|
||||
@ -1152,6 +1233,11 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
changed, RewriteDynamicSort(inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
if (inst->opcode() == HloOpcode::kReshape) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
for (int64 operand_num = 0; operand_num < inst->operand_count();
|
||||
++operand_num) {
|
||||
HloInstruction* original_operand = inst->mutable_operand(operand_num);
|
||||
@ -1160,11 +1246,6 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (inst->opcode() == HloOpcode::kReshape) {
|
||||
TF_ASSIGN_OR_RETURN(changed, RewriteDynamicReshape(
|
||||
inst, &dynamic_dimension_inference));
|
||||
continue;
|
||||
}
|
||||
for (int64 input_dim = 0; input_dim < operand->shape().rank();
|
||||
++input_dim) {
|
||||
HloInstruction* operand_dynamic_size =
|
||||
@ -1195,37 +1276,28 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (slice_dynamic_output_) {
|
||||
TF_RETURN_IF_ERROR(InsertSliceToDynamicBeforeModuleOutputs(
|
||||
dynamic_dimension_inference, module));
|
||||
}
|
||||
|
||||
// Remove all dynamic dimensions after entry parameter and root instruction --
|
||||
// Dynamic padder will produce an equivalent static shaped graph.
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
if (computation == module->entry_computation()) {
|
||||
TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(computation));
|
||||
} else {
|
||||
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
|
||||
bool operand_is_dynamic = false;
|
||||
for (auto* operand : inst->operands()) {
|
||||
if (!operand->shape().is_static()) {
|
||||
operand_is_dynamic = true;
|
||||
}
|
||||
}
|
||||
if (!operand_is_dynamic) {
|
||||
inst->mutable_shape()->clear_dynamic_dimensions();
|
||||
}
|
||||
}
|
||||
}
|
||||
// There are ops that only support dynamic lowering and ops that only support
|
||||
// static lowering, add dynamic<->static tensor conversion around the boundary
|
||||
// between those ops, as well as the root instruction.
|
||||
auto computations = module->MakeComputationPostOrder();
|
||||
// Reverse postorder so that if caller doesn't support dynamic tensor (while,
|
||||
// etc), change their called computation to only take static tensors.
|
||||
for (auto it = computations.rbegin(); it != computations.rend(); ++it) {
|
||||
HloComputation* computation = *it;
|
||||
// if slice_dynamic_output_ is set and this is entry computation, we need
|
||||
// the output tensor to be in dynamic form.
|
||||
bool require_dynamic_output =
|
||||
slice_dynamic_output_ && computation == module->entry_computation();
|
||||
TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(
|
||||
computation, op_supports_dynamism_handler_, dynamic_dimension_inference,
|
||||
/*require_dynamic_output=*/require_dynamic_output));
|
||||
}
|
||||
|
||||
HloDCE dce;
|
||||
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
|
||||
|
||||
VLOG(2) << "Post DynamicPadder HLO:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
@ -36,12 +36,38 @@ namespace xla {
|
||||
// Dynamic_padder removes dynamic shapes from the entry computation, and inserts
|
||||
// custom calls (with dynamic shapes), which are lowered by specialized
|
||||
// emitters: PadToStatic and SliceToDynamic.
|
||||
|
||||
// Each instruction can have one of the three modes in supporting dynamic
|
||||
// lowering.
|
||||
enum OpDynamismSupport {
|
||||
// There is no support for dynamic lowering -- dynamic padder will make sure
|
||||
// the input to that op has static bound by rewriting the op (e.g, extra space
|
||||
// in reduce_sum will be padded with 0).
|
||||
kNoSupport = 0,
|
||||
// The op can take either dynamic input or static input.
|
||||
kOptional,
|
||||
// The op only has a dynamic lowering, dynamic padder will make sure the input
|
||||
// to this op is in dynamic form.
|
||||
kRequired,
|
||||
};
|
||||
|
||||
class DynamicPadder : public HloModulePass {
|
||||
public:
|
||||
// Returns true if given instruction supports native dynamic lowering. If so,
|
||||
// dynamic padder will not attempt to pad it.
|
||||
using OpSupportsDynamismHandler =
|
||||
std::function<OpDynamismSupport(HloInstruction*)>;
|
||||
|
||||
// If `slice_dynamic_output` is true, insert 'slice_to_dynamic' ops to all
|
||||
// outputs that are inferred to be dynamic.
|
||||
explicit DynamicPadder(bool slice_dynamic_output = true)
|
||||
: slice_dynamic_output_(slice_dynamic_output) {}
|
||||
explicit DynamicPadder(
|
||||
bool slice_dynamic_output = true,
|
||||
DynamicDimensionInference::CustomCallInferenceHandler
|
||||
custom_call_handler = nullptr,
|
||||
OpSupportsDynamismHandler op_supports_dynamism_handler = nullptr)
|
||||
: slice_dynamic_output_(slice_dynamic_output),
|
||||
custom_call_handler_(custom_call_handler),
|
||||
op_supports_dynamism_handler_(op_supports_dynamism_handler) {}
|
||||
|
||||
absl::string_view name() const override { return "dynamic_padder"; }
|
||||
|
||||
@ -51,6 +77,13 @@ class DynamicPadder : public HloModulePass {
|
||||
// Insert 'slice_to_dynamic' ops to all outputs that are inferred to be
|
||||
// dynamic.
|
||||
bool slice_dynamic_output_;
|
||||
|
||||
// A handler for dynamic dimension inference of custom calls.
|
||||
DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_;
|
||||
|
||||
// A handler to indicate if a given hlo instruction support native dynamism
|
||||
// lowering.
|
||||
OpSupportsDynamismHandler op_supports_dynamism_handler_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -44,12 +44,49 @@ namespace op = xla::testing::opcode_matchers;
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
OpDynamismSupport OpHasDynamismSupport(HloInstruction* hlo) {
|
||||
if (hlo->opcode() != HloOpcode::kCustomCall) {
|
||||
return OpDynamismSupport::kNoSupport;
|
||||
}
|
||||
if (hlo->custom_call_target() == "OpWithDynamicLowering") {
|
||||
return OpDynamismSupport::kRequired;
|
||||
}
|
||||
return OpDynamismSupport::kNoSupport;
|
||||
}
|
||||
|
||||
Status CustomCallDynamicDimensionInference(
|
||||
HloInstruction* hlo, DynamicDimensionInference* inferencer) {
|
||||
if (hlo->custom_call_target() == "OpWithDynamicLowering") {
|
||||
if (hlo->shape().IsTuple()) {
|
||||
// Use the operand's dynamic size as output dynamic size.
|
||||
HloInstruction* dynamic_size =
|
||||
inferencer->GetDynamicSize(hlo->mutable_operand(0), {1}, 0);
|
||||
inferencer->SetDynamicSize(hlo, {1}, 0, dynamic_size);
|
||||
} else {
|
||||
// Use the operand's dynamic size as output dynamic size.
|
||||
HloInstruction* dynamic_size =
|
||||
inferencer->GetDynamicSize(hlo->mutable_operand(0), {}, 0);
|
||||
inferencer->SetDynamicSize(hlo, {}, 0, dynamic_size);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class DynamicPadderTest : public HloTestBase {
|
||||
protected:
|
||||
DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); }
|
||||
|
||||
std::unique_ptr<HloModule> GetHloModule(const string& hlo_text) {
|
||||
std::unique_ptr<HloModule> module =
|
||||
ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
|
||||
return module;
|
||||
}
|
||||
|
||||
StatusOr<bool> RunPadder() {
|
||||
DynamicPadder padder;
|
||||
DynamicPadder padder(/*slice_dynamic_output=*/true,
|
||||
CustomCallDynamicDimensionInference,
|
||||
OpHasDynamismSupport);
|
||||
return padder.Run(module_.get());
|
||||
}
|
||||
|
||||
@ -105,6 +142,120 @@ TEST_F(DynamicPadderTest, ReduceTest) {
|
||||
ExpectPadded(reduce->operand(0));
|
||||
}
|
||||
|
||||
TEST_F(DynamicPadderTest, DynamicLoweringTest) {
|
||||
const string hlo_text = R"(
|
||||
HloModule DynamicLowering
|
||||
|
||||
ENTRY main {
|
||||
param = s32[5] parameter(0)
|
||||
const = s32[] constant(3)
|
||||
param_padded = s32[<=5] set-dimension-size(param, const),
|
||||
dimensions={0}
|
||||
custom-call.1 = s32[<=5] custom-call(param_padded),
|
||||
custom_call_target="OpWithDynamicLowering"
|
||||
custom-call.2 = s32[<=5] custom-call(custom-call.1),
|
||||
custom_call_target="OpWithDynamicLowering"
|
||||
// Negate doesn't support dynamic lowering.
|
||||
ROOT negate = s32[<=5] negate(custom-call.2)
|
||||
}
|
||||
)";
|
||||
|
||||
module_ = GetHloModule(hlo_text);
|
||||
|
||||
TF_ASSERT_OK(RunPadder().status());
|
||||
// After rewrite, we should have :
|
||||
//
|
||||
// param
|
||||
// |
|
||||
// SliceToDynamic
|
||||
// |
|
||||
// OpWithDynamicLowering (custom_call_1)
|
||||
// |
|
||||
// OpWithDynamicLowering (custom_call_2)
|
||||
// |
|
||||
// PadToStatic
|
||||
// |
|
||||
// Negate
|
||||
// |
|
||||
// SliceToDynamic // Root require dynamic form tensor.
|
||||
auto custom_call_1 =
|
||||
module_->entry_computation()->GetInstructionWithName("custom-call.1");
|
||||
auto custom_call_2 =
|
||||
module_->entry_computation()->GetInstructionWithName("custom-call.2");
|
||||
// Test that the input to custom call
|
||||
HloInstruction* slice_to_dynamic = custom_call_1->mutable_operand(0);
|
||||
ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall);
|
||||
ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic");
|
||||
ASSERT_EQ(custom_call_2->user_count(), 1);
|
||||
HloInstruction* pad_to_static = custom_call_2->users()[0];
|
||||
ASSERT_THAT(pad_to_static->opcode(), HloOpcode::kCustomCall);
|
||||
ASSERT_THAT(pad_to_static->custom_call_target(), "PadToStatic");
|
||||
slice_to_dynamic = module_->entry_computation()->root_instruction();
|
||||
ASSERT_THAT(slice_to_dynamic->opcode(), HloOpcode::kCustomCall);
|
||||
ASSERT_THAT(slice_to_dynamic->custom_call_target(), "SliceToDynamic");
|
||||
}
|
||||
|
||||
TEST_F(DynamicPadderTest, DynamicLoweringTestTupleInput) {
|
||||
const string hlo_text = R"(
|
||||
HloModule DynamicLowering
|
||||
|
||||
ENTRY main {
|
||||
param = s32[5] parameter(0)
|
||||
const = s32[] constant(3)
|
||||
param_padded = s32[<=5] set-dimension-size(param, const),
|
||||
dimensions={0}
|
||||
// Create a tuple with static and dynamic componenet.
|
||||
tuple_arg = (s32[], s32[<=5]) tuple(const, param_padded)
|
||||
custom-call.1 = (s32[], s32[<=5]) custom-call(tuple_arg),
|
||||
custom_call_target="OpWithDynamicLowering"
|
||||
custom-call.2 = (s32[], s32[<=5]) custom-call(custom-call.1),
|
||||
custom_call_target="OpWithDynamicLowering"
|
||||
data = s32[<=5]{0} get-tuple-element(custom-call.2), index=1
|
||||
// Negate doesn't support dynamic lowering.
|
||||
ROOT negate = s32[<=5] negate(data)
|
||||
}
|
||||
)";
|
||||
|
||||
module_ = GetHloModule(hlo_text);
|
||||
|
||||
TF_ASSERT_OK(RunPadder().status());
|
||||
// After rewrite, we should have :
|
||||
//
|
||||
// param
|
||||
// |
|
||||
// SliceToDynamic
|
||||
// |
|
||||
// Tuple
|
||||
// |
|
||||
// OpWithDynamicLowering (custom_call_1)
|
||||
// |
|
||||
// OpWithDynamicLowering (custom_call_2)
|
||||
// |
|
||||
// GTE
|
||||
// |
|
||||
// PadToStatic
|
||||
// |
|
||||
// Negate
|
||||
// |
|
||||
// SliceToDynamic // Root require dynamic form tensor.
|
||||
|
||||
auto* root = module_->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root,
|
||||
op::CustomCall("SliceToDynamic", op::Negate(), op::Constant()));
|
||||
HloInstruction* negate = root->mutable_operand(0);
|
||||
EXPECT_THAT(
|
||||
negate,
|
||||
op::Negate(op::GetTupleElement(op::CustomCall(
|
||||
"PadToStatic", op::GetTupleElement(op::CustomCall(
|
||||
"OpWithDynamicLowering", ::testing::_))))));
|
||||
auto custom_call_1 =
|
||||
module_->entry_computation()->GetInstructionWithName("custom-call.1");
|
||||
EXPECT_THAT(custom_call_1,
|
||||
op::CustomCall(
|
||||
"OpWithDynamicLowering",
|
||||
op::Tuple(op::Constant(), op::CustomCall("SliceToDynamic"))));
|
||||
}
|
||||
|
||||
TEST_F(DynamicPadderTest, ConvolutionTest) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
constexpr int xdim = 3;
|
||||
|
@ -63,6 +63,8 @@ class Shape {
|
||||
// shapes are traversed recursively.
|
||||
bool is_static() const;
|
||||
|
||||
bool is_dynamic() const { return !is_static(); }
|
||||
|
||||
// Returns true if the given dimension is dynamically-sized.
|
||||
bool is_dynamic_dimension(int dimension) const {
|
||||
return dynamic_dimensions_.at(dimension);
|
||||
|
Loading…
Reference in New Issue
Block a user