[XLA] Update dynamic dimension inference when replacing a node.
Otherwise dynamic dimension inference won't have the latest view of the graph. PiperOrigin-RevId: 320667881 Change-Id: I75f8e993904385fc516f046c96343fe54419e27f
This commit is contained in:
parent
49750fb8ac
commit
5abbeeec7e
@ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
|
||||
custom_call_handler_);
|
||||
}
|
||||
|
||||
void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
|
||||
HloInstruction* replace, HloInstruction* with) {
|
||||
CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32)));
|
||||
CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32)));
|
||||
for (auto& kv : dynamic_mapping_) {
|
||||
if (kv.second == replace) {
|
||||
kv.second = with;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
|
||||
HloInstruction* new_inst,
|
||||
const ShapeIndex& index) {
|
||||
|
@ -68,6 +68,11 @@ class DynamicDimensionInference {
|
||||
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
|
||||
}
|
||||
|
||||
// For all tensors whose dynamic dimension is `replace`, replace them with
|
||||
// `with`.
|
||||
void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
|
||||
HloInstruction* with);
|
||||
|
||||
friend class DynamicDimensionInferenceVisitor;
|
||||
|
||||
private:
|
||||
|
@ -28,7 +28,7 @@ namespace {
|
||||
|
||||
StatusOr<bool> ReplaceGetSize(
|
||||
HloInstruction* instr,
|
||||
const DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
DynamicDimensionInference* dynamic_dimension_inference) {
|
||||
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
|
||||
return false;
|
||||
}
|
||||
@ -47,11 +47,18 @@ StatusOr<bool> ReplaceGetSize(
|
||||
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
|
||||
if (dynamic_size != nullptr) {
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
||||
// The dependency between a instruction and its dynamic dimensions is not
|
||||
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
|
||||
// dynamic dimension inference that the instruction is being replaced.
|
||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
|
||||
instr, dynamic_size);
|
||||
} else {
|
||||
int32 size = instr->operand(0)->shape().dimensions(dim);
|
||||
HloInstruction* new_instr = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
||||
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
|
||||
new_instr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -95,14 +102,14 @@ StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) {
|
||||
//
|
||||
// This will get static size of the op, which is incorrect.
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto instruction : computation->instructions()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool replaced_get_size,
|
||||
ReplaceGetSize(instruction, &inference));
|
||||
changed = changed || replaced_get_size;
|
||||
}
|
||||
}
|
||||
for (auto* computation : module->computations()) {
|
||||
for (auto instruction : computation->instructions()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
|
||||
changed = changed || replaced_set_size;
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -55,6 +56,24 @@ ENTRY gds {
|
||||
op::Multiply(op::Constant(), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3,4] parameter(0)
|
||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||
p_copy = s32[3,4] copy(p)
|
||||
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
|
||||
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
|
||||
ROOT mul = s32[] multiply(size0, size1)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
HloGetDimensionSizeRewriter pass;
|
||||
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Multiply(op::Constant(), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) {
|
||||
auto module = ParseAndReturnUnverifiedModule(R"(
|
||||
HloModule _
|
||||
|
Loading…
x
Reference in New Issue
Block a user