[XLA] In HloEvaluator, fix an issue where the return type and native type are assumed to be the same for HandleImag and HandleReal, when in fact they should be float and complex64 (or float for HandleReal's case), respectively.
PiperOrigin-RevId: 214548051
This commit is contained in:
parent
7f1d70d97f
commit
6666516f39
@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleReal(HloInstruction* real) {
|
||||
auto operand = real->operand(0);
|
||||
switch (operand->shape().element_type()) {
|
||||
case BF16: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
|
||||
real, [](bfloat16 elem_operand) { return elem_operand; },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case C64: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
|
||||
real, [](complex64 elem_operand) { return std::real(elem_operand); },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case F16: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
|
||||
real, [](Eigen::half elem_operand) { return elem_operand; },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case F32: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<float, float>(
|
||||
real, [](float elem_operand) { return elem_operand; },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case F64: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<double, double>(
|
||||
real, [](double elem_operand) { return elem_operand; },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
|
||||
<< PrimitiveType_Name(operand->shape().element_type());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleImag(HloInstruction* imag) {
|
||||
auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
|
||||
imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
|
||||
GetEvaluatedLiteralFor(imag->operand(0)));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
|
||||
HloOpcode opcode = compare->opcode();
|
||||
auto lhs = compare->operand(0);
|
||||
|
@ -184,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
|
||||
Status HandleReal(HloInstruction* real) override;
|
||||
|
||||
Status HandleImag(HloInstruction* imag) override;
|
||||
|
||||
Status HandleReduce(HloInstruction* reduce) override;
|
||||
|
||||
// Returns the already-evaluated literal result for the instruction.
|
||||
|
@ -89,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
|
||||
// to this rule, notably:
|
||||
// - HandleCompare and HandleIsFinite: where the resulting literal type is
|
||||
// always boolean.
|
||||
// - HandleImag and HandleReal: where the resulting literal type is always float
|
||||
// and the operand is always complex, or real in the case of HandleReal.
|
||||
// These operations are handled outside of the parent HloEvaluator handlers
|
||||
// instead of from within TypedVisitor.
|
||||
//
|
||||
@ -329,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleFloor<ReturnT>(floor);
|
||||
}
|
||||
|
||||
Status HandleImag(HloInstruction* imag) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag],
|
||||
ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) {
|
||||
return std::imag(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleLog(HloInstruction* log) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
|
||||
ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
|
||||
@ -684,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleReal(HloInstruction* real) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[real],
|
||||
ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) {
|
||||
return std::real(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||
NativeT>::value>::type* = nullptr>
|
||||
Status HandleRemainder(HloInstruction* remainder) {
|
||||
|
Loading…
Reference in New Issue
Block a user