[XLA] Add a memory space propagation pass.
PiperOrigin-RevId: 310371097 Change-Id: I7b58fc0c2a67d69a4b68136acc12e4ca9f16c464
This commit is contained in:
parent
75d1fdd15d
commit
0cc3e612bd
@ -3234,6 +3234,29 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_space_propagation",
|
||||
srcs = ["memory_space_propagation.cc"],
|
||||
hdrs = ["memory_space_propagation.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_pass",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "memory_space_propagation_test",
|
||||
srcs = ["memory_space_propagation_test.cc"],
|
||||
deps = [
|
||||
":hlo_parser",
|
||||
":memory_space_propagation",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_dce",
|
||||
srcs = ["hlo_dce.cc"],
|
||||
|
67
tensorflow/compiler/xla/service/memory_space_propagation.cc
Normal file
67
tensorflow/compiler/xla/service/memory_space_propagation.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/memory_space_propagation.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
|
||||
bool modified = false;
|
||||
TF_ASSIGN_OR_RETURN(auto dataflow_analysis,
|
||||
HloDataflowAnalysis::Run(*module));
|
||||
dataflow_analysis_ = std::move(dataflow_analysis);
|
||||
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
// Propagate the operand subshapes.
|
||||
for (int operand_idx = 0; operand_idx < instruction->operand_count();
|
||||
++operand_idx) {
|
||||
modified |=
|
||||
PropagateSubshapes(instruction->operand(operand_idx)->shape(),
|
||||
instruction->fused_parameter(operand_idx));
|
||||
}
|
||||
|
||||
// Propagate output subshapes.
|
||||
modified |= PropagateSubshapes(instruction->shape(),
|
||||
instruction->fused_expression_root());
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
bool MemorySpacePropagation::PropagateSubshapes(
|
||||
const Shape& caller_shape, const HloInstruction* callee_instruction) const {
|
||||
bool modified = false;
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(caller_shape)) {
|
||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
|
||||
callee_instruction, indexed_shape.index);
|
||||
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
Shape* shape = ShapeUtil::GetMutableSubshape(
|
||||
position.instruction->mutable_shape(), position.index);
|
||||
if (shape->layout().memory_space() != memory_space) {
|
||||
shape->mutable_layout()->set_memory_space(memory_space);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
} // namespace xla
|
46
tensorflow/compiler/xla/service/memory_space_propagation.h
Normal file
46
tensorflow/compiler/xla/service/memory_space_propagation.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This is a legalization pass that propagates the memory space in the layout to
|
||||
// the fusion computations.
|
||||
class MemorySpacePropagation : public HloModulePass {
|
||||
public:
|
||||
~MemorySpacePropagation() override = default;
|
||||
absl::string_view name() const override { return "memory-space-propagation"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Given the caller shape (operand or output) and its corresponding
|
||||
// insturction in the fused computation (parameter or root), propagates the
|
||||
// memory space to all the subshapes in the callee side. Returns true if the
|
||||
// module is modified.
|
||||
bool PropagateSubshapes(const Shape& caller_shape,
|
||||
const HloInstruction* callee_instruction) const;
|
||||
|
||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_
|
203
tensorflow/compiler/xla/service/memory_space_propagation_test.cc
Normal file
203
tensorflow/compiler/xla/service/memory_space_propagation_test.cc
Normal file
@ -0,0 +1,203 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/memory_space_propagation.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class MemorySpacePropagationTest : public HloTestBase {
|
||||
public:
|
||||
MemorySpacePropagationTest()
|
||||
: HloTestBase(),
|
||||
verifier_(/*layout_sensitive=*/false, /*allow_mixed_precision*/ false) {
|
||||
}
|
||||
|
||||
Status Verify(HloModule* module) { return verifier_.Run(module).status(); }
|
||||
|
||||
private:
|
||||
HloVerifier verifier_;
|
||||
};
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, NoMemorySpace) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule NoMemorySpace
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)} copy(%param2)
|
||||
%fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_FALSE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref, ParseAndReturnVerifiedModule(hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, NonTupleOutput) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule NonTupleOutput
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
absl::string_view expected_hlo_string = R"(
|
||||
HloModule NonTupleOutput
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_EXPECT_OK(Verify(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, TupleOutput) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule TupleOutput
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
ROOT %tuple = (s32[6]{0:T(128)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
%gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0
|
||||
%gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1
|
||||
ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1)
|
||||
}
|
||||
)";
|
||||
absl::string_view expected_hlo_string = R"(
|
||||
HloModule TupleOutput
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
%multiply.0 = s32[6]{0:T(128)} multiply(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
ROOT %tuple = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) tuple(%add.0, %multiply.0)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = (s32[6]{0:T(128)S(1)}, s32[6]{0:T(128)}) fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
%gte0 = s32[6]{0:T(128)S(1)} get-tuple-element(%fusion), index=0
|
||||
%gte1 = s32[6]{0:T(128)} get-tuple-element(%fusion), index=1
|
||||
ROOT %root = s32[6]{0:T(128)} add(%gte0, %gte1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_EXPECT_OK(Verify(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user