From 8e4ca33c8b5861526bcdc43eb1af6d73e1670e92 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Jan 2019 04:10:23 -0800 Subject: [PATCH] Correctly handle S32 and U32 in the AR-CRS combiner Some of the transforms we are doing are only valid on floating point types so we have to condition them on the element type of the operations. PiperOrigin-RevId: 227836047 --- tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/ar_crs_combiner.cc | 24 +++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b84792cfc3b..755c477a121 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3678,6 +3678,7 @@ cc_library( ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc index 47d2c7e3570..4a227d3b5c7 100644 --- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -44,11 +45,24 @@ bool MatchesArCrsPattern(HloInstruction* instruction) { if (instruction->user_count() != 1) { return false; } - auto opcode = instruction->opcode(); - return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose || - opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert || - opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract || - opcode == HloOpcode::kMultiply; + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + return true; + case HloOpcode::kConvert: + // Can be moved across if both input and output is either float or + // integer (e.g. S32<->U32 or F32<->BF16) + return ShapeUtil::ElementIsFloating(instruction->shape()) == + ShapeUtil::ElementIsFloating(instruction->operand(0)->shape()); + case HloOpcode::kAdd: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + // Only supported for floating point operands. + return ShapeUtil::ElementIsFloating(instruction->shape()); + default: + return false; + } }; auto computation_is_addition = [](HloComputation* c) {