[XLA] Allow integral dot results to be accumulated as a wider type.

PiperOrigin-RevId: 342923991
Change-Id: Ic3fee996c9a1c5fd10281b65264a2cef4c0cfcb4
This commit is contained in:
Ce Zheng 2020-11-17 12:42:31 -08:00 committed by TensorFlower Gardener
parent 4fe5106a8f
commit aec80b2420
14 changed files with 329 additions and 2 deletions

View File

@ -5223,3 +5223,27 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
cc_library(
name = "integral_upcaster",
srcs = ["integral_upcaster.cc"],
hdrs = ["integral_upcaster.h"],
deps = [
":hlo",
":op_expander_pass",
":shape_inference",
],
)
tf_cc_test(
name = "integral_upcaster_test",
srcs = ["integral_upcaster_test.cc"],
deps = [
":hlo_matchers",
":integral_upcaster",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -1800,6 +1800,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
: dot->mutable_operand(1);
TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums,
dot->precision_config()));
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
} else {
@ -4678,6 +4680,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(
auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config()));
dot->SetupDerivedInstruction(new_dot);
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
if (reduce_dims.empty()) {
return ReplaceInstruction(hlo, new_dot);
}

View File

@ -169,6 +169,7 @@ cc_library(
"//tensorflow/compiler/xla/service:comparison_expander",
"//tensorflow/compiler/xla/service:slice_sinker",
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla/service:integral_upcaster",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",

View File

@ -102,6 +102,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/logistic_expander.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
@ -269,6 +270,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
HloPassPipeline pipeline("HLO passes through layout assignment");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<IntegralUpcaster>();
// Expand random number generation.
pipeline.AddPass<RngExpander>();
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);

View File

@ -1238,6 +1238,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_proto_util",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:integral_upcaster",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:logistic_expander",
"//tensorflow/compiler/xla/service:loop_schedule_linearizer",

View File

@ -90,6 +90,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/logistic_expander.h"
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
@ -142,6 +143,8 @@ Status GpuCompiler::OptimizeHloModule(
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<IntegralUpcaster>();
// Expand random number generation.
pipeline.AddPass<RngExpander>();
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);

View File

@ -4567,5 +4567,46 @@ TEST_F(HloEvaluatorTest, MapBF16) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
TEST_F(HloEvaluatorTest, DotUpcast) {
const absl::string_view hlo_text = R"(
HloModule test
ENTRY DotUpcast {
l = s16[4,3]{1,0} parameter(0)
r = s8[3,2]{1,0} parameter(1)
ROOT result = s32[4,2] dot(l, r), lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
// lhs:
// s16[4,3] {
// { 1, 2, 3 },
// { 5, 6, 7 },
// { 9, 10, 11 },
// { 13, 14, 15 },
// }
auto lhs_array = absl::make_unique<Array2D<int16>>(4, 3);
lhs_array->FillUnique(1);
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<int16>(*lhs_array);
// rhs:
// s8[3,2] {
// { 1, 2 },
// { 3, 4 },
// { 5, 6 },
// }
auto rhs_array = absl::make_unique<Array2D<int8>>(3, 2);
rhs_array->FillUnique(1);
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<int8>(*rhs_array);
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(Literal result,
Evaluate({&lhs_literal, &rhs_literal}));
auto expected_array =
Array2D<int32>({{22, 28}, {58, 76}, {94, 124}, {130, 172}});
auto expected = LiteralUtil::CreateR2FromArray2D<int32>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
} // namespace
} // namespace xla

View File

@ -136,20 +136,26 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
}
Status ShapeVerifier::HandleDot(HloInstruction* dot) {
TF_ASSIGN_OR_RETURN(const Shape expected,
TF_ASSIGN_OR_RETURN(Shape expected,
ShapeInference::InferDotOpShape(
dot->operand(0)->shape(), dot->operand(1)->shape(),
dot->dot_dimension_numbers()));
if (ShapeUtil::CanUpcastIntegral(expected, dot->shape())) {
expected.set_element_type(dot->shape().element_type());
}
return CheckShape(dot, expected);
}
Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
TF_ASSIGN_OR_RETURN(
const Shape expected,
Shape expected,
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
convolution->feature_group_count(), convolution->batch_group_count(),
convolution->window(), convolution->convolution_dimension_numbers()));
if (ShapeUtil::CanUpcastIntegral(expected, convolution->shape())) {
expected.set_element_type(convolution->shape().element_type());
}
return CheckShape(convolution, expected);
}

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
namespace xla {
namespace {
StatusOr<absl::optional<Shape>> MaybeInferShape(
const HloInstruction* instruction) {
switch (instruction->opcode()) {
case HloOpcode::kDot:
return ShapeInference::InferDotOpShape(
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
instruction->dot_dimension_numbers());
case HloOpcode::kConvolution:
return ShapeInference::InferConvolveShape(
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
instruction->feature_group_count(), instruction->batch_group_count(),
instruction->window(), instruction->convolution_dimension_numbers());
default:
return absl::make_optional<Shape>();
}
}
} // namespace
bool IntegralUpcaster::InstructionMatchesPattern(HloInstruction* instruction) {
if (!ShapeUtil::ElementIsIntegral(instruction->shape())) {
return false;
}
auto status_or_inferred_shape = MaybeInferShape(instruction);
if (!status_or_inferred_shape.ok() ||
!status_or_inferred_shape->has_value()) {
return false;
}
const Shape& inferred_shape = status_or_inferred_shape.ValueOrDie().value();
return inferred_shape.element_type() != instruction->shape().element_type() &&
ShapeUtil::CanUpcastIntegral(inferred_shape, instruction->shape());
}
StatusOr<HloInstruction*> IntegralUpcaster::ExpandInstruction(
HloInstruction* instruction) {
auto* computation = instruction->parent();
auto type = instruction->shape().element_type();
for (int i = 0; i < instruction->operand_count(); ++i) {
auto* operand = instruction->mutable_operand(i);
if (operand->shape().element_type() == type) {
continue;
}
auto upcast_shape = operand->shape();
upcast_shape.set_element_type(type);
auto* convert_inst = computation->AddInstruction(
HloInstruction::CreateConvert(upcast_shape, operand));
TF_RETURN_IF_ERROR(
instruction->ReplaceOperandWithDifferentShape(i, convert_inst));
}
return nullptr;
}
} // namespace xla

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla {
// Inserts Convert to integral operands of instructions that allows result
// accumulation as wider integral types.
class IntegralUpcaster : public OpExpanderPass {
public:
absl::string_view name() const override { return "integral_upcaster"; }
protected:
bool InstructionMatchesPattern(HloInstruction* instruction) override;
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
#include "absl/strings/substitute.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace {
namespace op = ::xla::testing::opcode_matchers;
class IntegralUpcasterTest
: public HloTestBase,
public ::testing::WithParamInterface<
std::tuple<PrimitiveType, PrimitiveType, PrimitiveType>> {};
bool ShouldUpcast(PrimitiveType operand_type, PrimitiveType result_type) {
return primitive_util::BitWidth(operand_type) <
primitive_util::BitWidth(result_type);
}
TEST_P(IntegralUpcasterTest, ConvertInserted) {
PrimitiveType lhs_type, rhs_type, result_type;
std::tie(lhs_type, rhs_type, result_type) = GetParam();
absl::string_view module_tmpl = R"(
HloModule module
ENTRY main {
p0 = $0[2,3]{1,0} parameter(0)
p1 = $1[3,2]{1,0} parameter(1)
ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
rhs_contracting_dims={0}
})";
auto module_string = absl::Substitute(
module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type),
primitive_util::LowercasePrimitiveTypeName(rhs_type),
primitive_util::LowercasePrimitiveTypeName(result_type));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool upcasted, IntegralUpcaster().Run(module.get()));
EXPECT_EQ(upcasted, ShouldUpcast(lhs_type, result_type) ||
ShouldUpcast(rhs_type, result_type));
auto original_lhs = op::Parameter(0);
auto original_rhs = op::Parameter(1);
auto upcasted_lhs =
ShouldUpcast(lhs_type, result_type)
? AllOf(op::Convert(original_lhs),
op::Shape(absl::Substitute(
"$0[2,3]{1,0}",
primitive_util::LowercasePrimitiveTypeName(result_type))))
: original_lhs;
auto upcasted_rhs =
ShouldUpcast(rhs_type, result_type)
? AllOf(op::Convert(original_rhs),
op::Shape(absl::Substitute(
"$0[3,2]{1,0}",
primitive_util::LowercasePrimitiveTypeName(result_type))))
: original_rhs;
EXPECT_THAT(
module->entry_computation()->root_instruction(),
AllOf(op::Dot(upcasted_lhs, upcasted_rhs),
op::Shape(absl::Substitute(
"$0[2,2]{1,0}",
primitive_util::LowercasePrimitiveTypeName(result_type)))));
}
INSTANTIATE_TEST_SUITE_P(S16U16, IntegralUpcasterTest,
::testing::Values(std::make_tuple(S8, S8, S16),
std::make_tuple(U8, U8, U16)));
INSTANTIATE_TEST_SUITE_P(S32, IntegralUpcasterTest,
::testing::Combine(::testing::Values(S8, S16),
::testing::Values(S8, S16),
::testing::Values(S32)));
INSTANTIATE_TEST_SUITE_P(U32, IntegralUpcasterTest,
::testing::Combine(::testing::Values(U8, U16),
::testing::Values(U8, U16),
::testing::Values(U32)));
} // namespace
} // namespace xla

View File

@ -1633,4 +1633,13 @@ Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
return s;
}
/*static*/ bool ShapeUtil::CanUpcastIntegral(const Shape& from,
const Shape& to) {
return ElementIsIntegral(from) && ElementIsIntegral(to) &&
ElementIsSigned(from) == ElementIsSigned(to) &&
primitive_util::BitWidth(from.element_type()) <=
primitive_util::BitWidth(to.element_type()) &&
CompatibleIgnoringElementType(from, to);
}
} // namespace xla

View File

@ -787,6 +787,10 @@ class ShapeUtil {
// information, from a shape.
static Shape DeviceShapeToHostShape(Shape s);
// Returns true iff integral shape `from` can be safely upcasted to integral
// shape `to`.
static bool CanUpcastIntegral(const Shape& from, const Shape& to);
private:
// Validates the shape size is sane. This makes sure it's safe to do
// calculations in int64 without overflowing.

View File

@ -1753,6 +1753,22 @@ XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) {
ComputeAndCompare(&builder, {}, error_spec_);
}
XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) {
absl::string_view hlo_string =
R"(
HloModule WiderIntegralAccumulation
ENTRY MatrixVectorComplex {
p0 = s8[5,5]{1,0} parameter(0)
p1 = s16[5,1]{0,1} parameter(1)
ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
}
// This benchmark is to show the performance impact of the following
// transformation:
// dot(reshape(transpose(A)), Const) ==>