[XLA] Update dynamic dimension inference when replacing a node.
Otherwise dynamic dimension inference won't have the latest view of the graph. PiperOrigin-RevId: 320667881 Change-Id: I75f8e993904385fc516f046c96343fe54419e27f
This commit is contained in:
		
							parent
							
								
									49750fb8ac
								
							
						
					
					
						commit
						5abbeeec7e
					
				| @ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() { | |||||||
|       custom_call_handler_); |       custom_call_handler_); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith( | ||||||
|  |     HloInstruction* replace, HloInstruction* with) { | ||||||
|  |   CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32))); | ||||||
|  |   CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32))); | ||||||
|  |   for (auto& kv : dynamic_mapping_) { | ||||||
|  |     if (kv.second == replace) { | ||||||
|  |       kv.second = with; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, | Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, | ||||||
|                                                      HloInstruction* new_inst, |                                                      HloInstruction* new_inst, | ||||||
|                                                      const ShapeIndex& index) { |                                                      const ShapeIndex& index) { | ||||||
|  | |||||||
| @ -68,6 +68,11 @@ class DynamicDimensionInference { | |||||||
|     SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1)); |     SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   // For all tensors whose dynamic dimension is `replace`, replace them with
 | ||||||
|  |   // `with`.
 | ||||||
|  |   void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace, | ||||||
|  |                                           HloInstruction* with); | ||||||
|  | 
 | ||||||
|   friend class DynamicDimensionInferenceVisitor; |   friend class DynamicDimensionInferenceVisitor; | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ namespace { | |||||||
| 
 | 
 | ||||||
| StatusOr<bool> ReplaceGetSize( | StatusOr<bool> ReplaceGetSize( | ||||||
|     HloInstruction* instr, |     HloInstruction* instr, | ||||||
|     const DynamicDimensionInference* dynamic_dimension_inference) { |     DynamicDimensionInference* dynamic_dimension_inference) { | ||||||
|   if (instr->opcode() != HloOpcode::kGetDimensionSize) { |   if (instr->opcode() != HloOpcode::kGetDimensionSize) { | ||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
| @ -47,11 +47,18 @@ StatusOr<bool> ReplaceGetSize( | |||||||
|       dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); |       dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); | ||||||
|   if (dynamic_size != nullptr) { |   if (dynamic_size != nullptr) { | ||||||
|     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); |     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); | ||||||
|  |     // The dependency between a instruction and its dynamic dimensions is not
 | ||||||
|  |     // modeled in the IR. As instr is being replaced by dynamic_size, also tell
 | ||||||
|  |     // dynamic dimension inference that the instruction is being replaced.
 | ||||||
|  |     dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith( | ||||||
|  |         instr, dynamic_size); | ||||||
|   } else { |   } else { | ||||||
|     int32 size = instr->operand(0)->shape().dimensions(dim); |     int32 size = instr->operand(0)->shape().dimensions(dim); | ||||||
|     HloInstruction* new_instr = computation->AddInstruction( |     HloInstruction* new_instr = computation->AddInstruction( | ||||||
|         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size))); |         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size))); | ||||||
|     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); |     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); | ||||||
|  |     dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, | ||||||
|  |                                                                     new_instr); | ||||||
|   } |   } | ||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
| @ -95,14 +102,14 @@ StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) { | |||||||
|   //
 |   //
 | ||||||
|   // This will get static size of the op, which is incorrect.
 |   // This will get static size of the op, which is incorrect.
 | ||||||
|   for (auto* computation : module->computations()) { |   for (auto* computation : module->computations()) { | ||||||
|     for (auto instruction : computation->instructions()) { |     for (auto instruction : computation->MakeInstructionPostOrder()) { | ||||||
|       TF_ASSIGN_OR_RETURN(bool replaced_get_size, |       TF_ASSIGN_OR_RETURN(bool replaced_get_size, | ||||||
|                           ReplaceGetSize(instruction, &inference)); |                           ReplaceGetSize(instruction, &inference)); | ||||||
|       changed = changed || replaced_get_size; |       changed = changed || replaced_get_size; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   for (auto* computation : module->computations()) { |   for (auto* computation : module->computations()) { | ||||||
|     for (auto instruction : computation->instructions()) { |     for (auto instruction : computation->MakeInstructionPostOrder()) { | ||||||
|       TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); |       TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); | ||||||
|       changed = changed || replaced_set_size; |       changed = changed || replaced_set_size; | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -26,6 +26,7 @@ limitations under the License. | |||||||
| #include "tensorflow/compiler/xla/tests/literal_test_util.h" | #include "tensorflow/compiler/xla/tests/literal_test_util.h" | ||||||
| #include "tensorflow/compiler/xla/tests/test_utils.h" | #include "tensorflow/compiler/xla/tests/test_utils.h" | ||||||
| #include "tensorflow/compiler/xla/types.h" | #include "tensorflow/compiler/xla/types.h" | ||||||
|  | #include "tensorflow/compiler/xla/util.h" | ||||||
| #include "tensorflow/core/lib/core/status_test_util.h" | #include "tensorflow/core/lib/core/status_test_util.h" | ||||||
| #include "tensorflow/core/platform/types.h" | #include "tensorflow/core/platform/types.h" | ||||||
| 
 | 
 | ||||||
| @ -55,6 +56,24 @@ ENTRY gds { | |||||||
|               op::Multiply(op::Constant(), op::Constant())); |               op::Multiply(op::Constant(), op::Constant())); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) { | ||||||
|  |   auto module = ParseAndReturnVerifiedModule(R"( | ||||||
|  | HloModule _ | ||||||
|  | ENTRY gds { | ||||||
|  |   p = s32[3,4] parameter(0) | ||||||
|  |   size0 = s32[] get-dimension-size(p), dimensions={0} | ||||||
|  |   p_copy = s32[3,4] copy(p) | ||||||
|  |   p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0} | ||||||
|  |   size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0} | ||||||
|  |   ROOT mul = s32[] multiply(size0, size1) | ||||||
|  | })") | ||||||
|  |                     .ValueOrDie(); | ||||||
|  |   HloGetDimensionSizeRewriter pass; | ||||||
|  |   EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); | ||||||
|  |   EXPECT_THAT(module->entry_computation()->root_instruction(), | ||||||
|  |               op::Multiply(op::Constant(), op::Constant())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { | TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { | ||||||
|   auto module = ParseAndReturnUnverifiedModule(R"( |   auto module = ParseAndReturnUnverifiedModule(R"( | ||||||
| HloModule _ | HloModule _ | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user