Correctly handle S32 and U32 in the AR-CRS combiner
Some of the transforms we are doing are only valid on floating point types so we have to condition them on the element type of the operations. PiperOrigin-RevId: 227836047
This commit is contained in:
parent
6965d80c13
commit
8e4ca33c8b
@ -3678,6 +3678,7 @@ cc_library(
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -44,11 +45,24 @@ bool MatchesArCrsPattern(HloInstruction* instruction) {
|
||||
if (instruction->user_count() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto opcode = instruction->opcode();
|
||||
return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose ||
|
||||
opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert ||
|
||||
opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract ||
|
||||
opcode == HloOpcode::kMultiply;
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kReshape:
|
||||
return true;
|
||||
case HloOpcode::kConvert:
|
||||
// Can be moved across if both input and output is either float or
|
||||
// integer (e.g. S32<->U32 or F32<->BF16)
|
||||
return ShapeUtil::ElementIsFloating(instruction->shape()) ==
|
||||
ShapeUtil::ElementIsFloating(instruction->operand(0)->shape());
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kMultiply:
|
||||
// Only supported for floating point operands.
|
||||
return ShapeUtil::ElementIsFloating(instruction->shape());
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
auto computation_is_addition = [](HloComputation* c) {
|
||||
|
Loading…
Reference in New Issue
Block a user