[XLA] Replace individual comparison HLO ops with a single compare op.
PiperOrigin-RevId: 237318727
This commit is contained in:
parent
add7a1a911
commit
57467ada28
@ -57,6 +57,24 @@ xla_proto_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "comparison_util",
|
||||||
|
srcs = [
|
||||||
|
"comparison_util.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"comparison_util.h",
|
||||||
|
],
|
||||||
|
visibility = [":friends"],
|
||||||
|
deps = [
|
||||||
|
":statusor",
|
||||||
|
":types",
|
||||||
|
":util",
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "execution_options_util",
|
name = "execution_options_util",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -212,6 +212,7 @@ cc_library(
|
|||||||
":padding",
|
":padding",
|
||||||
":sharding_builder",
|
":sharding_builder",
|
||||||
":xla_computation",
|
":xla_computation",
|
||||||
|
"//tensorflow/compiler/xla:comparison_util",
|
||||||
"//tensorflow/compiler/xla:execution_options_util",
|
"//tensorflow/compiler/xla:execution_options_util",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
@ -480,7 +480,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
absl::optional<ComparisonDirection> direction) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
|
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
|
||||||
@ -489,6 +490,17 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
|||||||
ShapeInference::InferBinaryOpShape(
|
ShapeInference::InferBinaryOpShape(
|
||||||
binop, lhs_shape, rhs_shape, broadcast_dimensions));
|
binop, lhs_shape, rhs_shape, broadcast_dimensions));
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
|
if (binop == HloOpcode::kCompare) {
|
||||||
|
if (!direction.has_value()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"kCompare expects a ComparisonDirection, but none provided.");
|
||||||
|
}
|
||||||
|
instr.set_comparison_direction(ComparisonDirectionToString(*direction));
|
||||||
|
} else if (direction.has_value()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"A comparison direction is provided for a non-compare opcode: %s.",
|
||||||
|
HloOpcodeString(binop));
|
||||||
|
}
|
||||||
|
|
||||||
const int64 lhs_rank = lhs_shape.rank();
|
const int64 lhs_rank = lhs_shape.rank();
|
||||||
const int64 rhs_rank = rhs_shape.rank();
|
const int64 rhs_rank = rhs_shape.rank();
|
||||||
@ -2908,38 +2920,39 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
|
|||||||
|
|
||||||
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs,
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
|
||||||
broadcast_dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs,
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
|
||||||
broadcast_dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs,
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
|
||||||
broadcast_dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs,
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
|
||||||
broadcast_dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs,
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
|
||||||
broadcast_dimensions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs,
|
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
|
||||||
broadcast_dimensions);
|
}
|
||||||
|
|
||||||
|
XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
ComparisonDirection direction) {
|
||||||
|
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
|
||||||
|
broadcast_dimensions, direction);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/padding.h"
|
#include "tensorflow/compiler/xla/client/padding.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
|
#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
|
||||||
@ -596,9 +597,11 @@ class XlaBuilder {
|
|||||||
|
|
||||||
// Internal helper method that does the building for an arbitrary binary op.
|
// Internal helper method that does the building for an arbitrary binary op.
|
||||||
// broadcast_dimensions specifies which dimensions to use for broadcasting
|
// broadcast_dimensions specifies which dimensions to use for broadcasting
|
||||||
// when the operation is between tensors of different ranks.
|
// when the operation is between tensors of different ranks. The direction is
|
||||||
|
// only used if opcode is kCompare.
|
||||||
XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
absl::optional<ComparisonDirection> direction = absl::nullopt);
|
||||||
|
|
||||||
// Internal helper method that does the building for an arbitrary ternary op.
|
// Internal helper method that does the building for an arbitrary ternary op.
|
||||||
XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
|
||||||
@ -767,6 +770,9 @@ class XlaBuilder {
|
|||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
|
friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
ComparisonDirection direction);
|
||||||
friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
const PrecisionConfig* precision_config);
|
const PrecisionConfig* precision_config);
|
||||||
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
@ -1279,6 +1285,11 @@ XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
|
|||||||
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
|
|
||||||
|
// Enqueues a comparison instruction onto the computation.
|
||||||
|
XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
|
absl::Span<const int64> broadcast_dimensions,
|
||||||
|
ComparisonDirection direction);
|
||||||
|
|
||||||
// Enqueues a dot instruction onto the computation.
|
// Enqueues a dot instruction onto the computation.
|
||||||
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
const PrecisionConfig* precision_config = nullptr);
|
const PrecisionConfig* precision_config = nullptr);
|
||||||
|
57
tensorflow/compiler/xla/comparison_util.cc
Normal file
57
tensorflow/compiler/xla/comparison_util.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
std::string ComparisonDirectionToString(ComparisonDirection direction) {
|
||||||
|
switch (direction) {
|
||||||
|
case ComparisonDirection::kEq:
|
||||||
|
return "EQ";
|
||||||
|
case ComparisonDirection::kNe:
|
||||||
|
return "NE";
|
||||||
|
case ComparisonDirection::kGe:
|
||||||
|
return "GE";
|
||||||
|
case ComparisonDirection::kGt:
|
||||||
|
return "GT";
|
||||||
|
case ComparisonDirection::kLe:
|
||||||
|
return "LE";
|
||||||
|
case ComparisonDirection::kLt:
|
||||||
|
return "LT";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<ComparisonDirection> StringToComparisonDirection(
|
||||||
|
absl::string_view direction_name) {
|
||||||
|
static auto* direction_map =
|
||||||
|
new absl::flat_hash_map<string, ComparisonDirection>({
|
||||||
|
{"EQ", ComparisonDirection::kEq},
|
||||||
|
{"NE", ComparisonDirection::kNe},
|
||||||
|
{"GE", ComparisonDirection::kGe},
|
||||||
|
{"GT", ComparisonDirection::kGt},
|
||||||
|
{"LE", ComparisonDirection::kLe},
|
||||||
|
{"LT", ComparisonDirection::kLt},
|
||||||
|
});
|
||||||
|
auto it = direction_map->find(direction_name);
|
||||||
|
if (it == direction_map->end()) {
|
||||||
|
return InvalidArgument("Unknown comparison direction: %s", direction_name);
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
42
tensorflow/compiler/xla/comparison_util.h
Normal file
42
tensorflow/compiler/xla/comparison_util.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
|
||||||
|
|
||||||
|
#include "absl/base/macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Represents different comparison operations.
|
||||||
|
enum class ComparisonDirection : uint8 {
|
||||||
|
kEq,
|
||||||
|
kNe,
|
||||||
|
kGe,
|
||||||
|
kGt,
|
||||||
|
kLe,
|
||||||
|
kLt,
|
||||||
|
};
|
||||||
|
|
||||||
|
string ComparisonDirectionToString(ComparisonDirection direction);
|
||||||
|
|
||||||
|
StatusOr<ComparisonDirection> StringToComparisonDirection(
|
||||||
|
absl::string_view direction_name);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
|
@ -334,6 +334,7 @@ cc_library(
|
|||||||
":hlo_proto",
|
":hlo_proto",
|
||||||
":name_uniquer",
|
":name_uniquer",
|
||||||
"//tensorflow/compiler/xla:array",
|
"//tensorflow/compiler/xla:array",
|
||||||
|
"//tensorflow/compiler/xla:comparison_util",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:protobuf_util",
|
"//tensorflow/compiler/xla:protobuf_util",
|
||||||
@ -3301,6 +3302,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:xla_proto",
|
"//tensorflow/compiler/xla:xla_proto",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -873,9 +873,9 @@ std::unique_ptr<HloInstruction> TryDivideToShift(HloInstruction* divide,
|
|||||||
computation, a->shape().element_type(), a->shape().dimensions());
|
computation, a->shape().element_type(), a->shape().dimensions());
|
||||||
|
|
||||||
auto* dividend_is_negative =
|
auto* dividend_is_negative =
|
||||||
computation->AddInstruction(HloInstruction::CreateBinary(
|
computation->AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a,
|
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
|
||||||
zero_like_a));
|
ComparisonDirection::kLt));
|
||||||
|
|
||||||
auto* negated_dividend = computation->AddInstruction(
|
auto* negated_dividend = computation->AddInstruction(
|
||||||
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
|
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
|
||||||
@ -2475,9 +2475,9 @@ std::unique_ptr<HloInstruction> TryRemainderToAnd(HloInstruction* remainder,
|
|||||||
computation, a->shape().element_type(), a->shape().dimensions());
|
computation, a->shape().element_type(), a->shape().dimensions());
|
||||||
|
|
||||||
auto* dividend_is_negative =
|
auto* dividend_is_negative =
|
||||||
computation->AddInstruction(HloInstruction::CreateBinary(
|
computation->AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::ChangeElementType(a->shape(), PRED), HloOpcode::kLt, a,
|
ShapeUtil::ChangeElementType(a->shape(), PRED), a, zero_like_a,
|
||||||
zero_like_a));
|
ComparisonDirection::kLt));
|
||||||
|
|
||||||
auto* negated_dividend = computation->AddInstruction(
|
auto* negated_dividend = computation->AddInstruction(
|
||||||
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
|
HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
|
||||||
|
@ -221,7 +221,7 @@ HloModule foobar
|
|||||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||||
%constant.0 = s32[] constant(0)
|
%constant.0 = s32[] constant(0)
|
||||||
%constant.1 = s32[] constant(1)
|
%constant.1 = s32[] constant(1)
|
||||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
|
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
|
||||||
}
|
}
|
||||||
|
|
||||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||||
@ -258,7 +258,7 @@ HloModule foobar
|
|||||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||||
%constant.0 = s32[] constant(0)
|
%constant.0 = s32[] constant(0)
|
||||||
%constant.1 = s32[] constant(1)
|
%constant.1 = s32[] constant(1)
|
||||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
|
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
|
||||||
}
|
}
|
||||||
|
|
||||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||||
@ -296,7 +296,7 @@ HloModule foobar
|
|||||||
%x = (f32[2,2], f32[2,2]) parameter(0)
|
%x = (f32[2,2], f32[2,2]) parameter(0)
|
||||||
%constant.0 = s32[] constant(0)
|
%constant.0 = s32[] constant(0)
|
||||||
%constant.1 = s32[] constant(1)
|
%constant.1 = s32[] constant(1)
|
||||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0)
|
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
|
||||||
}
|
}
|
||||||
|
|
||||||
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
|
||||||
|
@ -109,8 +109,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
|
|||||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
|
||||||
HloInstruction* add1 = builder.AddInstruction(
|
HloInstruction* add1 = builder.AddInstruction(
|
||||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
|
||||||
HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary(
|
HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b));
|
ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq));
|
||||||
HloInstruction* sel = builder.AddInstruction(
|
HloInstruction* sel = builder.AddInstruction(
|
||||||
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
|
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
|
||||||
HloInstruction* xpose =
|
HloInstruction* xpose =
|
||||||
@ -574,8 +574,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
|
|||||||
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
||||||
auto cond_dot =
|
auto cond_dot =
|
||||||
builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
|
builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
|
||||||
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary(
|
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
ShapeUtil::MakeShape(PRED, {}),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond.AddInstruction(
|
builder_cond.AddInstruction(
|
||||||
@ -583,9 +583,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
|
|||||||
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
|
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
builder_cond.AddInstruction(
|
||||||
ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
|
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||||
{1, 1}))))));
|
cond_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||||
|
ComparisonDirection::kGt));
|
||||||
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
||||||
|
|
||||||
auto builder_body = HloComputation::Builder("body");
|
auto builder_body = HloComputation::Builder("body");
|
||||||
@ -631,8 +632,8 @@ TEST_F(BFloat16PropagationTest,
|
|||||||
auto builder_cond = HloComputation::Builder("cond");
|
auto builder_cond = HloComputation::Builder("cond");
|
||||||
auto cond_param = builder_cond.AddInstruction(
|
auto cond_param = builder_cond.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateBinary(
|
builder_cond.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
ShapeUtil::MakeShape(PRED, {}),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
||||||
@ -642,7 +643,8 @@ TEST_F(BFloat16PropagationTest,
|
|||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
||||||
ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
|
ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
|
||||||
{1, 1}))))));
|
{1, 1})))),
|
||||||
|
ComparisonDirection::kGt));
|
||||||
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
||||||
|
|
||||||
auto builder_body = HloComputation::Builder("body");
|
auto builder_body = HloComputation::Builder("body");
|
||||||
@ -705,8 +707,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
|
|||||||
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
|
||||||
auto cond_dot =
|
auto cond_dot =
|
||||||
builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
|
builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateBinary(
|
builder_cond.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
ShapeUtil::MakeShape(PRED, {}),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond.AddInstruction(
|
builder_cond.AddInstruction(
|
||||||
@ -714,9 +716,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
|
|||||||
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
|
cond_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond.AddInstruction(HloInstruction::CreateSlice(
|
builder_cond.AddInstruction(
|
||||||
ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
|
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||||
{1, 1}))))));
|
cond_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||||
|
ComparisonDirection::kGt));
|
||||||
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
|
||||||
|
|
||||||
auto builder_body = HloComputation::Builder("body");
|
auto builder_body = HloComputation::Builder("body");
|
||||||
@ -800,8 +803,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
|||||||
shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
|
shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
|
||||||
auto cond0_dot =
|
auto cond0_dot =
|
||||||
builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
|
builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
|
||||||
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
|
builder_cond0.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
ShapeUtil::MakeShape(PRED, {}),
|
||||||
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond0.AddInstruction(
|
builder_cond0.AddInstruction(
|
||||||
@ -809,9 +812,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
|||||||
cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
|
cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||||
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond0.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond0.AddInstruction(HloInstruction::CreateSlice(
|
builder_cond0.AddInstruction(
|
||||||
ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2},
|
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||||
{1, 1}))))));
|
cond0_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||||
|
ComparisonDirection::kGt));
|
||||||
auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
|
auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
|
||||||
|
|
||||||
// Condition computation for the second while.
|
// Condition computation for the second while.
|
||||||
@ -828,8 +832,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
|||||||
shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
|
shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
|
||||||
auto cond1_dot =
|
auto cond1_dot =
|
||||||
builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
|
builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
|
||||||
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
|
builder_cond1.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
|
ShapeUtil::MakeShape(PRED, {}),
|
||||||
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond1.AddInstruction(
|
builder_cond1.AddInstruction(
|
||||||
@ -837,9 +841,10 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
|
|||||||
cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
|
cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
|
||||||
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
|
builder_cond1.AddInstruction(HloInstruction::CreateReshape(
|
||||||
ShapeUtil::MakeShape(F32, {}),
|
ShapeUtil::MakeShape(F32, {}),
|
||||||
builder_cond1.AddInstruction(HloInstruction::CreateSlice(
|
builder_cond1.AddInstruction(
|
||||||
ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2},
|
HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
|
||||||
{1, 1}))))));
|
cond1_dot, {1, 1}, {2, 2}, {1, 1})))),
|
||||||
|
ComparisonDirection::kGt));
|
||||||
auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
|
auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
|
||||||
|
|
||||||
// Body computation shared by both whiles.
|
// Body computation shared by both whiles.
|
||||||
|
@ -190,8 +190,9 @@ class BufferAssignmentTest : public HloTestBase {
|
|||||||
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
|
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
|
||||||
auto index = builder.AddInstruction(
|
auto index = builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
|
HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4));
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
|
||||||
|
const4, ComparisonDirection::kLt));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1863,8 +1864,8 @@ class WhileBufferAssignmentTest : public HloTestBase {
|
|||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
|
||||||
auto ten = builder.AddInstruction(
|
auto ten = builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
|
ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2135,8 +2136,9 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
|
|||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
|
||||||
auto param =
|
auto param =
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
|
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4));
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
|
||||||
|
const4, ComparisonDirection::kLt));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2530,7 +2532,7 @@ while_condition {
|
|||||||
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
||||||
get-tuple-element = s32[] get-tuple-element(state), index=0
|
get-tuple-element = s32[] get-tuple-element(state), index=0
|
||||||
get-tuple-element.1 = s32[] constant(3)
|
get-tuple-element.1 = s32[] constant(3)
|
||||||
ROOT less-than.339.338 = pred[] less-than(get-tuple-element, get-tuple-element.1)
|
ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry_computation {
|
ENTRY entry_computation {
|
||||||
|
@ -83,8 +83,9 @@ class CallGraphTest : public HloTestBase {
|
|||||||
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
||||||
HloInstruction* zero = builder.AddInstruction(
|
HloInstruction* zero = builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
|
||||||
|
zero, ComparisonDirection::kGt));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,8 +191,9 @@ HloInstruction* GetExpandedFilterMask(
|
|||||||
// linspace to create a diagonal predicate.
|
// linspace to create a diagonal predicate.
|
||||||
Shape predicate_shape = ShapeUtil::MakeShape(
|
Shape predicate_shape = ShapeUtil::MakeShape(
|
||||||
PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
|
PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
|
||||||
return add_instruction(HloInstruction::CreateBinary(
|
return add_instruction(HloInstruction::CreateCompare(
|
||||||
predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2));
|
predicate_shape, broadcasted_mask1, broadcasted_mask2,
|
||||||
|
ComparisonDirection::kEq));
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function handles batch_group_counts which are relevant only for
|
// This function handles batch_group_counts which are relevant only for
|
||||||
|
@ -420,9 +420,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
|||||||
auto induction_variable =
|
auto induction_variable =
|
||||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||||
limit_const->shape(), loop_state, 0));
|
limit_const->shape(), loop_state, 0));
|
||||||
builder.AddInstruction(
|
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
|
condition_result_shape_, induction_variable, limit_const,
|
||||||
induction_variable, limit_const));
|
ComparisonDirection::kLt));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1842,7 +1842,7 @@ HloModule TokensShouldNotBeCopied
|
|||||||
%param = (s32[], token[]) parameter(0)
|
%param = (s32[], token[]) parameter(0)
|
||||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||||
%constant = s32[] constant(42)
|
%constant = s32[] constant(42)
|
||||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %TokensShouldNotBeCopied () -> s32[] {
|
ENTRY %TokensShouldNotBeCopied () -> s32[] {
|
||||||
@ -2060,7 +2060,7 @@ if-condition.v4 {
|
|||||||
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
||||||
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
|
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
|
||||||
constant.4 = s32[] constant(0)
|
constant.4 = s32[] constant(0)
|
||||||
ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
|
ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
|
||||||
}
|
}
|
||||||
|
|
||||||
_functionalize_body_1__.v28 {
|
_functionalize_body_1__.v28 {
|
||||||
@ -2070,7 +2070,7 @@ _functionalize_body_1__.v28 {
|
|||||||
add.4 = s32[] add(get-tuple-element.68, constant.7)
|
add.4 = s32[] add(get-tuple-element.68, constant.7)
|
||||||
get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
|
get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
|
||||||
get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
|
get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
|
||||||
less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70)
|
less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
|
||||||
constant.8 = s32[] constant(0)
|
constant.8 = s32[] constant(0)
|
||||||
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
|
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
|
||||||
get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
|
get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
|
||||||
@ -2087,7 +2087,7 @@ cond_wrapper.v3.1 {
|
|||||||
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
|
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
|
||||||
get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
|
get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
|
||||||
constant.11 = s32[] constant(7)
|
constant.11 = s32[] constant(7)
|
||||||
ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11)
|
ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
_functionalize_body_2__.v25 {
|
_functionalize_body_2__.v25 {
|
||||||
@ -2110,7 +2110,7 @@ cond_wrapper.v3.2 {
|
|||||||
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
|
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
|
||||||
get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
|
get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
|
||||||
constant.13 = s32[] constant(5)
|
constant.13 = s32[] constant(5)
|
||||||
ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13)
|
ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY TestComputation {
|
ENTRY TestComputation {
|
||||||
@ -2142,7 +2142,7 @@ if-condition.v4 {
|
|||||||
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
||||||
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
|
get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
|
||||||
constant.4 = s32[] constant(0)
|
constant.4 = s32[] constant(0)
|
||||||
ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4)
|
ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
|
||||||
}
|
}
|
||||||
|
|
||||||
if-body.v5.1 {
|
if-body.v5.1 {
|
||||||
@ -2159,7 +2159,7 @@ if-condition.v4.1 {
|
|||||||
p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
|
||||||
get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
|
get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
|
||||||
constant.6 = s32[] constant(1)
|
constant.6 = s32[] constant(1)
|
||||||
ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6)
|
ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
|
||||||
}
|
}
|
||||||
|
|
||||||
_functionalize_body_1__.v28 {
|
_functionalize_body_1__.v28 {
|
||||||
@ -2169,7 +2169,7 @@ _functionalize_body_1__.v28 {
|
|||||||
add.4 = s32[] add(get-tuple-element.72, constant.7)
|
add.4 = s32[] add(get-tuple-element.72, constant.7)
|
||||||
get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
|
get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
|
||||||
get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
|
get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
|
||||||
less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74)
|
less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
|
||||||
constant.8 = s32[] constant(0)
|
constant.8 = s32[] constant(0)
|
||||||
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
|
select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
|
||||||
get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
|
get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
|
||||||
@ -2187,7 +2187,7 @@ cond_wrapper.v3.1 {
|
|||||||
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
|
inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
|
||||||
get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
|
get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
|
||||||
constant.11 = s32[] constant(7)
|
constant.11 = s32[] constant(7)
|
||||||
ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11)
|
ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
_functionalize_body_2__.v25 {
|
_functionalize_body_2__.v25 {
|
||||||
@ -2210,7 +2210,7 @@ cond_wrapper.v3.2 {
|
|||||||
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
|
inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
|
||||||
get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
|
get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
|
||||||
constant.13 = s32[] constant(5)
|
constant.13 = s32[] constant(5)
|
||||||
ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13)
|
ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY TestComputation {
|
ENTRY TestComputation {
|
||||||
|
@ -29,7 +29,7 @@ HloModule KeyValueSort
|
|||||||
compare {
|
compare {
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY main {
|
ENTRY main {
|
||||||
|
@ -75,7 +75,7 @@ ENTRY TestComputation {
|
|||||||
broadcast = f32[42] broadcast(add), dimensions={}
|
broadcast = f32[42] broadcast(add), dimensions={}
|
||||||
slice = f32[1] slice(broadcast), slice={[1:2]}
|
slice = f32[1] slice(broadcast), slice={[1:2]}
|
||||||
copy = f32[] copy(arg)
|
copy = f32[] copy(arg)
|
||||||
eq = pred[] equal-to(arg, gte)
|
eq = pred[] compare(arg, gte), direction=EQ
|
||||||
neg = f32[] negate(arg)
|
neg = f32[] negate(arg)
|
||||||
ROOT convert = f64[] convert(f32[] arg)
|
ROOT convert = f64[] convert(f32[] arg)
|
||||||
})";
|
})";
|
||||||
|
@ -68,8 +68,8 @@ class DynamicDimensionInferenceTest : public HloTestBase {
|
|||||||
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
|
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
|
||||||
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||||
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
|
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
|
||||||
embedded_builder.AddInstruction(HloInstruction::CreateBinary(
|
embedded_builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGe, lhs, rhs));
|
ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
|
||||||
return module_->AddEmbeddedComputation(embedded_builder.Build());
|
return module_->AddEmbeddedComputation(embedded_builder.Build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,9 +160,10 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
|||||||
HloInstruction* broadcasted_effective_size =
|
HloInstruction* broadcasted_effective_size =
|
||||||
computation->AddInstruction(HloInstruction::CreateBroadcast(
|
computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
mask_shape, dynamic_size, {}));
|
mask_shape, dynamic_size, {}));
|
||||||
HloInstruction* pred = computation->AddInstruction(
|
HloInstruction* pred =
|
||||||
HloInstruction::CreateBinary(pred_shape, HloOpcode::kLt, iota,
|
computation->AddInstruction(HloInstruction::CreateCompare(
|
||||||
broadcasted_effective_size));
|
pred_shape, iota, broadcasted_effective_size,
|
||||||
|
ComparisonDirection::kLt));
|
||||||
|
|
||||||
HloInstruction* broadcasted_identity_value =
|
HloInstruction* broadcasted_identity_value =
|
||||||
computation->AddInstruction(HloInstruction::CreateBroadcast(
|
computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
@ -719,25 +719,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
|
|||||||
// We use ordered comparisons for everything except kNe, where we use an
|
// We use ordered comparisons for everything except kNe, where we use an
|
||||||
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
||||||
// matches C++'s semantics.
|
// matches C++'s semantics.
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kCompare: {
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
|
switch (op->comparison_direction()) {
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kEq:
|
||||||
case HloOpcode::kNe:
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
|
rhs_value, b_);
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kNe:
|
||||||
case HloOpcode::kLt:
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
|
rhs_value, b_);
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kLt:
|
||||||
case HloOpcode::kGt:
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
|
rhs_value, b_);
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kGt:
|
||||||
case HloOpcode::kLe:
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
|
rhs_value, b_);
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kLe:
|
||||||
case HloOpcode::kGe:
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
|
rhs_value, b_);
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kGe:
|
||||||
|
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
|
||||||
|
rhs_value, b_);
|
||||||
|
}
|
||||||
|
}
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
return EmitFloatMax(lhs_value, rhs_value);
|
return EmitFloatMax(lhs_value, rhs_value);
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
@ -839,21 +842,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
|
|||||||
// We use ordered comparisons for everything except kNe, where we use an
|
// We use ordered comparisons for everything except kNe, where we use an
|
||||||
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
// unordered comparison. This makes x != y equivalent to !(x == y), and
|
||||||
// matches C++'s semantics.
|
// matches C++'s semantics.
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kCompare: {
|
||||||
return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
switch (op->comparison_direction()) {
|
||||||
EmitExtractReal(lhs_value),
|
case ComparisonDirection::kEq:
|
||||||
EmitExtractReal(rhs_value), b_),
|
return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
||||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
EmitExtractReal(lhs_value),
|
||||||
EmitExtractImag(lhs_value),
|
EmitExtractReal(rhs_value), b_),
|
||||||
EmitExtractImag(rhs_value), b_));
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
|
||||||
case HloOpcode::kNe:
|
EmitExtractImag(lhs_value),
|
||||||
return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
EmitExtractImag(rhs_value), b_));
|
||||||
EmitExtractReal(lhs_value),
|
case ComparisonDirection::kNe:
|
||||||
EmitExtractReal(rhs_value), b_),
|
return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
||||||
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
EmitExtractReal(lhs_value),
|
||||||
EmitExtractImag(lhs_value),
|
EmitExtractReal(rhs_value), b_),
|
||||||
EmitExtractImag(rhs_value), b_));
|
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
|
||||||
|
EmitExtractImag(lhs_value),
|
||||||
|
EmitExtractImag(rhs_value), b_));
|
||||||
|
default:
|
||||||
|
return Unimplemented(
|
||||||
|
"complex comparison '%s'",
|
||||||
|
ComparisonDirectionToString(op->comparison_direction()));
|
||||||
|
}
|
||||||
|
}
|
||||||
case HloOpcode::kPower: {
|
case HloOpcode::kPower: {
|
||||||
auto a = EmitExtractReal(lhs_value);
|
auto a = EmitExtractReal(lhs_value);
|
||||||
auto b = EmitExtractImag(lhs_value);
|
auto b = EmitExtractImag(lhs_value);
|
||||||
@ -1278,28 +1288,32 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
|||||||
return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
|
return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
|
||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
|
return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kCompare: {
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
|
switch (op->comparison_direction()) {
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kEq:
|
||||||
case HloOpcode::kNe:
|
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
|
||||||
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
|
rhs_value, b_);
|
||||||
rhs_value, b_);
|
case ComparisonDirection::kNe:
|
||||||
case HloOpcode::kLt:
|
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
|
||||||
return llvm_ir::EmitComparison(
|
rhs_value, b_);
|
||||||
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
|
case ComparisonDirection::kLt:
|
||||||
lhs_value, rhs_value, b_);
|
return llvm_ir::EmitComparison(
|
||||||
case HloOpcode::kGt:
|
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
|
||||||
return llvm_ir::EmitComparison(
|
lhs_value, rhs_value, b_);
|
||||||
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
|
case ComparisonDirection::kGt:
|
||||||
lhs_value, rhs_value, b_);
|
return llvm_ir::EmitComparison(
|
||||||
case HloOpcode::kLe:
|
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
|
||||||
return llvm_ir::EmitComparison(
|
lhs_value, rhs_value, b_);
|
||||||
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
|
case ComparisonDirection::kLe:
|
||||||
lhs_value, rhs_value, b_);
|
return llvm_ir::EmitComparison(
|
||||||
case HloOpcode::kGe:
|
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
|
||||||
return llvm_ir::EmitComparison(
|
lhs_value, rhs_value, b_);
|
||||||
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
case ComparisonDirection::kGe:
|
||||||
lhs_value, rhs_value, b_);
|
return llvm_ir::EmitComparison(
|
||||||
|
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
||||||
|
lhs_value, rhs_value, b_);
|
||||||
|
}
|
||||||
|
}
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
|
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
@ -2197,17 +2211,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
case HloOpcode::kAnd:
|
case HloOpcode::kAnd:
|
||||||
case HloOpcode::kAtan2:
|
case HloOpcode::kAtan2:
|
||||||
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
case HloOpcode::kXor:
|
case HloOpcode::kXor:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
|
@ -81,8 +81,9 @@ class FlattenCallGraphTest : public HloTestBase {
|
|||||||
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
|
||||||
HloInstruction* zero = builder.AddInstruction(
|
HloInstruction* zero = builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
|
||||||
|
zero, ComparisonDirection::kGt));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,9 +159,9 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
|
|||||||
0, ShapeUtil::MakeShape(PRED, {}), "param0"));
|
0, ShapeUtil::MakeShape(PRED, {}), "param0"));
|
||||||
HloInstruction* false_constant = builder.AddInstruction(
|
HloInstruction* false_constant = builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||||
builder.AddInstruction(
|
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
ShapeUtil::MakeShape(PRED, {}), param0, false_constant,
|
||||||
HloOpcode::kEq, param0, false_constant));
|
ComparisonDirection::kEq));
|
||||||
cond_computation = module->AddEmbeddedComputation(builder.Build());
|
cond_computation = module->AddEmbeddedComputation(builder.Build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ TEST_F(GpuFusibleTest,
|
|||||||
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
|
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
|
||||||
c0 = f16[] constant(0)
|
c0 = f16[] constant(0)
|
||||||
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
||||||
greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
|
greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
|
||||||
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
||||||
}
|
}
|
||||||
fused_reduce {
|
fused_reduce {
|
||||||
@ -122,7 +122,7 @@ TEST_F(GpuFusibleTest,
|
|||||||
p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
|
p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
|
||||||
c0 = f16[] constant(0)
|
c0 = f16[] constant(0)
|
||||||
broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={}
|
broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={}
|
||||||
greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast)
|
greater-than = pred[128,1024,32,32]{3,2,1,0} compare(p1.1, broadcast), direction=GT
|
||||||
select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast)
|
select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast)
|
||||||
ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select)
|
ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select)
|
||||||
}
|
}
|
||||||
@ -507,7 +507,7 @@ TEST_F(GpuFusibleTest,
|
|||||||
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||||
c0 = f32[] constant(0)
|
c0 = f32[] constant(0)
|
||||||
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
|
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
|
||||||
greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
|
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
|
||||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||||
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
|
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
|
||||||
}
|
}
|
||||||
|
@ -374,7 +374,7 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = f32[] parameter(2)
|
p.1.lhs = f32[] parameter(2)
|
||||||
p.1.rhs = f32[] parameter(3)
|
p.1.rhs = f32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort {
|
ENTRY sort {
|
||||||
|
@ -437,7 +437,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
|
|||||||
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
p1.1 = f32[2,2,2]{2,1,0} parameter(1)
|
||||||
c0 = f32[] constant(0)
|
c0 = f32[] constant(0)
|
||||||
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
|
broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
|
||||||
greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
|
greater-than = pred[2,2,2]{2,1,0} compare(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast), direction=GT
|
||||||
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
p0.1 = f32[2,2,2]{2,1,0} parameter(0)
|
||||||
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
|
ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
|
||||||
}
|
}
|
||||||
@ -505,7 +505,7 @@ TEST_F(MultiOutputFusionTest,
|
|||||||
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
|
p1.1 = f16[2,2,2]{2,1,0} parameter(1)
|
||||||
c0 = f16[] constant(0)
|
c0 = f16[] constant(0)
|
||||||
broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
|
broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={}
|
||||||
greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast)
|
greater-than = pred[2,2,2]{2,1,0} compare(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast), direction=GT
|
||||||
p0.1 = f16[2,2,2]{2,1,0} parameter(0)
|
p0.1 = f16[2,2,2]{2,1,0} parameter(0)
|
||||||
ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
|
ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast)
|
||||||
}
|
}
|
||||||
@ -548,7 +548,7 @@ TEST_F(MultiOutputFusionTest,
|
|||||||
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
|
copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
|
||||||
c0 = f16[] constant(0)
|
c0 = f16[] constant(0)
|
||||||
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
||||||
greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast)
|
greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
|
||||||
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
||||||
}
|
}
|
||||||
fused_reduce {
|
fused_reduce {
|
||||||
|
@ -48,8 +48,9 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) {
|
|||||||
HloInstruction::CreateParameter(0, param_shape, "x"));
|
HloInstruction::CreateParameter(0, param_shape, "x"));
|
||||||
HloInstruction* param_y = builder.AddInstruction(
|
HloInstruction* param_y = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(1, param_shape, "y"));
|
HloInstruction::CreateParameter(1, param_shape, "y"));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y));
|
ShapeUtil::MakeShape(PRED, {5, 7, 2}), param_x, param_y,
|
||||||
|
ComparisonDirection::kGe));
|
||||||
|
|
||||||
auto hlo_module = CreateNewVerifiedModule();
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
hlo_module->AddEntryComputation(builder.Build());
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
@ -73,7 +74,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) {
|
|||||||
x = f32[5,7,2]{2,1,0} parameter(0)
|
x = f32[5,7,2]{2,1,0} parameter(0)
|
||||||
y = f32[5,14]{1,0} parameter(1)
|
y = f32[5,14]{1,0} parameter(1)
|
||||||
reshape = f32[5,7,2]{2,1,0} reshape(y)
|
reshape = f32[5,7,2]{2,1,0} reshape(y)
|
||||||
ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape)
|
ROOT gte = pred[5,7,2]{2,1,0} compare(x, reshape), direction=GE
|
||||||
})",
|
})",
|
||||||
config)
|
config)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
@ -98,7 +99,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) {
|
|||||||
y = f32[14]{0} parameter(1)
|
y = f32[14]{0} parameter(1)
|
||||||
reshape = f32[7,2]{1,0} reshape(y)
|
reshape = f32[7,2]{1,0} reshape(y)
|
||||||
broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2}
|
broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2}
|
||||||
ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast)
|
ROOT gte = pred[5,7,2]{2,1,0} compare(x, broadcast), direction=GE
|
||||||
})",
|
})",
|
||||||
config)
|
config)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
|
@ -44,9 +44,9 @@ class WhileTransformerTest : public HloTestBase {
|
|||||||
auto induction_variable =
|
auto induction_variable =
|
||||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||||
limit_const->shape(), loop_state, tuple_index));
|
limit_const->shape(), loop_state, tuple_index));
|
||||||
builder.AddInstruction(
|
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
|
condition_result_shape_, induction_variable, limit_const,
|
||||||
induction_variable, limit_const));
|
ComparisonDirection::kLt));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,8 +54,8 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
|
|||||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||||
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
|
// Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
|
||||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||||
HloOpcode::kLt, cond_iter, cond_data));
|
cond_data, ComparisonDirection::kLt));
|
||||||
HloComputation* cond_computation =
|
HloComputation* cond_computation =
|
||||||
module->AddEmbeddedComputation(cond_builder.Build());
|
module->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
@ -113,7 +113,8 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
|
|||||||
// %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
|
// %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
|
||||||
// %reshape = f32[] reshape(f32[1]{0} %slice)
|
// %reshape = f32[] reshape(f32[1]{0} %slice)
|
||||||
// %constant = f32[] constant(0)
|
// %constant = f32[] constant(0)
|
||||||
// ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant)
|
// ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
|
||||||
|
// direction=NE
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// ENTRY %SubcomputationAccounting () -> f32[2,4] {
|
// ENTRY %SubcomputationAccounting () -> f32[2,4] {
|
||||||
@ -143,9 +144,9 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
|
|||||||
cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
|
cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
|
||||||
HloInstruction* zero = cond_builder.AddInstruction(
|
HloInstruction* zero = cond_builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||||
HloInstruction* cond_comparison =
|
HloInstruction* cond_comparison = cond_builder.AddInstruction(
|
||||||
cond_builder.AddInstruction(HloInstruction::CreateBinary(
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero));
|
zero, ComparisonDirection::kNe));
|
||||||
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
// param - 1
|
// param - 1
|
||||||
@ -703,8 +704,8 @@ TEST_F(HeapSimulatorTest, WholeModule) {
|
|||||||
HloInstruction* cond_data = cond_builder.AddInstruction(
|
HloInstruction* cond_data = cond_builder.AddInstruction(
|
||||||
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
|
||||||
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
HloInstruction* cond_lt = cond_builder.AddInstruction(
|
||||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
|
||||||
HloOpcode::kLt, cond_iter, cond_data));
|
cond_data, ComparisonDirection::kLt));
|
||||||
HloComputation* cond_computation =
|
HloComputation* cond_computation =
|
||||||
tracker.module()->AddEmbeddedComputation(cond_builder.Build());
|
tracker.module()->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
|||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
// Serialization of HloInstruction.
|
// Serialization of HloInstruction.
|
||||||
// Next ID: 63
|
// Next ID: 64
|
||||||
message HloInstructionProto {
|
message HloInstructionProto {
|
||||||
reserved 10;
|
reserved 10;
|
||||||
reserved "parameter_name";
|
reserved "parameter_name";
|
||||||
@ -146,6 +146,9 @@ message HloInstructionProto {
|
|||||||
// FFT length.
|
// FFT length.
|
||||||
repeated int64 fft_length = 32;
|
repeated int64 fft_length = 32;
|
||||||
|
|
||||||
|
// Comparison direction only used for kCompare.
|
||||||
|
string comparison_direction = 63;
|
||||||
|
|
||||||
// Gather dimension numbers.
|
// Gather dimension numbers.
|
||||||
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
|
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
|
||||||
repeated int64 gather_slice_sizes = 34;
|
repeated int64 gather_slice_sizes = 34;
|
||||||
|
@ -509,8 +509,9 @@ TEST_F(HloComputationTest, CloneWithReplacements) {
|
|||||||
HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
|
HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
|
||||||
auto param2 =
|
auto param2 =
|
||||||
builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
|
builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
|
||||||
auto lt = builder.AddInstruction(HloInstruction::CreateBinary(
|
auto lt = builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param0, param1));
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
|
||||||
|
param1, ComparisonDirection::kLt));
|
||||||
auto module = CreateNewVerifiedModule();
|
auto module = CreateNewVerifiedModule();
|
||||||
auto computation =
|
auto computation =
|
||||||
module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));
|
module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));
|
||||||
|
@ -42,6 +42,18 @@ StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
|
|||||||
HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
|
HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
|
||||||
|
HloInstruction* lhs,
|
||||||
|
HloInstruction* rhs) {
|
||||||
|
HloComputation* computation = lhs->parent();
|
||||||
|
CHECK_EQ(computation, rhs->parent());
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
Shape binary_op_shape,
|
||||||
|
ShapeInference::InferBinaryOpShape(HloOpcode::kCompare, lhs, rhs));
|
||||||
|
return computation->AddInstruction(
|
||||||
|
HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction));
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
|
StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
|
||||||
HloInstruction* padding_value,
|
HloInstruction* padding_value,
|
||||||
const PaddingConfig& padding_config) {
|
const PaddingConfig& padding_config) {
|
||||||
|
@ -32,6 +32,12 @@ namespace xla {
|
|||||||
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
|
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
|
||||||
HloInstruction* rhs);
|
HloInstruction* rhs);
|
||||||
|
|
||||||
|
// Creates a compare HLO instruction and adds it to the computation containing
|
||||||
|
// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
|
||||||
|
StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
|
||||||
|
HloInstruction* lhs,
|
||||||
|
HloInstruction* rhs);
|
||||||
|
|
||||||
// Creates a pad HLO instruction and adds it to the computation containing
|
// Creates a pad HLO instruction and adds it to the computation containing
|
||||||
// `operand` and `padding_value` (`operand` and `padding_value` must be in the
|
// `operand` and `padding_value` (`operand` and `padding_value` must be in the
|
||||||
// same computation).
|
// same computation).
|
||||||
|
@ -2239,8 +2239,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
|
|||||||
HloInstruction::CreateParameter(0, in_shape, "param0"));
|
HloInstruction::CreateParameter(0, in_shape, "param0"));
|
||||||
auto param1 = builder.AddInstruction(
|
auto param1 = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(1, in_shape, "param1"));
|
HloInstruction::CreateParameter(1, in_shape, "param1"));
|
||||||
auto result = builder.AddInstruction(
|
auto result = builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
|
out_shape, param0, param1, ComparisonDirection::kEq));
|
||||||
|
|
||||||
BuildModuleAndRunAnalysis(builder.Build());
|
BuildModuleAndRunAnalysis(builder.Build());
|
||||||
|
|
||||||
@ -2563,8 +2563,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
|||||||
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
||||||
auto data = builder.AddInstruction(
|
auto data = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
|
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -223,8 +223,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
|
|||||||
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
HloInstruction::CreateParameter(0, shape, "cond_param"));
|
||||||
auto constant = cond_builder.AddInstruction(
|
auto constant = cond_builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
|
||||||
cond_builder.AddInstruction(HloInstruction::CreateBinary(
|
cond_builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant));
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
|
||||||
|
constant, ComparisonDirection::kLt));
|
||||||
}
|
}
|
||||||
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
|
@ -56,43 +56,40 @@ namespace xla {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename OperandT>
|
template <typename OperandT>
|
||||||
StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
|
StatusOr<Literal> Compare(const Shape& shape, ComparisonDirection direction,
|
||||||
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
|
LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
|
||||||
std::function<bool(OperandT, OperandT)> compare_op;
|
std::function<bool(OperandT, OperandT)> compare_op;
|
||||||
switch (opcode) {
|
switch (direction) {
|
||||||
case HloOpcode::kEq:
|
case ComparisonDirection::kEq:
|
||||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||||
return lhs_el == rhs_el;
|
return lhs_el == rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kNe:
|
case ComparisonDirection::kNe:
|
||||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||||
return lhs_el != rhs_el;
|
return lhs_el != rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kGe:
|
case ComparisonDirection::kGe:
|
||||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||||
return lhs_el >= rhs_el;
|
return lhs_el >= rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kGt:
|
case ComparisonDirection::kGt:
|
||||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||||
return lhs_el > rhs_el;
|
return lhs_el > rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kLe:
|
case ComparisonDirection::kLe:
|
||||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||||
return lhs_el <= rhs_el;
|
return lhs_el <= rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kLt:
|
case ComparisonDirection::kLt:
|
||||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||||
return lhs_el < rhs_el;
|
return lhs_el < rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
default:
|
|
||||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
|
||||||
<< HloOpcodeString(opcode);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Literal result(shape);
|
Literal result(shape);
|
||||||
@ -106,24 +103,25 @@ StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
|
StatusOr<Literal> Compare<complex64>(const Shape& shape,
|
||||||
|
ComparisonDirection direction,
|
||||||
LiteralSlice lhs_literal,
|
LiteralSlice lhs_literal,
|
||||||
LiteralSlice rhs_literal) {
|
LiteralSlice rhs_literal) {
|
||||||
std::function<bool(complex64, complex64)> compare_op;
|
std::function<bool(complex64, complex64)> compare_op;
|
||||||
switch (opcode) {
|
switch (direction) {
|
||||||
case HloOpcode::kEq:
|
case ComparisonDirection::kEq:
|
||||||
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||||
return lhs_el == rhs_el;
|
return lhs_el == rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kNe:
|
case ComparisonDirection::kNe:
|
||||||
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
compare_op = [](complex64 lhs_el, complex64 rhs_el) {
|
||||||
return lhs_el != rhs_el;
|
return lhs_el != rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
|
||||||
<< HloOpcodeString(opcode);
|
<< ComparisonDirectionToString(direction);
|
||||||
}
|
}
|
||||||
|
|
||||||
Literal result(shape);
|
Literal result(shape);
|
||||||
@ -137,24 +135,25 @@ StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
StatusOr<Literal> Compare<complex128>(const Shape& shape, HloOpcode opcode,
|
StatusOr<Literal> Compare<complex128>(const Shape& shape,
|
||||||
|
ComparisonDirection direction,
|
||||||
LiteralSlice lhs_literal,
|
LiteralSlice lhs_literal,
|
||||||
LiteralSlice rhs_literal) {
|
LiteralSlice rhs_literal) {
|
||||||
std::function<bool(complex128, complex128)> compare_op;
|
std::function<bool(complex128, complex128)> compare_op;
|
||||||
switch (opcode) {
|
switch (direction) {
|
||||||
case HloOpcode::kEq:
|
case ComparisonDirection::kEq:
|
||||||
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
||||||
return lhs_el == rhs_el;
|
return lhs_el == rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kNe:
|
case ComparisonDirection::kNe:
|
||||||
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
compare_op = [](complex128 lhs_el, complex128 rhs_el) {
|
||||||
return lhs_el != rhs_el;
|
return lhs_el != rhs_el;
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
LOG(FATAL) << "unhandled direction for conversion to Comparison: "
|
||||||
<< HloOpcodeString(opcode);
|
<< ComparisonDirectionToString(direction);
|
||||||
}
|
}
|
||||||
|
|
||||||
Literal result(shape);
|
Literal result(shape);
|
||||||
@ -671,7 +670,7 @@ Status HloEvaluator::HandleComplex(HloInstruction* complex) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
|
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
|
||||||
HloOpcode opcode = compare->opcode();
|
ComparisonDirection direction = compare->comparison_direction();
|
||||||
auto lhs = compare->operand(0);
|
auto lhs = compare->operand(0);
|
||||||
auto rhs = compare->operand(1);
|
auto rhs = compare->operand(1);
|
||||||
DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
|
DCHECK(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
|
||||||
@ -687,76 +686,76 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) {
|
|||||||
case PRED: {
|
case PRED: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
evaluated_[compare],
|
evaluated_[compare],
|
||||||
Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
Compare<bool>(compare->shape(), direction, lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case U8: {
|
case U8: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<uint8>(compare->shape(), direction,
|
||||||
Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case U16: {
|
case U16: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<uint16>(compare->shape(), direction,
|
||||||
Compare<uint16>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case U32: {
|
case U32: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<uint32>(compare->shape(), direction,
|
||||||
Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case U64: {
|
case U64: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<uint64>(compare->shape(), direction,
|
||||||
Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case S8: {
|
case S8: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
evaluated_[compare],
|
evaluated_[compare],
|
||||||
Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
Compare<int8>(compare->shape(), direction, lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case S16: {
|
case S16: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<int16>(compare->shape(), direction,
|
||||||
Compare<int16>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case S32: {
|
case S32: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<int32>(compare->shape(), direction,
|
||||||
Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case S64: {
|
case S64: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<int64>(compare->shape(), direction,
|
||||||
Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case F16: {
|
case F16: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
evaluated_[compare],
|
evaluated_[compare],
|
||||||
Compare<half>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
Compare<half>(compare->shape(), direction, lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case BF16: {
|
case BF16: {
|
||||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
Compare<bfloat16>(compare->shape(), opcode,
|
Compare<bfloat16>(compare->shape(), direction,
|
||||||
lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case F32: {
|
case F32: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<float>(compare->shape(), direction,
|
||||||
Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case F64: {
|
case F64: {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
evaluated_[compare],
|
Compare<double>(compare->shape(), direction,
|
||||||
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case C64: {
|
case C64: {
|
||||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
Compare<complex64>(compare->shape(), opcode,
|
Compare<complex64>(compare->shape(), direction,
|
||||||
lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
case C128: {
|
case C128: {
|
||||||
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
TF_ASSIGN_OR_RETURN(evaluated_[compare],
|
||||||
Compare<complex128>(compare->shape(), opcode,
|
Compare<complex128>(compare->shape(), direction,
|
||||||
lhs_literal, rhs_literal));
|
lhs_literal, rhs_literal));
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
@ -2848,8 +2848,15 @@ TEST_F(HloEvaluatorTest, DoesCompareBF16) {
|
|||||||
{bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
|
{bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
|
||||||
auto expected =
|
auto expected =
|
||||||
LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
|
LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
|
||||||
TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs),
|
|
||||||
std::move(rhs));
|
HloComputation::Builder b(TestName());
|
||||||
|
auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
|
||||||
|
auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
|
||||||
|
b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
|
||||||
|
ComparisonDirection::kGe));
|
||||||
|
m_->AddEntryComputation(b.Build());
|
||||||
|
|
||||||
|
EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
|
TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
|
||||||
|
@ -258,14 +258,16 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
|
|||||||
// param0), check that the operation being performed is commutative.
|
// param0), check that the operation being performed is commutative.
|
||||||
if (root->operand(0) == param1) {
|
if (root->operand(0) == param1) {
|
||||||
CHECK_EQ(root->operand(1), param0);
|
CHECK_EQ(root->operand(1), param0);
|
||||||
switch (root->opcode()) {
|
if (root->opcode() == HloOpcode()) {
|
||||||
case HloOpcode::kLe:
|
switch (root->comparison_direction()) {
|
||||||
case HloOpcode::kGe:
|
case ComparisonDirection::kLe:
|
||||||
case HloOpcode::kGt:
|
case ComparisonDirection::kGe:
|
||||||
case HloOpcode::kLt:
|
case ComparisonDirection::kGt:
|
||||||
return nullopt;
|
case ComparisonDirection::kLt:
|
||||||
default:
|
return nullopt;
|
||||||
break;
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,18 +281,22 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
|
|||||||
return "min";
|
return "min";
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
return "max";
|
return "max";
|
||||||
case HloOpcode::kLe:
|
case HloOpcode::kCompare: {
|
||||||
return "less-or-equal";
|
switch (root->comparison_direction()) {
|
||||||
case HloOpcode::kGe:
|
case ComparisonDirection::kLe:
|
||||||
return "greater-or-equal";
|
return "less-or-equal";
|
||||||
case HloOpcode::kGt:
|
case ComparisonDirection::kGe:
|
||||||
return "greater-than";
|
return "greater-or-equal";
|
||||||
case HloOpcode::kLt:
|
case ComparisonDirection::kGt:
|
||||||
return "less-than";
|
return "greater-than";
|
||||||
case HloOpcode::kEq:
|
case ComparisonDirection::kLt:
|
||||||
return "equal-to";
|
return "less-than";
|
||||||
case HloOpcode::kNe:
|
case ComparisonDirection::kEq:
|
||||||
return "not-equal-to";
|
return "equal-to";
|
||||||
|
case ComparisonDirection::kNe:
|
||||||
|
return "not-equal-to";
|
||||||
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nullopt;
|
return nullopt;
|
||||||
}
|
}
|
||||||
@ -922,27 +928,22 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
|||||||
case HloOpcode::kCeil:
|
case HloOpcode::kCeil:
|
||||||
case HloOpcode::kClamp:
|
case HloOpcode::kClamp:
|
||||||
case HloOpcode::kClz:
|
case HloOpcode::kClz:
|
||||||
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kConvert:
|
case HloOpcode::kConvert:
|
||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
case HloOpcode::kExpm1:
|
case HloOpcode::kExpm1:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kImag:
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIota:
|
case HloOpcode::kIota:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||||
|
|
||||||
@ -31,6 +32,8 @@ namespace {
|
|||||||
using absl::StrCat;
|
using absl::StrCat;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
|
using HloGraphDumperTest = HloTestBase;
|
||||||
|
|
||||||
string TestName() {
|
string TestName() {
|
||||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||||
}
|
}
|
||||||
@ -48,7 +51,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
|
|||||||
|
|
||||||
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
|
XLA_REGISTER_GRAPH_RENDERER(DotRenderer);
|
||||||
|
|
||||||
TEST(HloGraphDumperTest, NestedFusion) {
|
TEST_F(HloGraphDumperTest, NestedFusion) {
|
||||||
HloComputation::Builder b("b");
|
HloComputation::Builder b("b");
|
||||||
|
|
||||||
// Build param0 + param1 + param2 + param3 + param4.
|
// Build param0 + param1 + param2 + param3 + param4.
|
||||||
@ -118,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) {
|
|||||||
HasSubstr(inner_sum->name()));
|
HasSubstr(inner_sum->name()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HloGraphDumperTest, Constant) {
|
TEST_F(HloGraphDumperTest, Constant) {
|
||||||
HloComputation::Builder b("b");
|
HloComputation::Builder b("b");
|
||||||
auto instruction = b.AddInstruction(
|
auto instruction = b.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
|
||||||
@ -132,7 +135,7 @@ TEST(HloGraphDumperTest, Constant) {
|
|||||||
EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
|
EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HloGraphDumperTest, TupleConstant) {
|
TEST_F(HloGraphDumperTest, TupleConstant) {
|
||||||
Shape tuple_shape = ShapeUtil::MakeTupleShape(
|
Shape tuple_shape = ShapeUtil::MakeTupleShape(
|
||||||
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
|
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
|
||||||
HloComputation::Builder b("b");
|
HloComputation::Builder b("b");
|
||||||
@ -150,5 +153,21 @@ TEST(HloGraphDumperTest, TupleConstant) {
|
|||||||
EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
|
EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloGraphDumperTest, Compare) {
|
||||||
|
const char* hlo_string = R"(
|
||||||
|
HloModule comp
|
||||||
|
|
||||||
|
ENTRY comp {
|
||||||
|
param.0 = f32[10] parameter(0)
|
||||||
|
param.1 = f32[10] parameter(1)
|
||||||
|
ROOT lt = pred[10] compare(param.0, param.1), direction=LT
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
string graph = hlo_graph_dumper::DumpGraph(*module->entry_computation(),
|
||||||
|
/*label=*/"comp", DebugOptions());
|
||||||
|
EXPECT_THAT(graph, HasSubstr("direction=LT"));
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -64,7 +64,35 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
||||||
TF_RET_CHECK(!proto.opcode().empty());
|
TF_RET_CHECK(!proto.opcode().empty());
|
||||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
|
HloOpcode opcode;
|
||||||
|
auto opcode_or = StringToHloOpcode(proto.opcode());
|
||||||
|
absl::optional<ComparisonDirection> comparison_direction;
|
||||||
|
if (opcode_or.ok()) {
|
||||||
|
opcode = opcode_or.ConsumeValueOrDie();
|
||||||
|
} else {
|
||||||
|
// Unknown opcode. Try auto-upgrading deprecated "less-than",
|
||||||
|
// "greater-than", etc opcodes, which are now rolled into the kCompare
|
||||||
|
// opcode.
|
||||||
|
if (proto.opcode() == "equal-to") {
|
||||||
|
comparison_direction = ComparisonDirection::kEq;
|
||||||
|
} else if (proto.opcode() == "not-equal-to") {
|
||||||
|
comparison_direction = ComparisonDirection::kNe;
|
||||||
|
} else if (proto.opcode() == "greater-than-or-equal-to") {
|
||||||
|
comparison_direction = ComparisonDirection::kGe;
|
||||||
|
} else if (proto.opcode() == "greater-than") {
|
||||||
|
comparison_direction = ComparisonDirection::kGt;
|
||||||
|
} else if (proto.opcode() == "less-than-or-equal-to") {
|
||||||
|
comparison_direction = ComparisonDirection::kLe;
|
||||||
|
} else if (proto.opcode() == "less-than") {
|
||||||
|
comparison_direction = ComparisonDirection::kLt;
|
||||||
|
}
|
||||||
|
if (comparison_direction) {
|
||||||
|
opcode = HloOpcode::kCompare;
|
||||||
|
} else {
|
||||||
|
return InvalidArgument("Unknown opcode: %s", proto.opcode());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TF_RET_CHECK(proto.has_shape());
|
TF_RET_CHECK(proto.has_shape());
|
||||||
|
|
||||||
std::unique_ptr<HloInstruction> instruction;
|
std::unique_ptr<HloInstruction> instruction;
|
||||||
@ -136,6 +164,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
absl::Span<const int64>(fft_length));
|
absl::Span<const int64>(fft_length));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kCompare: {
|
||||||
|
// Auto-upgraded from deprecated opcode skips the following.
|
||||||
|
if (!comparison_direction) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
comparison_direction,
|
||||||
|
StringToComparisonDirection(proto.comparison_direction()));
|
||||||
|
}
|
||||||
|
instruction =
|
||||||
|
CreateCompare(shape, operands(0), operands(1), *comparison_direction);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case HloOpcode::kTriangularSolve: {
|
case HloOpcode::kTriangularSolve: {
|
||||||
instruction = CreateTriangularSolve(shape, operands(0), operands(1),
|
instruction = CreateTriangularSolve(shape, operands(0), operands(1),
|
||||||
proto.triangular_solve_options());
|
proto.triangular_solve_options());
|
||||||
@ -688,15 +727,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
|||||||
case HloOpcode::kAtan2:
|
case HloOpcode::kAtan2:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
@ -761,6 +794,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
|||||||
fft_length);
|
fft_length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
|
||||||
|
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||||
|
ComparisonDirection direction) {
|
||||||
|
return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction>
|
/* static */ std::unique_ptr<HloInstruction>
|
||||||
HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
|
HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
|
||||||
HloInstruction* b,
|
HloInstruction* b,
|
||||||
@ -1311,6 +1350,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
case HloOpcode::kBatchNormInference:
|
case HloOpcode::kBatchNormInference:
|
||||||
case HloOpcode::kBatchNormGrad:
|
case HloOpcode::kBatchNormGrad:
|
||||||
case HloOpcode::kFft:
|
case HloOpcode::kFft:
|
||||||
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kSend:
|
case HloOpcode::kSend:
|
||||||
case HloOpcode::kSendDone:
|
case HloOpcode::kSendDone:
|
||||||
case HloOpcode::kRecv:
|
case HloOpcode::kRecv:
|
||||||
@ -1384,12 +1424,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
@ -1705,26 +1739,20 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
case HloOpcode::kCos:
|
case HloOpcode::kCos:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kDynamicUpdateSlice:
|
case HloOpcode::kDynamicUpdateSlice:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
case HloOpcode::kExpm1:
|
case HloOpcode::kExpm1:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kImag:
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kAnd:
|
case HloOpcode::kAnd:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
case HloOpcode::kXor:
|
case HloOpcode::kXor:
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
@ -1772,6 +1800,7 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
case HloOpcode::kBatchNormInference:
|
case HloOpcode::kBatchNormInference:
|
||||||
case HloOpcode::kBatchNormGrad:
|
case HloOpcode::kBatchNormGrad:
|
||||||
case HloOpcode::kFft:
|
case HloOpcode::kFft:
|
||||||
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kSend:
|
case HloOpcode::kSend:
|
||||||
case HloOpcode::kSendDone:
|
case HloOpcode::kSendDone:
|
||||||
case HloOpcode::kRecv:
|
case HloOpcode::kRecv:
|
||||||
@ -2119,17 +2148,12 @@ bool HloInstruction::IsElementwiseImpl(
|
|||||||
// Binary elementwise operations, the same as in IsElementwiseBinary().
|
// Binary elementwise operations, the same as in IsElementwiseBinary().
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
case HloOpcode::kAtan2:
|
case HloOpcode::kAtan2:
|
||||||
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
@ -2472,12 +2496,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
|||||||
return visitor->HandleGetTupleElement(this);
|
return visitor->HandleGetTupleElement(this);
|
||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
return visitor->HandleParameter(this);
|
return visitor->HandleParameter(this);
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kNe:
|
|
||||||
return visitor->HandleCompare(this);
|
return visitor->HandleCompare(this);
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
return visitor->HandleComplex(this);
|
return visitor->HandleComplex(this);
|
||||||
@ -3519,6 +3538,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const {
|
|||||||
return Cast<HloDomainInstruction>(this)->user_side_metadata();
|
return Cast<HloDomainInstruction>(this)->user_side_metadata();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComparisonDirection HloInstruction::comparison_direction() const {
|
||||||
|
return Cast<HloCompareInstruction>(this)->direction();
|
||||||
|
}
|
||||||
|
|
||||||
const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
|
const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
|
||||||
return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
|
return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
#include "tensorflow/compiler/xla/iterator_util.h"
|
#include "tensorflow/compiler/xla/iterator_util.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
@ -444,6 +445,11 @@ class HloInstruction {
|
|||||||
const Shape& shape, HloInstruction* operand, FftType fft_type,
|
const Shape& shape, HloInstruction* operand, FftType fft_type,
|
||||||
absl::Span<const int64> fft_length);
|
absl::Span<const int64> fft_length);
|
||||||
|
|
||||||
|
// Creates a compare op, performing the comparison specified in direction.
|
||||||
|
static std::unique_ptr<HloInstruction> CreateCompare(
|
||||||
|
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||||
|
ComparisonDirection direction);
|
||||||
|
|
||||||
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
|
static std::unique_ptr<HloInstruction> CreateTriangularSolve(
|
||||||
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
||||||
const TriangularSolveOptions& options);
|
const TriangularSolveOptions& options);
|
||||||
@ -1600,6 +1606,9 @@ class HloInstruction {
|
|||||||
// Delegates to HloDomainInstruction::user_side_metadata().
|
// Delegates to HloDomainInstruction::user_side_metadata().
|
||||||
const DomainMetadata& user_side_metadata() const;
|
const DomainMetadata& user_side_metadata() const;
|
||||||
|
|
||||||
|
// Delegates to HloCompareInstruction::direction().
|
||||||
|
ComparisonDirection comparison_direction() const;
|
||||||
|
|
||||||
// Delegates to HloTriangularSolveInstruction::triangular_solve_options().
|
// Delegates to HloTriangularSolveInstruction::triangular_solve_options().
|
||||||
const TriangularSolveOptions& triangular_solve_options() const;
|
const TriangularSolveOptions& triangular_solve_options() const;
|
||||||
|
|
||||||
|
@ -1655,7 +1655,7 @@ body (bparam: s32[]) -> s32[] {
|
|||||||
condition (cparam: s32[]) -> pred[] {
|
condition (cparam: s32[]) -> pred[] {
|
||||||
xconstant = s32[] constant(5)
|
xconstant = s32[] constant(5)
|
||||||
cparam = s32[] parameter(0)
|
cparam = s32[] parameter(0)
|
||||||
ROOT greater-than = pred[] greater-than(xconstant, cparam)
|
ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry (param: s32[]) -> s32[] {
|
ENTRY entry (param: s32[]) -> s32[] {
|
||||||
|
@ -202,6 +202,42 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
|
|||||||
fft_length_);
|
fft_length_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloCompareInstruction::HloCompareInstruction(const Shape& shape,
|
||||||
|
HloInstruction* lhs,
|
||||||
|
HloInstruction* rhs,
|
||||||
|
ComparisonDirection direction)
|
||||||
|
: HloInstruction(HloOpcode::kCompare, shape), direction_(direction) {
|
||||||
|
AppendOperand(lhs);
|
||||||
|
AppendOperand(rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInstructionProto HloCompareInstruction::ToProto() const {
|
||||||
|
HloInstructionProto proto = HloInstruction::ToProto();
|
||||||
|
proto.set_comparison_direction(ComparisonDirectionToString(direction_));
|
||||||
|
return proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> HloCompareInstruction::ExtraAttributesToStringImpl(
|
||||||
|
const HloPrintOptions& options) const {
|
||||||
|
return {StrCat("direction=", ComparisonDirectionToString(direction()))};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloCompareInstruction::IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
eq_computations) const {
|
||||||
|
const auto& casted_other = static_cast<const HloCompareInstruction&>(other);
|
||||||
|
return direction() == casted_other.direction();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<HloInstruction> HloCompareInstruction::CloneWithNewOperandsImpl(
|
||||||
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||||
|
HloCloneContext* context) const {
|
||||||
|
CHECK_EQ(new_operands.size(), 2);
|
||||||
|
return absl::make_unique<HloCompareInstruction>(shape, new_operands[0],
|
||||||
|
new_operands[1], direction());
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
|
// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
|
||||||
|
@ -131,6 +131,28 @@ class HloFftInstruction : public HloInstruction {
|
|||||||
std::vector<int64> fft_length_;
|
std::vector<int64> fft_length_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HloCompareInstruction : public HloInstruction {
|
||||||
|
public:
|
||||||
|
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
|
||||||
|
HloInstruction* rhs,
|
||||||
|
ComparisonDirection direction);
|
||||||
|
ComparisonDirection direction() const { return direction_; }
|
||||||
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<string> ExtraAttributesToStringImpl(
|
||||||
|
const HloPrintOptions& options) const override;
|
||||||
|
bool IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
eq_computations) const override;
|
||||||
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||||
|
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||||
|
HloCloneContext* context) const override;
|
||||||
|
|
||||||
|
ComparisonDirection direction_;
|
||||||
|
};
|
||||||
|
|
||||||
class HloTriangularSolveInstruction : public HloInstruction {
|
class HloTriangularSolveInstruction : public HloInstruction {
|
||||||
public:
|
public:
|
||||||
explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
|
explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
|
||||||
|
@ -255,7 +255,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
|
|||||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.2 = s32[] constant(5)
|
constant.2 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(0)
|
constant.3 = s32[] constant(0)
|
||||||
@ -308,7 +308,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
|
|||||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
|
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1
|
||||||
add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
|
add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4)
|
||||||
constant.2 = s32[] constant(5)
|
constant.2 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(add.1, constant.2)
|
ROOT less-than = pred[] compare(add.1, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(0)
|
constant.3 = s32[] constant(0)
|
||||||
@ -360,7 +360,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
|
|||||||
loop_var.2 = (s32[], s32[], s32[]) parameter(0)
|
loop_var.2 = (s32[], s32[], s32[]) parameter(0)
|
||||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.1 = s32[] constant(5)
|
constant.1 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1)
|
ROOT less-than = pred[] compare(get-tuple-element.4, constant.1), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.2 = s32[] constant(0)
|
constant.2 = s32[] constant(0)
|
||||||
@ -415,7 +415,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
|
|||||||
cond_param = (s32[]) parameter(0)
|
cond_param = (s32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||||
constant.2 = s32[] constant(10)
|
constant.2 = s32[] constant(10)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(0)
|
constant.3 = s32[] constant(0)
|
||||||
@ -448,13 +448,13 @@ TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
|
|||||||
cond_param = (s32[]) parameter(0)
|
cond_param = (s32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||||
constant.2 = s32[] constant(10)
|
constant.2 = s32[] constant(10)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
OuterWhileCondition {
|
OuterWhileCondition {
|
||||||
cond_param.2 = (s32[]) parameter(0)
|
cond_param.2 = (s32[]) parameter(0)
|
||||||
get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
|
get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
|
||||||
constant.5 = s32[] constant(5)
|
constant.5 = s32[] constant(5)
|
||||||
ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
|
ROOT less-than.2 = pred[] compare(get-tuple-element.5, constant.5), direction=LT
|
||||||
}
|
}
|
||||||
OuterWhileBody {
|
OuterWhileBody {
|
||||||
body_param.2 = (s32[]) parameter(0)
|
body_param.2 = (s32[]) parameter(0)
|
||||||
|
@ -89,6 +89,22 @@ bool HloParameterMatcher::MatchAndExplain(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloComparisonMatcher::MatchAndExplain(
|
||||||
|
const HloInstruction* instruction,
|
||||||
|
::testing::MatchResultListener* listener) const {
|
||||||
|
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (instruction->comparison_direction() != direction_) {
|
||||||
|
*listener << "has wrong comparison direction (got "
|
||||||
|
<< ComparisonDirectionToString(
|
||||||
|
instruction->comparison_direction())
|
||||||
|
<< ", want " << ComparisonDirectionToString(direction_) << ")";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool HloGetTupleElementMatcher::MatchAndExplain(
|
bool HloGetTupleElementMatcher::MatchAndExplain(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
::testing::MatchResultListener* listener) const {
|
::testing::MatchResultListener* listener) const {
|
||||||
|
@ -54,6 +54,21 @@ class HloParameterMatcher : public HloMatcher {
|
|||||||
int64 parameter_number_;
|
int64 parameter_number_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Custom matcher for comparisons, which accepts a comparison direction.
|
||||||
|
class HloComparisonMatcher : public HloMatcher {
|
||||||
|
public:
|
||||||
|
explicit HloComparisonMatcher(
|
||||||
|
ComparisonDirection direction,
|
||||||
|
std::vector<::testing::Matcher<const HloInstruction*>> operands)
|
||||||
|
: HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
|
||||||
|
|
||||||
|
bool MatchAndExplain(const HloInstruction* instruction,
|
||||||
|
::testing::MatchResultListener* listener) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ComparisonDirection direction_;
|
||||||
|
};
|
||||||
|
|
||||||
// Custom matcher for get-tuple-element instructions, which accepts a tuple
|
// Custom matcher for get-tuple-element instructions, which accepts a tuple
|
||||||
// index to match.
|
// index to match.
|
||||||
class HloGetTupleElementMatcher : public HloMatcher {
|
class HloGetTupleElementMatcher : public HloMatcher {
|
||||||
@ -172,6 +187,7 @@ HLO_MATCHER(BatchNormGrad);
|
|||||||
HLO_MATCHER(Call);
|
HLO_MATCHER(Call);
|
||||||
HLO_MATCHER(Ceil);
|
HLO_MATCHER(Ceil);
|
||||||
HLO_MATCHER(Clamp);
|
HLO_MATCHER(Clamp);
|
||||||
|
HLO_MATCHER(Compare);
|
||||||
HLO_MATCHER(Concatenate);
|
HLO_MATCHER(Concatenate);
|
||||||
HLO_MATCHER(Conditional);
|
HLO_MATCHER(Conditional);
|
||||||
HLO_MATCHER(Constant);
|
HLO_MATCHER(Constant);
|
||||||
@ -184,28 +200,22 @@ HLO_MATCHER(Divide);
|
|||||||
HLO_MATCHER(Domain);
|
HLO_MATCHER(Domain);
|
||||||
HLO_MATCHER(DynamicSlice);
|
HLO_MATCHER(DynamicSlice);
|
||||||
HLO_MATCHER(DynamicUpdateSlice);
|
HLO_MATCHER(DynamicUpdateSlice);
|
||||||
HLO_MATCHER(Eq);
|
|
||||||
HLO_MATCHER(Exp);
|
HLO_MATCHER(Exp);
|
||||||
HLO_MATCHER(Floor);
|
HLO_MATCHER(Floor);
|
||||||
HLO_MATCHER(Fusion);
|
HLO_MATCHER(Fusion);
|
||||||
HLO_MATCHER(Ge);
|
|
||||||
HLO_MATCHER(AfterAll);
|
HLO_MATCHER(AfterAll);
|
||||||
HLO_MATCHER(Gt);
|
|
||||||
HLO_MATCHER(Iota);
|
HLO_MATCHER(Iota);
|
||||||
HLO_MATCHER(Infeed);
|
HLO_MATCHER(Infeed);
|
||||||
HLO_MATCHER(IsFinite);
|
HLO_MATCHER(IsFinite);
|
||||||
HLO_MATCHER(Le);
|
|
||||||
HLO_MATCHER(Log);
|
HLO_MATCHER(Log);
|
||||||
HLO_MATCHER(And);
|
HLO_MATCHER(And);
|
||||||
HLO_MATCHER(Not);
|
HLO_MATCHER(Not);
|
||||||
HLO_MATCHER(Or);
|
HLO_MATCHER(Or);
|
||||||
HLO_MATCHER(Xor);
|
HLO_MATCHER(Xor);
|
||||||
HLO_MATCHER(Lt);
|
|
||||||
HLO_MATCHER(Map);
|
HLO_MATCHER(Map);
|
||||||
HLO_MATCHER(Maximum);
|
HLO_MATCHER(Maximum);
|
||||||
HLO_MATCHER(Minimum);
|
HLO_MATCHER(Minimum);
|
||||||
HLO_MATCHER(Multiply);
|
HLO_MATCHER(Multiply);
|
||||||
HLO_MATCHER(Ne);
|
|
||||||
HLO_MATCHER(Negate);
|
HLO_MATCHER(Negate);
|
||||||
HLO_MATCHER(Outfeed);
|
HLO_MATCHER(Outfeed);
|
||||||
HLO_MATCHER(Pad);
|
HLO_MATCHER(Pad);
|
||||||
@ -256,6 +266,38 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
|
|||||||
new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
|
new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Comparison matchers below do not require any additional arguments.
|
||||||
|
template <typename... M>
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
|
||||||
|
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||||
|
ComparisonDirection::kEq, {operands...}));
|
||||||
|
}
|
||||||
|
template <typename... M>
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
|
||||||
|
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||||
|
ComparisonDirection::kNe, {operands...}));
|
||||||
|
}
|
||||||
|
template <typename... M>
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
|
||||||
|
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||||
|
ComparisonDirection::kGe, {operands...}));
|
||||||
|
}
|
||||||
|
template <typename... M>
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
|
||||||
|
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||||
|
ComparisonDirection::kGt, {operands...}));
|
||||||
|
}
|
||||||
|
template <typename... M>
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
|
||||||
|
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||||
|
ComparisonDirection::kLe, {operands...}));
|
||||||
|
}
|
||||||
|
template <typename... M>
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
|
||||||
|
return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
|
||||||
|
ComparisonDirection::kLt, {operands...}));
|
||||||
|
}
|
||||||
|
|
||||||
// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
|
// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
|
||||||
// tuple element of operand, while GetTupleElement(operand) matches any GTE
|
// tuple element of operand, while GetTupleElement(operand) matches any GTE
|
||||||
// operation on operand, and GetTupleElement() matches any GTE operation at all.
|
// operation on operand, and GetTupleElement() matches any GTE operation at all.
|
||||||
|
@ -220,5 +220,33 @@ ENTRY DotOperationFusion_TransposeFusion {
|
|||||||
"rhs_contracting_dimensions (got {0} want {1})");
|
"rhs_contracting_dimensions (got {0} want {1})");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(HloMatchersTest, ComparisonMatcher) {
|
||||||
|
auto shape = ShapeUtil::MakeShape(F32, {1});
|
||||||
|
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
|
||||||
|
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
|
||||||
|
auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||||
|
ComparisonDirection::kEq);
|
||||||
|
auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||||
|
ComparisonDirection::kNe);
|
||||||
|
auto add =
|
||||||
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
|
||||||
|
auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
|
||||||
|
ComparisonDirection::kLe);
|
||||||
|
|
||||||
|
EXPECT_THAT(eq.get(), op::Compare());
|
||||||
|
EXPECT_THAT(eq.get(), op::Eq());
|
||||||
|
EXPECT_THAT(ne.get(), op::Compare());
|
||||||
|
EXPECT_THAT(ne.get(), op::Ne());
|
||||||
|
EXPECT_THAT(le.get(),
|
||||||
|
op::Compare(op::Parameter(0),
|
||||||
|
op::Add(op::Parameter(0), op::Parameter(1))));
|
||||||
|
EXPECT_THAT(le.get(), op::Le(op::Parameter(0),
|
||||||
|
op::Add(op::Parameter(0), op::Parameter(1))));
|
||||||
|
|
||||||
|
EXPECT_THAT(Explain(eq.get(), op::Add()), Eq(""));
|
||||||
|
EXPECT_THAT(Explain(eq.get(), op::Ne()),
|
||||||
|
Eq("has wrong comparison direction (got EQ, want NE)"));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -254,8 +254,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
|||||||
HloInstruction* zero_vector =
|
HloInstruction* zero_vector =
|
||||||
cond_builder.AddInstruction(HloInstruction::CreateConstant(
|
cond_builder.AddInstruction(HloInstruction::CreateConstant(
|
||||||
LiteralUtil::CreateR1<float>({0, 0, 0, 0})));
|
LiteralUtil::CreateR1<float>({0, 0, 0, 0})));
|
||||||
cond_builder.AddInstruction(HloInstruction::CreateBinary(
|
cond_builder.AddInstruction(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
|
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param,
|
||||||
|
zero_vector, ComparisonDirection::kNe));
|
||||||
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
// param - 1
|
// param - 1
|
||||||
|
@ -86,7 +86,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
|
|||||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.2 = s32[] constant(5)
|
constant.2 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(0)
|
constant.3 = s32[] constant(0)
|
||||||
@ -125,7 +125,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
|
|||||||
loop_var.2 = (s32[], f32[]) parameter(0)
|
loop_var.2 = (s32[], f32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.3 = s32[] constant(5)
|
constant.3 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.3), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.4 = s32[] constant(0)
|
constant.4 = s32[] constant(0)
|
||||||
@ -163,7 +163,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
|
|||||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.2 = s32[] constant(5)
|
constant.2 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(0)
|
constant.3 = s32[] constant(0)
|
||||||
@ -206,7 +206,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
|
|||||||
loop_var.2 = (s32[], s32[]) parameter(0)
|
loop_var.2 = (s32[], s32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
|
||||||
constant.2 = s32[] constant(5)
|
constant.2 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(0)
|
constant.3 = s32[] constant(0)
|
||||||
@ -248,7 +248,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
|||||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.2 = s32[] constant(5)
|
constant.2 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
SimpleLoop.body1 {
|
SimpleLoop.body1 {
|
||||||
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
||||||
@ -263,7 +263,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
|
|||||||
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
||||||
constant.4 = s32[] constant(5)
|
constant.4 = s32[] constant(5)
|
||||||
ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
|
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.5 = s32[] constant(0)
|
constant.5 = s32[] constant(0)
|
||||||
@ -316,7 +316,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
|||||||
loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
|
loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
|
||||||
constant.2 = s32[] constant(5)
|
constant.2 = s32[] constant(5)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
SimpleLoop.body1 {
|
SimpleLoop.body1 {
|
||||||
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.3 = (s32[], s32[3]{0}) parameter(0)
|
||||||
@ -331,7 +331,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
|||||||
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.4 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
|
||||||
constant.4 = s32[] constant(5)
|
constant.4 = s32[] constant(5)
|
||||||
ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
|
ROOT less-than.1 = pred[] compare(get-tuple-element.6, constant.4), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.5 = s32[] constant(0)
|
constant.5 = s32[] constant(0)
|
||||||
@ -383,7 +383,7 @@ TEST_F(HloModuleDceTest, WhileWithOutfeed) {
|
|||||||
cond_param = (s32[]) parameter(0)
|
cond_param = (s32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||||
constant.2 = s32[] constant(10)
|
constant.2 = s32[] constant(10)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(0)
|
constant.3 = s32[] constant(0)
|
||||||
@ -418,7 +418,7 @@ TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
|
|||||||
cond_param = (s32[], s32[]) parameter(0)
|
cond_param = (s32[], s32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
|
||||||
constant.2 = s32[] constant(10)
|
constant.2 = s32[] constant(10)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
p0 = (s32[]) parameter(0)
|
p0 = (s32[]) parameter(0)
|
||||||
|
@ -44,21 +44,8 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
|
|||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CHECK_DEFAULT(property_name, opcode_name) false
|
|
||||||
#define CHECK_PROPERTY(property_name, opcode_name, value) \
|
|
||||||
(value & property_name)
|
|
||||||
#define RESOLVE(_1, _2, target, ...) target
|
|
||||||
#define HAS_PROPERTY(property, ...) \
|
|
||||||
RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__)
|
|
||||||
|
|
||||||
bool HloOpcodeIsComparison(HloOpcode opcode) {
|
bool HloOpcodeIsComparison(HloOpcode opcode) {
|
||||||
switch (opcode) {
|
return opcode == HloOpcode::kCompare;
|
||||||
#define CASE_IS_COMPARISON(enum_name, opcode_name, ...) \
|
|
||||||
case HloOpcode::enum_name: \
|
|
||||||
return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__);
|
|
||||||
HLO_OPCODE_LIST(CASE_IS_COMPARISON)
|
|
||||||
#undef CASE_IS_COMPARISON
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloOpcodeIsVariadic(HloOpcode opcode) {
|
bool HloOpcodeIsVariadic(HloOpcode opcode) {
|
||||||
@ -82,9 +69,4 @@ absl::optional<int> HloOpcodeArity(HloOpcode opcode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef HAS_PROPERTY
|
|
||||||
#undef RESOLVE
|
|
||||||
#undef CHECK_DEFAULT
|
|
||||||
#undef CHECK_PROPERTY
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -19,8 +19,10 @@ limitations under the License.
|
|||||||
#include <iosfwd>
|
#include <iosfwd>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -65,6 +67,7 @@ namespace xla {
|
|||||||
V(kClamp, "clamp", 3) \
|
V(kClamp, "clamp", 3) \
|
||||||
V(kCollectivePermute, "collective-permute", 1) \
|
V(kCollectivePermute, "collective-permute", 1) \
|
||||||
V(kClz, "count-leading-zeros", 1) \
|
V(kClz, "count-leading-zeros", 1) \
|
||||||
|
V(kCompare, "compare", 2) \
|
||||||
V(kComplex, "complex", 2) \
|
V(kComplex, "complex", 2) \
|
||||||
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
|
V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \
|
||||||
V(kConditional, "conditional", kHloOpcodeIsVariadic) \
|
V(kConditional, "conditional", kHloOpcodeIsVariadic) \
|
||||||
@ -79,34 +82,28 @@ namespace xla {
|
|||||||
V(kDot, "dot", 2) \
|
V(kDot, "dot", 2) \
|
||||||
V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \
|
V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \
|
||||||
V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \
|
V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \
|
||||||
V(kEq, "equal-to", 2, kHloOpcodeIsComparison) \
|
|
||||||
V(kExp, "exponential", 1) \
|
V(kExp, "exponential", 1) \
|
||||||
V(kExpm1, "exponential-minus-one", 1) \
|
V(kExpm1, "exponential-minus-one", 1) \
|
||||||
V(kFft, "fft", 1) \
|
V(kFft, "fft", 1) \
|
||||||
V(kFloor, "floor", 1) \
|
V(kFloor, "floor", 1) \
|
||||||
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
|
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
|
||||||
V(kGather, "gather", 2) \
|
V(kGather, "gather", 2) \
|
||||||
V(kGe, "greater-than-or-equal-to", 2, kHloOpcodeIsComparison) \
|
|
||||||
V(kGetDimensionSize, "get-dimension-size", 1) \
|
V(kGetDimensionSize, "get-dimension-size", 1) \
|
||||||
V(kGetTupleElement, "get-tuple-element", 1) \
|
V(kGetTupleElement, "get-tuple-element", 1) \
|
||||||
V(kGt, "greater-than", 2, kHloOpcodeIsComparison) \
|
|
||||||
V(kImag, "imag", 1) \
|
V(kImag, "imag", 1) \
|
||||||
V(kInfeed, "infeed", 1) \
|
V(kInfeed, "infeed", 1) \
|
||||||
V(kIota, "iota", 0) \
|
V(kIota, "iota", 0) \
|
||||||
V(kIsFinite, "is-finite", 1) \
|
V(kIsFinite, "is-finite", 1) \
|
||||||
V(kLe, "less-than-or-equal-to", 2, kHloOpcodeIsComparison) \
|
|
||||||
V(kLog, "log", 1) \
|
V(kLog, "log", 1) \
|
||||||
V(kLog1p, "log-plus-one", 1) \
|
V(kLog1p, "log-plus-one", 1) \
|
||||||
V(kAnd, "and", 2) \
|
V(kAnd, "and", 2) \
|
||||||
V(kNot, "not", 1) \
|
V(kNot, "not", 1) \
|
||||||
V(kOr, "or", 2) \
|
V(kOr, "or", 2) \
|
||||||
V(kXor, "xor", 2) \
|
V(kXor, "xor", 2) \
|
||||||
V(kLt, "less-than", 2, kHloOpcodeIsComparison) \
|
|
||||||
V(kMap, "map", kHloOpcodeIsVariadic) \
|
V(kMap, "map", kHloOpcodeIsVariadic) \
|
||||||
V(kMaximum, "maximum", 2) \
|
V(kMaximum, "maximum", 2) \
|
||||||
V(kMinimum, "minimum", 2) \
|
V(kMinimum, "minimum", 2) \
|
||||||
V(kMultiply, "multiply", 2) \
|
V(kMultiply, "multiply", 2) \
|
||||||
V(kNe, "not-equal-to", 2, kHloOpcodeIsComparison) \
|
|
||||||
V(kNegate, "negate", 1) \
|
V(kNegate, "negate", 1) \
|
||||||
V(kOutfeed, "outfeed", 2) \
|
V(kOutfeed, "outfeed", 2) \
|
||||||
V(kPad, "pad", 2) \
|
V(kPad, "pad", 2) \
|
||||||
|
@ -42,12 +42,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
|
|||||||
|
|
||||||
// Test some properties.
|
// Test some properties.
|
||||||
switch (opcode) {
|
switch (opcode) {
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
EXPECT_TRUE(HloOpcodeIsComparison(opcode));
|
EXPECT_TRUE(HloOpcodeIsComparison(opcode));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -306,7 +306,7 @@ condition.v4 {
|
|||||||
constant.2 = s32[] constant(2)
|
constant.2 = s32[] constant(2)
|
||||||
prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
|
prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0)
|
||||||
get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
|
get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0
|
||||||
ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8)
|
ROOT greater-than = pred[] compare(constant.2, get-tuple-element.8), direction=GT
|
||||||
}
|
}
|
||||||
|
|
||||||
fused_computation {
|
fused_computation {
|
||||||
|
@ -183,6 +183,7 @@ class HloParser {
|
|||||||
kHloComputation,
|
kHloComputation,
|
||||||
kBracedHloComputationList,
|
kBracedHloComputationList,
|
||||||
kFftType,
|
kFftType,
|
||||||
|
kComparisonDirection,
|
||||||
kWindow,
|
kWindow,
|
||||||
kConvolutionDimensionNumbers,
|
kConvolutionDimensionNumbers,
|
||||||
kSharding,
|
kSharding,
|
||||||
@ -300,6 +301,7 @@ class HloParser {
|
|||||||
bool ParseTiles(std::vector<Tile>* tiles);
|
bool ParseTiles(std::vector<Tile>* tiles);
|
||||||
bool ParseOpcode(HloOpcode* result);
|
bool ParseOpcode(HloOpcode* result);
|
||||||
bool ParseFftType(FftType* result);
|
bool ParseFftType(FftType* result);
|
||||||
|
bool ParseComparisonDirection(ComparisonDirection* result);
|
||||||
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
||||||
bool ParseRandomDistribution(RandomDistribution* result);
|
bool ParseRandomDistribution(RandomDistribution* result);
|
||||||
bool ParsePrecision(PrecisionConfig::Precision* result);
|
bool ParsePrecision(PrecisionConfig::Precision* result);
|
||||||
@ -763,12 +765,6 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
case HloOpcode::kSubtract:
|
case HloOpcode::kSubtract:
|
||||||
case HloOpcode::kAtan2:
|
case HloOpcode::kAtan2:
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
@ -1133,6 +1129,18 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
shape, operands[0], operands[1], options));
|
shape, operands[0], operands[1], options));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kCompare: {
|
||||||
|
optional<ComparisonDirection> direction;
|
||||||
|
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
|
||||||
|
&direction};
|
||||||
|
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||||
|
!ParseAttributes(attrs)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
|
||||||
|
shape, operands[0], operands[1], *direction));
|
||||||
|
break;
|
||||||
|
}
|
||||||
case HloOpcode::kCholesky: {
|
case HloOpcode::kCholesky: {
|
||||||
CholeskyOptions options;
|
CholeskyOptions options;
|
||||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||||
@ -2728,6 +2736,15 @@ bool HloParser::ParseAttributeHelper(
|
|||||||
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
|
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
case AttrTy::kComparisonDirection: {
|
||||||
|
ComparisonDirection result;
|
||||||
|
if (!ParseComparisonDirection(&result)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
|
||||||
|
->emplace(result);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
case AttrTy::kWindow: {
|
case AttrTy::kWindow: {
|
||||||
Window result;
|
Window result;
|
||||||
if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
|
if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
|
||||||
@ -3756,6 +3773,22 @@ bool HloParser::ParseFftType(FftType* result) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
|
||||||
|
VLOG(1) << "ParseComparisonDirection";
|
||||||
|
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||||
|
return TokenError("expects comparison direction");
|
||||||
|
}
|
||||||
|
string val = lexer_.GetStrVal();
|
||||||
|
auto status_or_result = StringToComparisonDirection(val);
|
||||||
|
if (!status_or_result.ok()) {
|
||||||
|
return TokenError(
|
||||||
|
StrFormat("expects comparison direction but sees: %s", val));
|
||||||
|
}
|
||||||
|
*result = status_or_result.ValueOrDie();
|
||||||
|
lexer_.Lex();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
|
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
|
||||||
VLOG(1) << "ParseFusionKind";
|
VLOG(1) << "ParseFusionKind";
|
||||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||||
|
@ -222,7 +222,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
|
|||||||
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
||||||
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
|
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
|
||||||
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
|
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
|
||||||
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
|
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
|
||||||
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
|
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,7 +292,7 @@ R"(HloModule WhileWithScalarS32Result_module
|
|||||||
%condition.v3 (prev.2: s32[]) -> pred[] {
|
%condition.v3 (prev.2: s32[]) -> pred[] {
|
||||||
%constant.1 = s32[] constant(5)
|
%constant.1 = s32[] constant(5)
|
||||||
%prev.2 = s32[] parameter(0)
|
%prev.2 = s32[] parameter(0)
|
||||||
ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2)
|
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
|
ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
|
||||||
@ -474,7 +474,7 @@ R"(HloModule R4F32OverlapSmall_module
|
|||||||
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
||||||
%lhs = f32[] parameter(0)
|
%lhs = f32[] parameter(0)
|
||||||
%rhs = f32[] parameter(1)
|
%rhs = f32[] parameter(1)
|
||||||
ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
|
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
|
||||||
}
|
}
|
||||||
|
|
||||||
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
||||||
@ -500,7 +500,7 @@ R"(HloModule select_and_scatter_scalar
|
|||||||
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
|
||||||
%lhs = f32[] parameter(0)
|
%lhs = f32[] parameter(0)
|
||||||
%rhs = f32[] parameter(1)
|
%rhs = f32[] parameter(1)
|
||||||
ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
|
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
|
||||||
}
|
}
|
||||||
|
|
||||||
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
|
||||||
@ -1037,7 +1037,7 @@ R"(HloModule TupleReduce
|
|||||||
max_argmax {
|
max_argmax {
|
||||||
value = f32[] parameter(2)
|
value = f32[] parameter(2)
|
||||||
prev_max = f32[] parameter(0)
|
prev_max = f32[] parameter(0)
|
||||||
is_next_larger = pred[] greater-than-or-equal-to(value, prev_max)
|
is_next_larger = pred[] compare(value, prev_max), direction=GE
|
||||||
max = f32[] select(is_next_larger, value, prev_max)
|
max = f32[] select(is_next_larger, value, prev_max)
|
||||||
index = s32[] parameter(3)
|
index = s32[] parameter(3)
|
||||||
prev_argmax = s32[] parameter(1)
|
prev_argmax = s32[] parameter(1)
|
||||||
@ -1106,7 +1106,7 @@ R"(HloModule sort
|
|||||||
compare {
|
compare {
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY Sort {
|
ENTRY Sort {
|
||||||
@ -1126,7 +1126,7 @@ compare {
|
|||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY Sort {
|
ENTRY Sort {
|
||||||
@ -1145,7 +1145,7 @@ R"(HloModule sort
|
|||||||
compare {
|
compare {
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY Sort {
|
ENTRY Sort {
|
||||||
@ -1165,7 +1165,7 @@ compare {
|
|||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY Sort {
|
ENTRY Sort {
|
||||||
@ -1190,7 +1190,7 @@ compare {
|
|||||||
p.3.rhs = f32[] parameter(7)
|
p.3.rhs = f32[] parameter(7)
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY Sort {
|
ENTRY Sort {
|
||||||
@ -1211,7 +1211,7 @@ R"(HloModule sort
|
|||||||
compare {
|
compare {
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY Sort {
|
ENTRY Sort {
|
||||||
@ -1469,7 +1469,7 @@ compare {
|
|||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
p.0.lhs = f32[] parameter(0)
|
p.0.lhs = f32[] parameter(0)
|
||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY Sort {
|
ENTRY Sort {
|
||||||
@ -1656,7 +1656,7 @@ TEST_F(HloParserTest, WrongOperandsSize) {
|
|||||||
|
|
||||||
ENTRY %blabla (x: f32[]) -> pred[] {
|
ENTRY %blabla (x: f32[]) -> pred[] {
|
||||||
%x = f32[]{} parameter(0)
|
%x = f32[]{} parameter(0)
|
||||||
%eq = pred[]{} equal-to(f32[]{} %x)
|
%eq = pred[]{} compare(f32[]{} %x), direction=EQ
|
||||||
}
|
}
|
||||||
|
|
||||||
)";
|
)";
|
||||||
@ -1668,7 +1668,7 @@ TEST_F(HloParserTest, OperandNotFound) {
|
|||||||
const string original = R"(HloModule operand_not_found:
|
const string original = R"(HloModule operand_not_found:
|
||||||
ENTRY %blabla (x: f32[]) -> pred[] {
|
ENTRY %blabla (x: f32[]) -> pred[] {
|
||||||
%x = f32[]{} parameter(0)
|
%x = f32[]{} parameter(0)
|
||||||
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
|
%eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto result = ParseHloString(original);
|
auto result = ParseHloString(original);
|
||||||
|
@ -228,7 +228,7 @@ HloModule UpdateScheduleWithMultipleComputations
|
|||||||
%param = (s32[], token[]) parameter(0)
|
%param = (s32[], token[]) parameter(0)
|
||||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||||
%constant = s32[] constant(42)
|
%constant = s32[] constant(42)
|
||||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %WhileLoop () -> s32[] {
|
ENTRY %WhileLoop () -> s32[] {
|
||||||
@ -297,7 +297,7 @@ HloModule UpdateScheduleWithMultipleComputations
|
|||||||
%param = (s32[], token[]) parameter(0)
|
%param = (s32[], token[]) parameter(0)
|
||||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||||
%constant = s32[] constant(42)
|
%constant = s32[] constant(42)
|
||||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %WhileLoop () -> s32[] {
|
ENTRY %WhileLoop () -> s32[] {
|
||||||
|
@ -65,6 +65,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
|||||||
case HloOpcode::kCeil:
|
case HloOpcode::kCeil:
|
||||||
case HloOpcode::kClamp:
|
case HloOpcode::kClamp:
|
||||||
case HloOpcode::kClz:
|
case HloOpcode::kClz:
|
||||||
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kConcatenate:
|
case HloOpcode::kConcatenate:
|
||||||
case HloOpcode::kConstant:
|
case HloOpcode::kConstant:
|
||||||
@ -72,21 +73,15 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
|||||||
case HloOpcode::kCopy:
|
case HloOpcode::kCopy:
|
||||||
case HloOpcode::kDynamicSlice:
|
case HloOpcode::kDynamicSlice:
|
||||||
case HloOpcode::kDynamicUpdateSlice:
|
case HloOpcode::kDynamicUpdateSlice:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGetTupleElement:
|
case HloOpcode::kGetTupleElement:
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kImag:
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kInfeed:
|
case HloOpcode::kInfeed:
|
||||||
case HloOpcode::kIota:
|
case HloOpcode::kIota:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
|
@ -2001,6 +2001,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
|||||||
case HloOpcode::kCeil:
|
case HloOpcode::kCeil:
|
||||||
case HloOpcode::kClamp:
|
case HloOpcode::kClamp:
|
||||||
case HloOpcode::kClz:
|
case HloOpcode::kClz:
|
||||||
|
case HloOpcode::kCompare:
|
||||||
case HloOpcode::kComplex:
|
case HloOpcode::kComplex:
|
||||||
case HloOpcode::kConcatenate:
|
case HloOpcode::kConcatenate:
|
||||||
case HloOpcode::kConditional:
|
case HloOpcode::kConditional:
|
||||||
@ -2012,24 +2013,18 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
|||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kDynamicSlice:
|
case HloOpcode::kDynamicSlice:
|
||||||
case HloOpcode::kDynamicUpdateSlice:
|
case HloOpcode::kDynamicUpdateSlice:
|
||||||
case HloOpcode::kEq:
|
|
||||||
case HloOpcode::kExp:
|
case HloOpcode::kExp:
|
||||||
case HloOpcode::kExpm1:
|
case HloOpcode::kExpm1:
|
||||||
case HloOpcode::kFft:
|
case HloOpcode::kFft:
|
||||||
case HloOpcode::kFloor:
|
case HloOpcode::kFloor:
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kImag:
|
case HloOpcode::kImag:
|
||||||
case HloOpcode::kIsFinite:
|
case HloOpcode::kIsFinite:
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLog:
|
case HloOpcode::kLog:
|
||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kMap:
|
case HloOpcode::kMap:
|
||||||
case HloOpcode::kMaximum:
|
case HloOpcode::kMaximum:
|
||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNe:
|
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
|
@ -1084,7 +1084,7 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
|
|||||||
tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
|
tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
|
||||||
counter.1 = s32[] get-tuple-element(tup.1), index=0
|
counter.1 = s32[] get-tuple-element(tup.1), index=0
|
||||||
five = s32[] constant(5)
|
five = s32[] constant(5)
|
||||||
ROOT lt = pred[] less-than(counter.1, five)
|
ROOT lt = pred[] compare(counter.1, five), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
|
body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
|
||||||
|
@ -46,7 +46,7 @@ condition {
|
|||||||
condition.state = f32[] parameter(0)
|
condition.state = f32[] parameter(0)
|
||||||
addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget"
|
addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget"
|
||||||
add = f32[] add(addend, condition.state)
|
add = f32[] add(addend, condition.state)
|
||||||
ROOT greater-than = pred[] greater-than(const.100, add)
|
ROOT greater-than = pred[] compare(const.100, add), direction=GT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY while3 {
|
ENTRY while3 {
|
||||||
|
@ -67,6 +67,7 @@ namespace xla {
|
|||||||
// - WithOneUse: Instruction is used as an operand exactly once.
|
// - WithOneUse: Instruction is used as an operand exactly once.
|
||||||
// - WithOneUser: Instruction is used by exactly one other instruction, but
|
// - WithOneUser: Instruction is used by exactly one other instruction, but
|
||||||
// is possibly used more than once as an operand (e.g. multiply(x,x)).
|
// is possibly used more than once as an operand (e.g. multiply(x,x)).
|
||||||
|
// - WithComparisonDirection: instr has the given direction
|
||||||
//
|
//
|
||||||
// Shape():
|
// Shape():
|
||||||
// - EqualTo
|
// - EqualTo
|
||||||
@ -1671,6 +1672,40 @@ class HloInstructionPatternOneUserImpl
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HloInstructionPatternComparisonDirectionImpl {
|
||||||
|
public:
|
||||||
|
explicit constexpr HloInstructionPatternComparisonDirectionImpl(
|
||||||
|
ComparisonDirection direction)
|
||||||
|
: direction_(direction) {}
|
||||||
|
|
||||||
|
bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
|
||||||
|
return MatchImpl(inst, option);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Match(::xla::HloInstruction* inst, MatchOption option) const {
|
||||||
|
return MatchImpl(inst, option);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DescribeTo(std::ostream* os, int64 indent = 0) const {
|
||||||
|
*os << "which has comparison direction "
|
||||||
|
<< ComparisonDirectionToString(direction_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename HloInstructionType>
|
||||||
|
bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
|
||||||
|
if (inst->opcode() != HloOpcode::kCompare ||
|
||||||
|
inst->comparison_direction() != direction_) {
|
||||||
|
EXPLAIN << "HloInstruction is not comparison "
|
||||||
|
<< ComparisonDirectionToString(direction_);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
ComparisonDirection direction_;
|
||||||
|
};
|
||||||
|
|
||||||
// Matches a constant scalar or effective scalar, optionally with a given value.
|
// Matches a constant scalar or effective scalar, optionally with a given value.
|
||||||
template <typename ScalarTy>
|
template <typename ScalarTy>
|
||||||
class HloConstantScalarImpl {
|
class HloConstantScalarImpl {
|
||||||
@ -1956,6 +1991,14 @@ class HloInstructionPattern {
|
|||||||
return AppendImpl(HloInstructionPatternOneUserImpl());
|
return AppendImpl(HloInstructionPatternOneUserImpl());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Modifies the pattern to match only if the instruction has the given
|
||||||
|
// comparison direction.
|
||||||
|
auto WithComparisonDirection(ComparisonDirection direction) const
|
||||||
|
-> decltype(this->AppendImpl(
|
||||||
|
HloInstructionPatternComparisonDirectionImpl(direction))) {
|
||||||
|
return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
|
||||||
|
}
|
||||||
|
|
||||||
void DescribeTo(std::ostream* os, int64 indent = 0) const {
|
void DescribeTo(std::ostream* os, int64 indent = 0) const {
|
||||||
impl_.DescribeTo(os, indent);
|
impl_.DescribeTo(os, indent);
|
||||||
}
|
}
|
||||||
@ -2118,18 +2161,13 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Add)
|
|||||||
XLA_BINOP_PATTERN(Atan2)
|
XLA_BINOP_PATTERN(Atan2)
|
||||||
XLA_BINOP_PATTERN(Divide)
|
XLA_BINOP_PATTERN(Divide)
|
||||||
XLA_BINOP_PATTERN(Complex)
|
XLA_BINOP_PATTERN(Complex)
|
||||||
|
XLA_BINOP_PATTERN(Compare)
|
||||||
XLA_BINOP_PATTERN(Convolution)
|
XLA_BINOP_PATTERN(Convolution)
|
||||||
XLA_BINOP_PATTERN(Dot)
|
XLA_BINOP_PATTERN(Dot)
|
||||||
XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
|
|
||||||
XLA_BINOP_PATTERN(Gather)
|
XLA_BINOP_PATTERN(Gather)
|
||||||
XLA_BINOP_PATTERN(Ge)
|
|
||||||
XLA_BINOP_PATTERN(Gt)
|
|
||||||
XLA_BINOP_PATTERN(Le)
|
|
||||||
XLA_BINOP_PATTERN(Lt)
|
|
||||||
XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
|
XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
|
||||||
XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
|
XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
|
||||||
XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
|
XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
|
||||||
XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
|
|
||||||
XLA_BINOP_PATTERN(Outfeed)
|
XLA_BINOP_PATTERN(Outfeed)
|
||||||
XLA_BINOP_PATTERN(Pad)
|
XLA_BINOP_PATTERN(Pad)
|
||||||
XLA_BINOP_PATTERN(Power)
|
XLA_BINOP_PATTERN(Power)
|
||||||
@ -2242,6 +2280,73 @@ XLA_VARIADIC_OP_PATTERN(Reduce);
|
|||||||
XLA_VARIADIC_OP_PATTERN(Sort);
|
XLA_VARIADIC_OP_PATTERN(Sort);
|
||||||
XLA_VARIADIC_OP_PATTERN(Tuple);
|
XLA_VARIADIC_OP_PATTERN(Tuple);
|
||||||
|
|
||||||
|
// Helpers for comparison instructions.
|
||||||
|
#define XLA_COMPARE_PATTERN(NAME) \
|
||||||
|
inline auto NAME()->decltype( \
|
||||||
|
Op().WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
|
||||||
|
return Op() \
|
||||||
|
.WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithComparisonDirection(ComparisonDirection::k##NAME); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
template <typename Lhs, typename Rhs> \
|
||||||
|
inline auto NAME(Lhs&& lhs, Rhs&& rhs) \
|
||||||
|
->decltype(Op().WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||||
|
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||||
|
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
|
||||||
|
return Op() \
|
||||||
|
.WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||||
|
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||||
|
.WithComparisonDirection(ComparisonDirection::k##NAME); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
template <typename HloInstructionType, typename Lhs, typename Rhs> \
|
||||||
|
inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
|
||||||
|
->decltype(Op(matched_inst) \
|
||||||
|
.WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||||
|
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||||
|
.WithComparisonDirection(ComparisonDirection::k##NAME)) { \
|
||||||
|
return Op(matched_inst) \
|
||||||
|
.WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithOperand(0, std::forward<Lhs>(lhs)) \
|
||||||
|
.WithOperand(1, std::forward<Rhs>(rhs)) \
|
||||||
|
.WithComparisonDirection(ComparisonDirection::k##NAME); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
|
||||||
|
XLA_COMPARE_PATTERN(NAME) \
|
||||||
|
\
|
||||||
|
template <typename HloInstructionType, typename Lhs, typename Rhs> \
|
||||||
|
inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
|
||||||
|
Rhs&& rhs) \
|
||||||
|
->decltype(Op(matched_inst) \
|
||||||
|
.WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
|
||||||
|
std::forward<Rhs>(rhs))) { \
|
||||||
|
return Op(matched_inst) \
|
||||||
|
.WithOpcode(HloOpcode::kCompare) \
|
||||||
|
.WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
|
||||||
|
std::forward<Rhs>(rhs)); \
|
||||||
|
} \
|
||||||
|
template <typename Lhs, typename Rhs> \
|
||||||
|
inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
|
||||||
|
->decltype(NAME##AnyOrder<const HloInstruction>( \
|
||||||
|
nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \
|
||||||
|
return NAME##AnyOrder<const HloInstruction>( \
|
||||||
|
nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
|
||||||
|
XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
|
||||||
|
XLA_COMPARE_PATTERN(Ge);
|
||||||
|
XLA_COMPARE_PATTERN(Gt);
|
||||||
|
XLA_COMPARE_PATTERN(Le);
|
||||||
|
XLA_COMPARE_PATTERN(Lt);
|
||||||
|
|
||||||
// Helpers for matching non-constant instructions.
|
// Helpers for matching non-constant instructions.
|
||||||
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
|
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
|
||||||
return Op().IsNonConstant();
|
return Op().IsNonConstant();
|
||||||
|
@ -931,5 +931,48 @@ TEST(PatternMatcherTest, OneUseAndOneUser) {
|
|||||||
"in p0 = f32[] parameter(0)");
|
"in p0 = f32[] parameter(0)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(HloMatchersTest, Comparison) {
|
||||||
|
auto shape = ShapeUtil::MakeShape(F32, {1});
|
||||||
|
auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
|
||||||
|
auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
|
||||||
|
auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||||
|
ComparisonDirection::kEq);
|
||||||
|
auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
|
||||||
|
ComparisonDirection::kNe);
|
||||||
|
auto add =
|
||||||
|
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
|
||||||
|
auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
|
||||||
|
ComparisonDirection::kLe);
|
||||||
|
|
||||||
|
EXPECT_TRUE(Match(eq.get(), m::Compare()));
|
||||||
|
EXPECT_TRUE(Match(eq.get(), m::Eq()));
|
||||||
|
EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
|
||||||
|
EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
|
||||||
|
EXPECT_TRUE(Match(ne.get(), m::Compare()));
|
||||||
|
EXPECT_TRUE(Match(ne.get(), m::Ne()));
|
||||||
|
EXPECT_TRUE(Match(
|
||||||
|
le.get(),
|
||||||
|
m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
|
||||||
|
EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
|
||||||
|
m::Add(m::Parameter(0), m::Parameter(1)))));
|
||||||
|
|
||||||
|
EXPECT_FALSE(Match(eq.get(), m::Add()));
|
||||||
|
EXPECT_FALSE(Match(eq.get(), m::Ne()));
|
||||||
|
EXPECT_FALSE(
|
||||||
|
Match(le.get(),
|
||||||
|
m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
|
||||||
|
EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
|
||||||
|
EXPECT_DESC_AND_EXPLANATION(
|
||||||
|
eq, m::Ne().WithOneUser(),
|
||||||
|
"an HloInstruction:\n"
|
||||||
|
" * with opcode compare AND\n"
|
||||||
|
" * which has comparison direction NE AND\n"
|
||||||
|
" * which has exactly one user (but possibly is used "
|
||||||
|
"multiple times by that instruction)",
|
||||||
|
"HloInstruction is not comparison NE\n"
|
||||||
|
"in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
|
||||||
|
"direction=EQ");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -181,8 +181,9 @@ static StatusOr<HloInstruction*> CheckIndexValidity(
|
|||||||
HloInstruction* zero_index =
|
HloInstruction* zero_index =
|
||||||
BroadcastZeros(computation, index->shape().element_type(),
|
BroadcastZeros(computation, index->shape().element_type(),
|
||||||
AsInt64Slice(index->shape().dimensions()));
|
AsInt64Slice(index->shape().dimensions()));
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check,
|
TF_ASSIGN_OR_RETURN(
|
||||||
MakeBinaryHlo(HloOpcode::kLe, zero_index, index));
|
HloInstruction * negative_index_check,
|
||||||
|
MakeCompareHlo(ComparisonDirection::kLe, zero_index, index));
|
||||||
|
|
||||||
// Check if the index is OOB w.r.t. the operand dimensions and window sizes.
|
// Check if the index is OOB w.r.t. the operand dimensions and window sizes.
|
||||||
std::vector<int64> max_valid_index(operand_dims.size());
|
std::vector<int64> max_valid_index(operand_dims.size());
|
||||||
@ -193,9 +194,9 @@ static StatusOr<HloInstruction*> CheckIndexValidity(
|
|||||||
HloInstruction * max_valid_index_constant,
|
HloInstruction * max_valid_index_constant,
|
||||||
MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
|
MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
|
||||||
max_valid_index));
|
max_valid_index));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(HloInstruction * oob_index_check,
|
||||||
HloInstruction * oob_index_check,
|
MakeCompareHlo(ComparisonDirection::kGe,
|
||||||
MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index));
|
max_valid_index_constant, index));
|
||||||
|
|
||||||
// Combine the results of the two checks above.
|
// Combine the results of the two checks above.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
@ -988,12 +988,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
}
|
}
|
||||||
return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
|
return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
|
||||||
broadcast_dimensions);
|
broadcast_dimensions);
|
||||||
case HloOpcode::kEq:
|
case HloOpcode::kCompare: {
|
||||||
case HloOpcode::kGe:
|
|
||||||
case HloOpcode::kGt:
|
|
||||||
case HloOpcode::kLe:
|
|
||||||
case HloOpcode::kLt:
|
|
||||||
case HloOpcode::kNe: {
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
||||||
InferElementwiseBinaryOpShape(opcode, lhs, rhs,
|
InferElementwiseBinaryOpShape(opcode, lhs, rhs,
|
||||||
broadcast_dimensions));
|
broadcast_dimensions));
|
||||||
|
@ -918,55 +918,10 @@ TEST_F(ShapeInferenceTest, InferPowShape) {
|
|||||||
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
|
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
|
TEST_F(ShapeInferenceTest, InferCompareShape) {
|
||||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
||||||
auto inferred_status =
|
auto inferred_status = ShapeInference::InferBinaryOpShape(
|
||||||
ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {});
|
HloOpcode::kCompare, ten_floats, f32_, {});
|
||||||
ASSERT_IS_OK(inferred_status.status());
|
|
||||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
|
||||||
inferred_status.ValueOrDie()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
|
|
||||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
|
||||||
auto inferred_status =
|
|
||||||
ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {});
|
|
||||||
ASSERT_IS_OK(inferred_status.status());
|
|
||||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
|
||||||
inferred_status.ValueOrDie()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
|
|
||||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
|
||||||
auto inferred_status =
|
|
||||||
ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {});
|
|
||||||
ASSERT_IS_OK(inferred_status.status());
|
|
||||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
|
||||||
inferred_status.ValueOrDie()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
|
|
||||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
|
||||||
auto inferred_status =
|
|
||||||
ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {});
|
|
||||||
ASSERT_IS_OK(inferred_status.status());
|
|
||||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
|
||||||
inferred_status.ValueOrDie()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
|
|
||||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
|
||||||
auto inferred_status =
|
|
||||||
ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {});
|
|
||||||
ASSERT_IS_OK(inferred_status.status());
|
|
||||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
|
||||||
inferred_status.ValueOrDie()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
|
|
||||||
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
|
||||||
auto inferred_status =
|
|
||||||
ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {});
|
|
||||||
ASSERT_IS_OK(inferred_status.status());
|
ASSERT_IS_OK(inferred_status.status());
|
||||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
||||||
inferred_status.ValueOrDie()));
|
inferred_status.ValueOrDie()));
|
||||||
|
@ -39,7 +39,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -73,7 +73,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedSortOperandTuple) {
|
|||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
p.2.lhs = u32[] parameter(4)
|
p.2.lhs = u32[] parameter(4)
|
||||||
p.2.rhs = u32[] parameter(5)
|
p.2.rhs = u32[] parameter(5)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -109,7 +109,7 @@ TEST_F(SortSimplifierTest, DontRemoveUnusedSortKey) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -134,7 +134,7 @@ TEST_F(SortSimplifierTest, RemoveUnusedFirstOperand) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.1.lhs, p.1.rhs)
|
ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
|
@ -180,13 +180,13 @@ StatusOr<HloInstruction*> StableSortExpander::ExpandInstruction(
|
|||||||
CHECK_NE(cloned_root, nullptr);
|
CHECK_NE(cloned_root, nullptr);
|
||||||
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
|
Shape scalar_pred = ShapeUtil::MakeShape(PRED, {});
|
||||||
HloInstruction* same =
|
HloInstruction* same =
|
||||||
comparator->AddInstruction(HloInstruction::CreateBinary(
|
comparator->AddInstruction(HloInstruction::CreateCompare(
|
||||||
scalar_pred, HloOpcode::kEq, old_root, cloned_root));
|
scalar_pred, old_root, cloned_root, ComparisonDirection::kEq));
|
||||||
HloInstruction* tie_breaker =
|
HloInstruction* tie_breaker =
|
||||||
comparator->AddInstruction(HloInstruction::CreateBinary(
|
comparator->AddInstruction(HloInstruction::CreateCompare(
|
||||||
scalar_pred, HloOpcode::kLt,
|
scalar_pred, comparator->parameter_instruction(2 * iota_index),
|
||||||
comparator->parameter_instruction(2 * iota_index),
|
comparator->parameter_instruction(2 * iota_index + 1),
|
||||||
comparator->parameter_instruction(2 * iota_index + 1)));
|
ComparisonDirection::kLt));
|
||||||
HloInstruction* new_root =
|
HloInstruction* new_root =
|
||||||
comparator->AddInstruction(HloInstruction::CreateTernary(
|
comparator->AddInstruction(HloInstruction::CreateTernary(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,
|
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kSelect, same, tie_breaker,
|
||||||
|
@ -65,7 +65,8 @@ void CheckComputationHasTieBreaker(const HloInstruction* root,
|
|||||||
// the copied comparison function where the parameters are reversed. Lt() is
|
// the copied comparison function where the parameters are reversed. Lt() is
|
||||||
// the tie breaker comparison using the Iota operand.
|
// the tie breaker comparison using the Iota operand.
|
||||||
ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
|
ASSERT_EQ(root->opcode(), HloOpcode::kSelect);
|
||||||
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kEq);
|
ASSERT_EQ(root->operand(0)->opcode(), HloOpcode::kCompare);
|
||||||
|
ASSERT_EQ(root->operand(0)->comparison_direction(), ComparisonDirection::kEq);
|
||||||
|
|
||||||
// Check that the tie breaker instruction is correct.
|
// Check that the tie breaker instruction is correct.
|
||||||
EXPECT_THAT(root->operand(1),
|
EXPECT_THAT(root->operand(1),
|
||||||
@ -88,7 +89,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortReuseIotaOperand) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -126,15 +127,15 @@ TEST_F(StableSortExpanderTest,
|
|||||||
lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
|
lhs.unsigned = u32[] bitcast-convert(p.0.lhs)
|
||||||
lhs.flipped = u32[] subtract(max, lhs.unsigned)
|
lhs.flipped = u32[] subtract(max, lhs.unsigned)
|
||||||
lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
|
lhs.flipped.signed = s32[] bitcast-convert(lhs.flipped)
|
||||||
lhs.is_negative = pred[] less-than(lhs.flipped.signed, zero)
|
lhs.is_negative = pred[] compare(lhs.flipped.signed, zero), direction=LT
|
||||||
lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
|
lhs.converted = s32[] select(lhs.is_negative, lhs.flipped.signed, lhs.signed)
|
||||||
rhs.signed = s32[] bitcast-convert(p.0.rhs)
|
rhs.signed = s32[] bitcast-convert(p.0.rhs)
|
||||||
rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
|
rhs.unsigned = u32[] bitcast-convert(p.0.rhs)
|
||||||
rhs.flipped = u32[] subtract(max, rhs.unsigned)
|
rhs.flipped = u32[] subtract(max, rhs.unsigned)
|
||||||
rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
|
rhs.flipped.signed = s32[] bitcast-convert(rhs.flipped)
|
||||||
rhs.is_negative = pred[] less-than(rhs.flipped.signed, zero)
|
rhs.is_negative = pred[] compare(rhs.flipped.signed, zero), direction=LT
|
||||||
rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
|
rhs.converted = s32[] select(rhs.is_negative, rhs.flipped.signed, rhs.signed)
|
||||||
ROOT lt = pred[] less-than(lhs.converted, rhs.converted)
|
ROOT lt = pred[] compare(lhs.converted, rhs.converted), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -165,7 +166,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortAddIotaOperandAndChangeRoot) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -200,7 +201,7 @@ TEST_F(StableSortExpanderTest, HonorIsStableFlag) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -227,7 +228,7 @@ TEST_F(StableSortExpanderTest,
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -264,7 +265,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortDontReuseIotaOperandWrongType) {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = f32[] parameter(2)
|
p.1.lhs = f32[] parameter(2)
|
||||||
p.1.rhs = f32[] parameter(3)
|
p.1.rhs = f32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -302,7 +303,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1) {
|
|||||||
mask = s32[] constant(65535)
|
mask = s32[] constant(65535)
|
||||||
lhs = s32[] and(p.0.lhs, mask)
|
lhs = s32[] and(p.0.lhs, mask)
|
||||||
rhs = s32[] and(p.0.rhs, mask)
|
rhs = s32[] and(p.0.rhs, mask)
|
||||||
ROOT lt = pred[] less-than(lhs, rhs)
|
ROOT lt = pred[] compare(lhs, rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
@ -332,7 +333,7 @@ TEST_F(StableSortExpanderTest, StabilizeSortR1NoRoot) {
|
|||||||
mask = s32[] constant(65535)
|
mask = s32[] constant(65535)
|
||||||
lhs = s32[] and(p.0.lhs, mask)
|
lhs = s32[] and(p.0.lhs, mask)
|
||||||
rhs = s32[] and(p.0.rhs, mask)
|
rhs = s32[] and(p.0.rhs, mask)
|
||||||
ROOT lt = pred[] less-than(lhs, rhs)
|
ROOT lt = pred[] compare(lhs, rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY sort_computation {
|
ENTRY sort_computation {
|
||||||
|
@ -934,8 +934,8 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
|
|||||||
HloInstruction::CreateParameter(0, in_shape, "param0"));
|
HloInstruction::CreateParameter(0, in_shape, "param0"));
|
||||||
auto param1 = builder.AddInstruction(
|
auto param1 = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(1, in_shape, "param1"));
|
HloInstruction::CreateParameter(1, in_shape, "param1"));
|
||||||
auto result = builder.AddInstruction(
|
auto result = builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1));
|
out_shape, param0, param1, ComparisonDirection::kEq));
|
||||||
|
|
||||||
BuildModuleAndRunAnalysis(builder.Build());
|
BuildModuleAndRunAnalysis(builder.Build());
|
||||||
|
|
||||||
@ -1185,8 +1185,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
|
|||||||
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
auto builder = HloComputation::Builder(TestName() + ".Cond");
|
||||||
auto data = builder.AddInstruction(
|
auto data = builder.AddInstruction(
|
||||||
HloInstruction::CreateParameter(0, data_shape, "data"));
|
HloInstruction::CreateParameter(0, data_shape, "data"));
|
||||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
builder.AddInstruction(HloInstruction::CreateCompare(
|
||||||
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data));
|
ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
|
||||||
return builder.Build();
|
return builder.Build();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -286,7 +286,7 @@ static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
|
|||||||
// Handle `i = K; i < N; ++i`.
|
// Handle `i = K; i < N; ++i`.
|
||||||
if (Match(while_cond_root,
|
if (Match(while_cond_root,
|
||||||
m::Op()
|
m::Op()
|
||||||
.WithOpcode(HloOpcode::kLt)
|
.WithComparisonDirection(ComparisonDirection::kLt)
|
||||||
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
|
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
|
||||||
VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
|
VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
|
||||||
<< while_cond_root->ToString();
|
<< while_cond_root->ToString();
|
||||||
@ -303,7 +303,7 @@ static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
|
|||||||
// Handle `i = K; i <= N; ++i`.
|
// Handle `i = K; i <= N; ++i`.
|
||||||
if (Match(while_cond_root,
|
if (Match(while_cond_root,
|
||||||
m::Op()
|
m::Op()
|
||||||
.WithOpcode(HloOpcode::kLe)
|
.WithComparisonDirection(ComparisonDirection::kLe)
|
||||||
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
|
.WithOperand(0, m::Op().Is(while_cond_indvar)))) {
|
||||||
VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
|
VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
|
||||||
<< while_cond_root->ToString();
|
<< while_cond_root->ToString();
|
||||||
|
@ -40,7 +40,7 @@ TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) {
|
|||||||
p_cond = (f32[2], s32[]) parameter(0)
|
p_cond = (f32[2], s32[]) parameter(0)
|
||||||
gte = s32[] get-tuple-element(p_cond), index=1
|
gte = s32[] get-tuple-element(p_cond), index=1
|
||||||
const = s32[] constant(42)
|
const = s32[] constant(42)
|
||||||
ROOT result = pred[] equal-to(gte, const)
|
ROOT result = pred[] compare(gte, const), direction=EQ
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
@ -71,7 +71,7 @@ TEST_F(WhileLoopAnalysisTest, NoUpperBound) {
|
|||||||
p_cond = (f32[2], s32[]) parameter(0)
|
p_cond = (f32[2], s32[]) parameter(0)
|
||||||
gte = s32[] get-tuple-element(p_cond), index=1
|
gte = s32[] get-tuple-element(p_cond), index=1
|
||||||
const = s32[] constant(42)
|
const = s32[] constant(42)
|
||||||
ROOT result = pred[] equal-to(gte, const)
|
ROOT result = pred[] compare(gte, const), direction=EQ
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
@ -104,7 +104,7 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) {
|
|||||||
p_cond = (f32[2], s32[]) parameter(0)
|
p_cond = (f32[2], s32[]) parameter(0)
|
||||||
gte = s32[] get-tuple-element(p_cond), index=1
|
gte = s32[] get-tuple-element(p_cond), index=1
|
||||||
const = s32[] constant(42)
|
const = s32[] constant(42)
|
||||||
ROOT result = pred[] less-than(gte, const)
|
ROOT result = pred[] compare(gte, const), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
|
@ -260,7 +260,7 @@ condition {
|
|||||||
p_cond = (f32[],f32[]) parameter(0)
|
p_cond = (f32[],f32[]) parameter(0)
|
||||||
p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0
|
p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0
|
||||||
p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1
|
p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1
|
||||||
ROOT result = pred[] less-than(p_cond.0, p_cond.1)
|
ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
@ -300,7 +300,7 @@ condition {
|
|||||||
p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0
|
p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0
|
||||||
p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1
|
p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1
|
||||||
p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1
|
p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1
|
||||||
ROOT result = pred[] less-than(p_c.0, p_c.1.1)
|
ROOT result = pred[] compare(p_c.0, p_c.1.1), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
@ -342,7 +342,7 @@ condition {
|
|||||||
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
|
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
|
||||||
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
|
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
|
||||||
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
||||||
ROOT result = pred[] less-than(p_cond.0, p_cond.1)
|
ROOT result = pred[] compare(p_cond.0, p_cond.1), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
@ -389,10 +389,10 @@ condition {
|
|||||||
p_cond = (f32[],f32[],f32[]) parameter(0)
|
p_cond = (f32[],f32[],f32[]) parameter(0)
|
||||||
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
|
p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0
|
||||||
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
||||||
lt.0 = pred[] less-than(p_cond.0, p_cond.2)
|
lt.0 = pred[] compare(p_cond.0, p_cond.2), direction=LT
|
||||||
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
|
p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1
|
||||||
p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2
|
||||||
lt.1 = pred[] less-than(p_cond.1, p_cond.2.c)
|
lt.1 = pred[] compare(p_cond.1, p_cond.2.c), direction=LT
|
||||||
ROOT result = pred[] and(lt.0, lt.1)
|
ROOT result = pred[] and(lt.0, lt.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -556,7 +556,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) {
|
|||||||
p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0)
|
p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0)
|
||||||
gte = s32[] get-tuple-element(p_cond), index=3
|
gte = s32[] get-tuple-element(p_cond), index=3
|
||||||
const = s32[] constant(42)
|
const = s32[] constant(42)
|
||||||
ROOT result = pred[] equal-to(gte, const)
|
ROOT result = pred[] compare(gte, const), direction=EQ
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
|
@ -72,7 +72,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) {
|
|||||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.2 = s32[] constant({{LOOP_BOUND}})
|
constant.2 = s32[] constant({{LOOP_BOUND}})
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(42)
|
constant.3 = s32[] constant(42)
|
||||||
@ -107,7 +107,7 @@ WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound(
|
|||||||
loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
|
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4)
|
ROOT less-than = pred[] compare(get-tuple-element.3, get-tuple-element.4), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoopWithIndirectLoopBound {
|
ENTRY SimpleLoopWithIndirectLoopBound {
|
||||||
constant.3 = s32[] constant(42)
|
constant.3 = s32[] constant(42)
|
||||||
@ -237,7 +237,7 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) {
|
|||||||
NonTupleShapedLoop.condition {
|
NonTupleShapedLoop.condition {
|
||||||
loop_var = s32[] parameter(0)
|
loop_var = s32[] parameter(0)
|
||||||
constant = s32[] constant(100)
|
constant = s32[] constant(100)
|
||||||
ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant)
|
ROOT less-than = pred[] compare(s32[] loop_var, s32[] constant), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY INonTupleShapedLoop {
|
ENTRY INonTupleShapedLoop {
|
||||||
constant.2 = s32[] constant(42)
|
constant.2 = s32[] constant(42)
|
||||||
@ -387,7 +387,7 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
|
|||||||
param0 = (s32[], s32[], s32[]) parameter(0)
|
param0 = (s32[], s32[], s32[]) parameter(0)
|
||||||
get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0),
|
get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0),
|
||||||
index=2
|
index=2
|
||||||
ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element)
|
ROOT equal-to = pred[] compare(s32[] constant.2, s32[] get-tuple-element), direction=EQ
|
||||||
}
|
}
|
||||||
ENTRY RemoveUnusedOperands {
|
ENTRY RemoveUnusedOperands {
|
||||||
x = s32[] parameter(0)
|
x = s32[] parameter(0)
|
||||||
@ -471,7 +471,7 @@ TEST_F(WhileLoopSimplifierTest,
|
|||||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.2 = s32[] constant(44)
|
constant.2 = s32[] constant(44)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(42)
|
constant.3 = s32[] constant(42)
|
||||||
@ -503,7 +503,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) {
|
|||||||
loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
|
loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0)
|
||||||
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
|
get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0
|
||||||
constant.2 = s32[] constant(47)
|
constant.2 = s32[] constant(47)
|
||||||
ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2)
|
ROOT less-than = pred[] compare(get-tuple-element.4, constant.2), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY SimpleLoop {
|
ENTRY SimpleLoop {
|
||||||
constant.3 = s32[] constant(42)
|
constant.3 = s32[] constant(42)
|
||||||
@ -679,7 +679,7 @@ const char* const kSimpleMergeInductionVariablesModule = R"(
|
|||||||
b = TYPE[] get-tuple-element(param), index=1
|
b = TYPE[] get-tuple-element(param), index=1
|
||||||
sum = TYPE[] power(a, b)
|
sum = TYPE[] power(a, b)
|
||||||
ten = TYPE[] constant(10)
|
ten = TYPE[] constant(10)
|
||||||
ROOT cond = pred[] less-than(sum, ten)
|
ROOT cond = pred[] compare(sum, ten), direction=LT
|
||||||
}
|
}
|
||||||
ENTRY Loop {
|
ENTRY Loop {
|
||||||
a = TYPE[] constant(10)
|
a = TYPE[] constant(10)
|
||||||
|
@ -41,7 +41,7 @@ TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) {
|
|||||||
param = (s32[]) parameter(0)
|
param = (s32[]) parameter(0)
|
||||||
i = s32[] get-tuple-element(param), index=0
|
i = s32[] get-tuple-element(param), index=0
|
||||||
trip_count = s32[] constant(10)
|
trip_count = s32[] constant(10)
|
||||||
ROOT done = pred[] less-than(i, trip_count)
|
ROOT done = pred[] compare(i, trip_count), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY test {
|
ENTRY test {
|
||||||
@ -77,7 +77,7 @@ TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) {
|
|||||||
param = (s32[]) parameter(0)
|
param = (s32[]) parameter(0)
|
||||||
i = s32[] get-tuple-element(param), index=0
|
i = s32[] get-tuple-element(param), index=0
|
||||||
trip_count = s32[] constant(1000000)
|
trip_count = s32[] constant(1000000)
|
||||||
ROOT done = pred[] less-than(i, trip_count)
|
ROOT done = pred[] compare(i, trip_count), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY test {
|
ENTRY test {
|
||||||
@ -113,7 +113,7 @@ TEST_F(TripCountAnnotatorTest, NonzeroStart) {
|
|||||||
param = (s32[]) parameter(0)
|
param = (s32[]) parameter(0)
|
||||||
i = s32[] get-tuple-element(param), index=0
|
i = s32[] get-tuple-element(param), index=0
|
||||||
trip_count = s32[] constant(1000000)
|
trip_count = s32[] constant(1000000)
|
||||||
ROOT done = pred[] less-than(i, trip_count)
|
ROOT done = pred[] compare(i, trip_count), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY test {
|
ENTRY test {
|
||||||
@ -149,7 +149,7 @@ TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) {
|
|||||||
param = (s32[]) parameter(0)
|
param = (s32[]) parameter(0)
|
||||||
i = s32[] get-tuple-element(param), index=0
|
i = s32[] get-tuple-element(param), index=0
|
||||||
trip_count = s32[] constant(1000000)
|
trip_count = s32[] constant(1000000)
|
||||||
ROOT done = pred[] less-than-or-equal-to(i, trip_count)
|
ROOT done = pred[] compare(i, trip_count), direction=LE
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY test {
|
ENTRY test {
|
||||||
@ -188,7 +188,7 @@ TEST_F(TripCountAnnotatorTest, Int64Overflow) {
|
|||||||
param = (s64[]) parameter(0)
|
param = (s64[]) parameter(0)
|
||||||
i = s64[] get-tuple-element(param), index=0
|
i = s64[] get-tuple-element(param), index=0
|
||||||
trip_count = s64[] constant(9223372036854775807) // 2^63-1
|
trip_count = s64[] constant(9223372036854775807) // 2^63-1
|
||||||
ROOT done = pred[] less-than-or-equal-to(i, trip_count)
|
ROOT done = pred[] compare(i, trip_count), direction=LE
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY test {
|
ENTRY test {
|
||||||
|
@ -166,7 +166,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HloInstruction * compare,
|
HloInstruction * compare,
|
||||||
MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant));
|
MakeCompareHlo(ComparisonDirection::kLt, indvar, trip_count_constant));
|
||||||
cond_computation->set_root_instruction(compare);
|
cond_computation->set_root_instruction(compare);
|
||||||
return std::move(cond_computation);
|
return std::move(cond_computation);
|
||||||
}
|
}
|
||||||
|
@ -63,7 +63,11 @@ const float test_float_vals[3][test_width][test_height] = {
|
|||||||
class FusionTest : public HloTestBase {
|
class FusionTest : public HloTestBase {
|
||||||
protected:
|
protected:
|
||||||
template <typename T, int Arity>
|
template <typename T, int Arity>
|
||||||
void TestElementwise2D(HloOpcode opcode) {
|
void TestElementwise2D(
|
||||||
|
HloOpcode opcode,
|
||||||
|
absl::optional<ComparisonDirection> direction = absl::nullopt) {
|
||||||
|
// Create a variable for comparisons since they require the direction.
|
||||||
|
bool is_compare = std::is_same<T, bool>::value;
|
||||||
Array2D<float> operand_data[Arity];
|
Array2D<float> operand_data[Arity];
|
||||||
for (int i = 0; i < Arity; ++i) {
|
for (int i = 0; i < Arity; ++i) {
|
||||||
new (&operand_data[i]) Array2D<float>(test_width, test_height);
|
new (&operand_data[i]) Array2D<float>(test_width, test_height);
|
||||||
@ -76,7 +80,11 @@ class FusionTest : public HloTestBase {
|
|||||||
xs[k] = test_float_vals[k][i][j];
|
xs[k] = test_float_vals[k][i][j];
|
||||||
operand_data[k](i, j) = xs[k];
|
operand_data[k](i, j) = xs[k];
|
||||||
}
|
}
|
||||||
answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
|
if (is_compare) {
|
||||||
|
answer_data(i, j) = ComputeElementwiseAnswerCompare(*direction, xs);
|
||||||
|
} else {
|
||||||
|
answer_data(i, j) = ComputeElementwiseAnswerFloat(opcode, xs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,8 +106,13 @@ class FusionTest : public HloTestBase {
|
|||||||
root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
|
root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
|
if (is_compare) {
|
||||||
hlos[2]);
|
root_hlo = HloInstruction::CreateCompare(answer_shape, hlos[1],
|
||||||
|
hlos[2], *direction);
|
||||||
|
} else {
|
||||||
|
root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
|
||||||
|
hlos[2]);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
|
root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
|
||||||
@ -124,13 +137,14 @@ class FusionTest : public HloTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T>
|
float ComputeElementwiseAnswerFloat(HloOpcode opcode,
|
||||||
T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span<const float> xs);
|
absl::Span<const float> xs);
|
||||||
|
bool ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
||||||
|
absl::Span<const float> xs);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode,
|
||||||
float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
|
absl::Span<const float> xs) {
|
||||||
absl::Span<const float> xs) {
|
|
||||||
switch (opcode) {
|
switch (opcode) {
|
||||||
case HloOpcode::kAdd:
|
case HloOpcode::kAdd:
|
||||||
return xs[0] + xs[1];
|
return xs[0] + xs[1];
|
||||||
@ -153,24 +167,21 @@ float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction,
|
||||||
bool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode,
|
absl::Span<const float> xs) {
|
||||||
absl::Span<const float> xs) {
|
switch (direction) {
|
||||||
switch (opcode) {
|
case ComparisonDirection::kEq:
|
||||||
case HloOpcode::kEq:
|
|
||||||
return xs[0] == xs[1];
|
return xs[0] == xs[1];
|
||||||
case HloOpcode::kNe:
|
case ComparisonDirection::kNe:
|
||||||
return xs[0] != xs[1];
|
return xs[0] != xs[1];
|
||||||
case HloOpcode::kGt:
|
case ComparisonDirection::kGt:
|
||||||
return xs[0] > xs[1];
|
return xs[0] > xs[1];
|
||||||
case HloOpcode::kLt:
|
case ComparisonDirection::kLt:
|
||||||
return xs[0] < xs[1];
|
return xs[0] < xs[1];
|
||||||
case HloOpcode::kGe:
|
case ComparisonDirection::kGe:
|
||||||
return xs[0] >= xs[1];
|
return xs[0] >= xs[1];
|
||||||
case HloOpcode::kLe:
|
case ComparisonDirection::kLe:
|
||||||
return xs[0] <= xs[1];
|
return xs[0] <= xs[1];
|
||||||
default:
|
|
||||||
LOG(FATAL) << "No comparatory opcode: " << opcode;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -740,24 +751,28 @@ XLA_TEST_F(FusionTest, Maximum2D) {
|
|||||||
TestElementwise2D<float, 2>(HloOpcode::kMaximum);
|
TestElementwise2D<float, 2>(HloOpcode::kMaximum);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<bool, 2>(HloOpcode::kEq); }
|
XLA_TEST_F(FusionTest, Equal2D) {
|
||||||
|
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kEq);
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Inequal2D) {
|
XLA_TEST_F(FusionTest, Inequal2D) {
|
||||||
TestElementwise2D<bool, 2>(HloOpcode::kNe);
|
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kNe);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Greater2D) {
|
XLA_TEST_F(FusionTest, Greater2D) {
|
||||||
TestElementwise2D<bool, 2>(HloOpcode::kGt);
|
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGt);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D<bool, 2>(HloOpcode::kLt); }
|
XLA_TEST_F(FusionTest, Lesser2D) {
|
||||||
|
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLt);
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
|
XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
|
||||||
TestElementwise2D<bool, 2>(HloOpcode::kGe);
|
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kGe);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, LesserOrEqual2D) {
|
XLA_TEST_F(FusionTest, LesserOrEqual2D) {
|
||||||
TestElementwise2D<bool, 2>(HloOpcode::kLe);
|
TestElementwise2D<bool, 2>(HloOpcode::kCompare, ComparisonDirection::kLe);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(FusionTest, Clamp2D) {
|
XLA_TEST_F(FusionTest, Clamp2D) {
|
||||||
|
@ -227,7 +227,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
|
|||||||
fused_computation {
|
fused_computation {
|
||||||
p = f32[4] parameter(0)
|
p = f32[4] parameter(0)
|
||||||
multiply = f32[4] multiply(p, p)
|
multiply = f32[4] multiply(p, p)
|
||||||
less-than = pred[4] less-than(p, multiply)
|
less-than = pred[4] compare(p, multiply), direction=LT
|
||||||
ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply)
|
ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -252,7 +252,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
|
|||||||
fused_computation {
|
fused_computation {
|
||||||
p = f32[] parameter(0)
|
p = f32[] parameter(0)
|
||||||
multiply = f32[] multiply(p, p)
|
multiply = f32[] multiply(p, p)
|
||||||
less-than = pred[] less-than(p, multiply)
|
less-than = pred[] compare(p, multiply), direction=LT
|
||||||
ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
|
ROOT tuple = (pred[], f32[]) tuple(less-than, multiply)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ compare {
|
|||||||
p.0.rhs = f32[] parameter(1)
|
p.0.rhs = f32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) {
|
ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) {
|
||||||
@ -174,7 +174,7 @@ compare {
|
|||||||
p.0.rhs = s32[] parameter(1)
|
p.0.rhs = s32[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) {
|
ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) {
|
||||||
@ -205,7 +205,7 @@ compare {
|
|||||||
p.0.rhs = bf16[] parameter(1)
|
p.0.rhs = bf16[] parameter(1)
|
||||||
p.1.lhs = s32[] parameter(2)
|
p.1.lhs = s32[] parameter(2)
|
||||||
p.1.rhs = s32[] parameter(3)
|
p.1.rhs = s32[] parameter(3)
|
||||||
ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
|
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) {
|
ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) {
|
||||||
|
@ -129,7 +129,7 @@ HloModule TokenInWhileLoop
|
|||||||
%param = (s32[], token[]) parameter(0)
|
%param = (s32[], token[]) parameter(0)
|
||||||
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
||||||
%constant = s32[] constant(42)
|
%constant = s32[] constant(42)
|
||||||
ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
|
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
||||||
}
|
}
|
||||||
|
|
||||||
ENTRY %TokenInWhileLoop () -> s32[] {
|
ENTRY %TokenInWhileLoop () -> s32[] {
|
||||||
|
Loading…
Reference in New Issue
Block a user