[XLA] IntegralUpcaster->OperandUpcaster to allow auto upcasting floating point types.
PiperOrigin-RevId: 349445605 Change-Id: I87d081c7f784155ffb885fcc3561bcc46a32c163
This commit is contained in:
parent
aa92d1d7a9
commit
9ac993ae57
@ -5222,9 +5222,9 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "integral_upcaster",
|
||||
srcs = ["integral_upcaster.cc"],
|
||||
hdrs = ["integral_upcaster.h"],
|
||||
name = "operand_upcaster",
|
||||
srcs = ["operand_upcaster.cc"],
|
||||
hdrs = ["operand_upcaster.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":op_expander_pass",
|
||||
@ -5233,11 +5233,11 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "integral_upcaster_test",
|
||||
srcs = ["integral_upcaster_test.cc"],
|
||||
name = "operand_upcaster_test",
|
||||
srcs = ["operand_upcaster_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":integral_upcaster",
|
||||
":operand_upcaster",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
|
@ -20,18 +20,19 @@ namespace {
|
||||
|
||||
bool IsUpcastConvert(const HloInstruction* hlo) {
|
||||
return hlo->opcode() == HloOpcode::kConvert &&
|
||||
ShapeUtil::CanUpcastIntegral(hlo->operand(0)->shape(), hlo->shape()) &&
|
||||
ShapeUtil::EqualIgnoringElementType(hlo->operand(0)->shape(),
|
||||
hlo->shape());
|
||||
ShapeUtil::ElementIsFloating(hlo->shape()) ==
|
||||
ShapeUtil::ElementIsFloating(hlo->operand(0)->shape()) &&
|
||||
ShapeUtil::ElementIsSigned(hlo->shape()) ==
|
||||
ShapeUtil::ElementIsSigned(hlo->operand(0)->shape()) &&
|
||||
ShapeUtil::HigherPrecisionElementType(hlo->operand(0)->shape(),
|
||||
hlo->shape()) ==
|
||||
hlo->shape().element_type();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool ConvertOperandFolding::InstructionMatchesPattern(
|
||||
HloInstruction* instruction) {
|
||||
if (!ShapeUtil::ElementIsIntegral(instruction->shape())) {
|
||||
return false;
|
||||
}
|
||||
if (instruction->opcode() != HloOpcode::kDot &&
|
||||
instruction->opcode() != HloOpcode::kConvolution) {
|
||||
return false;
|
||||
|
@ -26,7 +26,7 @@ namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
using ConvertOperandFoldingTest = HloTestBase;
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, UpcastConvertFolded) {
|
||||
TEST_F(ConvertOperandFoldingTest, IntegralUpcastConvertFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
@ -48,6 +48,54 @@ TEST_F(ConvertOperandFoldingTest, UpcastConvertFolded) {
|
||||
op::Shape("s16[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, FloatingUpcastConvertFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = f16[2,3]{1,0} parameter(0)
|
||||
p1 = bf16[3,2]{0,1} parameter(1)
|
||||
c0 = f32[2,3]{1,0} convert(p0)
|
||||
c1 = f32[3,2]{0,1} convert(p1)
|
||||
ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool folded,
|
||||
ConvertOperandFolding().Run(module.get()));
|
||||
EXPECT_TRUE(folded);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
AllOf(op::Dot(op::Parameter(0), op::Parameter(1)),
|
||||
op::Shape("f32[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, IntegralToFloatingConvertNotFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = s8[2,3]{1,0} parameter(0)
|
||||
p1 = s16[3,2]{0,1} parameter(1)
|
||||
c0 = f16[2,3]{1,0} convert(p0)
|
||||
c1 = f32[3,2]{0,1} convert(p1)
|
||||
ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool folded,
|
||||
ConvertOperandFolding().Run(module.get()));
|
||||
EXPECT_FALSE(folded);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
AllOf(
|
||||
op::Dot(
|
||||
AllOf(op::Convert(op::Parameter(0)), op::Shape("f16[2,3]{1,0}")),
|
||||
AllOf(op::Convert(op::Parameter(1)), op::Shape("f32[3,2]{0,1}"))),
|
||||
op::Shape("f32[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
@ -74,32 +122,6 @@ TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) {
|
||||
op::Shape("s16[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, LayoutChangingConvertNotFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = s8[2,3]{1,0} parameter(0)
|
||||
p1 = s16[3,2]{0,1} parameter(1)
|
||||
c0 = s16[2,3]{0,1} convert(p0)
|
||||
c1 = s16[3,2]{1,0} convert(p1)
|
||||
ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
|
||||
rhs_contracting_dims={0}
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool folded,
|
||||
ConvertOperandFolding().Run(module.get()));
|
||||
EXPECT_FALSE(folded);
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
AllOf(
|
||||
op::Dot(
|
||||
AllOf(op::Convert(op::Parameter(0)), op::Shape("s16[2,3]{0,1}")),
|
||||
AllOf(op::Convert(op::Parameter(1)), op::Shape("s16[3,2]{1,0}"))),
|
||||
op::Shape("s16[2,2]{1,0}")));
|
||||
}
|
||||
|
||||
TEST_F(ConvertOperandFoldingTest, OneOperandFolded) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
@ -169,7 +169,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||
"//tensorflow/compiler/xla/service:slice_sinker",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla/service:integral_upcaster",
|
||||
"//tensorflow/compiler/xla/service:operand_upcaster",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
@ -102,10 +102,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/logistic_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/qr_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
|
||||
@ -271,7 +271,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
pipeline.AddPass<IntegralUpcaster>();
|
||||
pipeline.AddPass<OperandUpcaster>();
|
||||
|
||||
// Expand random number generation.
|
||||
pipeline.AddPass<RngExpander>();
|
||||
|
@ -1442,10 +1442,10 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_util",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:integral_upcaster",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:logistic_expander",
|
||||
"//tensorflow/compiler/xla/service:loop_schedule_linearizer",
|
||||
"//tensorflow/compiler/xla/service:operand_upcaster",
|
||||
"//tensorflow/compiler/xla/service:qr_expander",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||
|
@ -94,10 +94,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/logistic_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
|
||||
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/qr_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
|
||||
@ -149,7 +149,7 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
pipeline.AddPass<IntegralUpcaster>();
|
||||
pipeline.AddPass<OperandUpcaster>();
|
||||
|
||||
// Expand random number generation.
|
||||
pipeline.AddPass<RngExpander>();
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
|
||||
@ -35,28 +35,26 @@ StatusOr<absl::optional<Shape>> MaybeInferShape(
|
||||
instruction->window(), instruction->convolution_dimension_numbers(),
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
default:
|
||||
return absl::make_optional<Shape>();
|
||||
return absl::optional<Shape>(absl::nullopt);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool IntegralUpcaster::InstructionMatchesPattern(HloInstruction* instruction) {
|
||||
if (!ShapeUtil::ElementIsIntegral(instruction->shape())) {
|
||||
return false;
|
||||
}
|
||||
bool OperandUpcaster::InstructionMatchesPattern(HloInstruction* instruction) {
|
||||
auto status_or_inferred_shape = MaybeInferShape(instruction);
|
||||
if (!status_or_inferred_shape.ok() ||
|
||||
!status_or_inferred_shape->has_value()) {
|
||||
return false;
|
||||
}
|
||||
const Shape& inferred_shape = status_or_inferred_shape.ValueOrDie().value();
|
||||
|
||||
return inferred_shape.element_type() != instruction->shape().element_type() &&
|
||||
ShapeUtil::CanUpcastIntegral(inferred_shape, instruction->shape());
|
||||
if (inferred_shape.element_type() == instruction->shape().element_type()) {
|
||||
return false;
|
||||
}
|
||||
return ShapeUtil::ElementCanUpcast(inferred_shape, instruction->shape());
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> IntegralUpcaster::ExpandInstruction(
|
||||
StatusOr<HloInstruction*> OperandUpcaster::ExpandInstruction(
|
||||
HloInstruction* instruction) {
|
||||
auto* computation = instruction->parent();
|
||||
auto type = instruction->shape().element_type();
|
@ -21,11 +21,11 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Inserts Convert to integral operands of instructions that allows result
|
||||
// accumulation as wider integral types.
|
||||
class IntegralUpcaster : public OpExpanderPass {
|
||||
// Inserts Convert to operands of instructions that allows result accumulation
|
||||
// as wider integral types.
|
||||
class OperandUpcaster : public OpExpanderPass {
|
||||
public:
|
||||
absl::string_view name() const override { return "integral_upcaster"; }
|
||||
absl::string_view name() const override { return "operand_upcaster"; }
|
||||
|
||||
protected:
|
||||
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/integral_upcaster.h"
|
||||
#include "tensorflow/compiler/xla/service/operand_upcaster.h"
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
@ -25,7 +25,7 @@ namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
|
||||
class IntegralUpcasterTest
|
||||
class OperandUpcasterTest
|
||||
: public HloTestBase,
|
||||
public ::testing::WithParamInterface<
|
||||
std::tuple<PrimitiveType, PrimitiveType, PrimitiveType>> {};
|
||||
@ -35,7 +35,7 @@ bool ShouldUpcast(PrimitiveType operand_type, PrimitiveType result_type) {
|
||||
primitive_util::BitWidth(result_type);
|
||||
}
|
||||
|
||||
TEST_P(IntegralUpcasterTest, ConvertInserted) {
|
||||
TEST_P(OperandUpcasterTest, ConvertInserted) {
|
||||
PrimitiveType lhs_type, rhs_type, result_type;
|
||||
std::tie(lhs_type, rhs_type, result_type) = GetParam();
|
||||
absl::string_view module_tmpl = R"(
|
||||
@ -53,7 +53,7 @@ TEST_P(IntegralUpcasterTest, ConvertInserted) {
|
||||
primitive_util::LowercasePrimitiveTypeName(result_type));
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool upcasted, IntegralUpcaster().Run(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool upcasted, OperandUpcaster().Run(module.get()));
|
||||
EXPECT_EQ(upcasted, ShouldUpcast(lhs_type, result_type) ||
|
||||
ShouldUpcast(rhs_type, result_type));
|
||||
auto original_lhs = op::Parameter(0);
|
||||
@ -80,20 +80,25 @@ TEST_P(IntegralUpcasterTest, ConvertInserted) {
|
||||
primitive_util::LowercasePrimitiveTypeName(result_type)))));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(S16U16, IntegralUpcasterTest,
|
||||
INSTANTIATE_TEST_SUITE_P(S16U16, OperandUpcasterTest,
|
||||
::testing::Values(std::make_tuple(S8, S8, S16),
|
||||
std::make_tuple(U8, U8, U16)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(S32, IntegralUpcasterTest,
|
||||
INSTANTIATE_TEST_SUITE_P(S32, OperandUpcasterTest,
|
||||
::testing::Combine(::testing::Values(S8, S16),
|
||||
::testing::Values(S8, S16),
|
||||
::testing::Values(S32)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(U32, IntegralUpcasterTest,
|
||||
INSTANTIATE_TEST_SUITE_P(U32, OperandUpcasterTest,
|
||||
::testing::Combine(::testing::Values(U8, U16),
|
||||
::testing::Values(U8, U16),
|
||||
::testing::Values(U32)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(F32, OperandUpcasterTest,
|
||||
::testing::Combine(::testing::Values(BF16, F16),
|
||||
::testing::Values(BF16, F16),
|
||||
::testing::Values(F32)));
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
@ -1711,13 +1711,11 @@ Shape ShapeUtil::DeviceShapeToHostShape(Shape s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
/*static*/ bool ShapeUtil::CanUpcastIntegral(const Shape& from,
|
||||
const Shape& to) {
|
||||
return ElementIsIntegral(from) && ElementIsIntegral(to) &&
|
||||
/*static*/ bool ShapeUtil::ElementCanUpcast(const Shape& from,
|
||||
const Shape& to) {
|
||||
return ElementIsFloating(from) == ElementIsFloating(to) &&
|
||||
ElementIsSigned(from) == ElementIsSigned(to) &&
|
||||
primitive_util::BitWidth(from.element_type()) <=
|
||||
primitive_util::BitWidth(to.element_type()) &&
|
||||
CompatibleIgnoringElementType(from, to);
|
||||
HigherPrecisionElementType(from, to) == to.element_type();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -787,9 +787,9 @@ class ShapeUtil {
|
||||
// information, from a shape.
|
||||
static Shape DeviceShapeToHostShape(Shape s);
|
||||
|
||||
// Returns true iff integral shape `from` can be safely upcasted to integral
|
||||
// shape `to`.
|
||||
static bool CanUpcastIntegral(const Shape& from, const Shape& to);
|
||||
// Returns true iff element type of shape `from` can be safely upcasted to
|
||||
// element type of shape `to`.
|
||||
static bool ElementCanUpcast(const Shape& from, const Shape& to);
|
||||
|
||||
private:
|
||||
// Fills *shape. Returns true on success.
|
||||
|
Loading…
x
Reference in New Issue
Block a user