Adding total-order comparison support in proto and HloInstruction.
Specifically a comparison type attribute is added to Hlo proto so that total order comparison can be explicitly specified. A comparison expander pass is added to all compilers to expand total order comparison into equivalent implementations through type conversion. PiperOrigin-RevId: 325820826 Change-Id: I7beceb2f751ddc0be7c6b7a74037e562e7580b62
This commit is contained in:
parent
9cd6a52394
commit
fd87e24980
@ -55,9 +55,13 @@ xla_test(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "comparators",
|
name = "comparators",
|
||||||
srcs = ["comparators.cc"],
|
srcs = ["comparators.cc"],
|
||||||
hdrs = ["comparators.h"],
|
hdrs = [
|
||||||
|
"comparators.h",
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
@ -32,85 +32,13 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using XlaOpGenerator = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>);
|
using XlaCompareOp = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>);
|
||||||
|
|
||||||
XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
|
|
||||||
int64 bit_width) {
|
|
||||||
PrimitiveType signed_type;
|
|
||||||
PrimitiveType unsigned_type;
|
|
||||||
XlaOp max_value;
|
|
||||||
switch (bit_width) {
|
|
||||||
case 16:
|
|
||||||
max_value =
|
|
||||||
ConstantR0(value.builder(),
|
|
||||||
static_cast<uint16>(std::numeric_limits<int16>::max()));
|
|
||||||
signed_type = S16;
|
|
||||||
unsigned_type = U16;
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
max_value =
|
|
||||||
ConstantR0(value.builder(),
|
|
||||||
static_cast<uint32>(std::numeric_limits<int32>::max()));
|
|
||||||
signed_type = S32;
|
|
||||||
unsigned_type = U32;
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
max_value =
|
|
||||||
ConstantR0(value.builder(),
|
|
||||||
static_cast<uint64>(std::numeric_limits<int64>::max()));
|
|
||||||
signed_type = S64;
|
|
||||||
unsigned_type = U64;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return value.builder()->ReportError(
|
|
||||||
InvalidArgument("Invalid bit width %lld for Comparator floating "
|
|
||||||
"point parameter.",
|
|
||||||
bit_width));
|
|
||||||
}
|
|
||||||
// Switch from a floating point value to a integer value in such a way that
|
|
||||||
// when using the integer value to compare, we get the same result for normal
|
|
||||||
// values, and -Nan is treated as the smallest value, and Nan is treated as
|
|
||||||
// the largest value.
|
|
||||||
// If f is a float, and
|
|
||||||
// x = bit_cast<int32>(f);
|
|
||||||
// y = x < 0 ? numeric_limits<int32>::max() - x : x;
|
|
||||||
// then y is ordered as an int32 such that finite values have the obvious
|
|
||||||
// order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
|
|
||||||
// and end of the ordering.
|
|
||||||
// Note that in order to avoid -x to overflow, we calculate
|
|
||||||
// numeric_limits<int32>::max() - x as unsigned, and then convert back to
|
|
||||||
// signed.
|
|
||||||
auto signed_value = BitcastConvertType(value, signed_type);
|
|
||||||
auto unsigned_value = BitcastConvertType(value, unsigned_type);
|
|
||||||
auto flipped_value =
|
|
||||||
BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
|
|
||||||
auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type));
|
|
||||||
return Select(is_negative, flipped_value, signed_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ConvertFloatingPoint(const PrimitiveType& operand_type, XlaOp* lhs_param,
|
|
||||||
XlaOp* rhs_param) {
|
|
||||||
if (primitive_util::IsFloatingPointType(operand_type)) {
|
|
||||||
PrimitiveType compare_type = operand_type;
|
|
||||||
// Special-case handling for BF16. We currently do not support direct
|
|
||||||
// comparisons with BF16, so we convert to F32 and then use the F32
|
|
||||||
// comparison logic.
|
|
||||||
if (compare_type == BF16) {
|
|
||||||
compare_type = F32;
|
|
||||||
*lhs_param = ConvertElementType(*lhs_param, F32);
|
|
||||||
*rhs_param = ConvertElementType(*rhs_param, F32);
|
|
||||||
}
|
|
||||||
int64 bit_width = primitive_util::BitWidth(compare_type);
|
|
||||||
*lhs_param = BitcastConvertFloatingPointToIntegral(*lhs_param, bit_width);
|
|
||||||
*rhs_param = BitcastConvertFloatingPointToIntegral(*rhs_param, bit_width);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
XlaComputation CreateScalarComparisonComputation(
|
XlaComputation CreateScalarComparisonComputation(
|
||||||
const string& name, const std::vector<PrimitiveType>& operand_types,
|
const string& name, const std::vector<PrimitiveType>& operand_types,
|
||||||
XlaBuilder* builder, XlaOpGenerator generator) {
|
XlaBuilder* builder, XlaCompareOp generator) {
|
||||||
CHECK_NE(operand_types.size(), 0);
|
CHECK_NE(operand_types.size(), 0);
|
||||||
std::vector<absl::optional<XlaOpGenerator>> generators(operand_types.size());
|
std::vector<absl::optional<XlaCompareOp>> generators(operand_types.size());
|
||||||
generators[0] = generator;
|
generators[0] = generator;
|
||||||
return CreateScalarComparisonComputation(name, operand_types, generators,
|
return CreateScalarComparisonComputation(name, operand_types, generators,
|
||||||
builder);
|
builder);
|
||||||
@ -119,7 +47,7 @@ XlaComputation CreateScalarComparisonComputation(
|
|||||||
|
|
||||||
XlaComputation CreateScalarComparisonComputation(
|
XlaComputation CreateScalarComparisonComputation(
|
||||||
const string& name, const std::vector<PrimitiveType>& operand_types,
|
const string& name, const std::vector<PrimitiveType>& operand_types,
|
||||||
const std::vector<absl::optional<XlaOpGenerator>>& generators,
|
const std::vector<absl::optional<XlaCompareOp>>& generators,
|
||||||
XlaBuilder* builder) {
|
XlaBuilder* builder) {
|
||||||
// Create a default computation where we compare only the first two
|
// Create a default computation where we compare only the first two
|
||||||
// parameters of type 'operand_types[0]'.
|
// parameters of type 'operand_types[0]'.
|
||||||
@ -146,7 +74,6 @@ XlaComputation CreateScalarComparisonComputation(
|
|||||||
absl::StrCat("p.", parameter_count, ".lhs"));
|
absl::StrCat("p.", parameter_count, ".lhs"));
|
||||||
auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
|
auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
|
||||||
absl::StrCat("p.", parameter_count, ".rhs"));
|
absl::StrCat("p.", parameter_count, ".rhs"));
|
||||||
ConvertFloatingPoint(operand_type, &lhs_param, &rhs_param);
|
|
||||||
lhs_params.emplace_back(lhs_param);
|
lhs_params.emplace_back(lhs_param);
|
||||||
rhs_params.emplace_back(rhs_param);
|
rhs_params.emplace_back(rhs_param);
|
||||||
if (generators[parameter_count].has_value()) {
|
if (generators[parameter_count].has_value()) {
|
||||||
@ -169,7 +96,8 @@ XlaComputation CreateScalarComparisonComputation(
|
|||||||
generators[i].value()(lhs_params[i], rhs_params[i], {}),
|
generators[i].value()(lhs_params[i], rhs_params[i], {}),
|
||||||
result);
|
result);
|
||||||
if (i != last_generator_index) {
|
if (i != last_generator_index) {
|
||||||
param_equal = And(param_equal, Eq(lhs_params[i], rhs_params[i]));
|
param_equal =
|
||||||
|
And(param_equal, EqTotalOrder(lhs_params[i], rhs_params[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -181,14 +109,14 @@ XlaComputation CreateScalarComparisonComputation(
|
|||||||
XlaComputation CreateScalarLtComputation(
|
XlaComputation CreateScalarLtComputation(
|
||||||
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
|
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
|
||||||
return CreateScalarComparisonComputation("compare-less-than", operand_types,
|
return CreateScalarComparisonComputation("compare-less-than", operand_types,
|
||||||
builder, Lt);
|
builder, LtTotalOrder);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a scalar greater-than computation and returns it.
|
// Creates a scalar greater-than computation and returns it.
|
||||||
XlaComputation CreateScalarGtComputation(
|
XlaComputation CreateScalarGtComputation(
|
||||||
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
|
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
|
||||||
return CreateScalarComparisonComputation("compare-greater-than",
|
return CreateScalarComparisonComputation(
|
||||||
operand_types, builder, Gt);
|
"compare-greater-than", operand_types, builder, GtTotalOrder);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -43,14 +43,13 @@ XlaComputation CreateScalarGtComputation(
|
|||||||
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder);
|
const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder);
|
||||||
|
|
||||||
// Creates a scalar comparison computation and returns it. This function takes
|
// Creates a scalar comparison computation and returns it. This function takes
|
||||||
// an std::vector<absl::optional<XlaOpGenerator>> and compare the operands
|
// a vector of comparator functions to compare the operands where the function
|
||||||
// where the generator isn't nullopt with the specified comparator
|
// isn't nullopt with the specified comparator at that location.
|
||||||
// at that location.
|
|
||||||
XlaComputation CreateScalarComparisonComputation(
|
XlaComputation CreateScalarComparisonComputation(
|
||||||
const string& name, const std::vector<PrimitiveType>& operand_types,
|
const string& name, const std::vector<PrimitiveType>& operand_types,
|
||||||
const std::vector<
|
const std::vector<
|
||||||
absl::optional<XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>)>>&
|
absl::optional<XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64>)>>&
|
||||||
generators,
|
comparators,
|
||||||
XlaBuilder* builder);
|
XlaBuilder* builder);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -577,7 +577,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
|
|||||||
|
|
||||||
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions,
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
absl::optional<ComparisonDirection> direction) {
|
absl::optional<ComparisonDirection> direction,
|
||||||
|
absl::optional<Comparison::Type> type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||||
@ -635,7 +636,11 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
|||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"kCompare expects a ComparisonDirection, but none provided.");
|
"kCompare expects a ComparisonDirection, but none provided.");
|
||||||
}
|
}
|
||||||
return Compare(shape, updated_lhs, updated_rhs, *direction);
|
if (type == absl::nullopt) {
|
||||||
|
return Compare(shape, updated_lhs, updated_rhs, *direction);
|
||||||
|
} else {
|
||||||
|
return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (direction.has_value()) {
|
if (direction.has_value()) {
|
||||||
@ -658,8 +663,16 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
|
|||||||
|
|
||||||
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||||
ComparisonDirection direction) {
|
ComparisonDirection direction) {
|
||||||
|
return Compare(shape, lhs, rhs, direction,
|
||||||
|
Comparison::DefaultComparisonType(shape.element_type()));
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||||
|
ComparisonDirection direction,
|
||||||
|
Comparison::Type type) {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
instr.set_comparison_direction(ComparisonDirectionToString(direction));
|
instr.set_comparison_direction(ComparisonDirectionToString(direction));
|
||||||
|
instr.set_comparison_type(ComparisonTypeToString(type));
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
|
return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
|
||||||
}
|
}
|
||||||
@ -3512,31 +3525,71 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
|
|||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||||
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq,
|
||||||
|
compare_type);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||||
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe,
|
||||||
|
compare_type);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||||
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe,
|
||||||
|
compare_type);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||||
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt,
|
||||||
|
compare_type);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
auto compare_type = Comparison::Type::kFloatTotalOrder;
|
||||||
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe,
|
||||||
|
compare_type);
|
||||||
|
}
|
||||||
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt,
|
||||||
|
Comparison::Type::kFloatTotalOrder);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions,
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
ComparisonDirection direction) {
|
ComparisonDirection direction) {
|
||||||
@ -3544,6 +3597,13 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
|||||||
broadcast_dimensions, direction);
|
broadcast_dimensions, direction);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
ComparisonDirection direction, Comparison::Type compare_type) {
|
||||||
|
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
|
||||||
|
broadcast_dimensions, direction, compare_type);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
|
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
|
||||||
return Compare(lhs, rhs, {}, direction);
|
return Compare(lhs, rhs, {}, direction);
|
||||||
}
|
}
|
||||||
|
@ -792,14 +792,17 @@ class XlaBuilder {
|
|||||||
// broadcast_dimensions specifies which dimensions to use for broadcasting
|
// broadcast_dimensions specifies which dimensions to use for broadcasting
|
||||||
// when the operation is between tensors of different ranks. The direction is
|
// when the operation is between tensors of different ranks. The direction is
|
||||||
// only used if opcode is kCompare.
|
// only used if opcode is kCompare.
|
||||||
XlaOp BinaryOp(
|
XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
||||||
HloOpcode binop, XlaOp lhs, XlaOp rhs,
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
absl::Span<const int64> broadcast_dimensions,
|
absl::optional<ComparisonDirection> direction = absl::nullopt,
|
||||||
absl::optional<Comparison::Direction> direction = absl::nullopt);
|
absl::optional<Comparison::Type> type = absl::nullopt);
|
||||||
|
|
||||||
// Internal helper method for binary op compare without broadcast dimensions.
|
// Internal helper method for binary op compare without broadcast dimensions.
|
||||||
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||||
Comparison::Direction direction);
|
ComparisonDirection direction);
|
||||||
|
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||||
|
ComparisonDirection direction,
|
||||||
|
Comparison::Type type);
|
||||||
|
|
||||||
// Internal helper method that does the building for an arbitrary binary op
|
// Internal helper method that does the building for an arbitrary binary op
|
||||||
// with same ranked operands that doesn't broadcast.
|
// with same ranked operands that doesn't broadcast.
|
||||||
@ -965,22 +968,13 @@ class XlaBuilder {
|
|||||||
friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
|
friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
|
||||||
friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
|
friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
|
||||||
friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
|
friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
|
||||||
friend XlaOp Eq(XlaOp lhs, XlaOp rhs,
|
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
|
||||||
friend XlaOp Ne(XlaOp lhs, XlaOp rhs,
|
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
|
||||||
friend XlaOp Ge(XlaOp lhs, XlaOp rhs,
|
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
|
||||||
friend XlaOp Gt(XlaOp lhs, XlaOp rhs,
|
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
|
||||||
friend XlaOp Lt(XlaOp lhs, XlaOp rhs,
|
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
|
||||||
friend XlaOp Le(XlaOp lhs, XlaOp rhs,
|
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
|
||||||
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions,
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
ComparisonDirection direction);
|
ComparisonDirection direction);
|
||||||
friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
|
friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
ComparisonDirection direction,
|
||||||
|
Comparison::Type compare_type);
|
||||||
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config);
|
||||||
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||||
@ -1574,29 +1568,44 @@ XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
|
|||||||
// Enqueues an equal-to comparison instruction onto the computation.
|
// Enqueues an equal-to comparison instruction onto the computation.
|
||||||
XlaOp Eq(XlaOp lhs, XlaOp rhs,
|
XlaOp Eq(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a not-equal comparison instruction onto the computation.
|
// Enqueues a not-equal comparison instruction onto the computation.
|
||||||
XlaOp Ne(XlaOp lhs, XlaOp rhs,
|
XlaOp Ne(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a greater-or-equal comparison instruction onto the computation.
|
// Enqueues a greater-or-equal comparison instruction onto the computation.
|
||||||
XlaOp Ge(XlaOp lhs, XlaOp rhs,
|
XlaOp Ge(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a greater-than comparison instruction onto the computation.
|
// Enqueues a greater-than comparison instruction onto the computation.
|
||||||
XlaOp Gt(XlaOp lhs, XlaOp rhs,
|
XlaOp Gt(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a less-than comparison instruction onto the computation.
|
// Enqueues a less-than comparison instruction onto the computation.
|
||||||
XlaOp Lt(XlaOp lhs, XlaOp rhs,
|
XlaOp Lt(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a less-or-equal comparison instruction onto the computation.
|
// Enqueues a less-or-equal comparison instruction onto the computation.
|
||||||
XlaOp Le(XlaOp lhs, XlaOp rhs,
|
XlaOp Le(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
// Enqueues a comparison instruction onto the computation (optionally without
|
// Enqueues a comparison instruction onto the computation (optionally without
|
||||||
// broadcast_dimensions for consistency with others).
|
// broadcast_dimensions for consistency with others).
|
||||||
|
XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
ComparisonDirection direction, Comparison::Type compare_type);
|
||||||
XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
XlaOp Compare(XlaOp lhs, XlaOp rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions,
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
ComparisonDirection direction);
|
ComparisonDirection direction);
|
||||||
|
@ -54,32 +54,59 @@ StatusOr<Comparison::Direction> StringToComparisonDirection(
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
Comparison::Comparison(Direction dir, PrimitiveType type) : dir_(dir) {
|
StatusOr<Comparison::Type> StringToComparisonType(
|
||||||
|
absl::string_view compare_type_name) {
|
||||||
|
static auto* type_map = new absl::flat_hash_map<string, Comparison::Type>({
|
||||||
|
{"FLOAT", Comparison::Type::kFloat},
|
||||||
|
{"TOTALORDER", Comparison::Type::kFloatTotalOrder},
|
||||||
|
{"SIGNED", Comparison::Type::kSigned},
|
||||||
|
{"UNSIGNED", Comparison::Type::kUnsigned},
|
||||||
|
});
|
||||||
|
auto it = type_map->find(compare_type_name);
|
||||||
|
if (it == type_map->end()) {
|
||||||
|
return InvalidArgument("Unknown comparison type: %s", compare_type_name);
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ComparisonTypeToString(Comparison::Type type) {
|
||||||
|
switch (type) {
|
||||||
|
case Comparison::Type::kFloat:
|
||||||
|
return "FLOAT";
|
||||||
|
case Comparison::Type::kFloatTotalOrder:
|
||||||
|
return "TOTALORDER";
|
||||||
|
case Comparison::Type::kSigned:
|
||||||
|
return "SIGNED";
|
||||||
|
case Comparison::Type::kUnsigned:
|
||||||
|
return "UNSIGNED";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Comparison::Comparison(Direction dir, PrimitiveType type)
|
||||||
|
: dir_(dir), type_(DefaultComparisonType(type)) {}
|
||||||
|
|
||||||
|
Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case S8:
|
case S8:
|
||||||
case S16:
|
case S16:
|
||||||
case S32:
|
case S32:
|
||||||
case S64:
|
case S64:
|
||||||
type_ = Type::kSigned;
|
return Type::kSigned;
|
||||||
break;
|
|
||||||
case PRED:
|
case PRED:
|
||||||
case U8:
|
case U8:
|
||||||
case U16:
|
case U16:
|
||||||
case U32:
|
case U32:
|
||||||
case U64:
|
case U64:
|
||||||
type_ = Type::kUnsigned;
|
return Type::kUnsigned;
|
||||||
break;
|
|
||||||
case F16:
|
case F16:
|
||||||
case F32:
|
case F32:
|
||||||
case BF16:
|
case BF16:
|
||||||
case F64:
|
case F64:
|
||||||
case C64:
|
case C64:
|
||||||
case C128:
|
case C128:
|
||||||
type_ = Type::kFloat;
|
return Type::kFloat;
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unsupported comparison mode."
|
LOG(FATAL) << "Unsupported comparison mode."
|
||||||
<< ComparisonDirectionToString(dir) << ":"
|
|
||||||
<< PrimitiveType_Name(type) << "\n";
|
<< PrimitiveType_Name(type) << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -164,20 +191,6 @@ bool Comparison::IsAntireflexive() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ const char* Comparison::ComparisonTypeToString(
|
|
||||||
Comparison::Type type) {
|
|
||||||
switch (type) {
|
|
||||||
case Type::kFloat:
|
|
||||||
return "f";
|
|
||||||
case Type::kFloatTotalOrder:
|
|
||||||
return "ft";
|
|
||||||
case Type::kSigned:
|
|
||||||
return "s";
|
|
||||||
case Type::kUnsigned:
|
|
||||||
return "u";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Comparison::ToString(std::string prefix1,
|
std::string Comparison::ToString(std::string prefix1,
|
||||||
std::string prefix2) const {
|
std::string prefix2) const {
|
||||||
return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 +
|
return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 +
|
||||||
|
@ -103,11 +103,11 @@ class Comparison {
|
|||||||
bool Compare(const T a, const T b) const {
|
bool Compare(const T a, const T b) const {
|
||||||
return GetComparator<T>()(a, b);
|
return GetComparator<T>()(a, b);
|
||||||
}
|
}
|
||||||
|
static Type DefaultComparisonType(PrimitiveType t);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static Direction Converse(Direction dir);
|
static Direction Converse(Direction dir);
|
||||||
static Direction Inverse(Direction dir);
|
static Direction Inverse(Direction dir);
|
||||||
static const char* ComparisonTypeToString(Type type);
|
|
||||||
|
|
||||||
const Direction dir_;
|
const Direction dir_;
|
||||||
Type type_;
|
Type type_;
|
||||||
@ -117,10 +117,14 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) {
|
|||||||
return os << cmp.ToString();
|
return os << cmp.ToString();
|
||||||
}
|
}
|
||||||
string ComparisonDirectionToString(Comparison::Direction direction);
|
string ComparisonDirectionToString(Comparison::Direction direction);
|
||||||
|
std::string ComparisonTypeToString(Comparison::Type type);
|
||||||
|
|
||||||
StatusOr<Comparison::Direction> StringToComparisonDirection(
|
StatusOr<Comparison::Direction> StringToComparisonDirection(
|
||||||
absl::string_view direction_name);
|
absl::string_view direction_name);
|
||||||
|
|
||||||
|
StatusOr<Comparison::Type> StringToComparisonType(
|
||||||
|
absl::string_view compare_type_name);
|
||||||
|
|
||||||
using ComparisonDirection = Comparison::Direction;
|
using ComparisonDirection = Comparison::Direction;
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -1235,7 +1235,10 @@ floating-point types.
|
|||||||
|
|
||||||
Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
|
Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
|
||||||
(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
|
(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
|
||||||
(less-than).
|
(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder,
|
||||||
|
GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities,
|
||||||
|
except that they additionally support a total order over the floating point
|
||||||
|
numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN.
|
||||||
|
|
||||||
Arguments | Type | Semantics
|
Arguments | Type | Semantics
|
||||||
--------- | ------- | ----------------------------------------
|
--------- | ------- | ----------------------------------------
|
||||||
|
@ -112,6 +112,21 @@ xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth) {
|
||||||
|
switch (src_bitwidth) {
|
||||||
|
case 8:
|
||||||
|
return xla::S8;
|
||||||
|
case 16:
|
||||||
|
return xla::S16;
|
||||||
|
case 32:
|
||||||
|
return xla::S32;
|
||||||
|
case 64:
|
||||||
|
return xla::S64;
|
||||||
|
default:
|
||||||
|
return xla::PRIMITIVE_TYPE_INVALID;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
|
PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
|
||||||
switch (complex_type) {
|
switch (complex_type) {
|
||||||
case C64:
|
case C64:
|
||||||
|
@ -153,6 +153,8 @@ int BitWidth(PrimitiveType type);
|
|||||||
|
|
||||||
PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
|
PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth);
|
||||||
|
|
||||||
|
PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth);
|
||||||
|
|
||||||
// Returns the real, imag component type underlying the given complex type.
|
// Returns the real, imag component type underlying the given complex type.
|
||||||
// LOG(FATAL)'s if complex_type is not complex.
|
// LOG(FATAL)'s if complex_type is not complex.
|
||||||
PrimitiveType ComplexComponentType(PrimitiveType complex_type);
|
PrimitiveType ComplexComponentType(PrimitiveType complex_type);
|
||||||
|
@ -1700,7 +1700,10 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "hlo_creation_utils",
|
name = "hlo_creation_utils",
|
||||||
srcs = ["hlo_creation_utils.cc"],
|
srcs = ["hlo_creation_utils.cc"],
|
||||||
hdrs = ["hlo_creation_utils.h"],
|
hdrs = [
|
||||||
|
"hlo_creation_utils.h",
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_module_config",
|
":hlo_module_config",
|
||||||
@ -1816,6 +1819,21 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "comparison_expander",
|
||||||
|
srcs = ["comparison_expander.cc"],
|
||||||
|
hdrs = ["comparison_expander.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":hlo_creation_utils",
|
||||||
|
":hlo_pass",
|
||||||
|
":op_expander_pass",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "scatter_expander",
|
name = "scatter_expander",
|
||||||
srcs = ["scatter_expander.cc"],
|
srcs = ["scatter_expander.cc"],
|
||||||
|
133
tensorflow/compiler/xla/service/comparison_expander.cc
Normal file
133
tensorflow/compiler/xla/service/comparison_expander.cc
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
HloInstruction* BitcastConvertFloatingPointToIntegral(
|
||||||
|
HloComputation* computation, HloInstruction* value,
|
||||||
|
const Shape& signed_shape, const Shape& unsigned_shape,
|
||||||
|
HloInstruction* zero, HloInstruction* max_value) {
|
||||||
|
// Switch from a floating point value to a integer value in such a way that
|
||||||
|
// when using the integer value to compare, we get the same result for normal
|
||||||
|
// values, and -Nan is treated as the smallest value, and Nan is treated as
|
||||||
|
// the largest value.
|
||||||
|
// If f is a float, and
|
||||||
|
// x = bit_cast<int32>(f);
|
||||||
|
// y = x < 0 ? numeric_limits<int32>::max() - x : x;
|
||||||
|
// then y is ordered as an int32 such that finite values have the obvious
|
||||||
|
// order, -0 is ordered before 0, and -NaN and NaN appear at the beginning
|
||||||
|
// and end of the ordering.
|
||||||
|
// Note that in order to avoid -x to overflow, we calculate
|
||||||
|
// numeric_limits<int32>::max() - x as unsigned, and then convert back to
|
||||||
|
// signed.
|
||||||
|
auto signed_value = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateBitcastConvert(signed_shape, value));
|
||||||
|
auto unsigned_value = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateBitcastConvert(unsigned_shape, value));
|
||||||
|
auto flipped_value = computation->AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
unsigned_shape, HloOpcode::kSubtract, max_value, unsigned_value));
|
||||||
|
flipped_value = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateBitcastConvert(signed_shape, flipped_value));
|
||||||
|
auto compare_shape = signed_shape;
|
||||||
|
compare_shape.set_element_type(PRED);
|
||||||
|
auto is_negative = computation->AddInstruction(HloInstruction::CreateCompare(
|
||||||
|
compare_shape, signed_value, zero, ComparisonDirection::kLt));
|
||||||
|
return computation->AddInstruction(
|
||||||
|
HloInstruction::CreateTernary(signed_shape, HloOpcode::kSelect,
|
||||||
|
is_negative, flipped_value, signed_value));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ComparisonExpander::InstructionMatchesPattern(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
if (HloCompareInstruction* compare =
|
||||||
|
dynamic_cast<HloCompareInstruction*>(instruction)) {
|
||||||
|
HloInstruction* lhs = instruction->operands()[0];
|
||||||
|
if (compare->type() == Comparison::Type::kFloatTotalOrder &&
|
||||||
|
primitive_util::IsFloatingPointType(lhs->shape().element_type())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<HloInstruction*> ComparisonExpander::ExpandInstruction(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
CHECK(instruction->opcode() == HloOpcode::kCompare);
|
||||||
|
HloCompareInstruction* compare =
|
||||||
|
static_cast<HloCompareInstruction*>(instruction);
|
||||||
|
CHECK(compare->type() == Comparison::Type::kFloatTotalOrder);
|
||||||
|
HloComputation* computation = instruction->parent();
|
||||||
|
HloInstruction* lhs = instruction->operands()[0];
|
||||||
|
HloInstruction* rhs = instruction->operands()[1];
|
||||||
|
Shape compare_shape = lhs->shape();
|
||||||
|
PrimitiveType compare_type = compare_shape.element_type();
|
||||||
|
CHECK(primitive_util::IsFloatingPointType(compare_type));
|
||||||
|
// Special-case handling for BF16. We currently do not support direct
|
||||||
|
// comparisons with BF16, so we convert to F32 and then use the F32
|
||||||
|
// comparison logic.
|
||||||
|
if (compare_type == BF16) {
|
||||||
|
compare_type = F32;
|
||||||
|
compare_shape.set_element_type(compare_type);
|
||||||
|
lhs = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConvert(compare_shape, lhs));
|
||||||
|
rhs = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConvert(compare_shape, rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 bit_width = primitive_util::BitWidth(compare_type);
|
||||||
|
PrimitiveType signed_type =
|
||||||
|
primitive_util::SignedIntegralTypeForBitWidth(bit_width);
|
||||||
|
PrimitiveType unsigned_type =
|
||||||
|
primitive_util::UnsignedIntegralTypeForBitWidth(bit_width);
|
||||||
|
auto signed_shape = compare_shape;
|
||||||
|
signed_shape.set_element_type(signed_type);
|
||||||
|
auto unsigned_shape = compare_shape;
|
||||||
|
unsigned_shape.set_element_type(unsigned_type);
|
||||||
|
auto zero_value = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type)));
|
||||||
|
zero_value = computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
signed_shape, zero_value, zero_value->shape().dimensions()));
|
||||||
|
auto max_signed = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type)));
|
||||||
|
auto max_shape = max_signed->shape();
|
||||||
|
max_shape.set_element_type(unsigned_type);
|
||||||
|
auto max_unsigned = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConvert(max_shape, max_signed));
|
||||||
|
auto max_value = computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
unsigned_shape, max_unsigned, max_shape.dimensions()));
|
||||||
|
lhs = BitcastConvertFloatingPointToIntegral(
|
||||||
|
computation, lhs, signed_shape, unsigned_shape, zero_value, max_value);
|
||||||
|
rhs = BitcastConvertFloatingPointToIntegral(
|
||||||
|
computation, rhs, signed_shape, unsigned_shape, zero_value, max_value);
|
||||||
|
auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare(
|
||||||
|
instruction->shape(), lhs, rhs, compare->direction(),
|
||||||
|
Comparison::Type::kSigned));
|
||||||
|
VLOG(2) << "New comparison instruction for total order:"
|
||||||
|
<< new_compare->ToString() << "\n";
|
||||||
|
return new_compare;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
47
tensorflow/compiler/xla/service/comparison_expander.h
Normal file
47
tensorflow/compiler/xla/service/comparison_expander.h
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// A pass which performs expansion of the comparison operator to support total
|
||||||
|
// order comparison of floating point numbers.
|
||||||
|
class ComparisonExpander : public OpExpanderPass {
|
||||||
|
public:
|
||||||
|
explicit ComparisonExpander() = default;
|
||||||
|
~ComparisonExpander() override = default;
|
||||||
|
absl::string_view name() const override { return "comparison-expander"; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns `true` if `instruction` should be expanded by this pass.
|
||||||
|
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||||
|
// Returns a replacement for `instruction`, or nullptr if no replacement is
|
||||||
|
// needed (e.g. only the to_apply subcomputation of the instruction was
|
||||||
|
// modified).
|
||||||
|
StatusOr<HloInstruction*> ExpandInstruction(
|
||||||
|
HloInstruction* instruction) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_
|
@ -145,6 +145,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:conditional_to_select",
|
"//tensorflow/compiler/xla/service:conditional_to_select",
|
||||||
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
||||||
"//tensorflow/compiler/xla/service:scatter_expander",
|
"//tensorflow/compiler/xla/service:scatter_expander",
|
||||||
|
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||||
"//tensorflow/compiler/xla/service:slice_sinker",
|
"//tensorflow/compiler/xla/service:slice_sinker",
|
||||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
@ -54,6 +54,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
|
#include "tensorflow/compiler/xla/service/conditional_to_select.h"
|
||||||
@ -261,6 +262,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
|||||||
pipeline.AddPass<ConditionalToSelect>();
|
pipeline.AddPass<ConditionalToSelect>();
|
||||||
pipeline.AddPass<MapInliner>();
|
pipeline.AddPass<MapInliner>();
|
||||||
|
|
||||||
|
pipeline.AddPass<ComparisonExpander>();
|
||||||
pipeline.AddPass<CholeskyExpander>();
|
pipeline.AddPass<CholeskyExpander>();
|
||||||
pipeline.AddPass<TriangularSolveExpander>();
|
pipeline.AddPass<TriangularSolveExpander>();
|
||||||
|
|
||||||
|
@ -1168,6 +1168,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
"//tensorflow/compiler/xla/service:batchnorm_expander",
|
||||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||||
"//tensorflow/compiler/xla/service:call_inliner",
|
"//tensorflow/compiler/xla/service:call_inliner",
|
||||||
|
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||||
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
"//tensorflow/compiler/xla/service:conditional_canonicalizer",
|
||||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||||
"//tensorflow/compiler/xla/service:convolution_4d_expander",
|
"//tensorflow/compiler/xla/service:convolution_4d_expander",
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
#include "tensorflow/compiler/xla/service/batchnorm_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
|
||||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
|
#include "tensorflow/compiler/xla/service/convolution_4d_expander.h"
|
||||||
@ -140,6 +141,9 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
pipeline.AddPass<RngExpander>();
|
pipeline.AddPass<RngExpander>();
|
||||||
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
||||||
|
|
||||||
|
// Comparison total order expander
|
||||||
|
pipeline.AddPass<ComparisonExpander>();
|
||||||
|
|
||||||
// Remove zero-sized HLO from the input so that other passes don't have to
|
// Remove zero-sized HLO from the input so that other passes don't have to
|
||||||
// handle it.
|
// handle it.
|
||||||
pipeline.AddPass<ZeroSizedHloElimination>();
|
pipeline.AddPass<ZeroSizedHloElimination>();
|
||||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
|||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
// Serialization of HloInstruction.
|
// Serialization of HloInstruction.
|
||||||
// Next ID: 72
|
// Next ID: 73
|
||||||
message HloInstructionProto {
|
message HloInstructionProto {
|
||||||
reserved 10;
|
reserved 10;
|
||||||
reserved "parameter_name";
|
reserved "parameter_name";
|
||||||
@ -248,6 +248,9 @@ message HloInstructionProto {
|
|||||||
|
|
||||||
// RNG algorithm used by kRngBitGenerator.
|
// RNG algorithm used by kRngBitGenerator.
|
||||||
xla.RandomAlgorithm rng_algorithm = 70;
|
xla.RandomAlgorithm rng_algorithm = 70;
|
||||||
|
|
||||||
|
// The comparison type used for kCompare.
|
||||||
|
string comparison_type = 72;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialization of HloComputation.
|
// Serialization of HloComputation.
|
||||||
|
@ -174,8 +174,19 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
comparison_direction,
|
comparison_direction,
|
||||||
StringToComparisonDirection(proto.comparison_direction()));
|
StringToComparisonDirection(proto.comparison_direction()));
|
||||||
}
|
}
|
||||||
instruction =
|
auto comparison_type_str = proto.comparison_type();
|
||||||
CreateCompare(shape, operands(0), operands(1), *comparison_direction);
|
if (!comparison_type_str.empty()) {
|
||||||
|
// If a comparison type is specified, it *must* be valid.
|
||||||
|
TF_ASSIGN_OR_RETURN(auto comparison_type,
|
||||||
|
StringToComparisonType(comparison_type_str));
|
||||||
|
instruction = CreateCompare(shape, operands(0), operands(1),
|
||||||
|
*comparison_direction, comparison_type);
|
||||||
|
} else {
|
||||||
|
// Allow the specify of comparison type to be optional.
|
||||||
|
// The comparison type will be determined by the types of the operands.
|
||||||
|
instruction = CreateCompare(shape, operands(0), operands(1),
|
||||||
|
*comparison_direction);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kTriangularSolve: {
|
case HloOpcode::kTriangularSolve: {
|
||||||
@ -926,8 +937,9 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
|||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
|
||||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||||
ComparisonDirection direction) {
|
ComparisonDirection direction, absl::optional<Comparison::Type> type) {
|
||||||
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
|
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
|
||||||
|
type);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction>
|
/* static */ std::unique_ptr<HloInstruction>
|
||||||
|
@ -595,7 +595,8 @@ class HloInstruction {
|
|||||||
// Creates a compare op, performing the comparison specified in direction.
|
// Creates a compare op, performing the comparison specified in direction.
|
||||||
static std::unique_ptr<HloInstruction> CreateCompare(
|
static std::unique_ptr<HloInstruction> CreateCompare(
|
||||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||||
Comparison::Direction direction);
|
Comparison::Direction direction,
|
||||||
|
absl::optional<Comparison::Type> type = absl::nullopt);
|
||||||
|
|
||||||
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
|
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
|
||||||
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
||||||
|
@ -204,12 +204,13 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
|
|||||||
fft_length_);
|
fft_length_);
|
||||||
}
|
}
|
||||||
|
|
||||||
HloCompareInstruction::HloCompareInstruction(const Shape& shape,
|
HloCompareInstruction::HloCompareInstruction(
|
||||||
HloInstruction* lhs,
|
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||||
HloInstruction* rhs,
|
ComparisonDirection direction, absl::optional<Comparison::Type> type)
|
||||||
ComparisonDirection direction)
|
|
||||||
: HloInstruction(HloOpcode::kCompare, shape),
|
: HloInstruction(HloOpcode::kCompare, shape),
|
||||||
compare_(direction, lhs->shape().element_type()) {
|
compare_(direction, type ? (*type)
|
||||||
|
: Comparison::DefaultComparisonType(
|
||||||
|
lhs->shape().element_type())) {
|
||||||
AppendOperand(lhs);
|
AppendOperand(lhs);
|
||||||
AppendOperand(rhs);
|
AppendOperand(rhs);
|
||||||
}
|
}
|
||||||
@ -218,12 +219,21 @@ HloInstructionProto HloCompareInstruction::ToProto() const {
|
|||||||
HloInstructionProto proto = HloInstruction::ToProto();
|
HloInstructionProto proto = HloInstruction::ToProto();
|
||||||
proto.set_comparison_direction(
|
proto.set_comparison_direction(
|
||||||
ComparisonDirectionToString(compare_.GetDirection()));
|
ComparisonDirectionToString(compare_.GetDirection()));
|
||||||
|
proto.set_comparison_type(ComparisonTypeToString(compare_.GetType()));
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
|
std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
|
||||||
const HloPrintOptions& options) const {
|
const HloPrintOptions& options) const {
|
||||||
return {StrCat("direction=", ComparisonDirectionToString(direction()))};
|
std::vector<string> result;
|
||||||
|
result.push_back(
|
||||||
|
StrCat("direction=", ComparisonDirectionToString(direction())));
|
||||||
|
if (compare_.GetType() !=
|
||||||
|
Comparison::DefaultComparisonType(operand(0)->shape().element_type())) {
|
||||||
|
result.push_back(
|
||||||
|
StrCat("type=", ComparisonTypeToString(compare_.GetType())));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloCompareInstruction::IdenticalSlowPath(
|
bool HloCompareInstruction::IdenticalSlowPath(
|
||||||
@ -238,8 +248,8 @@ std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
|
|||||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||||
HloCloneContext* context) const {
|
HloCloneContext* context) const {
|
||||||
CHECK_EQ(new_operands.size(), 2);
|
CHECK_EQ(new_operands.size(), 2);
|
||||||
return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
|
return absl::make_unique<HloCompareInstruction>(
|
||||||
new_operands[1], direction());
|
shape, new_operands[0], new_operands[1], direction(), type());
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -136,8 +136,10 @@ class HloCompareInstruction : public HloInstruction {
|
|||||||
public:
|
public:
|
||||||
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
|
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
|
||||||
HloInstruction* rhs,
|
HloInstruction* rhs,
|
||||||
ComparisonDirection direction);
|
ComparisonDirection direction,
|
||||||
|
absl::optional<Comparison::Type> type);
|
||||||
ComparisonDirection direction() const { return compare_.GetDirection(); }
|
ComparisonDirection direction() const { return compare_.GetDirection(); }
|
||||||
|
Comparison::Type type() const { return compare_.GetType(); }
|
||||||
HloInstructionProto ToProto() const override;
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -194,6 +194,7 @@ class HloParserImpl : public HloParser {
|
|||||||
kBracedHloComputationList,
|
kBracedHloComputationList,
|
||||||
kFftType,
|
kFftType,
|
||||||
kComparisonDirection,
|
kComparisonDirection,
|
||||||
|
kComparisonType,
|
||||||
kWindow,
|
kWindow,
|
||||||
kConvolutionDimensionNumbers,
|
kConvolutionDimensionNumbers,
|
||||||
kSharding,
|
kSharding,
|
||||||
@ -327,6 +328,7 @@ class HloParserImpl : public HloParser {
|
|||||||
bool ParseOpcode(HloOpcode* result);
|
bool ParseOpcode(HloOpcode* result);
|
||||||
bool ParseFftType(FftType* result);
|
bool ParseFftType(FftType* result);
|
||||||
bool ParseComparisonDirection(ComparisonDirection* result);
|
bool ParseComparisonDirection(ComparisonDirection* result);
|
||||||
|
bool ParseComparisonType(Comparison::Type* result);
|
||||||
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
||||||
bool ParseRandomDistribution(RandomDistribution* result);
|
bool ParseRandomDistribution(RandomDistribution* result);
|
||||||
bool ParseRandomAlgorithm(RandomAlgorithm* result);
|
bool ParseRandomAlgorithm(RandomAlgorithm* result);
|
||||||
@ -1362,14 +1364,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
}
|
}
|
||||||
case HloOpcode::kCompare: {
|
case HloOpcode::kCompare: {
|
||||||
optional<ComparisonDirection> direction;
|
optional<ComparisonDirection> direction;
|
||||||
|
optional<Comparison::Type> type;
|
||||||
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
|
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
|
||||||
&direction};
|
&direction};
|
||||||
|
attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type};
|
||||||
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||||
!ParseAttributes(attrs)) {
|
!ParseAttributes(attrs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
|
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
|
||||||
shape, operands[0], operands[1], *direction));
|
shape, operands[0], operands[1], *direction, type));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kCholesky: {
|
case HloOpcode::kCholesky: {
|
||||||
@ -3018,6 +3022,14 @@ bool HloParserImpl::ParseAttributeHelper(
|
|||||||
->emplace(result);
|
->emplace(result);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
case AttrTy::kComparisonType: {
|
||||||
|
Comparison::Type result;
|
||||||
|
if (!ParseComparisonType(&result)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
static_cast<optional<Comparison::Type>*>(attr_out_ptr)->emplace(result);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
case AttrTy::kEnum: {
|
case AttrTy::kEnum: {
|
||||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||||
return TokenError("expects an enumeration value");
|
return TokenError("expects an enumeration value");
|
||||||
@ -4145,6 +4157,21 @@ bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloParserImpl::ParseComparisonType(Comparison::Type* result) {
|
||||||
|
VLOG(1) << "ParseComparisonType";
|
||||||
|
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||||
|
return TokenError("expects comparison type");
|
||||||
|
}
|
||||||
|
std::string val = lexer_.GetStrVal();
|
||||||
|
auto status_or_result = StringToComparisonType(val);
|
||||||
|
if (!status_or_result.ok()) {
|
||||||
|
return TokenError(StrFormat("expects comparison type but sees: %s", val));
|
||||||
|
}
|
||||||
|
*result = status_or_result.ValueOrDie();
|
||||||
|
lexer_.Lex();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
|
bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) {
|
||||||
VLOG(3) << "ParseFusionKind";
|
VLOG(3) << "ParseFusionKind";
|
||||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||||
|
@ -230,7 +230,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
|
|||||||
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
||||||
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
|
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
|
||||||
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
|
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
|
||||||
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
|
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated}
|
||||||
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
|
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -512,7 +512,7 @@ R"(HloModule R4F32OverlapSmall_module
|
|||||||
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
||||||
%lhs = f32[] parameter(0)
|
%lhs = f32[] parameter(0)
|
||||||
%rhs = f32[] parameter(1)
|
%rhs = f32[] parameter(1)
|
||||||
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
|
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER
|
||||||
}
|
}
|
||||||
|
|
||||||
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
||||||
|
@ -34,6 +34,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||||
"//tensorflow/compiler/xla/service:cholesky_expander",
|
"//tensorflow/compiler/xla/service:cholesky_expander",
|
||||||
|
"//tensorflow/compiler/xla/service:comparison_expander",
|
||||||
"//tensorflow/compiler/xla/service:compiler",
|
"//tensorflow/compiler/xla/service:compiler",
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
"//tensorflow/compiler/xla/service:custom_call_target_registry",
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/comparison_expander.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||||
@ -81,6 +82,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
|||||||
|
|
||||||
pipeline.AddPass<DynamicIndexSplitter>();
|
pipeline.AddPass<DynamicIndexSplitter>();
|
||||||
pipeline.AddPass<CholeskyExpander>();
|
pipeline.AddPass<CholeskyExpander>();
|
||||||
|
pipeline.AddPass<ComparisonExpander>();
|
||||||
pipeline.AddPass<TriangularSolveExpander>();
|
pipeline.AddPass<TriangularSolveExpander>();
|
||||||
pipeline.AddPass<LayoutAssignment>(
|
pipeline.AddPass<LayoutAssignment>(
|
||||||
hlo_module->mutable_entry_computation_layout(),
|
hlo_module->mutable_entry_computation_layout(),
|
||||||
|
@ -1203,6 +1203,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
|
|||||||
ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
|
ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) {
|
||||||
|
SetFastMathDisabled(true);
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
|
||||||
|
auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN});
|
||||||
|
EqTotalOrder(lhs, rhs);
|
||||||
|
|
||||||
|
ComputeAndCompareR1<bool>(&builder, {false, false, true, true, false}, {});
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
|
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
auto lhs = ConstantR1<float>(&builder, {});
|
auto lhs = ConstantR1<float>(&builder, {});
|
||||||
@ -1222,6 +1232,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
|
|||||||
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
|
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) {
|
||||||
|
SetFastMathDisabled(true);
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
auto lhs =
|
||||||
|
ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f, 6.0f});
|
||||||
|
auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN, -NAN});
|
||||||
|
GeTotalOrder(lhs, rhs);
|
||||||
|
|
||||||
|
ComputeAndCompareR1<bool>(&builder, {false, true, true, true, false, true},
|
||||||
|
{});
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
|
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
|
||||||
SetFastMathDisabled(true);
|
SetFastMathDisabled(true);
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
Loading…
Reference in New Issue
Block a user