[XLA] Update dynamic dimension inference when replacing a node.
Otherwise dynamic dimension inference won't have the latest view of the graph. PiperOrigin-RevId: 320667881 Change-Id: I75f8e993904385fc516f046c96343fe54419e27f
This commit is contained in:
parent
49750fb8ac
commit
5abbeeec7e
@ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
|
|||||||
custom_call_handler_);
|
custom_call_handler_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
|
||||||
|
HloInstruction* replace, HloInstruction* with) {
|
||||||
|
CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32)));
|
||||||
|
CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32)));
|
||||||
|
for (auto& kv : dynamic_mapping_) {
|
||||||
|
if (kv.second == replace) {
|
||||||
|
kv.second = with;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
|
Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
|
||||||
HloInstruction* new_inst,
|
HloInstruction* new_inst,
|
||||||
const ShapeIndex& index) {
|
const ShapeIndex& index) {
|
||||||
|
|||||||
@ -68,6 +68,11 @@ class DynamicDimensionInference {
|
|||||||
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
|
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For all tensors whose dynamic dimension is `replace`, replace them with
|
||||||
|
// `with`.
|
||||||
|
void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
|
||||||
|
HloInstruction* with);
|
||||||
|
|
||||||
friend class DynamicDimensionInferenceVisitor;
|
friend class DynamicDimensionInferenceVisitor;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -28,7 +28,7 @@ namespace {
|
|||||||
|
|
||||||
StatusOr<bool> ReplaceGetSize(
|
StatusOr<bool> ReplaceGetSize(
|
||||||
HloInstruction* instr,
|
HloInstruction* instr,
|
||||||
const DynamicDimensionInference* dynamic_dimension_inference) {
|
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||||
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
|
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -47,11 +47,18 @@ StatusOr<bool> ReplaceGetSize(
|
|||||||
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
|
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
|
||||||
if (dynamic_size != nullptr) {
|
if (dynamic_size != nullptr) {
|
||||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
||||||
|
// The dependency between a instruction and its dynamic dimensions is not
|
||||||
|
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
|
||||||
|
// dynamic dimension inference that the instruction is being replaced.
|
||||||
|
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
|
||||||
|
instr, dynamic_size);
|
||||||
} else {
|
} else {
|
||||||
int32 size = instr->operand(0)->shape().dimensions(dim);
|
int32 size = instr->operand(0)->shape().dimensions(dim);
|
||||||
HloInstruction* new_instr = computation->AddInstruction(
|
HloInstruction* new_instr = computation->AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
||||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
||||||
|
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
|
||||||
|
new_instr);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -95,14 +102,14 @@ StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) {
|
|||||||
//
|
//
|
||||||
// This will get static size of the op, which is incorrect.
|
// This will get static size of the op, which is incorrect.
|
||||||
for (auto* computation : module->computations()) {
|
for (auto* computation : module->computations()) {
|
||||||
for (auto instruction : computation->instructions()) {
|
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||||
TF_ASSIGN_OR_RETURN(bool replaced_get_size,
|
TF_ASSIGN_OR_RETURN(bool replaced_get_size,
|
||||||
ReplaceGetSize(instruction, &inference));
|
ReplaceGetSize(instruction, &inference));
|
||||||
changed = changed || replaced_get_size;
|
changed = changed || replaced_get_size;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto* computation : module->computations()) {
|
for (auto* computation : module->computations()) {
|
||||||
for (auto instruction : computation->instructions()) {
|
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||||
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
||||||
changed = changed || replaced_set_size;
|
changed = changed || replaced_set_size;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -55,6 +56,24 @@ ENTRY gds {
|
|||||||
op::Multiply(op::Constant(), op::Constant()));
|
op::Multiply(op::Constant(), op::Constant()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) {
|
||||||
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
|
HloModule _
|
||||||
|
ENTRY gds {
|
||||||
|
p = s32[3,4] parameter(0)
|
||||||
|
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||||
|
p_copy = s32[3,4] copy(p)
|
||||||
|
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
|
||||||
|
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
|
||||||
|
ROOT mul = s32[] multiply(size0, size1)
|
||||||
|
})")
|
||||||
|
.ValueOrDie();
|
||||||
|
HloGetDimensionSizeRewriter pass;
|
||||||
|
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Multiply(op::Constant(), op::Constant()));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) {
|
TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) {
|
||||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||||
HloModule _
|
HloModule _
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user