[XLA] Replace individual comparison HLO ops with a single compare op.

PiperOrigin-RevId: 237318727
This commit is contained in:
A. Unique TensorFlower 2019-03-07 13:59:10 -08:00 committed by TensorFlower Gardener
parent add7a1a911
commit 57467ada28
78 changed files with 1041 additions and 537 deletions

View File

@ -57,6 +57,24 @@ xla_proto_library(
],
)
cc_library(
name = "comparison_util",
srcs = [
"comparison_util.cc",
],
hdrs = [
"comparison_util.h",
],
visibility = [":friends"],
deps = [
":statusor",
":types",
":util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "execution_options_util",
srcs = [

View File

@ -212,6 +212,7 @@ cc_library(
":padding",
":sharding_builder",
":xla_computation",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",

View File

@ -480,7 +480,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
}
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
absl::Span<const int64> broadcast_dimensions,
absl::optional<ComparisonDirection> direction) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@ -489,6 +490,17 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
ShapeInference::InferBinaryOpShape(
binop, lhs_shape, rhs_shape, broadcast_dimensions));
*instr.mutable_shape() = shape.ToProto();
if (binop == HloOpcode::kCompare) {
if (!direction.has_value()) {
return InvalidArgument(
"kCompare expects a ComparisonDirection, but none provided.");
}
instr.set_comparison_direction(ComparisonDirectionToString(*direction));
} else if (direction.has_value()) {
return InvalidArgument(
"A comparison direction is provided for a non-compare opcode: %s.",
HloOpcodeString(binop));
}
const int64 lhs_rank = lhs_shape.rank();
const int64 rhs_rank = rhs_shape.rank();
@ -2908,38 +2920,39 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs,
broadcast_dimensions);
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
}
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs,
broadcast_dimensions);
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
}
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs,
broadcast_dimensions);
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
}
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs,
broadcast_dimensions);
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
}
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs,
broadcast_dimensions);
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
}
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs,
broadcast_dimensions);
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
}
XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction) {
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
broadcast_dimensions, direction);
}
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
@ -596,9 +597,11 @@ class XlaBuilder {
// Internal helper method that does the building for an arbitrary binary op.
// broadcast_dimensions specifies which dimensions to use for broadcasting
// when the operation is between tensors of different ranks.
// when the operation is between tensors of different ranks. The direction is
// only used if opcode is kCompare.
XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions);
absl::Span<const int64> broadcast_dimensions,
absl::optional<ComparisonDirection> direction = absl::nullopt);
// Internal helper method that does the building for an arbitrary ternary op.
XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
@ -767,6 +770,9 @@ class XlaBuilder {
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction);
friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
const PrecisionConfig* precision_config);
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
@ -1279,6 +1285,11 @@ XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {});
// Enqueues a comparison instruction onto the computation.
XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions,
ComparisonDirection direction);
// Enqueues a dot instruction onto the computation.
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
const PrecisionConfig* precision_config = nullptr);

View File

@ -0,0 +1,57 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/comparison_util.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
std::string ComparisonDirectionToString(ComparisonDirection direction) {
switch (direction) {
case ComparisonDirection::kEq:
return "EQ";
case ComparisonDirection::kNe:
return "NE";
case ComparisonDirection::kGe:
return "GE";
case ComparisonDirection::kGt:
return "GT";
case ComparisonDirection::kLe:
return "LE";
case ComparisonDirection::kLt:
return "LT";
}
}
StatusOr<ComparisonDirection> StringToComparisonDirection(
absl::string_view direction_name) {
static auto* direction_map =
new absl::flat_hash_map<string, ComparisonDirection>({
{"EQ", ComparisonDirection::kEq},
{"NE", ComparisonDirection::kNe},
{"GE", ComparisonDirection::kGe},
{"GT", ComparisonDirection::kGt},
{"LE", ComparisonDirection::kLe},
{"LT", ComparisonDirection::kLt},
});
auto it = direction_map->find(direction_name);
if (it == direction_map->end()) {
return InvalidArgument("Unknown comparison direction: %s", direction_name);
}
return it->second;
}
} // namespace xla

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
// Represents different comparison operations.
enum class ComparisonDirection : uint8 {
kEq,
kNe,
kGe,
kGt,
kLe,
kLt,
};
string ComparisonDirectionToString(ComparisonDirection direction);
StatusOr<ComparisonDirection> StringToComparisonDirection(
absl::string_view direction_name);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_

View File

@ -334,6 +334,7 @@ cc_library(
":hlo_proto",
":name_uniquer",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
@ -3301,6 +3302,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",

View File

@ -873,9 +873,9 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
computation, a->shape().element_type(), a->shape().dimensions());
auto* dividend_is_negative =
computation->AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a,
zero_like_a));
computation->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
ComparisonDirection::kLt));
auto* negated_dividend = computation->AddInstruction(
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
@ -2475,9 +2475,9 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
computation, a->shape().element_type(), a->shape().dimensions());
auto* dividend_is_negative =
computation->AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a,
zero_like_a));
computation->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
ComparisonDirection::kLt));
auto* negated_dividend = computation->AddInstruction(
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));

View File

@ -221,7 +221,7 @@ HloModule foobar
%x = (f32[2,2], f32[2,2]) parameter(0)
%constant.0 = s32[] constant(0)
%constant.1 = s32[] constant(1)
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
}
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
@ -258,7 +258,7 @@ HloModule foobar
%x = (f32[2,2], f32[2,2]) parameter(0)
%constant.0 = s32[] constant(0)
%constant.1 = s32[] constant(1)
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
}
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
@ -296,7 +296,7 @@ HloModule foobar
%x = (f32[2,2], f32[2,2]) parameter(0)
%constant.0 = s32[] constant(0)
%constant.1 = s32[] constant(1)
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
}
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {

View File

@ -109,8 +109,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b));
HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq));
HloInstruction* sel = builder.AddInstruction(
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
HloInstruction* xpose =
@ -574,8 +574,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
HloInstruction::CreateParameter(0, shape, "cond_param"));
auto cond_dot =
builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}),
builder_cond.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond.AddInstruction(
@ -583,9 +583,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
builder_cond.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
{1, 1}))))));
builder_cond.AddInstruction(
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
cond_dot, {1, 1}, {2, 2}, {1, 1})))),
ComparisonDirection::kGt));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@ -631,8 +632,8 @@ TEST_F(BFloat16PropagationTest,
auto builder_cond = HloComputation::Builder("cond");
auto cond_param = builder_cond.AddInstruction(
HloInstruction::CreateParameter(0, shape, "cond_param"));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
builder_cond.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}),
builder_cond.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond.AddInstruction(HloInstruction::CreateSlice(
@ -642,7 +643,8 @@ TEST_F(BFloat16PropagationTest,
ShapeUtil::MakeShape(F32, {}),
builder_cond.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
{1, 1}))))));
{1, 1})))),
ComparisonDirection::kGt));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@ -705,8 +707,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
auto cond_dot =
builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
builder_cond.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}),
builder_cond.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond.AddInstruction(
@ -714,9 +716,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
builder_cond.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
{1, 1}))))));
builder_cond.AddInstruction(
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
cond_dot, {1, 1}, {2, 2}, {1, 1})))),
ComparisonDirection::kGt));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@ -800,8 +803,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
auto cond0_dot =
builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
builder_cond0.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}),
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond0.AddInstruction(
@ -809,9 +812,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond0.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2},
{1, 1}))))));
builder_cond0.AddInstruction(
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
cond0_dot, {1, 1}, {2, 2}, {1, 1})))),
ComparisonDirection::kGt));
auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
// Condition computation for the second while.
@ -828,8 +832,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
auto cond1_dot =
builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
builder_cond1.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}),
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond1.AddInstruction(
@ -837,9 +841,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {}),
builder_cond1.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2},
{1, 1}))))));
builder_cond1.AddInstruction(
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
cond1_dot, {1, 1}, {2, 2}, {1, 1})))),
ComparisonDirection::kGt));
auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
// Body computation shared by both whiles.

View File

@ -190,8 +190,9 @@ class BufferAssignmentTest : public HloTestBase {
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto index = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4));
builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
const4, ComparisonDirection::kLt));
return builder.Build();
}
@ -1863,8 +1864,8 @@ class WhileBufferAssignmentTest : public HloTestBase {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
auto ten = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
return builder.Build();
}
@ -2135,8 +2136,9 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4));
builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
const4, ComparisonDirection::kLt));
return builder.Build();
};
@ -2530,7 +2532,7 @@ while_condition {
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
get-tuple-element = s32[] get-tuple-element(state), index=0
get-tuple-element.1 = s32[] constant(3)
ROOT less-than.339.338 = pred[] less-than(get-tuple-element, get-tuple-element.1)
ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
}
ENTRY entry_computation {

View File

@ -83,8 +83,9 @@ class CallGraphTest : public HloTestBase {
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
zero, ComparisonDirection::kGt));
return builder.Build();
}

View File

@ -191,8 +191,9 @@ HloInstruction* GetExpandedFilterMask(
// linspace to create a diagonal predicate.
Shape predicate_shape = ShapeUtil::MakeShape(
PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
return add_instruction(HloInstruction::CreateBinary(
predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2));
return add_instruction(HloInstruction::CreateCompare(
predicate_shape, broadcasted_mask1, broadcasted_mask2,
ComparisonDirection::kEq));
}
// This function handles batch_group_counts which are relevant only for

View File

@ -420,9 +420,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
limit_const->shape(), loop_state, 0));
builder.AddInstruction(
HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
induction_variable, limit_const));
builder.AddInstruction(HloInstruction::CreateCompare(
condition_result_shape_, induction_variable, limit_const,
ComparisonDirection::kLt));
return builder.Build();
}
@ -1842,7 +1842,7 @@ HloModule TokensShouldNotBeCopied
%param = (s32[], token[]) parameter(0)
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
%constant = s32[] constant(42)
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
}
ENTRY %TokensShouldNotBeCopied () -> s32[] {
@ -2060,7 +2060,7 @@ if-condition.v4 {
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
constant.4 = s32[] constant(0)
ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
}
_functionalize_body_1__.v28 {
@ -2070,7 +2070,7 @@ _functionalize_body_1__.v28 {
add.4 = s32[] add(get-tuple-element.68, constant.7)
get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70)
less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
constant.8 = s32[] constant(0)
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
@ -2087,7 +2087,7 @@ cond_wrapper.v3.1 {
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
constant.11 = s32[] constant(7)
ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11)
ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
}
_functionalize_body_2__.v25 {
@ -2110,7 +2110,7 @@ cond_wrapper.v3.2 {
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
constant.13 = s32[] constant(5)
ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13)
ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
}
ENTRY TestComputation {
@ -2142,7 +2142,7 @@ if-condition.v4 {
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
constant.4 = s32[] constant(0)
ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
}
if-body.v5.1 {
@ -2159,7 +2159,7 @@ if-condition.v4.1 {
p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
constant.6 = s32[] constant(1)
ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6)
ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
}
_functionalize_body_1__.v28 {
@ -2169,7 +2169,7 @@ _functionalize_body_1__.v28 {
add.4 = s32[] add(get-tuple-element.72, constant.7)
get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74)
less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
constant.8 = s32[] constant(0)
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
@ -2187,7 +2187,7 @@ cond_wrapper.v3.1 {
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
constant.11 = s32[] constant(7)
ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11)
ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
}
_functionalize_body_2__.v25 {
@ -2210,7 +2210,7 @@ cond_wrapper.v3.2 {
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
constant.13 = s32[] constant(5)
ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13)
ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
}
ENTRY TestComputation {

View File

@ -29,7 +29,7 @@ HloModule KeyValueSort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY main {

View File

@ -75,7 +75,7 @@ ENTRY TestComputation {
broadcast = f32[42] broadcast(add), dimensions={}
slice = f32[1] slice(broadcast), slice={[1:2]}
copy = f32[] copy(arg)
eq = pred[] equal-to(arg, gte)
eq = pred[] compare(arg, gte), direction=EQ
neg = f32[] negate(arg)
ROOT convert = f64[] convert(f32[] arg)
})";

View File

@ -68,8 +68,8 @@ class DynamicDimensionInferenceTest : public HloTestBase {
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
embedded_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs));
embedded_builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
return module_->AddEmbeddedComputation(embedded_builder.Build());
}

View File

@ -160,9 +160,10 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
HloInstruction* broadcasted_effective_size =
computation->AddInstruction(HloInstruction::CreateBroadcast(
mask_shape, dynamic_size, {}));
HloInstruction* pred = computation->AddInstruction(
HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota,
broadcasted_effective_size));
HloInstruction* pred =
computation->AddInstruction(HloInstruction::CreateCompare(
pred_shape, iota, broadcasted_effective_size,
ComparisonDirection::kLt));
HloInstruction* broadcasted_identity_value =
computation->AddInstruction(HloInstruction::CreateBroadcast(

View File

@ -719,25 +719,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
// We use ordered comparisons for everything except kNe, where we use an
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
rhs_value, b_);
case HloOpcode::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
rhs_value, b_);
case HloOpcode::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
rhs_value, b_);
case HloOpcode::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
rhs_value, b_);
case HloOpcode::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
rhs_value, b_);
case HloOpcode::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
rhs_value, b_);
case HloOpcode::kCompare: {
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
rhs_value, b_);
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
rhs_value, b_);
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
rhs_value, b_);
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
rhs_value, b_);
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
rhs_value, b_);
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
rhs_value, b_);
}
}
case HloOpcode::kMaximum:
return EmitFloatMax(lhs_value, rhs_value);
case HloOpcode::kMinimum:
@ -839,21 +842,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
// We use ordered comparisons for everything except kNe, where we use an
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractReal(lhs_value),
EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractImag(lhs_value),
EmitExtractImag(rhs_value), b_));
case HloOpcode::kNe:
return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractReal(lhs_value),
EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractImag(lhs_value),
EmitExtractImag(rhs_value), b_));
case HloOpcode::kCompare: {
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractReal(lhs_value),
EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractImag(lhs_value),
EmitExtractImag(rhs_value), b_));
case ComparisonDirection::kNe:
return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractReal(lhs_value),
EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractImag(lhs_value),
EmitExtractImag(rhs_value), b_));
default:
return Unimplemented(
"complex comparison '%s'",
ComparisonDirectionToString(op->comparison_direction()));
}
}
case HloOpcode::kPower: {
auto a = EmitExtractReal(lhs_value);
auto b = EmitExtractImag(lhs_value);
@ -1278,28 +1288,32 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
case HloOpcode::kRemainder:
return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
rhs_value, b_);
case HloOpcode::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
rhs_value, b_);
case HloOpcode::kLt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
lhs_value, rhs_value, b_);
case HloOpcode::kGt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
lhs_value, rhs_value, b_);
case HloOpcode::kLe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
lhs_value, rhs_value, b_);
case HloOpcode::kGe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, b_);
case HloOpcode::kCompare: {
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
rhs_value, b_);
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
rhs_value, b_);
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
lhs_value, rhs_value, b_);
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
lhs_value, rhs_value, b_);
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
lhs_value, rhs_value, b_);
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, b_);
}
}
case HloOpcode::kMinimum:
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
case HloOpcode::kMaximum:
@ -2197,17 +2211,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kPower:

View File

@ -81,8 +81,9 @@ class FlattenCallGraphTest : public HloTestBase {
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
zero, ComparisonDirection::kGt));
return builder.Build();
}
@ -158,9 +159,9 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
0, ShapeUtil::MakeShape(PRED, {}), "param0"));
HloInstruction* false_constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kEq, param0, false_constant));
builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), param0, false_constant,
ComparisonDirection::kEq));
cond_computation = module->AddEmbeddedComputation(builder.Build());
}

View File

@ -62,7 +62,7 @@ TEST_F(GpuFusibleTest,
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
c0 = f16[] constant(0)
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
}
fused_reduce {
@ -122,7 +122,7 @@ TEST_F(GpuFusibleTest,
p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
c0 = f16[] constant(0)
broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={}
greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast)
greater-than = pred[128,1024,32,32]{3,2,1,0} compare(p1.1, broadcast), direction=GT
select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast)
ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select)
}
@ -507,7 +507,7 @@ TEST_F(GpuFusibleTest,
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
c0 = f32[] constant(0)
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
}

View File

@ -374,7 +374,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = f32[] parameter(2)
p.1.rhs = f32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort {

View File

@ -437,7 +437,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
c0 = f32[] constant(0)
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
}
@ -505,7 +505,7 @@ TEST_F(MultiOutputFusionTest,
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
c0 = f16[] constant(0)
broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast)
greater-than = pred[2,2,2]{2,1,0} compare(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast), direction=GT
p0.1 = f16[2,2,2]{2,1,0} parameter(0)
ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
}
@ -548,7 +548,7 @@ TEST_F(MultiOutputFusionTest,
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
c0 = f16[] constant(0)
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
}
fused_reduce {

View File

@ -48,8 +48,9 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) {
HloInstruction::CreateParameter(0, param_shape, "x"));
HloInstruction* param_y = builder.AddInstruction(
HloInstruction::CreateParameter(1, param_shape, "y"));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y));
builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {5, 7, 2}), param_x, param_y,
ComparisonDirection::kGe));
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
@ -73,7 +74,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) {
x = f32[5,7,2]{2,1,0} parameter(0)
y = f32[5,14]{1,0} parameter(1)
reshape = f32[5,7,2]{2,1,0} reshape(y)
ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape)
ROOT gte = pred[5,7,2]{2,1,0} compare(x, reshape), direction=GE
})",
config)
.ValueOrDie();
@ -98,7 +99,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) {
y = f32[14]{0} parameter(1)
reshape = f32[7,2]{1,0} reshape(y)
broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2}
ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast)
ROOT gte = pred[5,7,2]{2,1,0} compare(x, broadcast), direction=GE
})",
config)
.ValueOrDie();

View File

@ -44,9 +44,9 @@ class WhileTransformerTest : public HloTestBase {
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
limit_const->shape(), loop_state, tuple_index));
builder.AddInstruction(
HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
induction_variable, limit_const));
builder.AddInstruction(HloInstruction::CreateCompare(
condition_result_shape_, induction_variable, limit_const,
ComparisonDirection::kLt));
return builder.Build();
}

View File

@ -54,8 +54,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kLt, cond_iter, cond_data));
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_data, ComparisonDirection::kLt));
HloComputation* cond_computation =
module->AddEmbeddedComputation(cond_builder.Build());
@ -113,7 +113,8 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
// %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
// %reshape = f32[] reshape(f32[1]{0} %slice)
// %constant = f32[] constant(0)
// ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant)
// ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
// direction=NE
// }
// ENTRY %SubcomputationAccounting () -> f32[2,4] {
@ -143,9 +144,9 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
HloInstruction* zero = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
HloInstruction* cond_comparison =
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero));
HloInstruction* cond_comparison = cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
zero, ComparisonDirection::kNe));
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
// param - 1
@ -703,8 +704,8 @@ TEST_F(HeapSimulatorTest, WholeModule) {
HloInstruction* cond_data = cond_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
HloInstruction* cond_lt = cond_builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kLt, cond_iter, cond_data));
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
cond_data, ComparisonDirection::kLt));
HloComputation* cond_computation =
tracker.module()->AddEmbeddedComputation(cond_builder.Build());

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 63
// Next ID: 64
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -146,6 +146,9 @@ message HloInstructionProto {
// FFT length.
repeated int64 fft_length = 32;
// Comparison direction only used for kCompare.
string comparison_direction = 63;
// Gather dimension numbers.
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
repeated int64 gather_slice_sizes = 34;

View File

@ -509,8 +509,9 @@ TEST_F(HloComputationTest, CloneWithReplacements) {
HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
auto param2 =
builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
auto lt = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1));
auto lt = builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
param1, ComparisonDirection::kLt));
auto module = CreateNewVerifiedModule();
auto computation =
module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));

View File

@ -42,6 +42,18 @@ StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
}
StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
HloInstruction* lhs,
HloInstruction* rhs) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(
Shape binary_op_shape,
ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
return computation->AddInstruction(
HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
}
StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
HloInstruction* padding_value,
const PaddingConfig& padding_config) {

View File

@ -32,6 +32,12 @@ namespace xla {
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
HloInstruction* rhs);
// Creates a compare HLO instruction and adds it to the computation containing
// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
HloInstruction* lhs,
HloInstruction* rhs);
// Creates a pad HLO instruction and adds it to the computation containing
// `operand` and `padding_value` (`operand` and `padding_value` must be in the
// same computation).

View File

@ -2239,8 +2239,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
HloInstruction::CreateParameter(0, in_shape, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, in_shape, "param1"));
auto result = builder.AddInstruction(
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
auto result = builder.AddInstruction(HloInstruction::CreateCompare(
out_shape, param0, param1, ComparisonDirection::kEq));
BuildModuleAndRunAnalysis(builder.Build());
@ -2563,8 +2563,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
return builder.Build();
};

View File

@ -223,8 +223,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
HloInstruction::CreateParameter(0, shape, "cond_param"));
auto constant = cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant));
cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
constant, ComparisonDirection::kLt));
}
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());

View File

@ -56,43 +56,40 @@ namespace xla {
namespace {
template <typename OperandT>
StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
switch (direction) {
case ComparisonDirection::kEq:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
case ComparisonDirection::kNe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el != rhs_el;
};
break;
case HloOpcode::kGe:
case ComparisonDirection::kGe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el >= rhs_el;
};
break;
case HloOpcode::kGt:
case ComparisonDirection::kGt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el > rhs_el;
};
break;
case HloOpcode::kLe:
case ComparisonDirection::kLe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el <= rhs_el;
};
break;
case HloOpcode::kLt:
case ComparisonDirection::kLt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el < rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
}
Literal result(shape);
@ -106,24 +103,25 @@ StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
}
template <>
StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
StatusOr<Literal> Compare<complex64>(const Shape& shape,
ComparisonDirection direction,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
switch (direction) {
case ComparisonDirection::kEq:
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
case ComparisonDirection::kNe:
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
return lhs_el != rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
<< ComparisonDirectionToString(direction);
}
Literal result(shape);
@ -137,24 +135,25 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
}
template <>
StatusOr<Literal> Compare<complex128>(const Shape& shape, HloOpcode opcode,
StatusOr<Literal> Compare<complex128>(const Shape& shape,
ComparisonDirection direction,
LiteralSlice lhs_literal,
LiteralSlice rhs_literal) {
std::function<bool(complex128, complex128)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
switch (direction) {
case ComparisonDirection::kEq:
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
case ComparisonDirection::kNe:
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
return lhs_el != rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
<< ComparisonDirectionToString(direction);
}
Literal result(shape);
@ -671,7 +670,7 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) {
}
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
HloOpcode opcode = compare->opcode();
ComparisonDirection direction = compare->comparison_direction();
auto lhs = compare->operand(0);
auto rhs = compare->operand(1);
DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
@ -687,76 +686,76 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
case PRED: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
} break;
case U8: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint8>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case U16: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint16>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint16>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case U32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint32>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case U64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<uint64>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case S8: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
} break;
case S16: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int16>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<int16>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case S32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<int32>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case S64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<int64>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case F16: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<half>(compare->shape(), opcode, lhs_literal, rhs_literal));
Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
} break;
case BF16: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<bfloat16>(compare->shape(), opcode,
Compare<bfloat16>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case F32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<float>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case F64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<double>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case C64: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<complex64>(compare->shape(), opcode,
Compare<complex64>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
case C128: {
TF_ASSIGN_OR_RETURN(evaluated_[compare],
Compare<complex128>(compare->shape(), opcode,
Compare<complex128>(compare->shape(), direction,
lhs_literal, rhs_literal));
} break;
default:

View File

@ -2848,8 +2848,15 @@ TEST_F(HloEvaluatorTest, DoesCompareBF16) {
{bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
auto expected =
LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs),
std::move(rhs));
HloComputation::Builder b(TestName());
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
ComparisonDirection::kGe));
m_->AddEntryComputation(b.Build());
EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate()));
}
TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {

View File

@ -258,14 +258,16 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
// param0), check that the operation being performed is commutative.
if (root->operand(0) == param1) {
CHECK_EQ(root->operand(1), param0);
switch (root->opcode()) {
case HloOpcode::kLe:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLt:
return nullopt;
default:
break;
if (root->opcode() == HloOpcode()) {
switch (root->comparison_direction()) {
case ComparisonDirection::kLe:
case ComparisonDirection::kGe:
case ComparisonDirection::kGt:
case ComparisonDirection::kLt:
return nullopt;
default:
break;
}
}
}
@ -279,18 +281,22 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
return "min";
case HloOpcode::kMaximum:
return "max";
case HloOpcode::kLe:
return "less-or-equal";
case HloOpcode::kGe:
return "greater-or-equal";
case HloOpcode::kGt:
return "greater-than";
case HloOpcode::kLt:
return "less-than";
case HloOpcode::kEq:
return "equal-to";
case HloOpcode::kNe:
return "not-equal-to";
case HloOpcode::kCompare: {
switch (root->comparison_direction()) {
case ComparisonDirection::kLe:
return "less-or-equal";
case ComparisonDirection::kGe:
return "greater-or-equal";
case ComparisonDirection::kGt:
return "greater-than";
case ComparisonDirection::kLt:
return "less-than";
case ComparisonDirection::kEq:
return "equal-to";
case ComparisonDirection::kNe:
return "not-equal-to";
}
}
default:
return nullopt;
}
@ -922,27 +928,22 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConvert:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kIota:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
@ -31,6 +32,8 @@ namespace {
using absl::StrCat;
using ::testing::HasSubstr;
using HloGraphDumperTest = HloTestBase;
string TestName() {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
@ -48,7 +51,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
TEST(HloGraphDumperTest, NestedFusion) {
TEST_F(HloGraphDumperTest, NestedFusion) {
HloComputation::Builder b("b");
// Build param0 + param1 + param2 + param3 + param4.
@ -118,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) {
HasSubstr(inner_sum->name()));
}
TEST(HloGraphDumperTest, Constant) {
TEST_F(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
@ -132,7 +135,7 @@ TEST(HloGraphDumperTest, Constant) {
EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
}
TEST(HloGraphDumperTest, TupleConstant) {
TEST_F(HloGraphDumperTest, TupleConstant) {
Shape tuple_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
HloComputation::Builder b("b");
@ -150,5 +153,21 @@ TEST(HloGraphDumperTest, TupleConstant) {
EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
}
TEST_F(HloGraphDumperTest, Compare) {
const char* hlo_string = R"(
HloModule comp
ENTRY comp {
param.0 = f32[10] parameter(0)
param.1 = f32[10] parameter(1)
ROOT lt = pred[10] compare(param.0, param.1), direction=LT
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
string graph = hlo_graph_dumper::DumpGraph(*module->entry_computation(),
/*label=*/"comp", DebugOptions());
EXPECT_THAT(graph, HasSubstr("direction=LT"));
}
} // anonymous namespace
} // namespace xla

View File

@ -64,7 +64,35 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
HloOpcode opcode;
auto opcode_or = StringToHloOpcode(proto.opcode());
absl::optional<ComparisonDirection> comparison_direction;
if (opcode_or.ok()) {
opcode = opcode_or.ConsumeValueOrDie();
} else {
// Unknown opcode. Try auto-upgrading deprecated "less-than",
// "greater-than", etc opcodes, which are now rolled into the kCompare
// opcode.
if (proto.opcode() == "equal-to") {
comparison_direction = ComparisonDirection::kEq;
} else if (proto.opcode() == "not-equal-to") {
comparison_direction = ComparisonDirection::kNe;
} else if (proto.opcode() == "greater-than-or-equal-to") {
comparison_direction = ComparisonDirection::kGe;
} else if (proto.opcode() == "greater-than") {
comparison_direction = ComparisonDirection::kGt;
} else if (proto.opcode() == "less-than-or-equal-to") {
comparison_direction = ComparisonDirection::kLe;
} else if (proto.opcode() == "less-than") {
comparison_direction = ComparisonDirection::kLt;
}
if (comparison_direction) {
opcode = HloOpcode::kCompare;
} else {
return InvalidArgument("Unknown opcode: %s", proto.opcode());
}
}
TF_RET_CHECK(proto.has_shape());
std::unique_ptr<HloInstruction> instruction;
@ -136,6 +164,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
absl::Span<const int64>(fft_length));
break;
}
case HloOpcode::kCompare: {
// Auto-upgraded from deprecated opcode skips the following.
if (!comparison_direction) {
TF_ASSIGN_OR_RETURN(
comparison_direction,
StringToComparisonDirection(proto.comparison_direction()));
}
instruction =
CreateCompare(shape, operands(0), operands(1), *comparison_direction);
break;
}
case HloOpcode::kTriangularSolve: {
instruction = CreateTriangularSolve(shape, operands(0), operands(1),
proto.triangular_solve_options());
@ -688,15 +727,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
@ -761,6 +794,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
ComparisonDirection direction) {
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
HloInstruction* b,
@ -1311,6 +1350,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kCompare:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
@ -1384,12 +1424,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
@ -1705,26 +1739,20 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kEq:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kPower:
case HloOpcode::kReal:
@ -1772,6 +1800,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kCompare:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
@ -2119,17 +2148,12 @@ bool HloInstruction::IsElementwiseImpl(
// Binary elementwise operations, the same as in IsElementwiseBinary().
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSubtract:
@ -2472,12 +2496,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleGetTupleElement(this);
case HloOpcode::kParameter:
return visitor->HandleParameter(this);
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
case HloOpcode::kCompare:
return visitor->HandleCompare(this);
case HloOpcode::kComplex:
return visitor->HandleComplex(this);
@ -3519,6 +3538,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const {
return Cast<HloDomainInstruction>(this)->user_side_metadata();
}
ComparisonDirection HloInstruction::comparison_direction() const {
return Cast<HloCompareInstruction>(this)->direction();
}
const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
}

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
@ -444,6 +445,11 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, FftType fft_type,
absl::Span<const int64> fft_length);
// Creates a compare op, performing the comparison specified in direction.
static std::unique_ptr<HloInstruction> CreateCompare(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
ComparisonDirection direction);
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
const Shape& shape, HloInstruction* a, HloInstruction* b,
const TriangularSolveOptions& options);
@ -1600,6 +1606,9 @@ class HloInstruction {
// Delegates to HloDomainInstruction::user_side_metadata().
const DomainMetadata& user_side_metadata() const;
// Delegates to HloCompareInstruction::direction().
ComparisonDirection comparison_direction() const;
// Delegates to HloTriangularSolveInstruction::triangular_solve_options().
const TriangularSolveOptions& triangular_solve_options() const;

View File

@ -1655,7 +1655,7 @@ body (bparam: s32[]) -> s32[] {
condition (cparam: s32[]) -> pred[] {
xconstant = s32[] constant(5)
cparam = s32[] parameter(0)
ROOT greater-than = pred[] greater-than(xconstant, cparam)
ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
}
ENTRY entry (param: s32[]) -> s32[] {

View File

@ -202,6 +202,42 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
fft_length_);
}
HloCompareInstruction::HloCompareInstruction(const Shape& shape,
HloInstruction* lhs,
HloInstruction* rhs,
ComparisonDirection direction)
: HloInstruction(HloOpcode::kCompare, shape), direction_(direction) {
AppendOperand(lhs);
AppendOperand(rhs);
}
HloInstructionProto HloCompareInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_comparison_direction(ComparisonDirectionToString(direction_));
return proto;
}
std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {StrCat("direction=", ComparisonDirectionToString(direction()))};
}
bool HloCompareInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
return direction() == casted_other.direction();
}
std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
new_operands[1], direction());
}
namespace {
// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector

View File

@ -131,6 +131,28 @@ class HloFftInstruction : public HloInstruction {
std::vector<int64> fft_length_;
};
class HloCompareInstruction : public HloInstruction {
public:
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
HloInstruction* rhs,
ComparisonDirection direction);
ComparisonDirection direction() const { return direction_; }
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
ComparisonDirection direction_;
};
class HloTriangularSolveInstruction : public HloInstruction {
public:
explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,

View File

@ -255,7 +255,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
@ -308,7 +308,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
constant.2 = s32[] constant(5)
ROOT less-than = pred[] less-than(add.1, constant.2)
ROOT less-than = pred[] compare(add.1, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
@ -360,7 +360,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
loop_var.2 = (s32[], s32[], s32[]) parameter(0)
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
constant.1 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1)
ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT
}
ENTRY SimpleLoop {
constant.2 = s32[] constant(0)
@ -415,7 +415,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
cond_param = (s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
constant.2 = s32[] constant(10)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
@ -448,13 +448,13 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
cond_param = (s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
constant.2 = s32[] constant(10)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
OuterWhileCondition {
cond_param.2 = (s32[]) parameter(0)
get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
constant.5 = s32[] constant(5)
ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT
}
OuterWhileBody {
body_param.2 = (s32[]) parameter(0)

View File

@ -89,6 +89,22 @@ bool HloParameterMatcher::MatchAndExplain(
return true;
}
bool HloComparisonMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
if (instruction->comparison_direction() != direction_) {
*listener << "has wrong comparison direction (got "
<< ComparisonDirectionToString(
instruction->comparison_direction())
<< ", want " << ComparisonDirectionToString(direction_) << ")";
return false;
}
return true;
}
bool HloGetTupleElementMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {

View File

@ -54,6 +54,21 @@ class HloParameterMatcher : public HloMatcher {
int64 parameter_number_;
};
// Custom matcher for comparisons, which accepts a comparison direction.
class HloComparisonMatcher : public HloMatcher {
public:
explicit HloComparisonMatcher(
ComparisonDirection direction,
std::vector<::testing::Matcher<const HloInstruction*>> operands)
: HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
bool MatchAndExplain(const HloInstruction* instruction,
::testing::MatchResultListener* listener) const override;
private:
ComparisonDirection direction_;
};
// Custom matcher for get-tuple-element instructions, which accepts a tuple
// index to match.
class HloGetTupleElementMatcher : public HloMatcher {
@ -172,6 +187,7 @@ HLO_MATCHER(BatchNormGrad);
HLO_MATCHER(Call);
HLO_MATCHER(Ceil);
HLO_MATCHER(Clamp);
HLO_MATCHER(Compare);
HLO_MATCHER(Concatenate);
HLO_MATCHER(Conditional);
HLO_MATCHER(Constant);
@ -184,28 +200,22 @@ HLO_MATCHER(Divide);
HLO_MATCHER(Domain);
HLO_MATCHER(DynamicSlice);
HLO_MATCHER(DynamicUpdateSlice);
HLO_MATCHER(Eq);
HLO_MATCHER(Exp);
HLO_MATCHER(Floor);
HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
HLO_MATCHER(AfterAll);
HLO_MATCHER(Gt);
HLO_MATCHER(Iota);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
HLO_MATCHER(Le);
HLO_MATCHER(Log);
HLO_MATCHER(And);
HLO_MATCHER(Not);
HLO_MATCHER(Or);
HLO_MATCHER(Xor);
HLO_MATCHER(Lt);
HLO_MATCHER(Map);
HLO_MATCHER(Maximum);
HLO_MATCHER(Minimum);
HLO_MATCHER(Multiply);
HLO_MATCHER(Ne);
HLO_MATCHER(Negate);
HLO_MATCHER(Outfeed);
HLO_MATCHER(Pad);
@ -256,6 +266,38 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
}
// Comparison matchers below do not require any additional arguments.
template <typename... M>
inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
ComparisonDirection::kEq, {operands...}));
}
template <typename... M>
inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
ComparisonDirection::kNe, {operands...}));
}
template <typename... M>
inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
ComparisonDirection::kGe, {operands...}));
}
template <typename... M>
inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
ComparisonDirection::kGt, {operands...}));
}
template <typename... M>
inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
ComparisonDirection::kLe, {operands...}));
}
template <typename... M>
inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
ComparisonDirection::kLt, {operands...}));
}
// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
// tuple element of operand, while GetTupleElement(operand) matches any GTE
// operation on operand, and GetTupleElement() matches any GTE operation at all.

View File

@ -220,5 +220,33 @@ ENTRY DotOperationFusion_TransposeFusion {
"rhs_contracting_dimensions (got {0} want {1})");
}
TEST(HloMatchersTest, ComparisonMatcher) {
auto shape = ShapeUtil::MakeShape(F32, {1});
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
ComparisonDirection::kEq);
auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
ComparisonDirection::kNe);
auto add =
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
ComparisonDirection::kLe);
EXPECT_THAT(eq.get(), op::Compare());
EXPECT_THAT(eq.get(), op::Eq());
EXPECT_THAT(ne.get(), op::Compare());
EXPECT_THAT(ne.get(), op::Ne());
EXPECT_THAT(le.get(),
op::Compare(op::Parameter(0),
op::Add(op::Parameter(0), op::Parameter(1))));
EXPECT_THAT(le.get(), op::Le(op::Parameter(0),
op::Add(op::Parameter(0), op::Parameter(1))));
EXPECT_THAT(Explain(eq.get(), op::Add()), Eq(""));
EXPECT_THAT(Explain(eq.get(), op::Ne()),
Eq("has wrong comparison direction (got EQ, want NE)"));
}
} // namespace
} // namespace xla

View File

@ -254,8 +254,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
HloInstruction* zero_vector =
cond_builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({0, 0, 0, 0})));
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
cond_builder.AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param,
zero_vector, ComparisonDirection::kNe));
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
// param - 1

View File

@ -86,7 +86,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
@ -125,7 +125,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
loop_var.2 = (s32[], f32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.3 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT
}
ENTRY SimpleLoop {
constant.4 = s32[] constant(0)
@ -163,7 +163,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
@ -206,7 +206,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
loop_var.2 = (s32[], s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
constant.2 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
@ -248,7 +248,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
SimpleLoop.body1 {
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
@ -263,7 +263,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
constant.4 = s32[] constant(5)
ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
}
ENTRY SimpleLoop {
constant.5 = s32[] constant(0)
@ -316,7 +316,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
constant.2 = s32[] constant(5)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
SimpleLoop.body1 {
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
@ -331,7 +331,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
constant.4 = s32[] constant(5)
ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
}
ENTRY SimpleLoop {
constant.5 = s32[] constant(0)
@ -383,7 +383,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) {
cond_param = (s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
constant.2 = s32[] constant(10)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
@ -418,7 +418,7 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
cond_param = (s32[], s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
constant.2 = s32[] constant(10)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
p0 = (s32[]) parameter(0)

View File

@ -44,21 +44,8 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
return it->second;
}
#define CHECK_DEFAULT(property_name, opcode_name) false
#define CHECK_PROPERTY(property_name, opcode_name, value) \
(value & property_name)
#define RESOLVE(_1, _2, target, ...) target
#define HAS_PROPERTY(property, ...) \
RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__)
bool HloOpcodeIsComparison(HloOpcode opcode) {
switch (opcode) {
#define CASE_IS_COMPARISON(enum_name, opcode_name, ...) \
case HloOpcode::enum_name: \
return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__);
HLO_OPCODE_LIST(CASE_IS_COMPARISON)
#undef CASE_IS_COMPARISON
}
return opcode == HloOpcode::kCompare;
}
bool HloOpcodeIsVariadic(HloOpcode opcode) {
@ -82,9 +69,4 @@ absl::optional<int> HloOpcodeArity(HloOpcode opcode) {
}
}
#undef HAS_PROPERTY
#undef RESOLVE
#undef CHECK_DEFAULT
#undef CHECK_PROPERTY
} // namespace xla

View File

@ -19,8 +19,10 @@ limitations under the License.
#include <iosfwd>
#include <string>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -65,6 +67,7 @@ namespace xla {
V(kClamp, "clamp", 3) \
V(kCollectivePermute, "collective-permute", 1) \
V(kClz, "count-leading-zeros", 1) \
V(kCompare, "compare", 2) \
V(kComplex, "complex", 2) \
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
V(kConditional, "conditional", kHloOpcodeIsVariadic) \
@ -79,34 +82,28 @@ namespace xla {
V(kDot, "dot", 2) \
V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \
V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \
V(kEq, "equal-to", 2, kHloOpcodeIsComparison) \
V(kExp, "exponential", 1) \
V(kExpm1, "exponential-minus-one", 1) \
V(kFft, "fft", 1) \
V(kFloor, "floor", 1) \
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGather, "gather", 2) \
V(kGe, "greater-than-or-equal-to", 2, kHloOpcodeIsComparison) \
V(kGetDimensionSize, "get-dimension-size", 1) \
V(kGetTupleElement, "get-tuple-element", 1) \
V(kGt, "greater-than", 2, kHloOpcodeIsComparison) \
V(kImag, "imag", 1) \
V(kInfeed, "infeed", 1) \
V(kIota, "iota", 0) \
V(kIsFinite, "is-finite", 1) \
V(kLe, "less-than-or-equal-to", 2, kHloOpcodeIsComparison) \
V(kLog, "log", 1) \
V(kLog1p, "log-plus-one", 1) \
V(kAnd, "and", 2) \
V(kNot, "not", 1) \
V(kOr, "or", 2) \
V(kXor, "xor", 2) \
V(kLt, "less-than", 2, kHloOpcodeIsComparison) \
V(kMap, "map", kHloOpcodeIsVariadic) \
V(kMaximum, "maximum", 2) \
V(kMinimum, "minimum", 2) \
V(kMultiply, "multiply", 2) \
V(kNe, "not-equal-to", 2, kHloOpcodeIsComparison) \
V(kNegate, "negate", 1) \
V(kOutfeed, "outfeed", 2) \
V(kPad, "pad", 2) \

View File

@ -42,12 +42,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
// Test some properties.
switch (opcode) {
case HloOpcode::kEq:
case HloOpcode::kNe:
case HloOpcode::kGt:
case HloOpcode::kLt:
case HloOpcode::kGe:
case HloOpcode::kLe:
case HloOpcode::kCompare:
EXPECT_TRUE(HloOpcodeIsComparison(opcode));
break;
default:

View File

@ -306,7 +306,7 @@ condition.v4 {
constant.2 = s32[] constant(2)
prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8)
ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT
}
fused_computation {

View File

@ -183,6 +183,7 @@ class HloParser {
kHloComputation,
kBracedHloComputationList,
kFftType,
kComparisonDirection,
kWindow,
kConvolutionDimensionNumbers,
kSharding,
@ -300,6 +301,7 @@ class HloParser {
bool ParseTiles(std::vector<Tile>* tiles);
bool ParseOpcode(HloOpcode* result);
bool ParseFftType(FftType* result);
bool ParseComparisonDirection(ComparisonDirection* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
bool ParsePrecision(PrecisionConfig::Precision* result);
@ -763,12 +765,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kSubtract:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
@ -1133,6 +1129,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
shape, operands[0], operands[1], options));
break;
}
case HloOpcode::kCompare: {
optional<ComparisonDirection> direction;
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
&direction};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
shape, operands[0], operands[1], *direction));
break;
}
case HloOpcode::kCholesky: {
CholeskyOptions options;
if (!ParseOperands(&operands, /*expected_size=*/1) ||
@ -2728,6 +2736,15 @@ bool HloParser::ParseAttributeHelper(
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kComparisonDirection: {
ComparisonDirection result;
if (!ParseComparisonDirection(&result)) {
return false;
}
static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kWindow: {
Window result;
if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
@ -3756,6 +3773,22 @@ bool HloParser::ParseFftType(FftType* result) {
return true;
}
bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
VLOG(1) << "ParseComparisonDirection";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects comparison direction");
}
string val = lexer_.GetStrVal();
auto status_or_result = StringToComparisonDirection(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects comparison direction but sees: %s", val));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
VLOG(1) << "ParseFusionKind";
if (lexer_.GetKind() != TokKind::kIdent) {

View File

@ -222,7 +222,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
}
@ -292,7 +292,7 @@ R"(HloModule WhileWithScalarS32Result_module
%condition.v3 (prev.2: s32[]) -> pred[] {
%constant.1 = s32[] constant(5)
%prev.2 = s32[] parameter(0)
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2)
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
}
ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
@ -474,7 +474,7 @@ R"(HloModule R4F32OverlapSmall_module
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
}
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
@ -500,7 +500,7 @@ R"(HloModule select_and_scatter_scalar
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
}
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
@ -1037,7 +1037,7 @@ R"(HloModule TupleReduce
max_argmax {
value = f32[] parameter(2)
prev_max = f32[] parameter(0)
is_next_larger = pred[] greater-than-or-equal-to(value, prev_max)
is_next_larger = pred[] compare(value, prev_max), direction=GE
max = f32[] select(is_next_larger, value, prev_max)
index = s32[] parameter(3)
prev_argmax = s32[] parameter(1)
@ -1106,7 +1106,7 @@ R"(HloModule sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
@ -1126,7 +1126,7 @@ compare {
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
@ -1145,7 +1145,7 @@ R"(HloModule sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
@ -1165,7 +1165,7 @@ compare {
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
@ -1190,7 +1190,7 @@ compare {
p.3.rhs = f32[] parameter(7)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
@ -1211,7 +1211,7 @@ R"(HloModule sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
@ -1469,7 +1469,7 @@ compare {
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
@ -1656,7 +1656,7 @@ TEST_F(HloParserTest, WrongOperandsSize) {
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x)
%eq = pred[]{} compare(f32[]{} %x), direction=EQ
}
)";
@ -1668,7 +1668,7 @@ TEST_F(HloParserTest, OperandNotFound) {
const string original = R"(HloModule operand_not_found:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
%eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
}
)";
auto result = ParseHloString(original);

View File

@ -228,7 +228,7 @@ HloModule UpdateScheduleWithMultipleComputations
%param = (s32[], token[]) parameter(0)
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
%constant = s32[] constant(42)
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
}
ENTRY %WhileLoop () -> s32[] {
@ -297,7 +297,7 @@ HloModule UpdateScheduleWithMultipleComputations
%param = (s32[], token[]) parameter(0)
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
%constant = s32[] constant(42)
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
}
ENTRY %WhileLoop () -> s32[] {

View File

@ -65,6 +65,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConstant:
@ -72,21 +73,15 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kCopy:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kEq:
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGetTupleElement:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kInfeed:
case HloOpcode::kIota:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:

View File

@ -2001,6 +2001,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConditional:
@ -2012,24 +2013,18 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kDivide:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kEq:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFft:
case HloOpcode::kFloor:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLe:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kLt:
case HloOpcode::kMap:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:

View File

@ -1084,7 +1084,7 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
counter.1 = s32[] get-tuple-element(tup.1), index=0
five = s32[] constant(5)
ROOT lt = pred[] less-than(counter.1, five)
ROOT lt = pred[] compare(counter.1, five), direction=LT
}
body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {

View File

@ -46,7 +46,7 @@ condition {
condition.state = f32[] parameter(0)
addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget"
add = f32[] add(addend, condition.state)
ROOT greater-than = pred[] greater-than(const.100, add)
ROOT greater-than = pred[] compare(const.100, add), direction=GT
}
ENTRY while3 {

View File

@ -67,6 +67,7 @@ namespace xla {
// - WithOneUse: Instruction is used as an operand exactly once.
// - WithOneUser: Instruction is used by exactly one other instruction, but
// is possibly used more than once as an operand (e.g. multiply(x,x)).
// - WithComparisonDirection: instr has the given direction
//
// Shape():
// - EqualTo
@ -1671,6 +1672,40 @@ class HloInstructionPatternOneUserImpl
}
};
class HloInstructionPatternComparisonDirectionImpl {
public:
explicit constexpr HloInstructionPatternComparisonDirectionImpl(
ComparisonDirection direction)
: direction_(direction) {}
bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
return MatchImpl(inst, option);
}
bool Match(::xla::HloInstruction* inst, MatchOption option) const {
return MatchImpl(inst, option);
}
void DescribeTo(std::ostream* os, int64 indent = 0) const {
*os << "which has comparison direction "
<< ComparisonDirectionToString(direction_);
}
private:
template <typename HloInstructionType>
bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
if (inst->opcode() != HloOpcode::kCompare ||
inst->comparison_direction() != direction_) {
EXPLAIN << "HloInstruction is not comparison "
<< ComparisonDirectionToString(direction_);
return false;
}
return true;
}
ComparisonDirection direction_;
};
// Matches a constant scalar or effective scalar, optionally with a given value.
template <typename ScalarTy>
class HloConstantScalarImpl {
@ -1956,6 +1991,14 @@ class HloInstructionPattern {
return AppendImpl(HloInstructionPatternOneUserImpl());
}
// Modifies the pattern to match only if the instruction has the given
// comparison direction.
auto WithComparisonDirection(ComparisonDirection direction) const
-> decltype(this->AppendImpl(
HloInstructionPatternComparisonDirectionImpl(direction))) {
return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
}
void DescribeTo(std::ostream* os, int64 indent = 0) const {
impl_.DescribeTo(os, indent);
}
@ -2118,18 +2161,13 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Add)
XLA_BINOP_PATTERN(Atan2)
XLA_BINOP_PATTERN(Divide)
XLA_BINOP_PATTERN(Complex)
XLA_BINOP_PATTERN(Compare)
XLA_BINOP_PATTERN(Convolution)
XLA_BINOP_PATTERN(Dot)
XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
XLA_BINOP_PATTERN(Gather)
XLA_BINOP_PATTERN(Ge)
XLA_BINOP_PATTERN(Gt)
XLA_BINOP_PATTERN(Le)
XLA_BINOP_PATTERN(Lt)
XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
XLA_BINOP_PATTERN(Outfeed)
XLA_BINOP_PATTERN(Pad)
XLA_BINOP_PATTERN(Power)
@ -2242,6 +2280,73 @@ XLA_VARIADIC_OP_PATTERN(Reduce);
XLA_VARIADIC_OP_PATTERN(Sort);
XLA_VARIADIC_OP_PATTERN(Tuple);
// Helpers for comparison instructions.
#define XLA_COMPARE_PATTERN(NAME) \
inline auto NAME()->decltype( \
Op().WithOpcode(HloOpcode::kCompare) \
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
return Op() \
.WithOpcode(HloOpcode::kCompare) \
.WithComparisonDirection(ComparisonDirection::k##NAME); \
} \
\
template <typename Lhs, typename Rhs> \
inline auto NAME(Lhs&& lhs, Rhs&& rhs) \
->decltype(Op().WithOpcode(HloOpcode::kCompare) \
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)) \
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
return Op() \
.WithOpcode(HloOpcode::kCompare) \
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)) \
.WithComparisonDirection(ComparisonDirection::k##NAME); \
} \
\
template <typename HloInstructionType, typename Lhs, typename Rhs> \
inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
->decltype(Op(matched_inst) \
.WithOpcode(HloOpcode::kCompare) \
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)) \
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
return Op(matched_inst) \
.WithOpcode(HloOpcode::kCompare) \
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)) \
.WithComparisonDirection(ComparisonDirection::k##NAME); \
}
#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
XLA_COMPARE_PATTERN(NAME) \
\
template <typename HloInstructionType, typename Lhs, typename Rhs> \
inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
Rhs&& rhs) \
->decltype(Op(matched_inst) \
.WithOpcode(HloOpcode::kCompare) \
.WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
std::forward<Rhs>(rhs))) { \
return Op(matched_inst) \
.WithOpcode(HloOpcode::kCompare) \
.WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
std::forward<Rhs>(rhs)); \
} \
template <typename Lhs, typename Rhs> \
inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
->decltype(NAME##AnyOrder<const HloInstruction>( \
nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \
return NAME##AnyOrder<const HloInstruction>( \
nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
}
XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
XLA_COMPARE_PATTERN(Ge);
XLA_COMPARE_PATTERN(Gt);
XLA_COMPARE_PATTERN(Le);
XLA_COMPARE_PATTERN(Lt);
// Helpers for matching non-constant instructions.
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
return Op().IsNonConstant();

View File

@ -931,5 +931,48 @@ TEST(PatternMatcherTest, OneUseAndOneUser) {
"in p0 = f32[] parameter(0)");
}
TEST(HloMatchersTest, Comparison) {
auto shape = ShapeUtil::MakeShape(F32, {1});
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
ComparisonDirection::kEq);
auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
ComparisonDirection::kNe);
auto add =
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
ComparisonDirection::kLe);
EXPECT_TRUE(Match(eq.get(), m::Compare()));
EXPECT_TRUE(Match(eq.get(), m::Eq()));
EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
EXPECT_TRUE(Match(ne.get(), m::Compare()));
EXPECT_TRUE(Match(ne.get(), m::Ne()));
EXPECT_TRUE(Match(
le.get(),
m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
m::Add(m::Parameter(0), m::Parameter(1)))));
EXPECT_FALSE(Match(eq.get(), m::Add()));
EXPECT_FALSE(Match(eq.get(), m::Ne()));
EXPECT_FALSE(
Match(le.get(),
m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
EXPECT_DESC_AND_EXPLANATION(
eq, m::Ne().WithOneUser(),
"an HloInstruction:\n"
" * with opcode compare AND\n"
" * which has comparison direction NE AND\n"
" * which has exactly one user (but possibly is used "
"multiple times by that instruction)",
"HloInstruction is not comparison NE\n"
"in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
"direction=EQ");
}
} // namespace
} // namespace xla

View File

@ -181,8 +181,9 @@ static StatusOr<HloInstruction*> CheckIndexValidity(
HloInstruction* zero_index =
BroadcastZeros(computation, index->shape().element_type(),
AsInt64Slice(index->shape().dimensions()));
TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check,
MakeBinaryHlo(HloOpcode::kLe, zero_index, index));
TF_ASSIGN_OR_RETURN(
HloInstruction * negative_index_check,
MakeCompareHlo(ComparisonDirection::kLe, zero_index, index));
// Check if the index is OOB w.r.t. the operand dimensions and window sizes.
std::vector<int64> max_valid_index(operand_dims.size());
@ -193,9 +194,9 @@ static StatusOr<HloInstruction*> CheckIndexValidity(
HloInstruction * max_valid_index_constant,
MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
max_valid_index));
TF_ASSIGN_OR_RETURN(
HloInstruction * oob_index_check,
MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index));
TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check,
MakeCompareHlo(ComparisonDirection::kGe,
max_valid_index_constant, index));
// Combine the results of the two checks above.
TF_ASSIGN_OR_RETURN(

View File

@ -988,12 +988,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe: {
case HloOpcode::kCompare: {
TF_ASSIGN_OR_RETURN(const Shape& shape,
InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions));

View File

@ -918,55 +918,10 @@ TEST_F(ShapeInferenceTest, InferPowShape) {
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
TEST_F(ShapeInferenceTest, InferCompareShape) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {});
auto inferred_status = ShapeInference::InferBinaryOpShape(
HloOpcode::kCompare, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));

View File

@ -39,7 +39,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -73,7 +73,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) {
p.1.rhs = s32[] parameter(3)
p.2.lhs = u32[] parameter(4)
p.2.rhs = u32[] parameter(5)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -109,7 +109,7 @@ TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -134,7 +134,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.1.lhs, p.1.rhs)
ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT
}
ENTRY sort_computation {

View File

@ -180,13 +180,13 @@ StatusOr<HloInstruction*> StableSortExpander::ExpandInstruction(
CHECK_NE(cloned_root, nullptr);
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
HloInstruction* same =
comparator->AddInstruction(HloInstruction::CreateBinary(
scalar_pred, HloOpcode::kEq, old_root, cloned_root));
comparator->AddInstruction(HloInstruction::CreateCompare(
scalar_pred, old_root, cloned_root, ComparisonDirection::kEq));
HloInstruction* tie_breaker =
comparator->AddInstruction(HloInstruction::CreateBinary(
scalar_pred, HloOpcode::kLt,
comparator->parameter_instruction(2 * iota_index),
comparator->parameter_instruction(2 * iota_index + 1)));
comparator->AddInstruction(HloInstruction::CreateCompare(
scalar_pred, comparator->parameter_instruction(2 * iota_index),
comparator->parameter_instruction(2 * iota_index + 1),
ComparisonDirection::kLt));
HloInstruction* new_root =
comparator->AddInstruction(HloInstruction::CreateTernary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,

View File

@ -65,7 +65,8 @@ void CheckComputationHasTieBreaker(const HloInstruction* root,
// the copied comparison function where the parameters are reversed. Lt() is
// the tie breaker comparison using the Iota operand.
ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq);
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare);
ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq);
// Check that the tie breaker instruction is correct.
EXPECT_THAT(root->operand(1),
@ -88,7 +89,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -126,15 +127,15 @@ TEST_F(StableSortExpanderTest,
lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
lhs.flipped = u32[] subtract(max, lhs.unsigned)
lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero)
lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT
lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
rhs.signed = s32[] bitcast-convert(p.0.rhs)
rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
rhs.flipped = u32[] subtract(max, rhs.unsigned)
rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero)
rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT
rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
ROOT lt = pred[] less-than(lhs.converted, rhs.converted)
ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT
}
ENTRY sort_computation {
@ -165,7 +166,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -200,7 +201,7 @@ TEST_F(StableSortExpanderTest, HonorIsStableFlag) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -227,7 +228,7 @@ TEST_F(StableSortExpanderTest,
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -264,7 +265,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) {
p.0.rhs = f32[] parameter(1)
p.1.lhs = f32[] parameter(2)
p.1.rhs = f32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY sort_computation {
@ -302,7 +303,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1) {
mask = s32[] constant(65535)
lhs = s32[] and(p.0.lhs, mask)
rhs = s32[] and(p.0.rhs, mask)
ROOT lt = pred[] less-than(lhs, rhs)
ROOT lt = pred[] compare(lhs, rhs), direction=LT
}
ENTRY sort_computation {
@ -332,7 +333,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) {
mask = s32[] constant(65535)
lhs = s32[] and(p.0.lhs, mask)
rhs = s32[] and(p.0.rhs, mask)
ROOT lt = pred[] less-than(lhs, rhs)
ROOT lt = pred[] compare(lhs, rhs), direction=LT
}
ENTRY sort_computation {

View File

@ -934,8 +934,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
HloInstruction::CreateParameter(0, in_shape, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, in_shape, "param1"));
auto result = builder.AddInstruction(
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
auto result = builder.AddInstruction(HloInstruction::CreateCompare(
out_shape, param0, param1, ComparisonDirection::kEq));
BuildModuleAndRunAnalysis(builder.Build());
@ -1185,8 +1185,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
builder.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
return builder.Build();
};

View File

@ -286,7 +286,7 @@ static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
// Handle `i = K; i < N; ++i`.
if (Match(while_cond_root,
m::Op()
.WithOpcode(HloOpcode::kLt)
.WithComparisonDirection(ComparisonDirection::kLt)
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
<< while_cond_root->ToString();
@ -303,7 +303,7 @@ static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
// Handle `i = K; i <= N; ++i`.
if (Match(while_cond_root,
m::Op()
.WithOpcode(HloOpcode::kLe)
.WithComparisonDirection(ComparisonDirection::kLe)
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
<< while_cond_root->ToString();

View File

@ -40,7 +40,7 @@ TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) {
p_cond = (f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=1
const = s32[] constant(42)
ROOT result = pred[] equal-to(gte, const)
ROOT result = pred[] compare(gte, const), direction=EQ
}
ENTRY entry {
@ -71,7 +71,7 @@ TEST_F(WhileLoopAnalysisTest, NoUpperBound) {
p_cond = (f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=1
const = s32[] constant(42)
ROOT result = pred[] equal-to(gte, const)
ROOT result = pred[] compare(gte, const), direction=EQ
}
ENTRY entry {
@ -104,7 +104,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) {
p_cond = (f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=1
const = s32[] constant(42)
ROOT result = pred[] less-than(gte, const)
ROOT result = pred[] compare(gte, const), direction=LT
}
ENTRY entry {

View File

@ -260,7 +260,7 @@ condition {
p_cond = (f32[],f32[]) parameter(0)
p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0
p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1
ROOT result = pred[] less-than(p_cond.0, p_cond.1)
ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
}
ENTRY entry {
@ -300,7 +300,7 @@ condition {
p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0
p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1
p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1
ROOT result = pred[] less-than(p_c.0, p_c.1.1)
ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT
}
ENTRY entry {
@ -342,7 +342,7 @@ condition {
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
ROOT result = pred[] less-than(p_cond.0, p_cond.1)
ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
}
ENTRY entry {
@ -389,10 +389,10 @@ condition {
p_cond = (f32[],f32[],f32[]) parameter(0)
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
lt.0 = pred[] less-than(p_cond.0, p_cond.2)
lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
lt.1 = pred[] less-than(p_cond.1, p_cond.2.c)
lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT
ROOT result = pred[] and(lt.0, lt.1)
}

View File

@ -556,7 +556,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) {
p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=3
const = s32[] constant(42)
ROOT result = pred[] equal-to(gte, const)
ROOT result = pred[] compare(gte, const), direction=EQ
}
ENTRY entry {

View File

@ -72,7 +72,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant({{LOOP_BOUND}})
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(42)
@ -107,7 +107,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4)
ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT
}
ENTRY SimpleLoopWithIndirectLoopBound {
constant.3 = s32[] constant(42)
@ -237,7 +237,7 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) {
NonTupleShapedLoop.condition {
loop_var = s32[] parameter(0)
constant = s32[] constant(100)
ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant)
ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT
}
ENTRY INonTupleShapedLoop {
constant.2 = s32[] constant(42)
@ -387,7 +387,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
param0 = (s32[], s32[], s32[]) parameter(0)
get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0),
index=2
ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element)
ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ
}
ENTRY RemoveUnusedOperands {
x = s32[] parameter(0)
@ -471,7 +471,7 @@ TEST_F(WhileLoopSimplifierTest,
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(44)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(42)
@ -503,7 +503,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) {
loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(47)
ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2)
ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(42)
@ -679,7 +679,7 @@ const char* const kSimpleMergeInductionVariablesModule = R"(
b = TYPE[] get-tuple-element(param), index=1
sum = TYPE[] power(a, b)
ten = TYPE[] constant(10)
ROOT cond = pred[] less-than(sum, ten)
ROOT cond = pred[] compare(sum, ten), direction=LT
}
ENTRY Loop {
a = TYPE[] constant(10)

View File

@ -41,7 +41,7 @@ TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) {
param = (s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=0
trip_count = s32[] constant(10)
ROOT done = pred[] less-than(i, trip_count)
ROOT done = pred[] compare(i, trip_count), direction=LT
}
ENTRY test {
@ -77,7 +77,7 @@ TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) {
param = (s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=0
trip_count = s32[] constant(1000000)
ROOT done = pred[] less-than(i, trip_count)
ROOT done = pred[] compare(i, trip_count), direction=LT
}
ENTRY test {
@ -113,7 +113,7 @@ TEST_F(TripCountAnnotatorTest, NonzeroStart) {
param = (s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=0
trip_count = s32[] constant(1000000)
ROOT done = pred[] less-than(i, trip_count)
ROOT done = pred[] compare(i, trip_count), direction=LT
}
ENTRY test {
@ -149,7 +149,7 @@ TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) {
param = (s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=0
trip_count = s32[] constant(1000000)
ROOT done = pred[] less-than-or-equal-to(i, trip_count)
ROOT done = pred[] compare(i, trip_count), direction=LE
}
ENTRY test {
@ -188,7 +188,7 @@ TEST_F(TripCountAnnotatorTest, Int64Overflow) {
param = (s64[]) parameter(0)
i = s64[] get-tuple-element(param), index=0
trip_count = s64[] constant(9223372036854775807) // 2^63-1
ROOT done = pred[] less-than-or-equal-to(i, trip_count)
ROOT done = pred[] compare(i, trip_count), direction=LE
}
ENTRY test {

View File

@ -166,7 +166,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
TF_ASSIGN_OR_RETURN(
HloInstruction * compare,
MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant));
MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
cond_computation->set_root_instruction(compare);
return std::move(cond_computation);
}

View File

@ -63,7 +63,11 @@ const float test_float_vals[3][test_width][test_height] = {
class FusionTest : public HloTestBase {
protected:
template <typename T, int Arity>
void TestElementwise2D(HloOpcode opcode) {
void TestElementwise2D(
HloOpcode opcode,
absl::optional<ComparisonDirection> direction = absl::nullopt) {
// Create a variable for comparisons since they require the direction.
bool is_compare = std::is_same<T, bool>::value;
Array2D<float> operand_data[Arity];
for (int i = 0; i < Arity; ++i) {
new (&operand_data[i]) Array2D<float>(test_width, test_height);
@ -76,7 +80,11 @@ class FusionTest : public HloTestBase {
xs[k] = test_float_vals[k][i][j];
operand_data[k](i, j) = xs[k];
}
answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
if (is_compare) {
answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs);
} else {
answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs);
}
}
}
@ -98,8 +106,13 @@ class FusionTest : public HloTestBase {
root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
break;
case 2:
root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
hlos[2]);
if (is_compare) {
root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1],
hlos[2], *direction);
} else {
root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
hlos[2]);
}
break;
case 3:
root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
@ -124,13 +137,14 @@ class FusionTest : public HloTestBase {
}
private:
template <typename T>
T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span<const float> xs);
float ComputeElementwiseAnswerFloat(HloOpcode opcode,
absl::Span<const float> xs);
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
absl::Span<const float> xs);
};
template <>
float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
absl::Span<const float> xs) {
float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode,
absl::Span<const float> xs) {
switch (opcode) {
case HloOpcode::kAdd:
return xs[0] + xs[1];
@ -153,24 +167,21 @@ float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
}
}
template <>
bool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode,
absl::Span<const float> xs) {
switch (opcode) {
case HloOpcode::kEq:
bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction,
absl::Span<const float> xs) {
switch (direction) {
case ComparisonDirection::kEq:
return xs[0] == xs[1];
case HloOpcode::kNe:
case ComparisonDirection::kNe:
return xs[0] != xs[1];
case HloOpcode::kGt:
case ComparisonDirection::kGt:
return xs[0] > xs[1];
case HloOpcode::kLt:
case ComparisonDirection::kLt:
return xs[0] < xs[1];
case HloOpcode::kGe:
case ComparisonDirection::kGe:
return xs[0] >= xs[1];
case HloOpcode::kLe:
case ComparisonDirection::kLe:
return xs[0] <= xs[1];
default:
LOG(FATAL) << "No comparatory opcode: " << opcode;
}
}
@ -740,24 +751,28 @@ XLA_TEST_F(FusionTest, Maximum2D) {
TestElementwise2D<float, 2>(HloOpcode::kMaximum);
}
XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<bool, 2>(HloOpcode::kEq); }
XLA_TEST_F(FusionTest, Equal2D) {
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kEq);
}
XLA_TEST_F(FusionTest, Inequal2D) {
TestElementwise2D<bool, 2>(HloOpcode::kNe);
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kNe);
}
XLA_TEST_F(FusionTest, Greater2D) {
TestElementwise2D<bool, 2>(HloOpcode::kGt);
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGt);
}
XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D<bool, 2>(HloOpcode::kLt); }
XLA_TEST_F(FusionTest, Lesser2D) {
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLt);
}
XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
TestElementwise2D<bool, 2>(HloOpcode::kGe);
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGe);
}
XLA_TEST_F(FusionTest, LesserOrEqual2D) {
TestElementwise2D<bool, 2>(HloOpcode::kLe);
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLe);
}
XLA_TEST_F(FusionTest, Clamp2D) {

View File

@ -227,7 +227,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
fused_computation {
p = f32[4] parameter(0)
multiply = f32[4] multiply(p, p)
less-than = pred[4] less-than(p, multiply)
less-than = pred[4] compare(p, multiply), direction=LT
ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply)
}
@ -252,7 +252,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
fused_computation {
p = f32[] parameter(0)
multiply = f32[] multiply(p, p)
less-than = pred[] less-than(p, multiply)
less-than = pred[] compare(p, multiply), direction=LT
ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
}

View File

@ -143,7 +143,7 @@ compare {
p.0.rhs = f32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) {
@ -174,7 +174,7 @@ compare {
p.0.rhs = s32[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) {
@ -205,7 +205,7 @@ compare {
p.0.rhs = bf16[] parameter(1)
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) {

View File

@ -129,7 +129,7 @@ HloModule TokenInWhileLoop
%param = (s32[], token[]) parameter(0)
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
%constant = s32[] constant(42)
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
}
ENTRY %TokenInWhileLoop () -> s32[] {