diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index b7186c186f4..6ebbf622614 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() { custom_call_handler_); } +void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith( + HloInstruction* replace, HloInstruction* with) { + CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32))); + CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32))); + for (auto& kv : dynamic_mapping_) { + if (kv.second == replace) { + kv.second = with; + } + } +} + Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst, const ShapeIndex& index) { diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 417f0289143..607d68bd9c3 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -68,6 +68,11 @@ class DynamicDimensionInference { SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1)); } + // For all tensors whose dynamic dimension is `replace`, replace them with + // `with`. + void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace, + HloInstruction* with); + friend class DynamicDimensionInferenceVisitor; private: diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc index fdb311adb5d..9415e20af7b 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -28,7 +28,7 @@ namespace { StatusOr ReplaceGetSize( HloInstruction* instr, - const DynamicDimensionInference* dynamic_dimension_inference) { + DynamicDimensionInference* dynamic_dimension_inference) { if (instr->opcode() != HloOpcode::kGetDimensionSize) { return false; } @@ -47,11 +47,18 @@ StatusOr ReplaceGetSize( dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); if (dynamic_size != nullptr) { TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + // The dependency between a instruction and its dynamic dimensions is not + // modeled in the IR. As instr is being replaced by dynamic_size, also tell + // dynamic dimension inference that the instruction is being replaced. + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith( + instr, dynamic_size); } else { int32 size = instr->operand(0)->shape().dimensions(dim); HloInstruction* new_instr = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, + new_instr); } return true; } @@ -95,14 +102,14 @@ StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { // // This will get static size of the op, which is incorrect. for (auto* computation : module->computations()) { - for (auto instruction : computation->instructions()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool replaced_get_size, ReplaceGetSize(instruction, &inference)); changed = changed || replaced_get_size; } } for (auto* computation : module->computations()) { - for (auto instruction : computation->instructions()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); changed = changed || replaced_set_size; } diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc index d96f2db3c26..b1491e96095 100644 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" @@ -55,6 +56,24 @@ ENTRY gds { op::Multiply(op::Constant(), op::Constant())); } +TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = s32[] get-dimension-size(p), dimensions={0} + p_copy = s32[3,4] copy(p) + p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0} + size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0} + ROOT mul = s32[] multiply(size0, size1) +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { auto module = ParseAndReturnUnverifiedModule(R"( HloModule _