[XLA] Allow integral dot results to be accumulated as a wider type.
PiperOrigin-RevId: 342923991 Change-Id: Ic3fee996c9a1c5fd10281b65264a2cef4c0cfcb4
This commit is contained in:
parent
4fe5106a8f
commit
aec80b2420
@ -5223,3 +5223,27 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "integral_upcaster",
|
||||
srcs = ["integral_upcaster.cc"],
|
||||
hdrs = ["integral_upcaster.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":op_expander_pass",
|
||||
":shape_inference",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "integral_upcaster_test",
|
||||
srcs = ["integral_upcaster_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":integral_upcaster",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -1800,6 +1800,8 @@ StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
|
||||
: dot->mutable_operand(1);
|
||||
TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums,
|
||||
dot->precision_config()));
|
||||
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
|
||||
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
|
||||
if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
|
||||
TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
|
||||
} else {
|
||||
@ -4678,6 +4680,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config()));
|
||||
dot->SetupDerivedInstruction(new_dot);
|
||||
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
|
||||
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
|
||||
if (reduce_dims.empty()) {
|
||||
return ReplaceInstruction(hlo, new_dot);
|
||||
}
|
||||
|
@ -169,6 +169,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||
"//tensorflow/compiler/xla/service:slice_sinker",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla/service:integral_upcaster",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
@ -102,6 +102,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/logistic_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
@ -269,6 +270,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
HloPassPipeline pipeline("HLO passes through layout assignment");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
pipeline.AddPass<IntegralUpcaster>();
|
||||
|
||||
// Expand random number generation.
|
||||
pipeline.AddPass<RngExpander>();
|
||||
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
||||
|
@ -1238,6 +1238,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_util",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:integral_upcaster",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:logistic_expander",
|
||||
"//tensorflow/compiler/xla/service:loop_schedule_linearizer",
|
||||
|
@ -90,6 +90,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/logistic_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
|
||||
@ -142,6 +143,8 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
pipeline.AddPass<IntegralUpcaster>();
|
||||
|
||||
// Expand random number generation.
|
||||
pipeline.AddPass<RngExpander>();
|
||||
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
||||
|
@ -4567,5 +4567,46 @@ TEST_F(HloEvaluatorTest, MapBF16) {
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, DotUpcast) {
|
||||
const absl::string_view hlo_text = R"(
|
||||
HloModule test
|
||||
ENTRY DotUpcast {
|
||||
l = s16[4,3]{1,0} parameter(0)
|
||||
r = s8[3,2]{1,0} parameter(1)
|
||||
ROOT result = s32[4,2] dot(l, r), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
}
|
||||
)";
|
||||
// lhs:
|
||||
// s16[4,3] {
|
||||
// { 1, 2, 3 },
|
||||
// { 5, 6, 7 },
|
||||
// { 9, 10, 11 },
|
||||
// { 13, 14, 15 },
|
||||
// }
|
||||
auto lhs_array = absl::make_unique<Array2D<int16>>(4, 3);
|
||||
lhs_array->FillUnique(1);
|
||||
auto lhs_literal = LiteralUtil::CreateR2FromArray2D<int16>(*lhs_array);
|
||||
|
||||
// rhs:
|
||||
// s8[3,2] {
|
||||
// { 1, 2 },
|
||||
// { 3, 4 },
|
||||
// { 5, 6 },
|
||||
// }
|
||||
auto rhs_array = absl::make_unique<Array2D<int8>>(3, 2);
|
||||
rhs_array->FillUnique(1);
|
||||
auto rhs_literal = LiteralUtil::CreateR2FromArray2D<int8>(*rhs_array);
|
||||
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result,
|
||||
Evaluate({&lhs_literal, &rhs_literal}));
|
||||
|
||||
auto expected_array =
|
||||
Array2D<int32>({{22, 28}, {58, 76}, {94, 124}, {130, 172}});
|
||||
auto expected = LiteralUtil::CreateR2FromArray2D<int32>(expected_array);
|
||||
|
||||
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -136,20 +136,26 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleDot(HloInstruction* dot) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape expected,
|
||||
TF_ASSIGN_OR_RETURN(Shape expected,
|
||||
ShapeInference::InferDotOpShape(
|
||||
dot->operand(0)->shape(), dot->operand(1)->shape(),
|
||||
dot->dot_dimension_numbers()));
|
||||
if (ShapeUtil::CanUpcastIntegral(expected, dot->shape())) {
|
||||
expected.set_element_type(dot->shape().element_type());
|
||||
}
|
||||
return CheckShape(dot, expected);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const Shape expected,
|
||||
Shape expected,
|
||||
ShapeInference::InferConvolveShape(
|
||||
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
|
||||
convolution->feature_group_count(), convolution->batch_group_count(),
|
||||
convolution->window(), convolution->convolution_dimension_numbers()));
|
||||
if (ShapeUtil::CanUpcastIntegral(expected, convolution->shape())) {
|
||||
expected.set_element_type(convolution->shape().element_type());
|
||||
}
|
||||
return CheckShape(convolution, expected);
|
||||
}
|
||||
|
||||
|
76
tensorflow/compiler/xla/service/integral_upcaster.cc
Normal file
76
tensorflow/compiler/xla/service/integral_upcaster.cc
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
StatusOr<absl::optional<Shape>> MaybeInferShape(
|
||||
const HloInstruction* instruction) {
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kDot:
|
||||
return ShapeInference::InferDotOpShape(
|
||||
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
|
||||
instruction->dot_dimension_numbers());
|
||||
case HloOpcode::kConvolution:
|
||||
return ShapeInference::InferConvolveShape(
|
||||
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
|
||||
instruction->feature_group_count(), instruction->batch_group_count(),
|
||||
instruction->window(), instruction->convolution_dimension_numbers());
|
||||
default:
|
||||
return absl::make_optional<Shape>();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool IntegralUpcaster::InstructionMatchesPattern(HloInstruction* instruction) {
|
||||
if (!ShapeUtil::ElementIsIntegral(instruction->shape())) {
|
||||
return false;
|
||||
}
|
||||
auto status_or_inferred_shape = MaybeInferShape(instruction);
|
||||
if (!status_or_inferred_shape.ok() ||
|
||||
!status_or_inferred_shape->has_value()) {
|
||||
return false;
|
||||
}
|
||||
const Shape& inferred_shape = status_or_inferred_shape.ValueOrDie().value();
|
||||
|
||||
return inferred_shape.element_type() != instruction->shape().element_type() &&
|
||||
ShapeUtil::CanUpcastIntegral(inferred_shape, instruction->shape());
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> IntegralUpcaster::ExpandInstruction(
|
||||
HloInstruction* instruction) {
|
||||
auto* computation = instruction->parent();
|
||||
auto type = instruction->shape().element_type();
|
||||
for (int i = 0; i < instruction->operand_count(); ++i) {
|
||||
auto* operand = instruction->mutable_operand(i);
|
||||
if (operand->shape().element_type() == type) {
|
||||
continue;
|
||||
}
|
||||
auto upcast_shape = operand->shape();
|
||||
upcast_shape.set_element_type(type);
|
||||
auto* convert_inst = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(upcast_shape, operand));
|
||||
TF_RETURN_IF_ERROR(
|
||||
instruction->ReplaceOperandWithDifferentShape(i, convert_inst));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace xla
|
39
tensorflow/compiler/xla/service/integral_upcaster.h
Normal file
39
tensorflow/compiler/xla/service/integral_upcaster.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Inserts Convert to integral operands of instructions that allows result
|
||||
// accumulation as wider integral types.
|
||||
class IntegralUpcaster : public OpExpanderPass {
|
||||
public:
|
||||
absl::string_view name() const override { return "integral_upcaster"; }
|
||||
|
||||
protected:
|
||||
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTEGRAL_UPCASTER_H_
|
99
tensorflow/compiler/xla/service/integral_upcaster_test.cc
Normal file
99
tensorflow/compiler/xla/service/integral_upcaster_test.cc
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
class IntegralUpcasterTest
|
||||
: public HloTestBase,
|
||||
public ::testing::WithParamInterface<
|
||||
std::tuple<PrimitiveType, PrimitiveType, PrimitiveType>> {};
|
||||
|
||||
bool ShouldUpcast(PrimitiveType operand_type, PrimitiveType result_type) {
|
||||
return primitive_util::BitWidth(operand_type) <
|
||||
primitive_util::BitWidth(result_type);
|
||||
}
|
||||
|
||||
TEST_P(IntegralUpcasterTest, ConvertInserted) {
|
||||
PrimitiveType lhs_type, rhs_type, result_type;
|
||||
std::tie(lhs_type, rhs_type, result_type) = GetParam();
|
||||
absl::string_view module_tmpl = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = $0[2,3]{1,0} parameter(0)
|
||||
p1 = $1[3,2]{1,0} parameter(1)
|
||||
ROOT dot = $2[2,2]{1,0} dot(p0, p1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
auto module_string = absl::Substitute(
|
||||
module_tmpl, primitive_util::LowercasePrimitiveTypeName(lhs_type),
|
||||
primitive_util::LowercasePrimitiveTypeName(rhs_type),
|
||||
primitive_util::LowercasePrimitiveTypeName(result_type));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool upcasted, IntegralUpcaster().Run(module.get()));
|
||||
EXPECT_EQ(upcasted, ShouldUpcast(lhs_type, result_type) ||
|
||||
ShouldUpcast(rhs_type, result_type));
|
||||
auto original_lhs = op::Parameter(0);
|
||||
auto original_rhs = op::Parameter(1);
|
||||
auto upcasted_lhs =
|
||||
ShouldUpcast(lhs_type, result_type)
|
||||
? AllOf(op::Convert(original_lhs),
|
||||
op::Shape(absl::Substitute(
|
||||
"$0[2,3]{1,0}",
|
||||
primitive_util::LowercasePrimitiveTypeName(result_type))))
|
||||
: original_lhs;
|
||||
auto upcasted_rhs =
|
||||
ShouldUpcast(rhs_type, result_type)
|
||||
? AllOf(op::Convert(original_rhs),
|
||||
op::Shape(absl::Substitute(
|
||||
"$0[3,2]{1,0}",
|
||||
primitive_util::LowercasePrimitiveTypeName(result_type))))
|
||||
: original_rhs;
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
AllOf(op::Dot(upcasted_lhs, upcasted_rhs),
|
||||
op::Shape(absl::Substitute(
|
||||
"$0[2,2]{1,0}",
|
||||
primitive_util::LowercasePrimitiveTypeName(result_type)))));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(S16U16, IntegralUpcasterTest,
|
||||
::testing::Values(std::make_tuple(S8, S8, S16),
|
||||
std::make_tuple(U8, U8, U16)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(S32, IntegralUpcasterTest,
|
||||
::testing::Combine(::testing::Values(S8, S16),
|
||||
::testing::Values(S8, S16),
|
||||
::testing::Values(S32)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(U32, IntegralUpcasterTest,
|
||||
::testing::Combine(::testing::Values(U8, U16),
|
||||
::testing::Values(U8, U16),
|
||||
::testing::Values(U32)));
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
@ -1633,4 +1633,13 @@ Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
/*static*/ bool ShapeUtil::CanUpcastIntegral(const Shape& from,
|
||||
const Shape& to) {
|
||||
return ElementIsIntegral(from) && ElementIsIntegral(to) &&
|
||||
ElementIsSigned(from) == ElementIsSigned(to) &&
|
||||
primitive_util::BitWidth(from.element_type()) <=
|
||||
primitive_util::BitWidth(to.element_type()) &&
|
||||
CompatibleIgnoringElementType(from, to);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -787,6 +787,10 @@ class ShapeUtil {
|
||||
// information, from a shape.
|
||||
static Shape DeviceShapeToHostShape(Shape s);
|
||||
|
||||
// Returns true iff integral shape `from` can be safely upcasted to integral
|
||||
// shape `to`.
|
||||
static bool CanUpcastIntegral(const Shape& from, const Shape& to);
|
||||
|
||||
private:
|
||||
// Validates the shape size is sane. This makes sure it's safe to do
|
||||
// calculations in int64 without overflowing.
|
||||
|
@ -1753,6 +1753,22 @@ XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) {
|
||||
ComputeAndCompare(&builder, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule WiderIntegralAccumulation
|
||||
|
||||
ENTRY MatrixVectorComplex {
|
||||
p0 = s8[5,5]{1,0} parameter(0)
|
||||
p1 = s16[5,1]{0,1} parameter(1)
|
||||
ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
|
||||
}
|
||||
|
||||
// This benchmark is to show the performance impact of the following
|
||||
// transformation:
|
||||
// dot(reshape(transpose(A)), Const) ==>
|
||||
|
Loading…
x
Reference in New Issue
Block a user