[XLA] Implement dynamic input and output in DynamicPadder.

PiperOrigin-RevId: 285218938
Change-Id: Ia5f5e5ad62154b1427d2b56640eed3a443a50c1d
This commit is contained in:
Yunxing Dai 2019-12-12 10:27:36 -08:00 committed by TensorFlower Gardener
parent 267488da06
commit 2f28e151b3
11 changed files with 268 additions and 59 deletions

View File

@ -237,12 +237,12 @@ class XlaBuilder {
// TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
// dynamic dimensions information when XLA backend can handle dynamic
// dimensions.
StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = true);
StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
// Overload of Build which specifies a particular root instruction for the
// computation.
StatusOr<XlaComputation> Build(XlaOp root,
bool remove_dynamic_dimensions = true);
bool remove_dynamic_dimensions = false);
// Builds the computation with the requested operations, or notes an error in
// the parent XlaBuilder and returns an empty computation if building failed.

View File

@ -298,10 +298,12 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
const Shape& expected_shape =
entry_comp->parameter_instruction(i)->shape();
const Shape& actual_shape = arguments[i].shape();
CHECK(expected_shape == actual_shape) << absl::StreamFormat(
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
expected_shape.ToString(/*print_layout=*/true),
actual_shape.ToString(/*print_layout=*/true));
CHECK(
Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape))
<< absl::StreamFormat(
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
expected_shape.ToString(/*print_layout=*/true),
actual_shape.ToString(/*print_layout=*/true));
}
}

View File

@ -178,15 +178,32 @@ Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
}
Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
if (hlo->custom_call_target() == "PadToStatic") {
for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
HloInstruction* dynamic_size =
hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
// PadToStatic converts a dynamic dimension to static dimension. It then
// returns the padded data output and the dynamic sizes of input
// dimensions.
ShapeIndex data_output = {0};
parent_->SetDynamicSize(hlo, data_output, i, dynamic_size,
{.stride = 1, .multiple_of = 1});
}
}
return Status::OK();
}
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
if (hlo->custom_call_target() != "Unpad" ||
if (hlo->custom_call_target() != "SliceToDynamic" ||
absl::StartsWith(hlo->custom_call_target(), "Resize")) {
return Unimplemented(
"CustomCall is not supported to have a dynamic dimension");
}
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
return Status::OK();
});

View File

@ -23,7 +23,9 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -169,7 +171,7 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim,
return padded;
}
// In a reshape if a dynamci dimension is splitted into multiple output
// In a reshape if a dynamic dimension is splitted into multiple output
// dimensions, we need to rewrite the input of the reshape.
//
// The reason for this is that a continuous input may not be evenly reshaped
@ -641,9 +643,77 @@ StatusOr<bool> RewriteDynamicReshape(
return changed;
}
// For all dynamic outputs that live out of the computation, add unpad
// operations.
Status InsertUnpadsForModuleOutputs(
// Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
// Recurse into tuple instructions.
StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
if (inst->shape().is_static()) {
return inst;
}
HloComputation* comp = inst->parent();
if (!inst->shape().IsTuple()) {
// The output shape of pad static is a tuple. The 0th element is the data
// output, which is the same as input shape, but without dynamic dimensions;
// i-th element is the dynamic dimension size for i-1th input dimension.
Shape data_output_shape = inst->shape(); // 0th element.
data_output_shape.clear_dynamic_dimensions();
Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
for (int64 i = 0; i < inst->shape().rank(); ++i) {
ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
&output_shape);
}
HloInstruction* pad_to_static =
comp->AddInstruction(HloInstruction::CreateCustomCall(
output_shape, {inst}, "PadToStatic", ""));
HloInstruction* data_output =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
data_output_shape, pad_to_static, 0));
return data_output;
}
TF_RET_CHECK(inst->shape().IsTuple());
std::vector<HloInstruction*> static_tuple_elements;
for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
// For each tuple element, if it is static, pass it through. If it is
// dynamic, recursively call this function again.
HloInstruction* gte =
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
inst->shape().tuple_shapes(i), inst, i));
if (gte->shape().is_static()) {
static_tuple_elements.push_back(gte);
} else {
TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
InsertPadToStaticOnInstruction(gte));
static_tuple_elements.push_back(static_gte);
}
}
return comp->AddInstruction(
HloInstruction::CreateTuple(static_tuple_elements));
}
Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
std::vector<HloInstruction*> params;
HloComputation* entry = module->entry_computation();
for (int64 i = 0; i < entry->num_parameters(); ++i) {
HloInstruction* param =
module->entry_computation()->parameter_instruction(i);
auto users = param->users();
TF_ASSIGN_OR_RETURN(HloInstruction * static_param,
InsertPadToStaticOnInstruction(param));
for (auto* user : users) {
TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, static_param));
}
if (param == entry->root_instruction()) {
module->entry_computation()->set_root_instruction(static_param);
}
}
return Status::OK();
}
// For all dynamic outputs that live out of the computation, add
// slice-to-dynamic operations.
Status InsertSliceToDynamicBeforeModuleOutputs(
const DynamicDimensionInference& dynamic_dimension_inference,
HloModule* module) {
auto root = module->entry_computation()->root_instruction();
@ -656,7 +726,7 @@ Status InsertUnpadsForModuleOutputs(
if (dynamic_dimension_inference.GetDynamicSize(root, index, dim) !=
nullptr) {
CHECK_LE(index.size(), 1) << "XLA doesn't support nested output "
"dimensions that has dynamic size";
"dimension that has dynamic size";
has_dynamic_output = true;
}
}
@ -674,30 +744,36 @@ Status InsertUnpadsForModuleOutputs(
if (!subshape.IsArray()) {
return;
}
auto gte = module->entry_computation()->AddInstruction(
HloInstruction::CreateGetTupleElement(subshape, root, index[0]));
HloInstruction::CreateGetTupleElement(
ShapeUtil::MakeShapeWithStaticDimensions(subshape), root,
index[0]));
if (dynamic_outputs.contains(index)) {
CHECK_EQ(index.size(), 1)
<< "XLA only support 1 layer nested output tuple";
// For dynamic outputs, creates an unpad operation.
std::vector<HloInstruction*> unpad_operands;
// For dynamic outputs, creates an slice operation.
std::vector<HloInstruction*> slice_operands;
// First operand is the original input. Rest are dimension values.
unpad_operands.push_back(gte);
slice_operands.push_back(gte);
// Keep a dynamic version of the subshape as we are removing the
// dynamic dimension in the original root and gte.
Shape dynamic_subshape = subshape;
for (int64 dim = 0; dim < subshape.rank(); ++dim) {
HloInstruction* dynamic_size =
dynamic_dimension_inference.GetDynamicSize(root, index, dim);
if (dynamic_size != nullptr) {
unpad_operands.push_back(dynamic_size);
slice_operands.push_back(dynamic_size);
} else {
auto const_size = HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(subshape.dimensions(dim)));
unpad_operands.push_back(
slice_operands.push_back(
module->entry_computation()->AddInstruction(
std::move(const_size)));
}
}
// This is a dynamic output, add unpad operation.
// This is a dynamic output, add slice operation.
//
// Write the backend config in the format of
// 'dynamic_index'-'output_index'.
@ -707,11 +783,11 @@ Status InsertUnpadsForModuleOutputs(
//
// output_index indicates the position of this output in all outputs
// (including static inputs).
auto unpad = HloInstruction::CreateCustomCall(
subshape, unpad_operands, "Unpad",
auto slice = HloInstruction::CreateCustomCall(
dynamic_subshape, slice_operands, "SliceToDynamic",
absl::StrFormat("%d-%d", dynamic_index++, index[0]));
new_root_operands.push_back(
module->entry_computation()->AddInstruction(std::move(unpad)));
module->entry_computation()->AddInstruction(std::move(slice)));
} else {
new_root_operands.push_back(gte);
}
@ -721,37 +797,125 @@ Status InsertUnpadsForModuleOutputs(
HloInstruction::CreateTuple(new_root_operands));
module->entry_computation()->set_root_instruction(new_root);
} else {
std::vector<HloInstruction*> unpad_operands;
std::vector<HloInstruction*> slice_operands;
// First operand is the original input. Rest are dimension values.
unpad_operands.push_back(root);
slice_operands.push_back(root);
for (int64 dim = 0; dim < root->shape().rank(); ++dim) {
HloInstruction* dynamic_size =
dynamic_dimension_inference.GetDynamicSize(root, {}, dim);
if (dynamic_size != nullptr) {
unpad_operands.push_back(dynamic_size);
slice_operands.push_back(dynamic_size);
} else {
auto const_size = HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int32>(root->shape().dimensions(dim)));
unpad_operands.push_back(module->entry_computation()->AddInstruction(
slice_operands.push_back(module->entry_computation()->AddInstruction(
std::move(const_size)));
}
// This is a dynamic output, add unpad operation.
auto unpad = module->entry_computation()->AddInstruction(
HloInstruction::CreateCustomCall(root->shape(), unpad_operands,
"Unpad", "0-0"));
module->entry_computation()->set_root_instruction(unpad);
// This is a dynamic output, add slice operation.
auto slice = module->entry_computation()->AddInstruction(
HloInstruction::CreateCustomCall(root->shape(), slice_operands,
"SliceToDynamic", "0-0"));
module->entry_computation()->set_root_instruction(slice);
}
}
}
return Status::OK();
}
// Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
//
// After this visitor the entry computation then looks like:
// Param(dynamic)
// |
// GTE (dynamic)
// |
// PadToStatic(static)
// |
// .... regular computation with static shapes.
// |
// SliceToDynamic(dynamic)
// |
// ROOT tuple (dynamic)
class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
public:
Status DefaultAction(HloInstruction* hlo) override;
Status HandleCustomCall(HloInstruction* hlo) override;
Status HandleParameter(HloInstruction* hlo) override;
static Status Run(HloComputation* computation) {
DynamicShapeRemovingVisitor visitor;
return computation->Accept(&visitor);
}
};
Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
// Default rule: If input to an op is static, remove dynamism in output.
bool input_is_dynamic = false;
// Default rule:
for (int64 i = 0; i < hlo->operand_count(); ++i) {
if (!hlo->operand(i)->shape().is_static()) {
input_is_dynamic = true;
}
}
if (!input_is_dynamic) {
hlo->mutable_shape()->clear_dynamic_dimensions();
}
return Status::OK();
}
Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
if (hlo->custom_call_target() == "SliceToDynamic") {
// Don't remove slice-to-dynamic instruction.
return Status::OK();
}
return DefaultAction(hlo);
}
Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
return Status::OK();
}
} // namespace
StatusOr<bool> DynamicPadder::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "Pre DynamicPadder HLO:";
XLA_VLOG_LINES(2, module->ToString());
// Removes dynamic dimensions on parameters if there is already a binding for
// it. We do this because we have two different APIs to express a dynamic
// dimension:
//
// 1. Dynamic dimension as specificed directly in the shape -- Needed for
// Pytorch.
//
// 2. Dynamic dimension using dynamic parameter binding object. This
// is needed for tensorflow.
//
// For case 1, we will insert "pad-to-static" instruction in the
// beginning of xla execution, to make it into a static layout.
//
// For case 2, since it already has a static layout, we remove the
// dynamic dimension.
//
// TODO(b/145140571): Convert all API invocations to case 1.
//
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding(
[&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
-> Status {
HloInstruction* parameter =
module->entry_computation()->parameter_instruction(
dynamic_dimension.parameter_num);
ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(),
dynamic_dimension.parameter_index,
dynamic_dimension.dimension, false);
return Status::OK();
}));
TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
DynamicDimensionInference::Run(module));
@ -806,8 +970,28 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
}
}
TF_RETURN_IF_ERROR(
InsertUnpadsForModuleOutputs(dynamic_dimension_inference, module));
TF_RETURN_IF_ERROR(InsertSliceToDynamicBeforeModuleOutputs(
dynamic_dimension_inference, module));
// Remove all dynamic dimensions after entry parameter and root instruction --
// Dynamic padder will produce an equivalent static shaped graph.
for (HloComputation* computation : module->computations()) {
if (computation == module->entry_computation()) {
TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(computation));
} else {
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
bool operand_is_dynamic = false;
for (auto* operand : inst->operands()) {
if (!operand->shape().is_static()) {
operand_is_dynamic = true;
}
}
if (!operand_is_dynamic) {
inst->mutable_shape()->clear_dynamic_dimensions();
}
}
}
}
HloDCE dce;
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));

View File

@ -60,7 +60,6 @@ class GenericTransferManager : public TransferManager {
int64 GetByteSizeRequirement(const Shape& shape) const override;
protected:
Status WriteSingleTupleIndexTable(
se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) override;

View File

@ -942,8 +942,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
const Shape& instruction_subshape =
ShapeUtil::GetSubshape(instruction->shape(), index);
for (const LogicalBuffer* buffer : buffers) {
if (!Shape::Equal().MinorToMajorOnlyInLayout()(
instruction_subshape, buffer->shape())) {
if (!Shape::Equal()
.IgnoreDynamicDimension()
.MinorToMajorOnlyInLayout()(instruction_subshape,
buffer->shape())) {
return InternalError(
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
@ -1005,8 +1007,9 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, module->entry_computation())
.result_layout();
if (result_layout.LayoutIsSet()) {
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
module->result_shape(), result_layout.shape()));
TF_RET_CHECK(
Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
module->result_shape(), result_layout.shape()));
}
return Status::OK();
}
@ -1993,9 +1996,10 @@ Status LayoutAssignment::PropagateComputationLayouts(
<< ": " << computed_computation_layout.result_layout().ToString();
*result_layout = computed_computation_layout.result_layout();
} else {
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
computed_computation_layout.result_layout().shape(),
result_layout->shape()));
TF_RET_CHECK(
Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
computed_computation_layout.result_layout().shape(),
result_layout->shape()));
}
return Status::OK();
}

View File

@ -270,6 +270,13 @@ class TransferManager {
static StatusOr<TransferManager*> GetForPlatform(
const se::Platform* platform);
// Writes the given device-memory pointers in 'elements' to the given region
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
protected:
// Transfer a memory block of the given size from the device source into the
// 'destination' buffer.
@ -287,13 +294,6 @@ class TransferManager {
const void* source,
se::DeviceMemoryBase* destination);
// Writes the given device-memory pointers in 'elements' to the given region
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
private:
// The mutex that guards the platform-to-transfer manager map.
static tensorflow::mutex platform_transfer_manager_mutex_;

View File

@ -48,7 +48,7 @@ void ShapeLayout::SetToDefaultLayout() {
bool ShapeLayout::MatchesLayoutInShape(const Shape& shape,
bool minor_to_major_only) const {
auto equal = Shape::Equal();
auto equal = Shape::Equal().IgnoreDynamicDimension();
if (minor_to_major_only) {
equal.MinorToMajorOnlyInLayout();
}
@ -81,11 +81,11 @@ void ShapeLayout::ResetLayout(const Layout& layout,
}
bool ShapeLayout::operator==(const ShapeLayout& other) const {
return ShapeUtil::Equal(shape_, other.shape_);
return Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_);
}
bool ShapeLayout::operator!=(const ShapeLayout& other) const {
return !ShapeUtil::Equal(shape_, other.shape_);
return !Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_);
}
} // namespace xla

View File

@ -388,6 +388,9 @@ class ShapeUtil {
static Shape MakeShape(PrimitiveType element_type,
absl::Span<const int64> dimensions);
// Make a scalar shape with given primitive type.
static Shape MakeScalarShape(PrimitiveType element_type);
// Constructs a new shape with the given element type and sequence of
// potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
// with a true value that the respective dimension is dynamic. If the
@ -398,9 +401,6 @@ class ShapeUtil {
absl::Span<const int64> dimensions,
const std::vector<bool>& dynamic_dimensions);
// Make a scalar shape with given primitive type.
static Shape MakeScalarShape(PrimitiveType element_type);
// Constructs a new shape with the given element type and sequence of
// dimensions. Method checks if the element type is valid and the shape's
// size fits in std::numeric_limits<int64>::max().
@ -430,7 +430,6 @@ class ShapeUtil {
static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
absl::Span<const int64> dimensions,
int64 max_sparse_elements);
// Returns the same shape except with all dimensions set to be static.
static Shape MakeShapeWithStaticDimensions(const Shape& shape);

View File

@ -648,7 +648,8 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
XRTTupleAllocation::ToDeviceMemoryTree(
const std::function<bool(const xla::ShapeIndex&)>& release_checker) {
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
release_checker) {
xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
for (const auto& index_buffer : buffers_) {
if (index_buffer.second == nullptr ||
@ -657,7 +658,9 @@ XRTTupleAllocation::ToDeviceMemoryTree(
index_buffer.first.ToString(),
" has been released");
}
if (!release_checker(index_buffer.first)) {
TF_ASSIGN_OR_RETURN(bool should_release,
release_checker(index_buffer.first));
if (!should_release) {
*shaped_tree.mutable_element(index_buffer.first) =
index_buffer.second->allocation();
} else {

View File

@ -229,7 +229,8 @@ class XRTTupleAllocation : public core::RefCounted {
// ScopedShapedBuffer, which wants ownership and does not allow sharing.
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
ToDeviceMemoryTree(
const std::function<bool(const xla::ShapeIndex&)>& release_checker);
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
release_checker);
private:
// Creates a new handle with (tuple) shape.