diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 5e93bb2b3ba..42a8ae5b996 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -237,12 +237,12 @@ class XlaBuilder { // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the // dynamic dimensions information when XLA backend can handle dynamic // dimensions. - StatusOr Build(bool remove_dynamic_dimensions = true); + StatusOr Build(bool remove_dynamic_dimensions = false); // Overload of Build which specifies a particular root instruction for the // computation. StatusOr Build(XlaOp root, - bool remove_dynamic_dimensions = true); + bool remove_dynamic_dimensions = false); // Builds the computation with the requested operations, or notes an error in // the parent XlaBuilder and returns an empty computation if building failed. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index d19cf4fb015..366fdca442f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -298,10 +298,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( const Shape& expected_shape = entry_comp->parameter_instruction(i)->shape(); const Shape& actual_shape = arguments[i].shape(); - CHECK(expected_shape == actual_shape) << absl::StreamFormat( - "Shape mismatch on argument %d. Expected %s, but was %s.", i, - expected_shape.ToString(/*print_layout=*/true), - actual_shape.ToString(/*print_layout=*/true)); + CHECK( + Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape)) + << absl::StreamFormat( + "Shape mismatch on argument %d. Expected %s, but was %s.", i, + expected_shape.ToString(/*print_layout=*/true), + actual_shape.ToString(/*print_layout=*/true)); } } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 14ea6f988cb..84f93106474 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -178,15 +178,32 @@ Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { } Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "PadToStatic") { + for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) { + if (hlo->operand(0)->shape().is_dynamic_dimension(i)) { + HloInstruction* dynamic_size = + hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeScalarShape(S32), hlo, i + 1)); + // PadToStatic converts a dynamic dimension to static dimension. It then + // returns the padded data output and the dynamic sizes of input + // dimensions. + ShapeIndex data_output = {0}; + parent_->SetDynamicSize(hlo, data_output, i, dynamic_size, + {.stride = 1, .multiple_of = 1}); + } + } + return Status::OK(); + } return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, int64 operand_index, HloInstruction* dynamic_size, DimensionConstraint constraint) { - if (hlo->custom_call_target() != "Unpad" || + if (hlo->custom_call_target() != "SliceToDynamic" || absl::StartsWith(hlo->custom_call_target(), "Resize")) { return Unimplemented( "CustomCall is not supported to have a dynamic dimension"); } + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); return Status::OK(); }); diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index c94a2594f3b..7de4c9f01a4 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -23,7 +23,9 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -169,7 +171,7 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim, return padded; } -// In a reshape if a dynamci dimension is splitted into multiple output +// In a reshape if a dynamic dimension is splitted into multiple output // dimensions, we need to rewrite the input of the reshape. // // The reason for this is that a continuous input may not be evenly reshaped @@ -641,9 +643,77 @@ StatusOr RewriteDynamicReshape( return changed; } -// For all dynamic outputs that live out of the computation, add unpad -// operations. -Status InsertUnpadsForModuleOutputs( +// Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it. +// Recurse into tuple instructions. +StatusOr InsertPadToStaticOnInstruction(HloInstruction* inst) { + if (inst->shape().is_static()) { + return inst; + } + HloComputation* comp = inst->parent(); + if (!inst->shape().IsTuple()) { + // The output shape of pad static is a tuple. The 0th element is the data + // output, which is the same as input shape, but without dynamic dimensions; + // i-th element is the dynamic dimension size for i-1th input dimension. + Shape data_output_shape = inst->shape(); // 0th element. + data_output_shape.clear_dynamic_dimensions(); + Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape}); + for (int64 i = 0; i < inst->shape().rank(); ++i) { + ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32), + &output_shape); + } + HloInstruction* pad_to_static = + comp->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {inst}, "PadToStatic", "")); + HloInstruction* data_output = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + data_output_shape, pad_to_static, 0)); + return data_output; + } + + TF_RET_CHECK(inst->shape().IsTuple()); + std::vector static_tuple_elements; + for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { + // For each tuple element, if it is static, pass it through. If it is + // dynamic, recursively call this function again. + HloInstruction* gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + inst->shape().tuple_shapes(i), inst, i)); + + if (gte->shape().is_static()) { + static_tuple_elements.push_back(gte); + } else { + TF_ASSIGN_OR_RETURN(HloInstruction * static_gte, + InsertPadToStaticOnInstruction(gte)); + static_tuple_elements.push_back(static_gte); + } + } + + return comp->AddInstruction( + HloInstruction::CreateTuple(static_tuple_elements)); +} + +Status InsertPadToStaticAfterModuleInputs(HloModule* module) { + std::vector params; + HloComputation* entry = module->entry_computation(); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + HloInstruction* param = + module->entry_computation()->parameter_instruction(i); + auto users = param->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * static_param, + InsertPadToStaticOnInstruction(param)); + for (auto* user : users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, static_param)); + } + if (param == entry->root_instruction()) { + module->entry_computation()->set_root_instruction(static_param); + } + } + return Status::OK(); +} + +// For all dynamic outputs that live out of the computation, add +// slice-to-dynamic operations. +Status InsertSliceToDynamicBeforeModuleOutputs( const DynamicDimensionInference& dynamic_dimension_inference, HloModule* module) { auto root = module->entry_computation()->root_instruction(); @@ -656,7 +726,7 @@ Status InsertUnpadsForModuleOutputs( if (dynamic_dimension_inference.GetDynamicSize(root, index, dim) != nullptr) { CHECK_LE(index.size(), 1) << "XLA doesn't support nested output " - "dimensions that has dynamic size"; + "dimension that has dynamic size"; has_dynamic_output = true; } } @@ -674,30 +744,36 @@ Status InsertUnpadsForModuleOutputs( if (!subshape.IsArray()) { return; } + auto gte = module->entry_computation()->AddInstruction( - HloInstruction::CreateGetTupleElement(subshape, root, index[0])); + HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShapeWithStaticDimensions(subshape), root, + index[0])); if (dynamic_outputs.contains(index)) { CHECK_EQ(index.size(), 1) << "XLA only support 1 layer nested output tuple"; - // For dynamic outputs, creates an unpad operation. - std::vector unpad_operands; + // For dynamic outputs, creates an slice operation. + std::vector slice_operands; // First operand is the original input. Rest are dimension values. - unpad_operands.push_back(gte); + slice_operands.push_back(gte); + // Keep a dynamic version of the subshape as we are removing the + // dynamic dimension in the original root and gte. + Shape dynamic_subshape = subshape; for (int64 dim = 0; dim < subshape.rank(); ++dim) { HloInstruction* dynamic_size = dynamic_dimension_inference.GetDynamicSize(root, index, dim); if (dynamic_size != nullptr) { - unpad_operands.push_back(dynamic_size); + slice_operands.push_back(dynamic_size); } else { auto const_size = HloInstruction::CreateConstant( LiteralUtil::CreateR0(subshape.dimensions(dim))); - unpad_operands.push_back( + slice_operands.push_back( module->entry_computation()->AddInstruction( std::move(const_size))); } } - // This is a dynamic output, add unpad operation. + // This is a dynamic output, add slice operation. // // Write the backend config in the format of // 'dynamic_index'-'output_index'. @@ -707,11 +783,11 @@ Status InsertUnpadsForModuleOutputs( // // output_index indicates the position of this output in all outputs // (including static inputs). - auto unpad = HloInstruction::CreateCustomCall( - subshape, unpad_operands, "Unpad", + auto slice = HloInstruction::CreateCustomCall( + dynamic_subshape, slice_operands, "SliceToDynamic", absl::StrFormat("%d-%d", dynamic_index++, index[0])); new_root_operands.push_back( - module->entry_computation()->AddInstruction(std::move(unpad))); + module->entry_computation()->AddInstruction(std::move(slice))); } else { new_root_operands.push_back(gte); } @@ -721,37 +797,125 @@ Status InsertUnpadsForModuleOutputs( HloInstruction::CreateTuple(new_root_operands)); module->entry_computation()->set_root_instruction(new_root); } else { - std::vector unpad_operands; + std::vector slice_operands; // First operand is the original input. Rest are dimension values. - unpad_operands.push_back(root); + slice_operands.push_back(root); for (int64 dim = 0; dim < root->shape().rank(); ++dim) { HloInstruction* dynamic_size = dynamic_dimension_inference.GetDynamicSize(root, {}, dim); if (dynamic_size != nullptr) { - unpad_operands.push_back(dynamic_size); + slice_operands.push_back(dynamic_size); } else { auto const_size = HloInstruction::CreateConstant( LiteralUtil::CreateR0(root->shape().dimensions(dim))); - unpad_operands.push_back(module->entry_computation()->AddInstruction( + slice_operands.push_back(module->entry_computation()->AddInstruction( std::move(const_size))); } - // This is a dynamic output, add unpad operation. - auto unpad = module->entry_computation()->AddInstruction( - HloInstruction::CreateCustomCall(root->shape(), unpad_operands, - "Unpad", "0-0")); - module->entry_computation()->set_root_instruction(unpad); + // This is a dynamic output, add slice operation. + auto slice = module->entry_computation()->AddInstruction( + HloInstruction::CreateCustomCall(root->shape(), slice_operands, + "SliceToDynamic", "0-0")); + module->entry_computation()->set_root_instruction(slice); } } } return Status::OK(); } +// Remove all dynamic shapes between pad-to-static and slice-to-dynamic. +// +// After this visitor the entry computation then looks like: +// Param(dynamic) +// | +// GTE (dynamic) +// | +// PadToStatic(static) +// | +// .... regular computation with static shapes. +// | +// SliceToDynamic(dynamic) +// | +// ROOT tuple (dynamic) +class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault { + public: + Status DefaultAction(HloInstruction* hlo) override; + + Status HandleCustomCall(HloInstruction* hlo) override; + + Status HandleParameter(HloInstruction* hlo) override; + + static Status Run(HloComputation* computation) { + DynamicShapeRemovingVisitor visitor; + return computation->Accept(&visitor); + } +}; + +Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) { + // Default rule: If input to an op is static, remove dynamism in output. + bool input_is_dynamic = false; + // Default rule: + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (!hlo->operand(i)->shape().is_static()) { + input_is_dynamic = true; + } + } + + if (!input_is_dynamic) { + hlo->mutable_shape()->clear_dynamic_dimensions(); + } + return Status::OK(); +} + +Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) { + if (hlo->custom_call_target() == "SliceToDynamic") { + // Don't remove slice-to-dynamic instruction. + return Status::OK(); + } + return DefaultAction(hlo); +} + +Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) { + return Status::OK(); +} + } // namespace StatusOr DynamicPadder::Run(HloModule* module) { bool changed = false; VLOG(2) << "Pre DynamicPadder HLO:"; - XLA_VLOG_LINES(2, module->ToString()); + + // Removes dynamic dimensions on parameters if there is already a binding for + // it. We do this because we have two different APIs to express a dynamic + // dimension: + // + // 1. Dynamic dimension as specificed directly in the shape -- Needed for + // Pytorch. + // + // 2. Dynamic dimension using dynamic parameter binding object. This + // is needed for tensorflow. + // + // For case 1, we will insert "pad-to-static" instruction in the + // beginning of xla execution, to make it into a static layout. + // + // For case 2, since it already has a static layout, we remove the + // dynamic dimension. + // + // TODO(b/145140571): Convert all API invocations to case 1. + // + TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding( + [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, + const DynamicParameterBinding::DynamicDimension& dynamic_dimension) + -> Status { + HloInstruction* parameter = + module->entry_computation()->parameter_instruction( + dynamic_dimension.parameter_num); + ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(), + dynamic_dimension.parameter_index, + dynamic_dimension.dimension, false); + return Status::OK(); + })); + + TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module)); TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference, DynamicDimensionInference::Run(module)); @@ -806,8 +970,28 @@ StatusOr DynamicPadder::Run(HloModule* module) { } } - TF_RETURN_IF_ERROR( - InsertUnpadsForModuleOutputs(dynamic_dimension_inference, module)); + TF_RETURN_IF_ERROR(InsertSliceToDynamicBeforeModuleOutputs( + dynamic_dimension_inference, module)); + + // Remove all dynamic dimensions after entry parameter and root instruction -- + // Dynamic padder will produce an equivalent static shaped graph. + for (HloComputation* computation : module->computations()) { + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(computation)); + } else { + for (HloInstruction* inst : computation->MakeInstructionPostOrder()) { + bool operand_is_dynamic = false; + for (auto* operand : inst->operands()) { + if (!operand->shape().is_static()) { + operand_is_dynamic = true; + } + } + if (!operand_is_dynamic) { + inst->mutable_shape()->clear_dynamic_dimensions(); + } + } + } + } HloDCE dce; TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 9f415c8fbae..9cc344be06c 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -60,7 +60,6 @@ class GenericTransferManager : public TransferManager { int64 GetByteSizeRequirement(const Shape& shape) const override; - protected: Status WriteSingleTupleIndexTable( se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) override; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index defaf4cd7ab..6da22ff9393 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -942,8 +942,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { const Shape& instruction_subshape = ShapeUtil::GetSubshape(instruction->shape(), index); for (const LogicalBuffer* buffer : buffers) { - if (!Shape::Equal().MinorToMajorOnlyInLayout()( - instruction_subshape, buffer->shape())) { + if (!Shape::Equal() + .IgnoreDynamicDimension() + .MinorToMajorOnlyInLayout()(instruction_subshape, + buffer->shape())) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -1005,8 +1007,9 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, module->entry_computation()) .result_layout(); if (result_layout.LayoutIsSet()) { - TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( - module->result_shape(), result_layout.shape())); + TF_RET_CHECK( + Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()( + module->result_shape(), result_layout.shape())); } return Status::OK(); } @@ -1993,9 +1996,10 @@ Status LayoutAssignment::PropagateComputationLayouts( << ": " << computed_computation_layout.result_layout().ToString(); *result_layout = computed_computation_layout.result_layout(); } else { - TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( - computed_computation_layout.result_layout().shape(), - result_layout->shape())); + TF_RET_CHECK( + Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()( + computed_computation_layout.result_layout().shape(), + result_layout->shape())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f08862bff26..40fda188fe3 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -270,6 +270,13 @@ class TransferManager { static StatusOr GetForPlatform( const se::Platform* platform); + // Writes the given device-memory pointers in 'elements' to the given region + // to construct a tuple index table in the platform-specific tuple + // representation. + virtual Status WriteSingleTupleIndexTable( + se::Stream* stream, absl::Span elements, + const Shape& shape, se::DeviceMemoryBase* region) = 0; + protected: // Transfer a memory block of the given size from the device source into the // 'destination' buffer. @@ -287,13 +294,6 @@ class TransferManager { const void* source, se::DeviceMemoryBase* destination); - // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple index table in the platform-specific tuple - // representation. - virtual Status WriteSingleTupleIndexTable( - se::Stream* stream, absl::Span elements, - const Shape& shape, se::DeviceMemoryBase* region) = 0; - private: // The mutex that guards the platform-to-transfer manager map. static tensorflow::mutex platform_transfer_manager_mutex_; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 44ed3181162..05401156270 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -48,7 +48,7 @@ void ShapeLayout::SetToDefaultLayout() { bool ShapeLayout::MatchesLayoutInShape(const Shape& shape, bool minor_to_major_only) const { - auto equal = Shape::Equal(); + auto equal = Shape::Equal().IgnoreDynamicDimension(); if (minor_to_major_only) { equal.MinorToMajorOnlyInLayout(); } @@ -81,11 +81,11 @@ void ShapeLayout::ResetLayout(const Layout& layout, } bool ShapeLayout::operator==(const ShapeLayout& other) const { - return ShapeUtil::Equal(shape_, other.shape_); + return Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_); } bool ShapeLayout::operator!=(const ShapeLayout& other) const { - return !ShapeUtil::Equal(shape_, other.shape_); + return !Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_); } } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 668274ae714..769094b1f0b 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -388,6 +388,9 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions); + // Make a scalar shape with given primitive type. + static Shape MakeScalarShape(PrimitiveType element_type); + // Constructs a new shape with the given element type and sequence of // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates // with a true value that the respective dimension is dynamic. If the @@ -398,9 +401,6 @@ class ShapeUtil { absl::Span dimensions, const std::vector& dynamic_dimensions); - // Make a scalar shape with given primitive type. - static Shape MakeScalarShape(PrimitiveType element_type); - // Constructs a new shape with the given element type and sequence of // dimensions. Method checks if the element type is valid and the shape's // size fits in std::numeric_limits::max(). @@ -430,7 +430,6 @@ class ShapeUtil { static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, absl::Span dimensions, int64 max_sparse_elements); - // Returns the same shape except with all dimensions set to be static. static Shape MakeShapeWithStaticDimensions(const Shape& shape); diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 4558a7d9f80..cb8d9a1d4da 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -648,7 +648,8 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, xla::StatusOr> XRTTupleAllocation::ToDeviceMemoryTree( - const std::function& release_checker) { + const std::function(const xla::ShapeIndex&)>& + release_checker) { xla::ShapeTree shaped_tree(on_device_shape()); for (const auto& index_buffer : buffers_) { if (index_buffer.second == nullptr || @@ -657,7 +658,9 @@ XRTTupleAllocation::ToDeviceMemoryTree( index_buffer.first.ToString(), " has been released"); } - if (!release_checker(index_buffer.first)) { + TF_ASSIGN_OR_RETURN(bool should_release, + release_checker(index_buffer.first)); + if (!should_release) { *shaped_tree.mutable_element(index_buffer.first) = index_buffer.second->allocation(); } else { diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 810c6128cad..8b87a12cfd6 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -229,7 +229,8 @@ class XRTTupleAllocation : public core::RefCounted { // ScopedShapedBuffer, which wants ownership and does not allow sharing. xla::StatusOr> ToDeviceMemoryTree( - const std::function& release_checker); + const std::function(const xla::ShapeIndex&)>& + release_checker); private: // Creates a new handle with (tuple) shape.