[XLA] Extend the HLO verifier to check that non-layout-changing instructions
preserve operand layouts. Add an std::function member to the HloVerifier for a backend to specify the function object used to determine whether an instruction can change layouts. Use the function object to find out the non-layout-changing instructions and check that such instructions should produce results with the same layouts as its operands. Add test cases. PiperOrigin-RevId: 215941282
This commit is contained in:
parent
496bc15898
commit
03b4161326
@ -2450,6 +2450,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_parser",
|
||||
":hlo_verifier",
|
||||
":layout_assignment",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
||||
@ -327,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
|
||||
{
|
||||
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
|
||||
"simplification after layout assignement");
|
||||
pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false);
|
||||
// TODO(b/117156505): When the bug is fixed, the CPU backend should not
|
||||
// produce layout changing elementwise operations. We will then pass
|
||||
// LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to
|
||||
// enable stricter verification.
|
||||
pass.AddInvariantChecker<HloVerifier>(
|
||||
/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
|
||||
/*is_layout_sensitive=*/true,
|
||||
[](const Shape&, const Shape&) { return true; },
|
||||
|
||||
@ -239,8 +239,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
|
||||
{
|
||||
HloPassPipeline pipeline("post-layout_assignment");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pipeline.AddInvariantChecker<HloVerifier>(
|
||||
/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false,
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
|
||||
// The LayoutAssignment pass may leave behind kCopy instructions which are
|
||||
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
|
||||
@ -286,8 +288,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
|
||||
{
|
||||
HloPassFix<HloPassPipeline> fusion("fusion");
|
||||
fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false);
|
||||
fusion.AddInvariantChecker<HloVerifier>(
|
||||
/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false,
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
|
||||
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
|
||||
fusion.AddPass<FusionMerger>();
|
||||
@ -299,7 +303,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
|
||||
HloPassPipeline reduce_pipeline("reduce-precision");
|
||||
reduce_pipeline.AddInvariantChecker<HloVerifier>(
|
||||
/*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
|
||||
/*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false,
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
ReducePrecisionInsertion::AddPasses(
|
||||
&reduce_pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
|
||||
@ -325,8 +330,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
|
||||
// (b/27180329). Therefore, in that case, we set the output to be a copy of
|
||||
// the parameter.
|
||||
HloPassPipeline pipeline("GPU-ir-emit-prepare");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pipeline.AddInvariantChecker<HloVerifier>(
|
||||
/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false,
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
|
||||
// Copy insertion should be performed immediately before IR emission to avoid
|
||||
// inserting unnecessary copies (later pass adds an instruction which
|
||||
|
||||
@ -1042,7 +1042,10 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) {
|
||||
// not check result shape as that is checked in the ShapeVerifier.
|
||||
class InstructionVerifier : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
InstructionVerifier() {}
|
||||
explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func)
|
||||
: instruction_can_change_layout_func_(
|
||||
instruction_can_change_layout_func) {}
|
||||
|
||||
Status DefaultAction(HloInstruction*) override { return Status::OK(); }
|
||||
|
||||
@ -1143,8 +1146,34 @@ class InstructionVerifier : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Postprocess(HloInstruction* instruction) override {
|
||||
if (instruction_can_change_layout_func_ &&
|
||||
LayoutUtil::IsDenseArray(instruction->shape()) &&
|
||||
!instruction_can_change_layout_func_(instruction)) {
|
||||
const Shape& result_shape = instruction->shape();
|
||||
const Layout& result_layout = result_shape.layout();
|
||||
for (HloInstruction* operand : instruction->operands()) {
|
||||
const Shape& operand_shape = operand->shape();
|
||||
if (LayoutUtil::IsDenseArray(operand_shape) &&
|
||||
ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) {
|
||||
const Layout& operand_layout = operand_shape.layout();
|
||||
TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
|
||||
<< "Instruction shouldn't change layouts "
|
||||
<< instruction->ToString() << " From "
|
||||
<< ShapeUtil::HumanString(result_shape) << " To "
|
||||
<< ShapeUtil::HumanString(operand_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
|
||||
// Determines whether an instruction can change layouts.
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -1158,7 +1187,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
|
||||
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
|
||||
|
||||
InstructionVerifier instruction_verifier;
|
||||
InstructionVerifier instruction_verifier(
|
||||
instruction_can_change_layout_func_);
|
||||
TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
|
||||
}
|
||||
|
||||
|
||||
@ -155,11 +155,17 @@ class HloVerifier : public HloModulePass {
|
||||
public:
|
||||
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
|
||||
|
||||
explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
|
||||
explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func = {})
|
||||
: shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
|
||||
return absl::make_unique<ShapeVerifier>(layout_sensitive,
|
||||
allow_mixed_precision);
|
||||
}) {}
|
||||
}),
|
||||
instruction_can_change_layout_func_(
|
||||
std::move(instruction_can_change_layout_func)) {
|
||||
CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive);
|
||||
}
|
||||
|
||||
// Uses custom shape verification.
|
||||
explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
|
||||
@ -177,6 +183,10 @@ class HloVerifier : public HloModulePass {
|
||||
// being a DfsHloVisitor, is stateful. We want a clean object
|
||||
// for each run of the verifier.
|
||||
ShapeVerifierFactory shape_verifier_factory_;
|
||||
|
||||
// Determines whether an instruction can change layouts.
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase {
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
|
||||
};
|
||||
|
||||
class HloVerifierTestLayoutSensitive : public HloTestBase {
|
||||
public:
|
||||
HloVerifierTestLayoutSensitive()
|
||||
: HloTestBase(/*verifier_layout_sensitive=*/true,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/false,
|
||||
LayoutAssignment::InstructionCanChangeLayout) {}
|
||||
};
|
||||
|
||||
TEST_F(HloVerifierTest, NullInstructionParent) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
|
||||
HasSubstr("non-positive base area dilation factor"));
|
||||
}
|
||||
|
||||
static const char* const kAddWithLayoutChangeHlo = R"(
|
||||
HloModule AddWithLayoutChange
|
||||
ENTRY AddWithLayoutChange {
|
||||
par0 = f32[3,4]{1,0} parameter(0)
|
||||
par1 = f32[3,4]{0,1} parameter(1)
|
||||
ROOT add0 = f32[3,4]{1,0} add(par0,par1)
|
||||
}
|
||||
)";
|
||||
|
||||
TEST_F(HloVerifierTest, AddWithLayoutChange) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Instruction shouldn't change layouts"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
|
||||
const char* const kSliceWithLayoutChangeHlo = R"(
|
||||
HloModule SliceWithLayoutChange
|
||||
ENTRY SliceWithLayoutChange {
|
||||
par0 = f32[4,5]{0,1} parameter(0)
|
||||
par1 = s32[2] parameter(1)
|
||||
ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1),
|
||||
dynamic_slice_sizes={3,4}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseHloString(kSliceWithLayoutChangeHlo));
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Instruction shouldn't change layouts"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
|
||||
const char* const kConcatWithLayoutChangeHlo = R"(
|
||||
HloModule ConcatWithLayoutChange
|
||||
ENTRY ConcatWithLayoutChange {
|
||||
par0 = f32[3,5]{0,1} parameter(0)
|
||||
par1 = f32[3,3]{1,0} parameter(1)
|
||||
ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
|
||||
dimensions={1}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseHloString(kConcatWithLayoutChangeHlo));
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Instruction shouldn't change layouts"));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
|
||||
} // namespace
|
||||
|
||||
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier)
|
||||
bool allow_mixed_precision_in_hlo_verifier,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func)
|
||||
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
|
||||
verifier_layout_sensitive,
|
||||
allow_mixed_precision_in_hlo_verifier) {}
|
||||
allow_mixed_precision_in_hlo_verifier,
|
||||
instruction_can_change_layout_func) {}
|
||||
|
||||
HloTestBase::HloTestBase(se::Platform* test_platform,
|
||||
se::Platform* reference_platform,
|
||||
bool verifier_layout_sensitive,
|
||||
bool allow_mixed_precision_in_hlo_verifier)
|
||||
bool allow_mixed_precision_in_hlo_verifier,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func)
|
||||
: test_runner_(test_platform), reference_runner_(reference_platform) {
|
||||
hlo_verifier_ = absl::make_unique<HloVerifier>(
|
||||
/*layout_sensitive=*/verifier_layout_sensitive,
|
||||
/*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
|
||||
/*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
|
||||
instruction_can_change_layout_func);
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
|
||||
|
||||
@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test {
|
||||
// interpreter is the only supported backend, it will be both the test backend
|
||||
// and the reference backend.
|
||||
HloTestBase(bool verifier_layout_sensitive = false,
|
||||
bool allow_mixed_precision_in_hlo_verifier = true);
|
||||
bool allow_mixed_precision_in_hlo_verifier = true,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func = {});
|
||||
|
||||
// If your test doesn't use interpreter as the reference backend, you can use
|
||||
// this constructor. Note that your test target is responsible for linking in
|
||||
// both needed backends.
|
||||
HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
|
||||
bool verifier_layout_sensitive = false,
|
||||
bool allow_mixed_precision_in_hlo_verifier = true);
|
||||
bool allow_mixed_precision_in_hlo_verifier = true,
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func = {});
|
||||
|
||||
~HloTestBase() override {}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user