From 2f28e151b3af2cebebedd199e774c16ef693a5e1 Mon Sep 17 00:00:00 2001
From: Yunxing Dai <yunxing@google.com>
Date: Thu, 12 Dec 2019 10:27:36 -0800
Subject: [PATCH] [XLA] Implement dynamic input and output in DynamicPadder.

PiperOrigin-RevId: 285218938
Change-Id: Ia5f5e5ad62154b1427d2b56640eed3a443a50c1d
---
 tensorflow/compiler/xla/client/xla_builder.h  |   4 +-
 .../xla/service/cpu/cpu_executable.cc         |  10 +-
 .../service/dynamic_dimension_inference.cc    |  19 +-
 .../compiler/xla/service/dynamic_padder.cc    | 238 ++++++++++++++++--
 .../xla/service/generic_transfer_manager.h    |   1 -
 .../compiler/xla/service/layout_assignment.cc |  18 +-
 .../compiler/xla/service/transfer_manager.h   |  14 +-
 tensorflow/compiler/xla/shape_layout.cc       |   6 +-
 tensorflow/compiler/xla/shape_util.h          |   7 +-
 tensorflow/compiler/xrt/xrt_state.cc          |   7 +-
 tensorflow/compiler/xrt/xrt_state.h           |   3 +-
 11 files changed, 268 insertions(+), 59 deletions(-)

diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 5e93bb2b3ba..42a8ae5b996 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -237,12 +237,12 @@ class XlaBuilder {
   // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
   // dynamic dimensions information when XLA backend can handle dynamic
   // dimensions.
-  StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = true);
+  StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
 
   // Overload of Build which specifies a particular root instruction for the
   // computation.
   StatusOr<XlaComputation> Build(XlaOp root,
-                                 bool remove_dynamic_dimensions = true);
+                                 bool remove_dynamic_dimensions = false);
 
   // Builds the computation with the requested operations, or notes an error in
   // the parent XlaBuilder and returns an empty computation if building failed.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index d19cf4fb015..366fdca442f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -298,10 +298,12 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
       const Shape& expected_shape =
           entry_comp->parameter_instruction(i)->shape();
       const Shape& actual_shape = arguments[i].shape();
-      CHECK(expected_shape == actual_shape) << absl::StreamFormat(
-          "Shape mismatch on argument %d.  Expected %s, but was %s.", i,
-          expected_shape.ToString(/*print_layout=*/true),
-          actual_shape.ToString(/*print_layout=*/true));
+      CHECK(
+          Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape))
+          << absl::StreamFormat(
+                 "Shape mismatch on argument %d.  Expected %s, but was %s.", i,
+                 expected_shape.ToString(/*print_layout=*/true),
+                 actual_shape.ToString(/*print_layout=*/true));
     }
   }
 
diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
index 14ea6f988cb..84f93106474 100644
--- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
+++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
@@ -178,15 +178,32 @@ Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
 }
 
 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
+  if (hlo->custom_call_target() == "PadToStatic") {
+    for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
+      if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
+        HloInstruction* dynamic_size =
+            hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+                ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
+        // PadToStatic converts a dynamic dimension to static dimension. It then
+        // returns the padded data output and the dynamic sizes of input
+        // dimensions.
+        ShapeIndex data_output = {0};
+        parent_->SetDynamicSize(hlo, data_output, i, dynamic_size,
+                                {.stride = 1, .multiple_of = 1});
+      }
+    }
+    return Status::OK();
+  }
   return ForEachOperandDynamicDimension(
       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
                int64 operand_index, HloInstruction* dynamic_size,
                DimensionConstraint constraint) {
-        if (hlo->custom_call_target() != "Unpad" ||
+        if (hlo->custom_call_target() != "SliceToDynamic" ||
             absl::StartsWith(hlo->custom_call_target(), "Resize")) {
           return Unimplemented(
               "CustomCall is not supported to have a dynamic dimension");
         }
+
         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
         return Status::OK();
       });
diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc
index c94a2594f3b..7de4c9f01a4 100644
--- a/tensorflow/compiler/xla/service/dynamic_padder.cc
+++ b/tensorflow/compiler/xla/service/dynamic_padder.cc
@@ -23,7 +23,9 @@ limitations under the License.
 #include "absl/strings/str_format.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_dce.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -169,7 +171,7 @@ HloInstruction* PadWithScalar(HloInstruction* inst, int64 dim,
   return padded;
 }
 
-// In a reshape if a dynamci dimension is splitted into multiple output
+// In a reshape if a dynamic dimension is splitted into multiple output
 // dimensions, we need to rewrite the input of the reshape.
 //
 // The reason for this is that a continuous input may not be evenly reshaped
@@ -641,9 +643,77 @@ StatusOr<bool> RewriteDynamicReshape(
   return changed;
 }
 
-// For all dynamic outputs that live out of the computation, add unpad
-// operations.
-Status InsertUnpadsForModuleOutputs(
+// Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
+// Recurse into tuple instructions.
+StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
+  if (inst->shape().is_static()) {
+    return inst;
+  }
+  HloComputation* comp = inst->parent();
+  if (!inst->shape().IsTuple()) {
+    // The output shape of pad static is a tuple. The 0th element is the data
+    // output, which is the same as input shape, but without dynamic dimensions;
+    // i-th element is the dynamic dimension size for i-1th input dimension.
+    Shape data_output_shape = inst->shape();  // 0th element.
+    data_output_shape.clear_dynamic_dimensions();
+    Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
+    for (int64 i = 0; i < inst->shape().rank(); ++i) {
+      ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
+                                    &output_shape);
+    }
+    HloInstruction* pad_to_static =
+        comp->AddInstruction(HloInstruction::CreateCustomCall(
+            output_shape, {inst}, "PadToStatic", ""));
+    HloInstruction* data_output =
+        comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+            data_output_shape, pad_to_static, 0));
+    return data_output;
+  }
+
+  TF_RET_CHECK(inst->shape().IsTuple());
+  std::vector<HloInstruction*> static_tuple_elements;
+  for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
+    // For each tuple element, if it is static, pass it through. If it is
+    // dynamic, recursively call this function again.
+    HloInstruction* gte =
+        comp->AddInstruction(HloInstruction::CreateGetTupleElement(
+            inst->shape().tuple_shapes(i), inst, i));
+
+    if (gte->shape().is_static()) {
+      static_tuple_elements.push_back(gte);
+    } else {
+      TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
+                          InsertPadToStaticOnInstruction(gte));
+      static_tuple_elements.push_back(static_gte);
+    }
+  }
+
+  return comp->AddInstruction(
+      HloInstruction::CreateTuple(static_tuple_elements));
+}
+
+Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
+  std::vector<HloInstruction*> params;
+  HloComputation* entry = module->entry_computation();
+  for (int64 i = 0; i < entry->num_parameters(); ++i) {
+    HloInstruction* param =
+        module->entry_computation()->parameter_instruction(i);
+    auto users = param->users();
+    TF_ASSIGN_OR_RETURN(HloInstruction * static_param,
+                        InsertPadToStaticOnInstruction(param));
+    for (auto* user : users) {
+      TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, static_param));
+    }
+    if (param == entry->root_instruction()) {
+      module->entry_computation()->set_root_instruction(static_param);
+    }
+  }
+  return Status::OK();
+}
+
+// For all dynamic outputs that live out of the computation, add
+// slice-to-dynamic operations.
+Status InsertSliceToDynamicBeforeModuleOutputs(
     const DynamicDimensionInference& dynamic_dimension_inference,
     HloModule* module) {
   auto root = module->entry_computation()->root_instruction();
@@ -656,7 +726,7 @@ Status InsertUnpadsForModuleOutputs(
             if (dynamic_dimension_inference.GetDynamicSize(root, index, dim) !=
                 nullptr) {
               CHECK_LE(index.size(), 1) << "XLA doesn't support nested output "
-                                           "dimensions that has dynamic size";
+                                           "dimension that has dynamic size";
               has_dynamic_output = true;
             }
           }
@@ -674,30 +744,36 @@ Status InsertUnpadsForModuleOutputs(
         if (!subshape.IsArray()) {
           return;
         }
+
         auto gte = module->entry_computation()->AddInstruction(
-            HloInstruction::CreateGetTupleElement(subshape, root, index[0]));
+            HloInstruction::CreateGetTupleElement(
+                ShapeUtil::MakeShapeWithStaticDimensions(subshape), root,
+                index[0]));
 
         if (dynamic_outputs.contains(index)) {
           CHECK_EQ(index.size(), 1)
               << "XLA only support 1 layer nested output tuple";
-          // For dynamic outputs, creates an unpad operation.
-          std::vector<HloInstruction*> unpad_operands;
+          // For dynamic outputs, creates an slice operation.
+          std::vector<HloInstruction*> slice_operands;
           // First operand is the original input. Rest are dimension values.
-          unpad_operands.push_back(gte);
+          slice_operands.push_back(gte);
+          // Keep a dynamic version of the subshape as we are removing the
+          // dynamic dimension in the original root and gte.
+          Shape dynamic_subshape = subshape;
           for (int64 dim = 0; dim < subshape.rank(); ++dim) {
             HloInstruction* dynamic_size =
                 dynamic_dimension_inference.GetDynamicSize(root, index, dim);
             if (dynamic_size != nullptr) {
-              unpad_operands.push_back(dynamic_size);
+              slice_operands.push_back(dynamic_size);
             } else {
               auto const_size = HloInstruction::CreateConstant(
                   LiteralUtil::CreateR0<int32>(subshape.dimensions(dim)));
-              unpad_operands.push_back(
+              slice_operands.push_back(
                   module->entry_computation()->AddInstruction(
                       std::move(const_size)));
             }
           }
-          // This is a dynamic output, add unpad operation.
+          // This is a dynamic output, add slice operation.
           //
           // Write the backend config in the format of
           // 'dynamic_index'-'output_index'.
@@ -707,11 +783,11 @@ Status InsertUnpadsForModuleOutputs(
           //
           // output_index indicates the position of this output in all outputs
           // (including static inputs).
-          auto unpad = HloInstruction::CreateCustomCall(
-              subshape, unpad_operands, "Unpad",
+          auto slice = HloInstruction::CreateCustomCall(
+              dynamic_subshape, slice_operands, "SliceToDynamic",
               absl::StrFormat("%d-%d", dynamic_index++, index[0]));
           new_root_operands.push_back(
-              module->entry_computation()->AddInstruction(std::move(unpad)));
+              module->entry_computation()->AddInstruction(std::move(slice)));
         } else {
           new_root_operands.push_back(gte);
         }
@@ -721,37 +797,125 @@ Status InsertUnpadsForModuleOutputs(
           HloInstruction::CreateTuple(new_root_operands));
       module->entry_computation()->set_root_instruction(new_root);
     } else {
-      std::vector<HloInstruction*> unpad_operands;
+      std::vector<HloInstruction*> slice_operands;
       // First operand is the original input. Rest are dimension values.
-      unpad_operands.push_back(root);
+      slice_operands.push_back(root);
       for (int64 dim = 0; dim < root->shape().rank(); ++dim) {
         HloInstruction* dynamic_size =
             dynamic_dimension_inference.GetDynamicSize(root, {}, dim);
         if (dynamic_size != nullptr) {
-          unpad_operands.push_back(dynamic_size);
+          slice_operands.push_back(dynamic_size);
         } else {
           auto const_size = HloInstruction::CreateConstant(
               LiteralUtil::CreateR0<int32>(root->shape().dimensions(dim)));
-          unpad_operands.push_back(module->entry_computation()->AddInstruction(
+          slice_operands.push_back(module->entry_computation()->AddInstruction(
               std::move(const_size)));
         }
-        // This is a dynamic output, add unpad operation.
-        auto unpad = module->entry_computation()->AddInstruction(
-            HloInstruction::CreateCustomCall(root->shape(), unpad_operands,
-                                             "Unpad", "0-0"));
-        module->entry_computation()->set_root_instruction(unpad);
+        // This is a dynamic output, add slice operation.
+        auto slice = module->entry_computation()->AddInstruction(
+            HloInstruction::CreateCustomCall(root->shape(), slice_operands,
+                                             "SliceToDynamic", "0-0"));
+        module->entry_computation()->set_root_instruction(slice);
       }
     }
   }
   return Status::OK();
 }
 
+// Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
+//
+// After this visitor the entry computation then looks like:
+//  Param(dynamic)
+//    |
+//   GTE (dynamic)
+//    |
+//  PadToStatic(static)
+//    |
+//   .... regular computation with static shapes.
+//    |
+//  SliceToDynamic(dynamic)
+//    |
+// ROOT tuple (dynamic)
+class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
+ public:
+  Status DefaultAction(HloInstruction* hlo) override;
+
+  Status HandleCustomCall(HloInstruction* hlo) override;
+
+  Status HandleParameter(HloInstruction* hlo) override;
+
+  static Status Run(HloComputation* computation) {
+    DynamicShapeRemovingVisitor visitor;
+    return computation->Accept(&visitor);
+  }
+};
+
+Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
+  // Default rule: If input to an op is static, remove dynamism in output.
+  bool input_is_dynamic = false;
+  // Default rule:
+  for (int64 i = 0; i < hlo->operand_count(); ++i) {
+    if (!hlo->operand(i)->shape().is_static()) {
+      input_is_dynamic = true;
+    }
+  }
+
+  if (!input_is_dynamic) {
+    hlo->mutable_shape()->clear_dynamic_dimensions();
+  }
+  return Status::OK();
+}
+
+Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
+  if (hlo->custom_call_target() == "SliceToDynamic") {
+    // Don't remove slice-to-dynamic instruction.
+    return Status::OK();
+  }
+  return DefaultAction(hlo);
+}
+
+Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
+  return Status::OK();
+}
+
 }  // namespace
 
 StatusOr<bool> DynamicPadder::Run(HloModule* module) {
   bool changed = false;
   VLOG(2) << "Pre DynamicPadder HLO:";
-  XLA_VLOG_LINES(2, module->ToString());
+
+  // Removes dynamic dimensions on parameters if there is already a binding for
+  // it. We do this because we have two different APIs to express a dynamic
+  // dimension:
+  //
+  // 1. Dynamic dimension as specificed directly in the shape -- Needed for
+  // Pytorch.
+  //
+  // 2. Dynamic dimension using dynamic parameter binding object. This
+  // is needed for tensorflow.
+  //
+  // For case 1, we will insert "pad-to-static" instruction in the
+  // beginning of xla execution, to make it into a static layout.
+  //
+  // For case 2, since it already has a static layout, we remove the
+  // dynamic dimension.
+  //
+  // TODO(b/145140571): Convert all API invocations to case 1.
+  //
+  TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding(
+      [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
+          const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
+          -> Status {
+        HloInstruction* parameter =
+            module->entry_computation()->parameter_instruction(
+                dynamic_dimension.parameter_num);
+        ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(),
+                                          dynamic_dimension.parameter_index,
+                                          dynamic_dimension.dimension, false);
+        return Status::OK();
+      }));
+
+  TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
   TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
                       DynamicDimensionInference::Run(module));
 
@@ -806,8 +970,28 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
     }
   }
 
-  TF_RETURN_IF_ERROR(
-      InsertUnpadsForModuleOutputs(dynamic_dimension_inference, module));
+  TF_RETURN_IF_ERROR(InsertSliceToDynamicBeforeModuleOutputs(
+      dynamic_dimension_inference, module));
+
+  // Remove all dynamic dimensions after entry parameter and root instruction --
+  // Dynamic padder will produce an equivalent static shaped graph.
+  for (HloComputation* computation : module->computations()) {
+    if (computation == module->entry_computation()) {
+      TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(computation));
+    } else {
+      for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
+        bool operand_is_dynamic = false;
+        for (auto* operand : inst->operands()) {
+          if (!operand->shape().is_static()) {
+            operand_is_dynamic = true;
+          }
+        }
+        if (!operand_is_dynamic) {
+          inst->mutable_shape()->clear_dynamic_dimensions();
+        }
+      }
+    }
+  }
 
   HloDCE dce;
   TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 9f415c8fbae..9cc344be06c 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -60,7 +60,6 @@ class GenericTransferManager : public TransferManager {
 
   int64 GetByteSizeRequirement(const Shape& shape) const override;
 
- protected:
   Status WriteSingleTupleIndexTable(
       se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
       const Shape& shape, se::DeviceMemoryBase* region) override;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index defaf4cd7ab..6da22ff9393 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -942,8 +942,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
               const Shape& instruction_subshape =
                   ShapeUtil::GetSubshape(instruction->shape(), index);
               for (const LogicalBuffer* buffer : buffers) {
-                if (!Shape::Equal().MinorToMajorOnlyInLayout()(
-                        instruction_subshape, buffer->shape())) {
+                if (!Shape::Equal()
+                         .IgnoreDynamicDimension()
+                         .MinorToMajorOnlyInLayout()(instruction_subshape,
+                                                     buffer->shape())) {
                   return InternalError(
                       "Layout of instruction %s at index {%s} does not match "
                       "source LogicalBuffer %s: %s vs %s",
@@ -1005,8 +1007,9 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
       FindOrDie(computation_layouts_, module->entry_computation())
           .result_layout();
   if (result_layout.LayoutIsSet()) {
-    TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
-        module->result_shape(), result_layout.shape()));
+    TF_RET_CHECK(
+        Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
+            module->result_shape(), result_layout.shape()));
   }
   return Status::OK();
 }
@@ -1993,9 +1996,10 @@ Status LayoutAssignment::PropagateComputationLayouts(
             << ": " << computed_computation_layout.result_layout().ToString();
     *result_layout = computed_computation_layout.result_layout();
   } else {
-    TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
-        computed_computation_layout.result_layout().shape(),
-        result_layout->shape()));
+    TF_RET_CHECK(
+        Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
+            computed_computation_layout.result_layout().shape(),
+            result_layout->shape()));
   }
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index f08862bff26..40fda188fe3 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -270,6 +270,13 @@ class TransferManager {
   static StatusOr<TransferManager*> GetForPlatform(
       const se::Platform* platform);
 
+  // Writes the given device-memory pointers in 'elements' to the given region
+  // to construct a tuple index table in the platform-specific tuple
+  // representation.
+  virtual Status WriteSingleTupleIndexTable(
+      se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
+      const Shape& shape, se::DeviceMemoryBase* region) = 0;
+
  protected:
   // Transfer a memory block of the given size from the device source into the
   // 'destination' buffer.
@@ -287,13 +294,6 @@ class TransferManager {
                                         const void* source,
                                         se::DeviceMemoryBase* destination);
 
-  // Writes the given device-memory pointers in 'elements' to the given region
-  // to construct a tuple index table in the platform-specific tuple
-  // representation.
-  virtual Status WriteSingleTupleIndexTable(
-      se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
-      const Shape& shape, se::DeviceMemoryBase* region) = 0;
-
  private:
   // The mutex that guards the platform-to-transfer manager map.
   static tensorflow::mutex platform_transfer_manager_mutex_;
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc
index 44ed3181162..05401156270 100644
--- a/tensorflow/compiler/xla/shape_layout.cc
+++ b/tensorflow/compiler/xla/shape_layout.cc
@@ -48,7 +48,7 @@ void ShapeLayout::SetToDefaultLayout() {
 
 bool ShapeLayout::MatchesLayoutInShape(const Shape& shape,
                                        bool minor_to_major_only) const {
-  auto equal = Shape::Equal();
+  auto equal = Shape::Equal().IgnoreDynamicDimension();
   if (minor_to_major_only) {
     equal.MinorToMajorOnlyInLayout();
   }
@@ -81,11 +81,11 @@ void ShapeLayout::ResetLayout(const Layout& layout,
 }
 
 bool ShapeLayout::operator==(const ShapeLayout& other) const {
-  return ShapeUtil::Equal(shape_, other.shape_);
+  return Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_);
 }
 
 bool ShapeLayout::operator!=(const ShapeLayout& other) const {
-  return !ShapeUtil::Equal(shape_, other.shape_);
+  return !Shape::Equal().IgnoreDynamicDimension()(shape_, other.shape_);
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 668274ae714..769094b1f0b 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -388,6 +388,9 @@ class ShapeUtil {
   static Shape MakeShape(PrimitiveType element_type,
                          absl::Span<const int64> dimensions);
 
+  // Make a scalar shape with given primitive type.
+  static Shape MakeScalarShape(PrimitiveType element_type);
+
   // Constructs a new shape with the given element type and sequence of
   // potentially dynamic dimensions. The argument 'dynamic_dimensions' indicates
   // with a true value that the respective dimension is dynamic. If the
@@ -398,9 +401,6 @@ class ShapeUtil {
                          absl::Span<const int64> dimensions,
                          const std::vector<bool>& dynamic_dimensions);
 
-  // Make a scalar shape with given primitive type.
-  static Shape MakeScalarShape(PrimitiveType element_type);
-
   // Constructs a new shape with the given element type and sequence of
   // dimensions. Method checks if the element type is valid and the shape's
   // size fits in std::numeric_limits<int64>::max().
@@ -430,7 +430,6 @@ class ShapeUtil {
   static Shape MakeShapeWithSparseLayout(PrimitiveType element_type,
                                          absl::Span<const int64> dimensions,
                                          int64 max_sparse_elements);
-
   // Returns the same shape except with all dimensions set to be static.
   static Shape MakeShapeWithStaticDimensions(const Shape& shape);
 
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
index 4558a7d9f80..cb8d9a1d4da 100644
--- a/tensorflow/compiler/xrt/xrt_state.cc
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -648,7 +648,8 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
 
 xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
 XRTTupleAllocation::ToDeviceMemoryTree(
-    const std::function<bool(const xla::ShapeIndex&)>& release_checker) {
+    const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
+        release_checker) {
   xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
   for (const auto& index_buffer : buffers_) {
     if (index_buffer.second == nullptr ||
@@ -657,7 +658,9 @@ XRTTupleAllocation::ToDeviceMemoryTree(
                                      index_buffer.first.ToString(),
                                      " has been released");
     }
-    if (!release_checker(index_buffer.first)) {
+    TF_ASSIGN_OR_RETURN(bool should_release,
+                        release_checker(index_buffer.first));
+    if (!should_release) {
       *shaped_tree.mutable_element(index_buffer.first) =
           index_buffer.second->allocation();
     } else {
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
index 810c6128cad..8b87a12cfd6 100644
--- a/tensorflow/compiler/xrt/xrt_state.h
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -229,7 +229,8 @@ class XRTTupleAllocation : public core::RefCounted {
   // ScopedShapedBuffer, which wants ownership and does not allow sharing.
   xla::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>>
   ToDeviceMemoryTree(
-      const std::function<bool(const xla::ShapeIndex&)>& release_checker);
+      const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
+          release_checker);
 
  private:
   // Creates a new handle with (tuple) shape.