[XLA] Implement dynamic input and output in DynamicPadder.
PiperOrigin-RevId: 285218938 Change-Id: Ia5f5e5ad62154b1427d2b56640eed3a443a50c1d
This commit is contained in:
parent
267488da06
commit
2f28e151b3
@ -237,12 +237,12 @@ class XlaBuilder {
|
||||
// TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
|
||||
// dynamic dimensions information when XLA backend can handle dynamic
|
||||
// dimensions.
|
||||
StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = true);
|
||||
StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
|
||||
|
||||
// Overload of Build which specifies a particular root instruction for the
|
||||
// computation.
|
||||
StatusOr<XlaComputation> Build(XlaOp root,
|
||||
bool remove_dynamic_dimensions = true);
|
||||
bool remove_dynamic_dimensions = false);
|
||||
|
||||
// Builds the computation with the requested operations, or notes an error in
|
||||
// the parent XlaBuilder and returns an empty computation if building failed.
|
||||
|
@ -298,10 +298,12 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
const Shape& expected_shape =
|
||||
entry_comp->parameter_instruction(i)->shape();
|
||||
const Shape& actual_shape = arguments[i].shape();
|
||||
CHECK(expected_shape == actual_shape) << absl::StreamFormat(
|
||||
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
|
||||
expected_shape.ToString(/*print_layout=*/true),
|
||||
actual_shape.ToString(/*print_layout=*/true));
|
||||
CHECK(
|
||||
Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape))
|
||||
<< absl::StreamFormat(
|
||||
"Shape mismatch on argument %d. Expected %s, but was %s.", i,
|
||||
expected_shape.ToString(/*print_layout=*/true),
|
||||
actual_shape.ToString(/*print_layout=*/true));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,15 +178,32 @@ Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
if (hlo->custom_call_target() == "PadToStatic") {
|
||||
for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
|
||||
if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
|
||||
HloInstruction* dynamic_size =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
|
||||
// PadToStatic converts a dynamic dimension to static dimension. It then
|
||||
// returns the padded data output and the dynamic sizes of input
|
||||
// dimensions.
|
||||
ShapeIndex data_output = {0};
|
||||
parent_->SetDynamicSize(hlo, data_output, i, dynamic_size,
|
||||
{.stride = 1, .multiple_of = 1});
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
if (hlo->custom_call_target() != "Unpad" ||
|
||||
if (hlo->custom_call_target() != "SliceToDynamic" ||
|
||||
absl::StartsWith(hlo->custom_call_target(), "Resize")) {
|
||||
return Unimplemented(
|
||||
"CustomCall is not supported to have a dynamic dimension");
|
||||
}
|
||||
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
return Status::OK();
|
||||
});
|
||||
|
@ -23,7 +23,9 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -169,7 +171,7 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim,
|
||||
return padded;
|
||||
}
|
||||
|
||||
// In a reshape if a dynamci dimension is splitted into multiple output
|
||||
// In a reshape if a dynamic dimension is splitted into multiple output
|
||||
// dimensions, we need to rewrite the input of the reshape.
|
||||
//
|
||||
// The reason for this is that a continuous input may not be evenly reshaped
|
||||
@ -641,9 +643,77 @@ StatusOr<bool> RewriteDynamicReshape(
|
||||
return changed;
|
||||
}
|
||||
|
||||
// For all dynamic outputs that live out of the computation, add unpad
|
||||
// operations.
|
||||
Status InsertUnpadsForModuleOutputs(
|
||||
// Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
|
||||
// Recurse into tuple instructions.
|
||||
StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
|
||||
if (inst->shape().is_static()) {
|
||||
return inst;
|
||||
}
|
||||
HloComputation* comp = inst->parent();
|
||||
if (!inst->shape().IsTuple()) {
|
||||
// The output shape of pad static is a tuple. The 0th element is the data
|
||||
// output, which is the same as input shape, but without dynamic dimensions;
|
||||
// i-th element is the dynamic dimension size for i-1th input dimension.
|
||||
Shape data_output_shape = inst->shape(); // 0th element.
|
||||
data_output_shape.clear_dynamic_dimensions();
|
||||
Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
|
||||
for (int64 i = 0; i < inst->shape().rank(); ++i) {
|
||||
ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
|
||||
&output_shape);
|
||||
}
|
||||
HloInstruction* pad_to_static =
|
||||
comp->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
output_shape, {inst}, "PadToStatic", ""));
|
||||
HloInstruction* data_output =
|
||||
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
data_output_shape, pad_to_static, 0));
|
||||
return data_output;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(inst->shape().IsTuple());
|
||||
std::vector<HloInstruction*> static_tuple_elements;
|
||||
for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
|
||||
// For each tuple element, if it is static, pass it through. If it is
|
||||
// dynamic, recursively call this function again.
|
||||
HloInstruction* gte =
|
||||
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
inst->shape().tuple_shapes(i), inst, i));
|
||||
|
||||
if (gte->shape().is_static()) {
|
||||
static_tuple_elements.push_back(gte);
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
|
||||
InsertPadToStaticOnInstruction(gte));
|
||||
static_tuple_elements.push_back(static_gte);
|
||||
}
|
||||
}
|
||||
|
||||
return comp->AddInstruction(
|
||||
HloInstruction::CreateTuple(static_tuple_elements));
|
||||
}
|
||||
|
||||
Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
|
||||
std::vector<HloInstruction*> params;
|
||||
HloComputation* entry = module->entry_computation();
|
||||
for (int64 i = 0; i < entry->num_parameters(); ++i) {
|
||||
HloInstruction* param =
|
||||
module->entry_computation()->parameter_instruction(i);
|
||||
auto users = param->users();
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * static_param,
|
||||
InsertPadToStaticOnInstruction(param));
|
||||
for (auto* user : users) {
|
||||
TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, static_param));
|
||||
}
|
||||
if (param == entry->root_instruction()) {
|
||||
module->entry_computation()->set_root_instruction(static_param);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// For all dynamic outputs that live out of the computation, add
|
||||
// slice-to-dynamic operations.
|
||||
Status InsertSliceToDynamicBeforeModuleOutputs(
|
||||
const DynamicDimensionInference& dynamic_dimension_inference,
|
||||
HloModule* module) {
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
@ -656,7 +726,7 @@ Status InsertUnpadsForModuleOutputs(
|
||||
if (dynamic_dimension_inference.GetDynamicSize(root, index, dim) !=
|
||||
nullptr) {
|
||||
CHECK_LE(index.size(), 1) << "XLA doesn't support nested output "
|
||||
"dimensions that has dynamic size";
|
||||
"dimension that has dynamic size";
|
||||
has_dynamic_output = true;
|
||||
}
|
||||
}
|
||||
@ -674,30 +744,36 @@ Status InsertUnpadsForModuleOutputs(
|
||||
if (!subshape.IsArray()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto gte = module->entry_computation()->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(subshape, root, index[0]));
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::MakeShapeWithStaticDimensions(subshape), root,
|
||||
index[0]));
|
||||
|
||||
if (dynamic_outputs.contains(index)) {
|
||||
CHECK_EQ(index.size(), 1)
|
||||
<< "XLA only support 1 layer nested output tuple";
|
||||
// For dynamic outputs, creates an unpad operation.
|
||||
std::vector<HloInstruction*> unpad_operands;
|
||||
// For dynamic outputs, creates an slice operation.
|
||||
std::vector<HloInstruction*> slice_operands;
|
||||
// First operand is the original input. Rest are dimension values.
|
||||
unpad_operands.push_back(gte);
|
||||
slice_operands.push_back(gte);
|
||||
// Keep a dynamic version of the subshape as we are removing the
|
||||
// dynamic dimension in the original root and gte.
|
||||
Shape dynamic_subshape = subshape;
|
||||
for (int64 dim = 0; dim < subshape.rank(); ++dim) {
|
||||
HloInstruction* dynamic_size =
|
||||
dynamic_dimension_inference.GetDynamicSize(root, index, dim);
|
||||
if (dynamic_size != nullptr) {
|
||||
unpad_operands.push_back(dynamic_size);
|
||||
slice_operands.push_back(dynamic_size);
|
||||
} else {
|
||||
auto const_size = HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(subshape.dimensions(dim)));
|
||||
unpad_operands.push_back(
|
||||
slice_operands.push_back(
|
||||
module->entry_computation()->AddInstruction(
|
||||
std::move(const_size)));
|
||||
}
|
||||
}
|
||||
// This is a dynamic output, add unpad operation.
|
||||
// This is a dynamic output, add slice operation.
|
||||
//
|
||||
// Write the backend config in the format of
|
||||
// 'dynamic_index'-'output_index'.
|
||||
@ -707,11 +783,11 @@ Status InsertUnpadsForModuleOutputs(
|
||||
//
|
||||
// output_index indicates the position of this output in all outputs
|
||||
// (including static inputs).
|
||||
auto unpad = HloInstruction::CreateCustomCall(
|
||||
subshape, unpad_operands, "Unpad",
|
||||
auto slice = HloInstruction::CreateCustomCall(
|
||||
dynamic_subshape, slice_operands, "SliceToDynamic",
|
||||
absl::StrFormat("%d-%d", dynamic_index++, index[0]));
|
||||
new_root_operands.push_back(
|
||||
module->entry_computation()->AddInstruction(std::move(unpad)));
|
||||
module->entry_computation()->AddInstruction(std::move(slice)));
|
||||
} else {
|
||||
new_root_operands.push_back(gte);
|
||||
}
|
||||
@ -721,37 +797,125 @@ Status InsertUnpadsForModuleOutputs(
|
||||
HloInstruction::CreateTuple(new_root_operands));
|
||||
module->entry_computation()->set_root_instruction(new_root);
|
||||
} else {
|
||||
std::vector<HloInstruction*> unpad_operands;
|
||||
std::vector<HloInstruction*> slice_operands;
|
||||
// First operand is the original input. Rest are dimension values.
|
||||
unpad_operands.push_back(root);
|
||||
slice_operands.push_back(root);
|
||||
for (int64 dim = 0; dim < root->shape().rank(); ++dim) {
|
||||
HloInstruction* dynamic_size =
|
||||
dynamic_dimension_inference.GetDynamicSize(root, {}, dim);
|
||||
if (dynamic_size != nullptr) {
|
||||
unpad_operands.push_back(dynamic_size);
|
||||
slice_operands.push_back(dynamic_size);
|
||||
} else {
|
||||
auto const_size = HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(root->shape().dimensions(dim)));
|
||||
unpad_operands.push_back(module->entry_computation()->AddInstruction(
|
||||
slice_operands.push_back(module->entry_computation()->AddInstruction(
|
||||
std::move(const_size)));
|
||||
}
|
||||
// This is a dynamic output, add unpad operation.
|
||||
auto unpad = module->entry_computation()->AddInstruction(
|
||||
HloInstruction::CreateCustomCall(root->shape(), unpad_operands,
|
||||
"Unpad", "0-0"));
|
||||
module->entry_computation()->set_root_instruction(unpad);
|
||||
// This is a dynamic output, add slice operation.
|
||||
auto slice = module->entry_computation()->AddInstruction(
|
||||
HloInstruction::CreateCustomCall(root->shape(), slice_operands,
|
||||
"SliceToDynamic", "0-0"));
|
||||
module->entry_computation()->set_root_instruction(slice);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
|
||||
//
|
||||
// After this visitor the entry computation then looks like:
|
||||
// Param(dynamic)
|
||||
// |
|
||||
// GTE (dynamic)
|
||||
// |
|
||||
// PadToStatic(static)
|
||||
// |
|
||||
// .... regular computation with static shapes.
|
||||
// |
|
||||
// SliceToDynamic(dynamic)
|
||||
// |
|
||||
// ROOT tuple (dynamic)
|
||||
class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleCustomCall(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleParameter(HloInstruction* hlo) override;
|
||||
|
||||
static Status Run(HloComputation* computation) {
|
||||
DynamicShapeRemovingVisitor visitor;
|
||||
return computation->Accept(&visitor);
|
||||
}
|
||||
};
|
||||
|
||||
Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
// Default rule: If input to an op is static, remove dynamism in output.
|
||||
bool input_is_dynamic = false;
|
||||
// Default rule:
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (!hlo->operand(i)->shape().is_static()) {
|
||||
input_is_dynamic = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!input_is_dynamic) {
|
||||
hlo->mutable_shape()->clear_dynamic_dimensions();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
if (hlo->custom_call_target() == "SliceToDynamic") {
|
||||
// Don't remove slice-to-dynamic instruction.
|
||||
return Status::OK();
|
||||
}
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
VLOG(2) << "Pre DynamicPadder HLO:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
// Removes dynamic dimensions on parameters if there is already a binding for
|
||||
// it. We do this because we have two different APIs to express a dynamic
|
||||
// dimension:
|
||||
//
|
||||
// 1. Dynamic dimension as specificed directly in the shape -- Needed for
|
||||
// Pytorch.
|
||||
//
|
||||
// 2. Dynamic dimension using dynamic parameter binding object. This
|
||||
// is needed for tensorflow.
|
||||
//
|
||||
// For case 1, we will insert "pad-to-static" instruction in the
|
||||
// beginning of xla execution, to make it into a static layout.
|
||||
//
|
||||
// For case 2, since it already has a static layout, we remove the
|
||||
// dynamic dimension.
|
||||
//
|
||||
// TODO(b/145140571): Convert all API invocations to case 1.
|
||||
//
|
||||
TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding(
|
||||
[&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
|
||||
const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
|
||||
-> Status {
|
||||
HloInstruction* parameter =
|
||||
module->entry_computation()->parameter_instruction(
|
||||
dynamic_dimension.parameter_num);
|
||||
ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(),
|
||||
dynamic_dimension.parameter_index,
|
||||
dynamic_dimension.dimension, false);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
|
||||
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
|
||||
DynamicDimensionInference::Run(module));
|
||||
|
||||
@ -806,8 +970,28 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertUnpadsForModuleOutputs(dynamic_dimension_inference, module));
|
||||
TF_RETURN_IF_ERROR(InsertSliceToDynamicBeforeModuleOutputs(
|
||||
dynamic_dimension_inference, module));
|
||||
|
||||
// Remove all dynamic dimensions after entry parameter and root instruction --
|
||||
// Dynamic padder will produce an equivalent static shaped graph.
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
if (computation == module->entry_computation()) {
|
||||
TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(computation));
|
||||
} else {
|
||||
for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
|
||||
bool operand_is_dynamic = false;
|
||||
for (auto* operand : inst->operands()) {
|
||||
if (!operand->shape().is_static()) {
|
||||
operand_is_dynamic = true;
|
||||
}
|
||||
}
|
||||
if (!operand_is_dynamic) {
|
||||
inst->mutable_shape()->clear_dynamic_dimensions();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HloDCE dce;
|
||||
TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
|
||||
|
@ -60,7 +60,6 @@ class GenericTransferManager : public TransferManager {
|
||||
|
||||
int64 GetByteSizeRequirement(const Shape& shape) const override;
|
||||
|
||||
protected:
|
||||
Status WriteSingleTupleIndexTable(
|
||||
se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
|
||||
const Shape& shape, se::DeviceMemoryBase* region) override;
|
||||
|
@ -942,8 +942,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
|
||||
const Shape& instruction_subshape =
|
||||
ShapeUtil::GetSubshape(instruction->shape(), index);
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
if (!Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
instruction_subshape, buffer->shape())) {
|
||||
if (!Shape::Equal()
|
||||
.IgnoreDynamicDimension()
|
||||
.MinorToMajorOnlyInLayout()(instruction_subshape,
|
||||
buffer->shape())) {
|
||||
return InternalError(
|
||||
"Layout of instruction %s at index {%s} does not match "
|
||||
"source LogicalBuffer %s: %s vs %s",
|
||||
@ -1005,8 +1007,9 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
|
||||
FindOrDie(computation_layouts_, module->entry_computation())
|
||||
.result_layout();
|
||||
if (result_layout.LayoutIsSet()) {
|
||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
module->result_shape(), result_layout.shape()));
|
||||
TF_RET_CHECK(
|
||||
Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
|
||||
module->result_shape(), result_layout.shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1993,9 +1996,10 @@ Status LayoutAssignment::PropagateComputationLayouts(
|
||||
<< ": " << computed_computation_layout.result_layout().ToString();
|
||||
*result_layout = computed_computation_layout.result_layout();
|
||||
} else {
|
||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
computed_computation_layout.result_layout().shape(),
|
||||
result_layout->shape()));
|
||||
TF_RET_CHECK(
|
||||
Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
|
||||
computed_computation_layout.result_layout().shape(),
|
||||
result_layout->shape()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -270,6 +270,13 @@ class TransferManager {
|
||||
static StatusOr<TransferManager*> GetForPlatform(
|
||||
const se::Platform* platform);
|
||||
|
||||
// Writes the given device-memory pointers in 'elements' to the given region
|
||||
// to construct a tuple index table in the platform-specific tuple
|
||||
// representation.
|
||||
virtual Status WriteSingleTupleIndexTable(
|
||||
se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
|
||||
const Shape& shape, se::DeviceMemoryBase* region) = 0;
|
||||
|
||||
protected:
|
||||
// Transfer a memory block of the given size from the device source into the
|
||||
// 'destination' buffer.
|
||||
@ -287,13 +294,6 @@ class TransferManager {
|
||||
const void* source,
|
||||
se::DeviceMemoryBase* destination);
|
||||
|
||||
// Writes the given device-memory pointers in 'elements' to the given region
|
||||
// to construct a tuple index table in the platform-specific tuple
|
||||
// representation.
|
||||
virtual Status WriteSingleTupleIndexTable(
|
||||
se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
|
||||
const Shape& shape, se::DeviceMemoryBase* region) = 0;
|
||||
|
||||
private:
|
||||
// The mutex that guards the platform-to-transfer manager map.
|
||||
static tensorflow::mutex platform_transfer_manager_mutex_;
|
||||
|
@ -48,7 +48,7 @@ void ShapeLayout::SetToDefaultLayout() {
|
||||
|
||||
bool ShapeLayout::MatchesLayoutInShape(const Shape& shape,
|
||||
bool minor_to_major_only) const {
|
||||
auto equal = Shape::Equal();
|
||||
auto equal = Shape::Equal().IgnoreDynamicDimension();
|
||||
if (minor_to_major_only) {
|
||||
equal.MinorToMajorOnlyInLayout();
|
||||
}
|
||||
@ -81,11 +81,11 @@ void ShapeLayout::ResetLayout(const Layout& layout,
|
||||
}
|
||||
|
||||
bool ShapeLayout::operator==(const ShapeLayout& other) const {
|
||||
return ShapeUtil::Equal(shape_, other.shape_);
|
||||
return Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_);
|
||||
}
|
||||
|
||||
bool ShapeLayout::operator!=(const ShapeLayout& other) const {
|
||||
return !ShapeUtil::Equal(shape_, other.shape_);
|
||||
return !Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -388,6 +388,9 @@ class ShapeUtil {
|
||||
static Shape MakeShape(PrimitiveType element_type,
|
||||
absl::Span<const int64> dimensions);
|
||||
|
||||
// Make a scalar shape with given primitive type.
|
||||
static Shape MakeScalarShape(PrimitiveType element_type);
|
||||
|
||||
// Constructs a new shape with the given element type and sequence of
|
||||
// potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
|
||||
// with a true value that the respective dimension is dynamic. If the
|
||||
@ -398,9 +401,6 @@ class ShapeUtil {
|
||||
absl::Span<const int64> dimensions,
|
||||
const std::vector<bool>& dynamic_dimensions);
|
||||
|
||||
// Make a scalar shape with given primitive type.
|
||||
static Shape MakeScalarShape(PrimitiveType element_type);
|
||||
|
||||
// Constructs a new shape with the given element type and sequence of
|
||||
// dimensions. Method checks if the element type is valid and the shape's
|
||||
// size fits in std::numeric_limits<int64>::max().
|
||||
@ -430,7 +430,6 @@ class ShapeUtil {
|
||||
static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
|
||||
absl::Span<const int64> dimensions,
|
||||
int64 max_sparse_elements);
|
||||
|
||||
// Returns the same shape except with all dimensions set to be static.
|
||||
static Shape MakeShapeWithStaticDimensions(const Shape& shape);
|
||||
|
||||
|
@ -648,7 +648,8 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
|
||||
|
||||
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
|
||||
XRTTupleAllocation::ToDeviceMemoryTree(
|
||||
const std::function<bool(const xla::ShapeIndex&)>& release_checker) {
|
||||
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
|
||||
release_checker) {
|
||||
xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
|
||||
for (const auto& index_buffer : buffers_) {
|
||||
if (index_buffer.second == nullptr ||
|
||||
@ -657,7 +658,9 @@ XRTTupleAllocation::ToDeviceMemoryTree(
|
||||
index_buffer.first.ToString(),
|
||||
" has been released");
|
||||
}
|
||||
if (!release_checker(index_buffer.first)) {
|
||||
TF_ASSIGN_OR_RETURN(bool should_release,
|
||||
release_checker(index_buffer.first));
|
||||
if (!should_release) {
|
||||
*shaped_tree.mutable_element(index_buffer.first) =
|
||||
index_buffer.second->allocation();
|
||||
} else {
|
||||
|
@ -229,7 +229,8 @@ class XRTTupleAllocation : public core::RefCounted {
|
||||
// ScopedShapedBuffer, which wants ownership and does not allow sharing.
|
||||
xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
|
||||
ToDeviceMemoryTree(
|
||||
const std::function<bool(const xla::ShapeIndex&)>& release_checker);
|
||||
const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
|
||||
release_checker);
|
||||
|
||||
private:
|
||||
// Creates a new handle with (tuple) shape.
|
||||
|
Loading…
Reference in New Issue
Block a user