[XLA] Update dynamic dimension inference when replacing a node.

Otherwise dynamic dimension inference won't have the latest view of the graph.

PiperOrigin-RevId: 320667881
Change-Id: I75f8e993904385fc516f046c96343fe54419e27f
This commit is contained in:
Yunxing Dai 2020-07-10 13:44:14 -07:00 committed by TensorFlower Gardener
parent 49750fb8ac
commit 5abbeeec7e
4 changed files with 45 additions and 3 deletions

View File

@ -1602,6 +1602,17 @@ Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
custom_call_handler_);
}
void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
HloInstruction* replace, HloInstruction* with) {
CHECK(Shape::Equal()(replace->shape(), ShapeUtil::MakeScalarShape(S32)));
CHECK(Shape::Equal()(with->shape(), ShapeUtil::MakeScalarShape(S32)));
for (auto& kv : dynamic_mapping_) {
if (kv.second == replace) {
kv.second = with;
}
}
}
Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
HloInstruction* new_inst,
const ShapeIndex& index) {

View File

@ -68,6 +68,11 @@ class DynamicDimensionInference {
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
}
// For all tensors whose dynamic dimension is `replace`, replace them with
// `with`.
void ReplaceAllDynamicDimensionUsesWith(HloInstruction* replace,
HloInstruction* with);
friend class DynamicDimensionInferenceVisitor;
private:

View File

@ -28,7 +28,7 @@ namespace {
StatusOr<bool> ReplaceGetSize(
HloInstruction* instr,
const DynamicDimensionInference* dynamic_dimension_inference) {
DynamicDimensionInference* dynamic_dimension_inference) {
if (instr->opcode() != HloOpcode::kGetDimensionSize) {
return false;
}
@ -47,11 +47,18 @@ StatusOr<bool> ReplaceGetSize(
dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
if (dynamic_size != nullptr) {
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
// The dependency between a instruction and its dynamic dimensions is not
// modeled in the IR. As instr is being replaced by dynamic_size, also tell
// dynamic dimension inference that the instruction is being replaced.
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
instr, dynamic_size);
} else {
int32 size = instr->operand(0)->shape().dimensions(dim);
HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
new_instr);
}
return true;
}
@ -95,14 +102,14 @@ StatusOr<bool> HloGetDimensionSizeRewriter::Run(HloModule* module) {
//
// This will get static size of the op, which is incorrect.
for (auto* computation : module->computations()) {
for (auto instruction : computation->instructions()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool replaced_get_size,
ReplaceGetSize(instruction, &inference));
changed = changed || replaced_get_size;
}
}
for (auto* computation : module->computations()) {
for (auto instruction : computation->instructions()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
changed = changed || replaced_set_size;
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
@ -55,6 +56,24 @@ ENTRY gds {
op::Multiply(op::Constant(), op::Constant()));
}
TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule _
ENTRY gds {
p = s32[3,4] parameter(0)
size0 = s32[] get-dimension-size(p), dimensions={0}
p_copy = s32[3,4] copy(p)
p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0}
size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0}
ROOT mul = s32[] multiply(size0, size1)
})")
.ValueOrDie();
HloGetDimensionSizeRewriter pass;
EXPECT_TRUE(pass.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Multiply(op::Constant(), op::Constant()));
}
TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) {
auto module = ParseAndReturnUnverifiedModule(R"(
HloModule _