[XLA] Add a pass to fold hlo(convert(a),convert(b)) into hlo(a,b) to enable an alternative way of specifying wider accumulation type than the shape inference result for HLOs that support it.
PiperOrigin-RevId: 344301743 Change-Id: If3b84a8369fd396012a6915237610e10f5e0d318
This commit is contained in:
parent
287d8116bf
commit
185e785f9a
@ -5263,3 +5263,26 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convert_operand_folding",
|
||||
srcs = ["convert_operand_folding.cc"],
|
||||
hdrs = ["convert_operand_folding.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":op_expander_pass",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "convert_operand_folding_test",
|
||||
srcs = ["convert_operand_folding_test.cc"],
|
||||
deps = [
|
||||
":convert_operand_folding",
|
||||
":hlo_matchers",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
59
tensorflow/compiler/xla/service/convert_operand_folding.cc
Normal file
59
tensorflow/compiler/xla/service/convert_operand_folding.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/convert_operand_folding.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
bool IsUpcastConvert(const HloInstruction* hlo) {
|
||||
return hlo->opcode() == HloOpcode::kConvert &&
|
||||
ShapeUtil::CanUpcastIntegral(hlo->operand(0)->shape(), hlo->shape()) &&
|
||||
ShapeUtil::EqualIgnoringElementType(hlo->operand(0)->shape(),
|
||||
hlo->shape());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ConvertOperandFolding::InstructionMatchesPattern(
|
||||
HloInstruction* instruction) {
|
||||
if (!ShapeUtil::ElementIsIntegral(instruction->shape())) {
|
||||
return false;
|
||||
}
|
||||
if (instruction->opcode() != HloOpcode::kDot &&
|
||||
instruction->opcode() != HloOpcode::kConvolution) {
|
||||
return false;
|
||||
}
|
||||
for (auto* operand : instruction->operands()) {
|
||||
if (IsUpcastConvert(operand)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> ConvertOperandFolding::ExpandInstruction(
|
||||
HloInstruction* instruction) {
|
||||
for (int i = 0; i < instruction->operand_count(); ++i) {
|
||||
auto* operand = instruction->mutable_operand(i);
|
||||
if (IsUpcastConvert(operand)) {
|
||||
TF_RETURN_IF_ERROR(instruction->ReplaceOperandWithDifferentShape(
|
||||
i, operand->mutable_operand(0)));
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace xla
|
41
tensorflow/compiler/xla/service/convert_operand_folding.h
Normal file
41
tensorflow/compiler/xla/service/convert_operand_folding.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Folds Convert operands to wider types into instructions that supports wider
|
||||
// result accumulation than the shape inference type.
|
||||
//
|
||||
// e.g. s32 hlo(s32 convert(s8), s32 convert(s8)) -> s32 hlo(s8, s8)
|
||||
class ConvertOperandFolding : public OpExpanderPass {
|
||||
public:
|
||||
absl::string_view name() const override { return "convert_operand_folding"; }
|
||||
|
||||
protected:
|
||||
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_
|
128
tensorflow/compiler/xla/service/convert_operand_folding_test.cc
Normal file
128
tensorflow/compiler/xla/service/convert_operand_folding_test.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/convert_operand_folding.h"
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
using ConvertOperandFoldingTest = HloTestBase;
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, UpcastConvertFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = s8[2,3]{1,0} parameter(0)
|
||||
p1 = s16[3,2]{0,1} parameter(1)
|
||||
c0 = s16[2,3]{1,0} convert(p0)
|
||||
c1 = s16[3,2]{0,1} convert(p1)
|
||||
ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool folded,
|
||||
ConvertOperandFolding().Run(module.get()));
|
||||
EXPECT_TRUE(folded);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
AllOf(op::Dot(op::Parameter(0), op::Parameter(1)),
|
||||
op::Shape("s16[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = s32[2,3]{1,0} parameter(0)
|
||||
p1 = s16[3,2]{0,1} parameter(1)
|
||||
c0 = s16[2,3]{1,0} convert(p0)
|
||||
c1 = s8[3,2]{0,1} convert(p1)
|
||||
ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool folded,
|
||||
ConvertOperandFolding().Run(module.get()));
|
||||
EXPECT_FALSE(folded);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
AllOf(
|
||||
op::Dot(
|
||||
AllOf(op::Convert(op::Parameter(0)), op::Shape("s16[2,3]{1,0}")),
|
||||
AllOf(op::Convert(op::Parameter(1)), op::Shape("s8[3,2]{0,1}"))),
|
||||
op::Shape("s16[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, LayoutChangingConvertNotFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = s8[2,3]{1,0} parameter(0)
|
||||
p1 = s16[3,2]{0,1} parameter(1)
|
||||
c0 = s16[2,3]{0,1} convert(p0)
|
||||
c1 = s16[3,2]{1,0} convert(p1)
|
||||
ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool folded,
|
||||
ConvertOperandFolding().Run(module.get()));
|
||||
EXPECT_FALSE(folded);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
AllOf(
|
||||
op::Dot(
|
||||
AllOf(op::Convert(op::Parameter(0)), op::Shape("s16[2,3]{0,1}")),
|
||||
AllOf(op::Convert(op::Parameter(1)), op::Shape("s16[3,2]{1,0}"))),
|
||||
op::Shape("s16[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, OneOperandFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = s8[2,3]{1,0} parameter(0)
|
||||
p1 = s16[3,2]{0,1} parameter(1)
|
||||
c0 = s16[2,3]{1,0} convert(p0)
|
||||
c1 = s8[3,2]{0,1} convert(p1)
|
||||
ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool folded,
|
||||
ConvertOperandFolding().Run(module.get()));
|
||||
EXPECT_TRUE(folded);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
AllOf(op::Dot(op::Parameter(0), AllOf(op::Convert(op::Parameter(1)),
|
||||
op::Shape("s8[3,2]{0,1}"))),
|
||||
op::Shape("s16[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user