Correctly handle S32 and U32 in the AR-CRS combiner

Some of the transforms we are doing are only valid on floating point types
so we have to condition them on the element type of the operations.

PiperOrigin-RevId: 227836047
This commit is contained in:
A. Unique TensorFlower 2019-01-04 04:10:23 -08:00 committed by TensorFlower Gardener
parent 6965d80c13
commit 8e4ca33c8b
2 changed files with 20 additions and 5 deletions

View File

@ -3678,6 +3678,7 @@ cc_library(
":pattern_matcher", ":pattern_matcher",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -44,11 +45,24 @@ bool MatchesArCrsPattern(HloInstruction* instruction) {
if (instruction->user_count() != 1) { if (instruction->user_count() != 1) {
return false; return false;
} }
auto opcode = instruction->opcode(); switch (instruction->opcode()) {
return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose || case HloOpcode::kBitcast:
opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert || case HloOpcode::kTranspose:
opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract || case HloOpcode::kReshape:
opcode == HloOpcode::kMultiply; return true;
case HloOpcode::kConvert:
// Can be moved across if both input and output is either float or
// integer (e.g. S32<->U32 or F32<->BF16)
return ShapeUtil::ElementIsFloating(instruction->shape()) ==
ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
case HloOpcode::kAdd:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply:
// Only supported for floating point operands.
return ShapeUtil::ElementIsFloating(instruction->shape());
default:
return false;
}
}; };
auto computation_is_addition = [](HloComputation* c) { auto computation_is_addition = [](HloComputation* c) {