[XLA] IntegralUpcaster->OperandUpcaster to allow auto upcasting floating point types.

PiperOrigin-RevId: 349445605
Change-Id: I87d081c7f784155ffb885fcc3561bcc46a32c163
This commit is contained in:
Ce Zheng 2020-12-29 11:10:09 -08:00 committed by TensorFlower Gardener
parent aa92d1d7a9
commit 9ac993ae57
12 changed files with 99 additions and 75 deletions

View File

@ -5222,9 +5222,9 @@ tf_cc_test(
)
cc_library(
name = "integral_upcaster",
srcs = ["integral_upcaster.cc"],
hdrs = ["integral_upcaster.h"],
name = "operand_upcaster",
srcs = ["operand_upcaster.cc"],
hdrs = ["operand_upcaster.h"],
deps = [
":hlo",
":op_expander_pass",
@ -5233,11 +5233,11 @@ cc_library(
)
tf_cc_test(
name = "integral_upcaster_test",
srcs = ["integral_upcaster_test.cc"],
name = "operand_upcaster_test",
srcs = ["operand_upcaster_test.cc"],
deps = [
":hlo_matchers",
":integral_upcaster",
":operand_upcaster",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",

View File

@ -20,18 +20,19 @@ namespace {
bool IsUpcastConvert(const HloInstruction* hlo) {
return hlo->opcode() == HloOpcode::kConvert &&
ShapeUtil::CanUpcastIntegral(hlo->operand(0)->shape(), hlo->shape()) &&
ShapeUtil::EqualIgnoringElementType(hlo->operand(0)->shape(),
hlo->shape());
ShapeUtil::ElementIsFloating(hlo->shape()) ==
ShapeUtil::ElementIsFloating(hlo->operand(0)->shape()) &&
ShapeUtil::ElementIsSigned(hlo->shape()) ==
ShapeUtil::ElementIsSigned(hlo->operand(0)->shape()) &&
ShapeUtil::HigherPrecisionElementType(hlo->operand(0)->shape(),
hlo->shape()) ==
hlo->shape().element_type();
}
} // namespace
bool ConvertOperandFolding::InstructionMatchesPattern(
HloInstruction* instruction) {
if (!ShapeUtil::ElementIsIntegral(instruction->shape())) {
return false;
}
if (instruction->opcode() != HloOpcode::kDot &&
instruction->opcode() != HloOpcode::kConvolution) {
return false;

View File

@ -26,7 +26,7 @@ namespace op = ::xla::testing::opcode_matchers;
using ConvertOperandFoldingTest = HloTestBase;
TEST_F(ConvertOperandFoldingTest, UpcastConvertFolded) {
TEST_F(ConvertOperandFoldingTest, IntegralUpcastConvertFolded) {
absl::string_view module_string = R"(
HloModule module
@ -48,6 +48,54 @@ TEST_F(ConvertOperandFoldingTest, UpcastConvertFolded) {
op::Shape("s16[2,2]{1,0}")));
}
TEST_F(ConvertOperandFoldingTest, FloatingUpcastConvertFolded) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
p0 = f16[2,3]{1,0} parameter(0)
p1 = bf16[3,2]{0,1} parameter(1)
c0 = f32[2,3]{1,0} convert(p0)
c1 = f32[3,2]{0,1} convert(p1)
ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool folded,
ConvertOperandFolding().Run(module.get()));
EXPECT_TRUE(folded);
EXPECT_THAT(module->entry_computation()->root_instruction(),
AllOf(op::Dot(op::Parameter(0), op::Parameter(1)),
op::Shape("f32[2,2]{1,0}")));
}
TEST_F(ConvertOperandFoldingTest, IntegralToFloatingConvertNotFolded) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
p0 = s8[2,3]{1,0} parameter(0)
p1 = s16[3,2]{0,1} parameter(1)
c0 = f16[2,3]{1,0} convert(p0)
c1 = f32[3,2]{0,1} convert(p1)
ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool folded,
ConvertOperandFolding().Run(module.get()));
EXPECT_FALSE(folded);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
AllOf(
op::Dot(
AllOf(op::Convert(op::Parameter(0)), op::Shape("f16[2,3]{1,0}")),
AllOf(op::Convert(op::Parameter(1)), op::Shape("f32[3,2]{0,1}"))),
op::Shape("f32[2,2]{1,0}")));
}
TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) {
absl::string_view module_string = R"(
HloModule module
@ -74,32 +122,6 @@ TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) {
op::Shape("s16[2,2]{1,0}")));
}
TEST_F(ConvertOperandFoldingTest, LayoutChangingConvertNotFolded) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
p0 = s8[2,3]{1,0} parameter(0)
p1 = s16[3,2]{0,1} parameter(1)
c0 = s16[2,3]{0,1} convert(p0)
c1 = s16[3,2]{1,0} convert(p1)
ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
rhs_contracting_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool folded,
ConvertOperandFolding().Run(module.get()));
EXPECT_FALSE(folded);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
AllOf(
op::Dot(
AllOf(op::Convert(op::Parameter(0)), op::Shape("s16[2,3]{0,1}")),
AllOf(op::Convert(op::Parameter(1)), op::Shape("s16[3,2]{1,0}"))),
op::Shape("s16[2,2]{1,0}")));
}
TEST_F(ConvertOperandFoldingTest, OneOperandFolded) {
absl::string_view module_string = R"(
HloModule module

View File

@ -169,7 +169,7 @@ cc_library(
"//tensorflow/compiler/xla/service:comparison_expander",
"//tensorflow/compiler/xla/service:slice_sinker",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:integral_upcaster",
"//tensorflow/compiler/xla/service:operand_upcaster",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",

View File

@ -102,10 +102,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/logistic_expander.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
#include "tensorflow/compiler/xla/service/qr_expander.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
@ -271,7 +271,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<IntegralUpcaster>();
pipeline.AddPass<OperandUpcaster>();
// Expand random number generation.
pipeline.AddPass<RngExpander>();

View File

@ -1442,10 +1442,10 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto_util",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:integral_upcaster",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:logistic_expander",
"//tensorflow/compiler/xla/service:loop_schedule_linearizer",
"//tensorflow/compiler/xla/service:operand_upcaster",
"//tensorflow/compiler/xla/service:qr_expander",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",

View File

@ -94,10 +94,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/logistic_expander.h"
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
#include "tensorflow/compiler/xla/service/qr_expander.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
@ -149,7 +149,7 @@ Status GpuCompiler::OptimizeHloModule(
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<IntegralUpcaster>();
pipeline.AddPass<OperandUpcaster>();
// Expand random number generation.
pipeline.AddPass<RngExpander>();

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@ -35,28 +35,26 @@ StatusOr<absl::optional<Shape>> MaybeInferShape(
instruction->window(), instruction->convolution_dimension_numbers(),
/*preferred_element_type=*/absl::nullopt);
default:
return absl::make_optional<Shape>();
return absl::optional<Shape>(absl::nullopt);
}
}
} // namespace
bool IntegralUpcaster::InstructionMatchesPattern(HloInstruction* instruction) {
if (!ShapeUtil::ElementIsIntegral(instruction->shape())) {
return false;
}
bool OperandUpcaster::InstructionMatchesPattern(HloInstruction* instruction) {
auto status_or_inferred_shape = MaybeInferShape(instruction);
if (!status_or_inferred_shape.ok() ||
!status_or_inferred_shape->has_value()) {
return false;
}
const Shape& inferred_shape = status_or_inferred_shape.ValueOrDie().value();
return inferred_shape.element_type() != instruction->shape().element_type() &&
ShapeUtil::CanUpcastIntegral(inferred_shape, instruction->shape());
if (inferred_shape.element_type() == instruction->shape().element_type()) {
return false;
}
return ShapeUtil::ElementCanUpcast(inferred_shape, instruction->shape());
}
StatusOr<HloInstruction*> IntegralUpcaster::ExpandInstruction(
StatusOr<HloInstruction*> OperandUpcaster::ExpandInstruction(
HloInstruction* instruction) {
auto* computation = instruction->parent();
auto type = instruction->shape().element_type();

View File

@ -21,11 +21,11 @@ limitations under the License.
namespace xla {
// Inserts Convert to integral operands of instructions that allows result
// accumulation as wider integral types.
class IntegralUpcaster : public OpExpanderPass {
// Inserts Convert to operands of instructions that allows result accumulation
// as wider integral types.
class OperandUpcaster : public OpExpanderPass {
public:
absl::string_view name() const override { return "integral_upcaster"; }
absl::string_view name() const override { return "operand_upcaster"; }
protected:
bool InstructionMatchesPattern(HloInstruction* instruction) override;

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
#include "absl/strings/substitute.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@ -25,7 +25,7 @@ namespace {
namespace op = ::xla::testing::opcode_matchers;
class IntegralUpcasterTest
class OperandUpcasterTest
: public HloTestBase,
public ::testing::WithParamInterface<
std::tuple<PrimitiveType, PrimitiveType, PrimitiveType>> {};
@ -35,7 +35,7 @@ bool ShouldUpcast(PrimitiveType operand_type, PrimitiveType result_type) {
primitive_util::BitWidth(result_type);
}
TEST_P(IntegralUpcasterTest, ConvertInserted) {
TEST_P(OperandUpcasterTest, ConvertInserted) {
PrimitiveType lhs_type, rhs_type, result_type;
std::tie(lhs_type, rhs_type, result_type) = GetParam();
absl::string_view module_tmpl = R"(
@ -53,7 +53,7 @@ TEST_P(IntegralUpcasterTest, ConvertInserted) {
primitive_util::LowercasePrimitiveTypeName(result_type));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool upcasted, IntegralUpcaster().Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool upcasted, OperandUpcaster().Run(module.get()));
EXPECT_EQ(upcasted, ShouldUpcast(lhs_type, result_type) ||
ShouldUpcast(rhs_type, result_type));
auto original_lhs = op::Parameter(0);
@ -80,20 +80,25 @@ TEST_P(IntegralUpcasterTest, ConvertInserted) {
primitive_util::LowercasePrimitiveTypeName(result_type)))));
}
INSTANTIATE_TEST_SUITE_P(S16U16, IntegralUpcasterTest,
INSTANTIATE_TEST_SUITE_P(S16U16, OperandUpcasterTest,
::testing::Values(std::make_tuple(S8, S8, S16),
std::make_tuple(U8, U8, U16)));
INSTANTIATE_TEST_SUITE_P(S32, IntegralUpcasterTest,
INSTANTIATE_TEST_SUITE_P(S32, OperandUpcasterTest,
::testing::Combine(::testing::Values(S8, S16),
::testing::Values(S8, S16),
::testing::Values(S32)));
INSTANTIATE_TEST_SUITE_P(U32, IntegralUpcasterTest,
INSTANTIATE_TEST_SUITE_P(U32, OperandUpcasterTest,
::testing::Combine(::testing::Values(U8, U16),
::testing::Values(U8, U16),
::testing::Values(U32)));
INSTANTIATE_TEST_SUITE_P(F32, OperandUpcasterTest,
::testing::Combine(::testing::Values(BF16, F16),
::testing::Values(BF16, F16),
::testing::Values(F32)));
} // namespace
} // namespace xla

View File

@ -1711,13 +1711,11 @@ Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
return s;
}
/*static*/ bool ShapeUtil::CanUpcastIntegral(const Shape& from,
const Shape& to) {
return ElementIsIntegral(from) && ElementIsIntegral(to) &&
/*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
const Shape& to) {
return ElementIsFloating(from) == ElementIsFloating(to) &&
ElementIsSigned(from) == ElementIsSigned(to) &&
primitive_util::BitWidth(from.element_type()) <=
primitive_util::BitWidth(to.element_type()) &&
CompatibleIgnoringElementType(from, to);
HigherPrecisionElementType(from, to) == to.element_type();
}
} // namespace xla

View File

@ -787,9 +787,9 @@ class ShapeUtil {
// information, from a shape.
static Shape DeviceShapeToHostShape(Shape s);
// Returns true iff integral shape `from` can be safely upcasted to integral
// shape `to`.
static bool CanUpcastIntegral(const Shape& from, const Shape& to);
// Returns true iff element type of shape `from` can be safely upcasted to
// element type of shape `to`.
static bool ElementCanUpcast(const Shape& from, const Shape& to);
private:
// Fills *shape. Returns true on success.