Adding total-order comparison support in proto and HloInstruction.
Specifically a comparison type attribute is added to Hlo proto so that total order comparison can be explicitly specified. A comparison expander pass is added to all compilers to expand total order comparison into equivalent implementations through type conversion. PiperOrigin-RevId: 325820826 Change-Id: I7beceb2f751ddc0be7c6b7a74037e562e7580b62
This commit is contained in:
parent
9cd6a52394
commit
fd87e24980
@ -55,9 +55,13 @@ xla_test(
|
||||
cc_library(
|
||||
name = "comparators",
|
||||
srcs = ["comparators.cc"],
|
||||
hdrs = ["comparators.h"],
|
||||
hdrs = [
|
||||
"comparators.h",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
],
|
||||
deps = [
|
||||
":constants",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
@ -32,85 +32,13 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using XlaOpGenerator = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>);
|
||||
|
||||
XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
|
||||
int64 bit_width) {
|
||||
PrimitiveType signed_type;
|
||||
PrimitiveType unsigned_type;
|
||||
XlaOp max_value;
|
||||
switch (bit_width) {
|
||||
case 16:
|
||||
max_value =
|
||||
ConstantR0(value.builder(),
|
||||
static_cast<uint16>(std::numeric_limits<int16>::max()));
|
||||
signed_type = S16;
|
||||
unsigned_type = U16;
|
||||
break;
|
||||
case 32:
|
||||
max_value =
|
||||
ConstantR0(value.builder(),
|
||||
static_cast<uint32>(std::numeric_limits<int32>::max()));
|
||||
signed_type = S32;
|
||||
unsigned_type = U32;
|
||||
break;
|
||||
case 64:
|
||||
max_value =
|
||||
ConstantR0(value.builder(),
|
||||
static_cast<uint64>(std::numeric_limits<int64>::max()));
|
||||
signed_type = S64;
|
||||
unsigned_type = U64;
|
||||
break;
|
||||
default:
|
||||
return value.builder()->ReportError(
|
||||
InvalidArgument("Invalid bit width %lld for Comparator floating "
|
||||
"point parameter.",
|
||||
bit_width));
|
||||
}
|
||||
// Switch from a floating point value to a integer value in such a way that
|
||||
// when using the integer value to compare, we get the same result for normal
|
||||
// values, and -Nan is treated as the smallest value, and Nan is treated as
|
||||
// the largest value.
|
||||
// If f is a float, and
|
||||
// x = bit_cast<int32>(f);
|
||||
// y = x < 0 ? numeric_limits<int32>::max() - x : x;
|
||||
// then y is ordered as an int32 such that finite values have the obvious
|
||||
// order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
|
||||
// and end of the ordering.
|
||||
// Note that in order to avoid -x to overflow, we calculate
|
||||
// numeric_limits<int32>::max() - x as unsigned, and then convert back to
|
||||
// signed.
|
||||
auto signed_value = BitcastConvertType(value, signed_type);
|
||||
auto unsigned_value = BitcastConvertType(value, unsigned_type);
|
||||
auto flipped_value =
|
||||
BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
|
||||
auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type));
|
||||
return Select(is_negative, flipped_value, signed_value);
|
||||
}
|
||||
|
||||
void ConvertFloatingPoint(const PrimitiveType& operand_type, XlaOp* lhs_param,
|
||||
XlaOp* rhs_param) {
|
||||
if (primitive_util::IsFloatingPointType(operand_type)) {
|
||||
PrimitiveType compare_type = operand_type;
|
||||
// Special-case handling for BF16. We currently do not support direct
|
||||
// comparisons with BF16, so we convert to F32 and then use the F32
|
||||
// comparison logic.
|
||||
if (compare_type == BF16) {
|
||||
compare_type = F32;
|
||||
*lhs_param = ConvertElementType(*lhs_param, F32);
|
||||
*rhs_param = ConvertElementType(*rhs_param, F32);
|
||||
}
|
||||
int64 bit_width = primitive_util::BitWidth(compare_type);
|
||||
*lhs_param = BitcastConvertFloatingPointToIntegral(*lhs_param, bit_width);
|
||||
*rhs_param = BitcastConvertFloatingPointToIntegral(*rhs_param, bit_width);
|
||||
}
|
||||
}
|
||||
using XlaCompareOp = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>);
|
||||
|
||||
XlaComputation CreateScalarComparisonComputation(
|
||||
const string& name, const std::vector<PrimitiveType>& operand_types,
|
||||
XlaBuilder* builder, XlaOpGenerator generator) {
|
||||
XlaBuilder* builder, XlaCompareOp generator) {
|
||||
CHECK_NE(operand_types.size(), 0);
|
||||
std::vector<absl::optional<XlaOpGenerator>> generators(operand_types.size());
|
||||
std::vector<absl::optional<XlaCompareOp>> generators(operand_types.size());
|
||||
generators[0] = generator;
|
||||
return CreateScalarComparisonComputation(name, operand_types, generators,
|
||||
builder);
|
||||
@ -119,7 +47,7 @@ XlaComputation CreateScalarComparisonComputation(
|
||||
|
||||
XlaComputation CreateScalarComparisonComputation(
|
||||
const string& name, const std::vector<PrimitiveType>& operand_types,
|
||||
const std::vector<absl::optional<XlaOpGenerator>>& generators,
|
||||
const std::vector<absl::optional<XlaCompareOp>>& generators,
|
||||
XlaBuilder* builder) {
|
||||
// Create a default computation where we compare only the first two
|
||||
// parameters of type 'operand_types[0]'.
|
||||
@ -146,7 +74,6 @@ XlaComputation CreateScalarComparisonComputation(
|
||||
absl::StrCat("p.", parameter_count, ".lhs"));
|
||||
auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
|
||||
absl::StrCat("p.", parameter_count, ".rhs"));
|
||||
ConvertFloatingPoint(operand_type, &lhs_param, &rhs_param);
|
||||
lhs_params.emplace_back(lhs_param);
|
||||
rhs_params.emplace_back(rhs_param);
|
||||
if (generators[parameter_count].has_value()) {
|
||||
@ -169,7 +96,8 @@ XlaComputation CreateScalarComparisonComputation(
|
||||
generators[i].value()(lhs_params[i], rhs_params[i], {}),
|
||||
result);
|
||||
if (i != last_generator_index) {
|
||||
param_equal = And(param_equal, Eq(lhs_params[i], rhs_params[i]));
|
||||
param_equal =
|
||||
And(param_equal, EqTotalOrder(lhs_params[i], rhs_params[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -181,14 +109,14 @@ XlaComputation CreateScalarComparisonComputation(
|
||||
XlaComputation CreateScalarLtComputation(
|
||||
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
|
||||
return CreateScalarComparisonComputation("compare-less-than", operand_types,
|
||||
builder, Lt);
|
||||
builder, LtTotalOrder);
|
||||
}
|
||||
|
||||
// Creates a scalar greater-than computation and returns it.
|
||||
XlaComputation CreateScalarGtComputation(
|
||||
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
|
||||
return CreateScalarComparisonComputation("compare-greater-than",
|
||||
operand_types, builder, Gt);
|
||||
return CreateScalarComparisonComputation(
|
||||
"compare-greater-than", operand_types, builder, GtTotalOrder);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -43,14 +43,13 @@ XlaComputation CreateScalarGtComputation(
|
||||
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder);
|
||||
|
||||
// Creates a scalar comparison computation and returns it. This function takes
|
||||
// an std::vector<absl::optional<XlaOpGenerator>> and compare the operands
|
||||
// where the generator isn't nullopt with the specified comparator
|
||||
// at that location.
|
||||
// a vector of comparator functions to compare the operands where the function
|
||||
// isn't nullopt with the specified comparator at that location.
|
||||
XlaComputation CreateScalarComparisonComputation(
|
||||
const string& name, const std::vector<PrimitiveType>& operand_types,
|
||||
const std::vector<
|
||||
absl::optional<XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>)>>&
|
||||
generators,
|
||||
comparators,
|
||||
XlaBuilder* builder);
|
||||
|
||||
} // namespace xla
|
||||
|
@ -577,7 +577,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
|
||||
|
||||
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
absl::optional<ComparisonDirection> direction) {
|
||||
absl::optional<ComparisonDirection> direction,
|
||||
absl::optional<Comparison::Type> type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||
@ -635,7 +636,11 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
||||
return InvalidArgument(
|
||||
"kCompare expects a ComparisonDirection, but none provided.");
|
||||
}
|
||||
return Compare(shape, updated_lhs, updated_rhs, *direction);
|
||||
if (type == absl::nullopt) {
|
||||
return Compare(shape, updated_lhs, updated_rhs, *direction);
|
||||
} else {
|
||||
return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
|
||||
}
|
||||
}
|
||||
|
||||
if (direction.has_value()) {
|
||||
@ -658,8 +663,16 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||
ComparisonDirection direction) {
|
||||
return Compare(shape, lhs, rhs, direction,
|
||||
Comparison::DefaultComparisonType(shape.element_type()));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||
ComparisonDirection direction,
|
||||
Comparison::Type type) {
|
||||
HloInstructionProto instr;
|
||||
instr.set_comparison_direction(ComparisonDirectionToString(direction));
|
||||
instr.set_comparison_type(ComparisonTypeToString(type));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
|
||||
}
|
||||
@ -3512,31 +3525,71 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
||||
}
|
||||
|
||||
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
|
||||
compare_type);
|
||||
}
|
||||
|
||||
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
|
||||
}
|
||||
|
||||
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
|
||||
compare_type);
|
||||
}
|
||||
|
||||
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
|
||||
}
|
||||
|
||||
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
|
||||
compare_type);
|
||||
}
|
||||
|
||||
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
|
||||
}
|
||||
|
||||
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
|
||||
compare_type);
|
||||
}
|
||||
|
||||
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
|
||||
}
|
||||
|
||||
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
|
||||
compare_type);
|
||||
}
|
||||
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
||||
}
|
||||
|
||||
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
|
||||
Comparison::Type::kFloatTotalOrder);
|
||||
}
|
||||
|
||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction) {
|
||||
@ -3544,6 +3597,13 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
||||
broadcast_dimensions, direction);
|
||||
}
|
||||
|
||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction, Comparison::Type compare_type) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
|
||||
broadcast_dimensions, direction, compare_type);
|
||||
}
|
||||
|
||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
|
||||
return Compare(lhs, rhs, {}, direction);
|
||||
}
|
||||
|
@ -792,14 +792,17 @@ class XlaBuilder {
|
||||
// broadcast_dimensions specifies which dimensions to use for broadcasting
|
||||
// when the operation is between tensors of different ranks. The direction is
|
||||
// only used if opcode is kCompare.
|
||||
XlaOp BinaryOp(
|
||||
HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
absl::optional<Comparison::Direction> direction = absl::nullopt);
|
||||
XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
absl::optional<ComparisonDirection> direction = absl::nullopt,
|
||||
absl::optional<Comparison::Type> type = absl::nullopt);
|
||||
|
||||
// Internal helper method for binary op compare without broadcast dimensions.
|
||||
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||
Comparison::Direction direction);
|
||||
ComparisonDirection direction);
|
||||
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||
ComparisonDirection direction,
|
||||
Comparison::Type type);
|
||||
|
||||
// Internal helper method that does the building for an arbitrary binary op
|
||||
// with same ranked operands that doesn't broadcast.
|
||||
@ -965,22 +968,13 @@ class XlaBuilder {
|
||||
friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
|
||||
friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
|
||||
friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
|
||||
friend XlaOp Eq(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Ne(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Ge(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Gt(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Lt(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Le(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction);
|
||||
friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
|
||||
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction,
|
||||
Comparison::Type compare_type);
|
||||
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config);
|
||||
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||
@ -1574,29 +1568,44 @@ XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
|
||||
// Enqueues an equal-to comparison instruction onto the computation.
|
||||
XlaOp Eq(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a not-equal comparison instruction onto the computation.
|
||||
XlaOp Ne(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a greater-or-equal comparison instruction onto the computation.
|
||||
XlaOp Ge(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a greater-than comparison instruction onto the computation.
|
||||
XlaOp Gt(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a less-than comparison instruction onto the computation.
|
||||
XlaOp Lt(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a less-or-equal comparison instruction onto the computation.
|
||||
XlaOp Le(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
|
||||
// Enqueues a comparison instruction onto the computation (optionally without
|
||||
// broadcast_dimensions for consistency with others).
|
||||
XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction, Comparison::Type compare_type);
|
||||
XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> broadcast_dimensions,
|
||||
ComparisonDirection direction);
|
||||
|
@ -54,32 +54,59 @@ StatusOr<Comparison::Direction> StringToComparisonDirection(
|
||||
return it->second;
|
||||
}
|
||||
|
||||
Comparison::Comparison(Direction dir, PrimitiveType type) : dir_(dir) {
|
||||
StatusOr<Comparison::Type> StringToComparisonType(
|
||||
absl::string_view compare_type_name) {
|
||||
static auto* type_map = new absl::flat_hash_map<string, Comparison::Type>({
|
||||
{"FLOAT", Comparison::Type::kFloat},
|
||||
{"TOTALORDER", Comparison::Type::kFloatTotalOrder},
|
||||
{"SIGNED", Comparison::Type::kSigned},
|
||||
{"UNSIGNED", Comparison::Type::kUnsigned},
|
||||
});
|
||||
auto it = type_map->find(compare_type_name);
|
||||
if (it == type_map->end()) {
|
||||
return InvalidArgument("Unknown comparison type: %s", compare_type_name);
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::string ComparisonTypeToString(Comparison::Type type) {
|
||||
switch (type) {
|
||||
case Comparison::Type::kFloat:
|
||||
return "FLOAT";
|
||||
case Comparison::Type::kFloatTotalOrder:
|
||||
return "TOTALORDER";
|
||||
case Comparison::Type::kSigned:
|
||||
return "SIGNED";
|
||||
case Comparison::Type::kUnsigned:
|
||||
return "UNSIGNED";
|
||||
}
|
||||
}
|
||||
|
||||
Comparison::Comparison(Direction dir, PrimitiveType type)
|
||||
: dir_(dir), type_(DefaultComparisonType(type)) {}
|
||||
|
||||
Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) {
|
||||
switch (type) {
|
||||
case S8:
|
||||
case S16:
|
||||
case S32:
|
||||
case S64:
|
||||
type_ = Type::kSigned;
|
||||
break;
|
||||
return Type::kSigned;
|
||||
case PRED:
|
||||
case U8:
|
||||
case U16:
|
||||
case U32:
|
||||
case U64:
|
||||
type_ = Type::kUnsigned;
|
||||
break;
|
||||
return Type::kUnsigned;
|
||||
case F16:
|
||||
case F32:
|
||||
case BF16:
|
||||
case F64:
|
||||
case C64:
|
||||
case C128:
|
||||
type_ = Type::kFloat;
|
||||
break;
|
||||
return Type::kFloat;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported comparison mode."
|
||||
<< ComparisonDirectionToString(dir) << ":"
|
||||
<< PrimitiveType_Name(type) << "\n";
|
||||
}
|
||||
}
|
||||
@ -164,20 +191,6 @@ bool Comparison::IsAntireflexive() const {
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ const char* Comparison::ComparisonTypeToString(
|
||||
Comparison::Type type) {
|
||||
switch (type) {
|
||||
case Type::kFloat:
|
||||
return "f";
|
||||
case Type::kFloatTotalOrder:
|
||||
return "ft";
|
||||
case Type::kSigned:
|
||||
return "s";
|
||||
case Type::kUnsigned:
|
||||
return "u";
|
||||
}
|
||||
}
|
||||
|
||||
std::string Comparison::ToString(std::string prefix1,
|
||||
std::string prefix2) const {
|
||||
return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 +
|
||||
|
@ -103,11 +103,11 @@ class Comparison {
|
||||
bool Compare(const T a, const T b) const {
|
||||
return GetComparator<T>()(a, b);
|
||||
}
|
||||
static Type DefaultComparisonType(PrimitiveType t);
|
||||
|
||||
private:
|
||||
static Direction Converse(Direction dir);
|
||||
static Direction Inverse(Direction dir);
|
||||
static const char* ComparisonTypeToString(Type type);
|
||||
|
||||
const Direction dir_;
|
||||
Type type_;
|
||||
@ -117,10 +117,14 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) {
|
||||
return os << cmp.ToString();
|
||||
}
|
||||
string ComparisonDirectionToString(Comparison::Direction direction);
|
||||
std::string ComparisonTypeToString(Comparison::Type type);
|
||||
|
||||
StatusOr<Comparison::Direction> StringToComparisonDirection(
|
||||
absl::string_view direction_name);
|
||||
|
||||
StatusOr<Comparison::Type> StringToComparisonType(
|
||||
absl::string_view compare_type_name);
|
||||
|
||||
using ComparisonDirection = Comparison::Direction;
|
||||
|
||||
} // namespace xla
|
||||
|
@ -1235,7 +1235,10 @@ floating-point types.
|
||||
|
||||
Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
|
||||
(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
|
||||
(less-than).
|
||||
(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder,
|
||||
GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities,
|
||||
except that they additionally support a total order over the floating point
|
||||
numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN.
|
||||
|
||||
Arguments | Type | Semantics
|
||||
--------- | ------- | ----------------------------------------
|
||||
|
@ -112,6 +112,21 @@ xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) {
|
||||
}
|
||||
}
|
||||
|
||||
xla::PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth) {
|
||||
switch (src_bitwidth) {
|
||||
case 8:
|
||||
return xla::S8;
|
||||
case 16:
|
||||
return xla::S16;
|
||||
case 32:
|
||||
return xla::S32;
|
||||
case 64:
|
||||
return xla::S64;
|
||||
default:
|
||||
return xla::PRIMITIVE_TYPE_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
|
||||
switch (complex_type) {
|
||||
case C64:
|
||||
|
@ -153,6 +153,8 @@ int BitWidth(PrimitiveType type);
|
||||
|
||||
PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
|
||||
|
||||
PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth);
|
||||
|
||||
// Returns the real, imag component type underlying the given complex type.
|
||||
// LOG(FATAL)'s if complex_type is not complex.
|
||||
PrimitiveType ComplexComponentType(PrimitiveType complex_type);
|
||||
|
@ -1700,7 +1700,10 @@ cc_library(
|
||||
cc_library(
|
||||
name = "hlo_creation_utils",
|
||||
srcs = ["hlo_creation_utils.cc"],
|
||||
hdrs = ["hlo_creation_utils.h"],
|
||||
hdrs = [
|
||||
"hlo_creation_utils.h",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_module_config",
|
||||
@ -1816,6 +1819,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "comparison_expander",
|
||||
srcs = ["comparison_expander.cc"],
|
||||
hdrs = ["comparison_expander.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_creation_utils",
|
||||
":hlo_pass",
|
||||
":op_expander_pass",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scatter_expander",
|
||||
srcs = ["scatter_expander.cc"],
|
||||
|
133
tensorflow/compiler/xla/service/comparison_expander.cc
Normal file
133
tensorflow/compiler/xla/service/comparison_expander.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
HloInstruction* BitcastConvertFloatingPointToIntegral(
|
||||
HloComputation* computation, HloInstruction* value,
|
||||
const Shape& signed_shape, const Shape& unsigned_shape,
|
||||
HloInstruction* zero, HloInstruction* max_value) {
|
||||
// Switch from a floating point value to a integer value in such a way that
|
||||
// when using the integer value to compare, we get the same result for normal
|
||||
// values, and -Nan is treated as the smallest value, and Nan is treated as
|
||||
// the largest value.
|
||||
// If f is a float, and
|
||||
// x = bit_cast<int32>(f);
|
||||
// y = x < 0 ? numeric_limits<int32>::max() - x : x;
|
||||
// then y is ordered as an int32 such that finite values have the obvious
|
||||
// order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
|
||||
// and end of the ordering.
|
||||
// Note that in order to avoid -x to overflow, we calculate
|
||||
// numeric_limits<int32>::max() - x as unsigned, and then convert back to
|
||||
// signed.
|
||||
auto signed_value = computation->AddInstruction(
|
||||
HloInstruction::CreateBitcastConvert(signed_shape, value));
|
||||
auto unsigned_value = computation->AddInstruction(
|
||||
HloInstruction::CreateBitcastConvert(unsigned_shape, value));
|
||||
auto flipped_value = computation->AddInstruction(HloInstruction::CreateBinary(
|
||||
unsigned_shape, HloOpcode::kSubtract, max_value, unsigned_value));
|
||||
flipped_value = computation->AddInstruction(
|
||||
HloInstruction::CreateBitcastConvert(signed_shape, flipped_value));
|
||||
auto compare_shape = signed_shape;
|
||||
compare_shape.set_element_type(PRED);
|
||||
auto is_negative = computation->AddInstruction(HloInstruction::CreateCompare(
|
||||
compare_shape, signed_value, zero, ComparisonDirection::kLt));
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateTernary(signed_shape, HloOpcode::kSelect,
|
||||
is_negative, flipped_value, signed_value));
|
||||
}
|
||||
|
||||
bool ComparisonExpander::InstructionMatchesPattern(
|
||||
HloInstruction* instruction) {
|
||||
if (HloCompareInstruction* compare =
|
||||
dynamic_cast<HloCompareInstruction*>(instruction)) {
|
||||
HloInstruction* lhs = instruction->operands()[0];
|
||||
if (compare->type() == Comparison::Type::kFloatTotalOrder &&
|
||||
primitive_util::IsFloatingPointType(lhs->shape().element_type())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> ComparisonExpander::ExpandInstruction(
|
||||
HloInstruction* instruction) {
|
||||
CHECK(instruction->opcode() == HloOpcode::kCompare);
|
||||
HloCompareInstruction* compare =
|
||||
static_cast<HloCompareInstruction*>(instruction);
|
||||
CHECK(compare->type() == Comparison::Type::kFloatTotalOrder);
|
||||
HloComputation* computation = instruction->parent();
|
||||
HloInstruction* lhs = instruction->operands()[0];
|
||||
HloInstruction* rhs = instruction->operands()[1];
|
||||
Shape compare_shape = lhs->shape();
|
||||
PrimitiveType compare_type = compare_shape.element_type();
|
||||
CHECK(primitive_util::IsFloatingPointType(compare_type));
|
||||
// Special-case handling for BF16. We currently do not support direct
|
||||
// comparisons with BF16, so we convert to F32 and then use the F32
|
||||
// comparison logic.
|
||||
if (compare_type == BF16) {
|
||||
compare_type = F32;
|
||||
compare_shape.set_element_type(compare_type);
|
||||
lhs = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(compare_shape, lhs));
|
||||
rhs = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(compare_shape, rhs));
|
||||
}
|
||||
|
||||
int64 bit_width = primitive_util::BitWidth(compare_type);
|
||||
PrimitiveType signed_type =
|
||||
primitive_util::SignedIntegralTypeForBitWidth(bit_width);
|
||||
PrimitiveType unsigned_type =
|
||||
primitive_util::UnsignedIntegralTypeForBitWidth(bit_width);
|
||||
auto signed_shape = compare_shape;
|
||||
signed_shape.set_element_type(signed_type);
|
||||
auto unsigned_shape = compare_shape;
|
||||
unsigned_shape.set_element_type(unsigned_type);
|
||||
auto zero_value = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type)));
|
||||
zero_value = computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
signed_shape, zero_value, zero_value->shape().dimensions()));
|
||||
auto max_signed = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type)));
|
||||
auto max_shape = max_signed->shape();
|
||||
max_shape.set_element_type(unsigned_type);
|
||||
auto max_unsigned = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(max_shape, max_signed));
|
||||
auto max_value = computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
unsigned_shape, max_unsigned, max_shape.dimensions()));
|
||||
lhs = BitcastConvertFloatingPointToIntegral(
|
||||
computation, lhs, signed_shape, unsigned_shape, zero_value, max_value);
|
||||
rhs = BitcastConvertFloatingPointToIntegral(
|
||||
computation, rhs, signed_shape, unsigned_shape, zero_value, max_value);
|
||||
auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare(
|
||||
instruction->shape(), lhs, rhs, compare->direction(),
|
||||
Comparison::Type::kSigned));
|
||||
VLOG(2) << "New comparison instruction for total order:"
|
||||
<< new_compare->ToString() << "\n";
|
||||
return new_compare;
|
||||
}
|
||||
|
||||
} // namespace xla
|
47
tensorflow/compiler/xla/service/comparison_expander.h
Normal file
47
tensorflow/compiler/xla/service/comparison_expander.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A pass which performs expansion of the comparison operator to support total
|
||||
// order comparison of floating point numbers.
|
||||
class ComparisonExpander : public OpExpanderPass {
|
||||
public:
|
||||
explicit ComparisonExpander() = default;
|
||||
~ComparisonExpander() override = default;
|
||||
absl::string_view name() const override { return "comparison-expander"; }
|
||||
|
||||
private:
|
||||
// Returns `true` if `instruction` should be expanded by this pass.
|
||||
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||
// Returns a replacement for `instruction`, or nullptr if no replacement is
|
||||
// needed (e.g. only the to_apply subcomputation of the instruction was
|
||||
// modified).
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
|
@ -145,6 +145,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:conditional_to_select",
|
||||
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
||||
"//tensorflow/compiler/xla/service:scatter_expander",
|
||||
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||
"//tensorflow/compiler/xla/service:slice_sinker",
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
|
@ -54,6 +54,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
|
||||
@ -261,6 +262,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<ConditionalToSelect>();
|
||||
pipeline.AddPass<MapInliner>();
|
||||
|
||||
pipeline.AddPass<ComparisonExpander>();
|
||||
pipeline.AddPass<CholeskyExpander>();
|
||||
pipeline.AddPass<TriangularSolveExpander>();
|
||||
|
||||
|
@ -1168,6 +1168,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_4d_expander",
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
|
||||
@ -140,6 +141,9 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
pipeline.AddPass<RngExpander>();
|
||||
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
||||
|
||||
// Comparison total order expander
|
||||
pipeline.AddPass<ComparisonExpander>();
|
||||
|
||||
// Remove zero-sized HLO from the input so that other passes don't have to
|
||||
// handle it.
|
||||
pipeline.AddPass<ZeroSizedHloElimination>();
|
||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 72
|
||||
// Next ID: 73
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -248,6 +248,9 @@ message HloInstructionProto {
|
||||
|
||||
// RNG algorithm used by kRngBitGenerator.
|
||||
xla.RandomAlgorithm rng_algorithm = 70;
|
||||
|
||||
// The comparison type used for kCompare.
|
||||
string comparison_type = 72;
|
||||
}
|
||||
|
||||
// Serialization of HloComputation.
|
||||
|
@ -174,8 +174,19 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
comparison_direction,
|
||||
StringToComparisonDirection(proto.comparison_direction()));
|
||||
}
|
||||
instruction =
|
||||
CreateCompare(shape, operands(0), operands(1), *comparison_direction);
|
||||
auto comparison_type_str = proto.comparison_type();
|
||||
if (!comparison_type_str.empty()) {
|
||||
// If a comparison type is specified, it *must* be valid.
|
||||
TF_ASSIGN_OR_RETURN(auto comparison_type,
|
||||
StringToComparisonType(comparison_type_str));
|
||||
instruction = CreateCompare(shape, operands(0), operands(1),
|
||||
*comparison_direction, comparison_type);
|
||||
} else {
|
||||
// Allow the specify of comparison type to be optional.
|
||||
// The comparison type will be determined by the types of the operands.
|
||||
instruction = CreateCompare(shape, operands(0), operands(1),
|
||||
*comparison_direction);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTriangularSolve: {
|
||||
@ -926,8 +937,9 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
ComparisonDirection direction) {
|
||||
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
|
||||
ComparisonDirection direction, absl::optional<Comparison::Type> type) {
|
||||
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
|
||||
type);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
|
@ -595,7 +595,8 @@ class HloInstruction {
|
||||
// Creates a compare op, performing the comparison specified in direction.
|
||||
static std::unique_ptr<HloInstruction> CreateCompare(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
Comparison::Direction direction);
|
||||
Comparison::Direction direction,
|
||||
absl::optional<Comparison::Type> type = absl::nullopt);
|
||||
|
||||
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
|
||||
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
||||
|
@ -204,12 +204,13 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
|
||||
fft_length_);
|
||||
}
|
||||
|
||||
HloCompareInstruction::HloCompareInstruction(const Shape& shape,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs,
|
||||
ComparisonDirection direction)
|
||||
HloCompareInstruction::HloCompareInstruction(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
ComparisonDirection direction, absl::optional<Comparison::Type> type)
|
||||
: HloInstruction(HloOpcode::kCompare, shape),
|
||||
compare_(direction, lhs->shape().element_type()) {
|
||||
compare_(direction, type ? (*type)
|
||||
: Comparison::DefaultComparisonType(
|
||||
lhs->shape().element_type())) {
|
||||
AppendOperand(lhs);
|
||||
AppendOperand(rhs);
|
||||
}
|
||||
@ -218,12 +219,21 @@ HloInstructionProto HloCompareInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
proto.set_comparison_direction(
|
||||
ComparisonDirectionToString(compare_.GetDirection()));
|
||||
proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
return {StrCat("direction=", ComparisonDirectionToString(direction()))};
|
||||
std::vector<string> result;
|
||||
result.push_back(
|
||||
StrCat("direction=", ComparisonDirectionToString(direction())));
|
||||
if (compare_.GetType() !=
|
||||
Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
|
||||
result.push_back(
|
||||
StrCat("type=", ComparisonTypeToString(compare_.GetType())));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool HloCompareInstruction::IdenticalSlowPath(
|
||||
@ -238,8 +248,8 @@ std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
CHECK_EQ(new_operands.size(), 2);
|
||||
return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
|
||||
new_operands[1], direction());
|
||||
return absl::make_unique<HloCompareInstruction>(
|
||||
shape, new_operands[0], new_operands[1], direction(), type());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -136,8 +136,10 @@ class HloCompareInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
|
||||
HloInstruction* rhs,
|
||||
ComparisonDirection direction);
|
||||
ComparisonDirection direction,
|
||||
absl::optional<Comparison::Type> type);
|
||||
ComparisonDirection direction() const { return compare_.GetDirection(); }
|
||||
Comparison::Type type() const { return compare_.GetType(); }
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
private:
|
||||
|
@ -194,6 +194,7 @@ class HloParserImpl : public HloParser {
|
||||
kBracedHloComputationList,
|
||||
kFftType,
|
||||
kComparisonDirection,
|
||||
kComparisonType,
|
||||
kWindow,
|
||||
kConvolutionDimensionNumbers,
|
||||
kSharding,
|
||||
@ -327,6 +328,7 @@ class HloParserImpl : public HloParser {
|
||||
bool ParseOpcode(HloOpcode* result);
|
||||
bool ParseFftType(FftType* result);
|
||||
bool ParseComparisonDirection(ComparisonDirection* result);
|
||||
bool ParseComparisonType(Comparison::Type* result);
|
||||
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
||||
bool ParseRandomDistribution(RandomDistribution* result);
|
||||
bool ParseRandomAlgorithm(RandomAlgorithm* result);
|
||||
@ -1362,14 +1364,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
}
|
||||
case HloOpcode::kCompare: {
|
||||
optional<ComparisonDirection> direction;
|
||||
optional<Comparison::Type> type;
|
||||
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
|
||||
&direction};
|
||||
attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
|
||||
shape, operands[0], operands[1], *direction));
|
||||
shape, operands[0], operands[1], *direction, type));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCholesky: {
|
||||
@ -3018,6 +3022,14 @@ bool HloParserImpl::ParseAttributeHelper(
|
||||
->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kComparisonType: {
|
||||
Comparison::Type result;
|
||||
if (!ParseComparisonType(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kEnum: {
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects an enumeration value");
|
||||
@ -4145,6 +4157,21 @@ bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
|
||||
VLOG(1) << "ParseComparisonType";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects comparison type");
|
||||
}
|
||||
std::string val = lexer_.GetStrVal();
|
||||
auto status_or_result = StringToComparisonType(val);
|
||||
if (!status_or_result.ok()) {
|
||||
return TokenError(StrFormat("expects comparison type but sees: %s", val));
|
||||
}
|
||||
*result = status_or_result.ValueOrDie();
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
|
||||
VLOG(3) << "ParseFusionKind";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
|
@ -230,7 +230,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
|
||||
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
||||
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
|
||||
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
|
||||
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
|
||||
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated}
|
||||
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
|
||||
}
|
||||
|
||||
@ -512,7 +512,7 @@ R"(HloModule R4F32OverlapSmall_module
|
||||
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
||||
%lhs = f32[] parameter(0)
|
||||
%rhs = f32[] parameter(1)
|
||||
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
|
||||
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER
|
||||
}
|
||||
|
||||
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
||||
|
@ -34,6 +34,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:cholesky_expander",
|
||||
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
@ -81,6 +82,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<CholeskyExpander>();
|
||||
pipeline.AddPass<ComparisonExpander>();
|
||||
pipeline.AddPass<TriangularSolveExpander>();
|
||||
pipeline.AddPass<LayoutAssignment>(
|
||||
hlo_module->mutable_entry_computation_layout(),
|
||||
|
@ -1203,6 +1203,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
|
||||
ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) {
|
||||
SetFastMathDisabled(true);
|
||||
XlaBuilder builder(TestName());
|
||||
auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
|
||||
auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN});
|
||||
EqTotalOrder(lhs, rhs);
|
||||
|
||||
ComputeAndCompareR1<bool>(&builder, {false, false, true, true, false}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto lhs = ConstantR1<float>(&builder, {});
|
||||
@ -1222,6 +1232,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
|
||||
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) {
|
||||
SetFastMathDisabled(true);
|
||||
XlaBuilder builder(TestName());
|
||||
auto lhs =
|
||||
ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f, 6.0f});
|
||||
auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN, -NAN});
|
||||
GeTotalOrder(lhs, rhs);
|
||||
|
||||
ComputeAndCompareR1<bool>(&builder, {false, true, true, true, false, true},
|
||||
{});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
|
||||
SetFastMathDisabled(true);
|
||||
XlaBuilder builder(TestName());
|
||||
|
Loading…
Reference in New Issue
Block a user