Automated g4 rollback of changelist 157174708
PiperOrigin-RevId: 157253080
This commit is contained in:
parent
70313342b6
commit
2ff1d7bf04
@ -1071,7 +1071,6 @@ template <typename NativeT>
|
||||
ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions,
|
||||
stride_config.step, init_function);
|
||||
} else {
|
||||
// For scalars.
|
||||
data.at(0) = generator({});
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -807,9 +807,7 @@ TEST_F(LiteralUtilTest, Populate) {
|
||||
std::vector<int64> layout;
|
||||
} populate_data[] = {
|
||||
{{}, {}},
|
||||
{{0}, {0}},
|
||||
{{16}, {0}},
|
||||
{{2, 0}, {1, 0}},
|
||||
{{4, 16}, {1, 0}},
|
||||
{{21, 12}, {0, 1}},
|
||||
{{6, 11, 17}, {2, 0, 1}},
|
||||
|
@ -106,12 +106,10 @@ cc_test(
|
||||
":hlo_evaluator",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
@ -1461,9 +1459,7 @@ cc_library(
|
||||
hdrs = ["hlo_constant_folding.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_evaluator",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
@ -24,57 +24,230 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
|
||||
static std::unique_ptr<Literal> ConvertIfTypesMatch(
|
||||
const Literal& src_literal) {
|
||||
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
|
||||
return LiteralUtil::Convert<
|
||||
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
|
||||
typename primitive_util::PrimitiveTypeToNative<
|
||||
primitive_dest_type>::type>(src_literal);
|
||||
}
|
||||
|
||||
template <PrimitiveType primitive_src_type>
|
||||
static std::unique_ptr<Literal> ConvertIfDestTypeMatches(
|
||||
const Literal& src_literal, PrimitiveType primitive_dest_type) {
|
||||
switch (primitive_dest_type) {
|
||||
#define CONVERT_IF_TYPES_MATCH(type) \
|
||||
case (type): \
|
||||
return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
|
||||
CONVERT_IF_TYPES_MATCH(PRED)
|
||||
CONVERT_IF_TYPES_MATCH(S8)
|
||||
CONVERT_IF_TYPES_MATCH(S32)
|
||||
CONVERT_IF_TYPES_MATCH(S64)
|
||||
CONVERT_IF_TYPES_MATCH(U8)
|
||||
CONVERT_IF_TYPES_MATCH(U32)
|
||||
CONVERT_IF_TYPES_MATCH(U64)
|
||||
CONVERT_IF_TYPES_MATCH(F32)
|
||||
CONVERT_IF_TYPES_MATCH(F64)
|
||||
#undef CONVERT_IF_TYPES_MATCH
|
||||
// Other types are not yet supported.
|
||||
default:
|
||||
LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
|
||||
<< PrimitiveType_Name(src_literal.shape().element_type());
|
||||
}
|
||||
}
|
||||
|
||||
static std::unique_ptr<Literal> ConvertIfSrcTypeMatches(
|
||||
const Literal& src_literal, PrimitiveType primitive_dest_type) {
|
||||
switch (src_literal.shape().element_type()) {
|
||||
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
|
||||
case (type): \
|
||||
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(S8)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(S32)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(S64)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(U8)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(U32)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(U64)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(F32)
|
||||
CONVERT_IF_DEST_TYPE_MATCHES(F64)
|
||||
#undef CONVERT_IF_DEST_TYPE_MATCHES
|
||||
// Other types are not yet supported.
|
||||
default:
|
||||
LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
|
||||
<< PrimitiveType_Name(src_literal.shape().element_type());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ConstantFolderVisitor traverses the HLO computation and reduces certain
|
||||
// constant graph sections, to literals.
|
||||
class ConstantFolderVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
// Default visitor action is to do nothing and return OK.
|
||||
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
|
||||
|
||||
Status HandleConvert(HloInstruction* convert,
|
||||
HloInstruction* operand) override;
|
||||
|
||||
Status HandleReshape(HloInstruction* reshape) override;
|
||||
|
||||
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
|
||||
|
||||
Status HandleTranspose(HloInstruction* transpose) override;
|
||||
|
||||
// Returns whether a constant folding operation has occurred.
|
||||
const bool changed() const { return changed_; }
|
||||
|
||||
// Runs the visitor on a computation and returns whether any changes were
|
||||
// performed.
|
||||
static StatusOr<bool> Run(HloComputation* computation);
|
||||
|
||||
private:
|
||||
ConstantFolderVisitor() = default;
|
||||
|
||||
// Replaces the existing HLO instruction old_instruction, with a literal,
|
||||
// and marks the optimizer status as changed.
|
||||
// Returns the Status representing the result of the replace operation.
|
||||
Status ReplaceWithConstant(HloInstruction* old_instruction,
|
||||
std::unique_ptr<Literal> literal) {
|
||||
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
|
||||
old_instruction, HloInstruction::CreateConstant(std::move(literal))));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Whether any constant folding operations have occurred.
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
StatusOr<bool> ConstantFolderVisitor::Run(HloComputation* computation) {
|
||||
ConstantFolderVisitor visitor;
|
||||
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
|
||||
return visitor.changed();
|
||||
}
|
||||
|
||||
StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
||||
auto evaluator = MakeUnique<HloEvaluator>();
|
||||
|
||||
XLA_VLOG_LINES(2,
|
||||
"HloConstantFolding::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
|
||||
for (auto& computation : module->computations()) {
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
// Skip dead code.
|
||||
if (instruction->user_count() == 0 &&
|
||||
computation->root_instruction() != instruction) {
|
||||
continue;
|
||||
}
|
||||
// Skip Constant and Parameter operation.
|
||||
if (instruction->opcode() == HloOpcode::kParameter ||
|
||||
instruction->opcode() == HloOpcode::kConstant) {
|
||||
continue;
|
||||
}
|
||||
// Skip instructions with non-constant operands.
|
||||
if (!hlo_query::AllOperandsAreConstants(*instruction)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
|
||||
// Currently we skip unimplemented operations.
|
||||
// TODO(b/35975797): Fold constant computations for more operations.
|
||||
if (result == nullptr) {
|
||||
VLOG(2) << "Constant folding failed for instruction: "
|
||||
<< instruction->ToString();
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
|
||||
instruction, HloInstruction::CreateConstant(std::move(result))));
|
||||
changed = true;
|
||||
}
|
||||
for (auto& comp : module->computations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get()));
|
||||
changed = changed || result;
|
||||
}
|
||||
XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
|
||||
Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) {
|
||||
if (reshape->operand(0)->opcode() == HloOpcode::kConstant) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto reshaped_literal,
|
||||
LiteralUtil::Reshape(reshape->operand(0)->literal(),
|
||||
AsInt64Slice(reshape->shape().dimensions())));
|
||||
return ReplaceWithConstant(reshape, std::move(reshaped_literal));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) {
|
||||
if (transpose->operand(0)->opcode() == HloOpcode::kConstant) {
|
||||
auto transposed_literal = LiteralUtil::Transpose(
|
||||
transpose->operand(0)->literal(), transpose->dimensions());
|
||||
return ReplaceWithConstant(transpose, std::move(transposed_literal));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolderVisitor::HandleConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
||||
if (operands[0]->opcode() == HloOpcode::kConstant) {
|
||||
// If all the operands of a concatenate are constant, fold them into a
|
||||
// single constant tensor.
|
||||
// The result concatenate dimension is going to be the sum of all the
|
||||
// concatenate dimensions of the arrays taking part of the operation.
|
||||
int64 concat_dim = concatenate->dimensions()[0];
|
||||
const Shape& reference_shape = operands[0]->shape();
|
||||
CHECK(!ShapeUtil::IsTuple(reference_shape));
|
||||
int64 rank = ShapeUtil::Rank(reference_shape);
|
||||
std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
|
||||
reference_shape.dimensions().end());
|
||||
if (concat_dim < 0) {
|
||||
concat_dim += rank;
|
||||
}
|
||||
for (int64 i = 1; i < operands.size(); ++i) {
|
||||
const Shape& operand_shape = operands[i]->shape();
|
||||
CHECK(!ShapeUtil::IsTuple(operand_shape));
|
||||
if (operands[i]->opcode() != HloOpcode::kConstant) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Accumulate the concat dimension from all tensors taking part to the
|
||||
// operation.
|
||||
concat_dimensions[concat_dim] +=
|
||||
ShapeUtil::GetDimension(operand_shape, concat_dim);
|
||||
}
|
||||
|
||||
auto literal = LiteralUtil::CreateFromDimensions(
|
||||
reference_shape.element_type(), concat_dimensions);
|
||||
std::vector<int64> source_indices(rank, 0);
|
||||
std::vector<int64> dest_indices(concat_dimensions.size(), 0);
|
||||
for (auto operand : operands) {
|
||||
const Shape& operand_shape = operand->shape();
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
|
||||
operand->literal(), source_indices, literal.get(), dest_indices,
|
||||
AsInt64Slice(operand_shape.dimensions())));
|
||||
dest_indices[concat_dim] +=
|
||||
ShapeUtil::GetDimension(operand_shape, concat_dim);
|
||||
}
|
||||
return ReplaceWithConstant(concatenate, std::move(literal));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice,
|
||||
HloInstruction* operand) {
|
||||
if (operand->opcode() == HloOpcode::kConstant) {
|
||||
const Shape& shape = slice->shape();
|
||||
auto literal = LiteralUtil::CreateFromDimensions(
|
||||
shape.element_type(), AsInt64Slice(shape.dimensions()));
|
||||
std::vector<int64> dest_indices(slice->slice_starts().size(), 0);
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
|
||||
operand->literal(), slice->slice_starts(), literal.get(), dest_indices,
|
||||
AsInt64Slice(shape.dimensions())));
|
||||
TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert,
|
||||
HloInstruction* operand) {
|
||||
if (operand->opcode() == HloOpcode::kConstant) {
|
||||
const Literal& src_literal = operand->literal();
|
||||
std::unique_ptr<Literal> new_constant =
|
||||
ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type());
|
||||
return ReplaceWithConstant(convert, std::move(new_constant));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -46,89 +46,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename OperandT>
|
||||
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
|
||||
const Literal& lhs_literal,
|
||||
const Literal& rhs_literal) {
|
||||
std::function<bool(OperandT, OperandT)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el == rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kNe:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el != rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kGe:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el >= rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kGt:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el > rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kLe:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el <= rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kLt:
|
||||
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
|
||||
return lhs_el < rhs_el;
|
||||
};
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
|
||||
auto result = LiteralUtil::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Populate<bool>(
|
||||
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return compare_op(LiteralUtil::Get<OperandT>(lhs_literal, multi_index),
|
||||
LiteralUtil::Get<OperandT>(rhs_literal, multi_index));
|
||||
}));
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
template <typename ReturnT, typename NativeT>
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
|
||||
HloInstruction* instruction,
|
||||
const std::function<ReturnT(NativeT)>& unary_op,
|
||||
const Literal& operand_literal) {
|
||||
const auto shape = instruction->shape();
|
||||
const auto* operand = instruction->operand(0);
|
||||
|
||||
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
|
||||
// removed.
|
||||
if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
|
||||
return Unimplemented(
|
||||
"Implicit broadcasting is currently unsupported in HLO evaluator "
|
||||
"Shape Mismatch: %s vs %s",
|
||||
ShapeUtil::HumanString(shape).c_str(),
|
||||
ShapeUtil::HumanString(operand->shape()).c_str());
|
||||
}
|
||||
|
||||
auto result = LiteralUtil::CreateFromShape(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
|
||||
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return unary_op(
|
||||
LiteralUtil::Get<NativeT>(operand_literal, multi_index));
|
||||
}));
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename ReturnT>
|
||||
class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
@ -151,7 +68,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return elem_operand;
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename NativeT,
|
||||
@ -162,7 +79,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return std::abs(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override {
|
||||
return HandleAbs<ReturnT>(abs, operand);
|
||||
@ -184,45 +101,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
template <PrimitiveType src_type, PrimitiveType dest_type>
|
||||
std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
|
||||
DCHECK_EQ(src_type, src_literal.shape().element_type());
|
||||
return LiteralUtil::Convert<
|
||||
typename primitive_util::PrimitiveTypeToNative<src_type>::type,
|
||||
typename primitive_util::PrimitiveTypeToNative<dest_type>::type>(
|
||||
src_literal);
|
||||
}
|
||||
|
||||
Status HandleConvert(HloInstruction* convert,
|
||||
HloInstruction* operand) override {
|
||||
auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
|
||||
switch (operand->shape().element_type()) {
|
||||
#define CONVERT_IF_TYPES_MATCH(src_type) \
|
||||
case (src_type): \
|
||||
parent_->evaluated_[convert] = LiteralUtil::Convert< \
|
||||
typename primitive_util::PrimitiveTypeToNative<src_type>::type, \
|
||||
ReturnT>(operand_literal); \
|
||||
break;
|
||||
CONVERT_IF_TYPES_MATCH(PRED)
|
||||
CONVERT_IF_TYPES_MATCH(S8)
|
||||
CONVERT_IF_TYPES_MATCH(S32)
|
||||
CONVERT_IF_TYPES_MATCH(S64)
|
||||
CONVERT_IF_TYPES_MATCH(U8)
|
||||
CONVERT_IF_TYPES_MATCH(U32)
|
||||
CONVERT_IF_TYPES_MATCH(U64)
|
||||
CONVERT_IF_TYPES_MATCH(F32)
|
||||
CONVERT_IF_TYPES_MATCH(F64)
|
||||
#undef CONVERT_IF_TYPES_MATCH
|
||||
// Other types are not yet supported.
|
||||
default:
|
||||
LOG(FATAL) << "unimplemented operand type for HandleCovert: "
|
||||
<< PrimitiveType_Name(operand->shape().element_type());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleExp(HloInstruction* exp, HloInstruction* operand) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
|
||||
ElementWiseUnaryOp(exp, [](ReturnT elem_operand) {
|
||||
@ -239,6 +117,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleIsFinite(HloInstruction* is_finite,
|
||||
HloInstruction* operand) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite],
|
||||
ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) {
|
||||
return std::isfinite(elem_operand);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
|
||||
ElementWiseUnaryOp(log, [](ReturnT elem_operand) {
|
||||
@ -322,12 +209,77 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||
HloInstruction* lhs, HloInstruction* rhs) override {
|
||||
std::function<bool(ReturnT, ReturnT)> compare_op;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kEq:
|
||||
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el == rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kNe:
|
||||
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el != rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kGe:
|
||||
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el >= rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kGt:
|
||||
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el > rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kLe:
|
||||
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el <= rhs_el;
|
||||
};
|
||||
break;
|
||||
case HloOpcode::kLt:
|
||||
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return lhs_el < rhs_el;
|
||||
};
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
|
||||
<< HloOpcodeString(opcode);
|
||||
}
|
||||
|
||||
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
|
||||
// removed.
|
||||
if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
|
||||
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
|
||||
return Unimplemented(
|
||||
"Compare operation with mismatched dimensions, likely due to "
|
||||
"broadcasting is unsupported.");
|
||||
}
|
||||
|
||||
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
|
||||
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
|
||||
|
||||
auto result = LiteralUtil::CreateFromShape(compare->shape());
|
||||
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
|
||||
do {
|
||||
LiteralUtil::Set<bool>(
|
||||
result.get(), multi_index,
|
||||
compare_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
|
||||
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
|
||||
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
|
||||
|
||||
parent_->evaluated_[compare] = std::move(result);
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[maximum],
|
||||
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
|
||||
return std::fmax(lhs, rhs);
|
||||
return std::max(lhs, rhs);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
@ -337,7 +289,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[minimum],
|
||||
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return std::fmin(lhs_el, rhs_el);
|
||||
return std::min(lhs_el, rhs_el);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
@ -357,7 +309,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[remainder],
|
||||
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
|
||||
return std::fmod(lhs_el, rhs_el);
|
||||
return std::remainder(lhs_el, rhs_el);
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
@ -386,7 +338,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
HloInstruction* arg, HloInstruction* max) override {
|
||||
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
|
||||
[](ReturnT low, ReturnT high, ReturnT value) {
|
||||
return std::fmax(low, std::fmin(value, high));
|
||||
return std::max(low, std::min(value, high));
|
||||
};
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
|
||||
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
|
||||
@ -418,11 +370,32 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
|
||||
HloInstruction* instruction,
|
||||
const std::function<ReturnT(ReturnT)>& unary_op) {
|
||||
const Literal& operand_literal =
|
||||
parent_->GetEvaluatedLiteralFor(instruction->operand(0));
|
||||
return ElementWiseUnaryOpImpl<ReturnT, ReturnT>(instruction, unary_op,
|
||||
operand_literal);
|
||||
}
|
||||
const auto shape = instruction->shape();
|
||||
const auto* operand = instruction->operand(0);
|
||||
|
||||
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
|
||||
// removed.
|
||||
if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
|
||||
return Unimplemented(
|
||||
"Implicit broadcasting is currently unsupported in HLO evaluator "
|
||||
"Shape Mismatch: %s vs %s",
|
||||
ShapeUtil::HumanString(shape).c_str(),
|
||||
ShapeUtil::HumanString(operand->shape()).c_str());
|
||||
}
|
||||
|
||||
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
|
||||
auto result = LiteralUtil::CreateFromShape(shape);
|
||||
|
||||
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
|
||||
do {
|
||||
LiteralUtil::Set<ReturnT>(
|
||||
result.get(), multi_index,
|
||||
unary_op(LiteralUtil::Get<ReturnT>(operand_literal, multi_index)));
|
||||
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
|
||||
|
||||
return std::move(result);
|
||||
};
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
|
||||
HloInstruction* instruction,
|
||||
@ -447,14 +420,16 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
|
||||
|
||||
auto result = LiteralUtil::CreateFromShape(shape);
|
||||
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
|
||||
do {
|
||||
LiteralUtil::Set<ReturnT>(
|
||||
result.get(), multi_index,
|
||||
binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
|
||||
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
|
||||
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
|
||||
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
|
||||
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
|
||||
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index));
|
||||
}));
|
||||
return std::move(result);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LhsType, typename RhsType, typename EhsType>
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseTernaryOp(
|
||||
@ -484,17 +459,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
|
||||
|
||||
auto result = LiteralUtil::CreateFromShape(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
|
||||
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
return ternary_op(
|
||||
LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
|
||||
LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
|
||||
LiteralUtil::Get<EhsType>(ehs_literal, multi_index));
|
||||
}));
|
||||
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
|
||||
do {
|
||||
LiteralUtil::Set<ReturnT>(
|
||||
result.get(), multi_index,
|
||||
ternary_op(LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
|
||||
LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
|
||||
LiteralUtil::Get<EhsType>(ehs_literal, multi_index)));
|
||||
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
};
|
||||
|
||||
HloEvaluator* parent_;
|
||||
};
|
||||
@ -518,12 +493,6 @@ HloEvaluator::HloEvaluator() {
|
||||
});
|
||||
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
|
||||
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
|
||||
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: TUPLE.");
|
||||
});
|
||||
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
|
||||
return Unimplemented("unhandled primitive type: OPAQUE.");
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
@ -533,15 +502,15 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
evaluated_.clear();
|
||||
|
||||
TF_RETURN_IF_ERROR(computation->Accept(this));
|
||||
return MakeUnique<Literal>(
|
||||
GetEvaluatedLiteralFor(computation->root_instruction()));
|
||||
return std::move(FindOrDie(evaluated_, computation->root_instruction()));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> operands) {
|
||||
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
|
||||
DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
|
||||
Shape shape = instruction->shape();
|
||||
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
|
||||
|
||||
arg_literals_ = operands;
|
||||
evaluated_.clear();
|
||||
@ -556,34 +525,13 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
|
||||
|
||||
evaluated_[operand] = MakeUnique<Literal>(*input_literal);
|
||||
} else if (operand->opcode() == HloOpcode::kConstant) {
|
||||
evaluated_[operand] = MakeUnique<Literal>(operand->literal());
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(instruction->Visit(this));
|
||||
return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
HloInstruction* instruction) {
|
||||
TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction));
|
||||
TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter);
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
|
||||
|
||||
arg_literals_.clear();
|
||||
evaluated_.clear();
|
||||
TF_RETURN_IF_ERROR(instruction->Visit(this));
|
||||
return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
|
||||
HloInstruction* instruction) {
|
||||
auto result_or = Evaluate(instruction);
|
||||
if (!result_or.ok()) {
|
||||
LOG(ERROR) << "TryEvaluate failed:" << result_or.status();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return result_or.ConsumeValueOrDie();
|
||||
return std::move(FindOrDie(evaluated_, instruction));
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
|
||||
@ -600,191 +548,9 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
|
||||
Status HloEvaluator::HandleConstant(HloInstruction* constant,
|
||||
const Literal& literal) {
|
||||
VLOG(2) << "HandleConstant: " << constant->ToString();
|
||||
return Status::OK();
|
||||
}
|
||||
DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape()));
|
||||
|
||||
Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[reshape],
|
||||
LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)),
|
||||
AsInt64Slice(reshape->shape().dimensions())));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
|
||||
evaluated_[transpose] = LiteralUtil::Transpose(
|
||||
GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
||||
// The result concatenate dimension is going to be the sum of all concatenate
|
||||
// dimensions of the operands taking part of the operation.
|
||||
const Shape& reference_shape = operands[0]->shape();
|
||||
CHECK(!ShapeUtil::IsTuple(reference_shape));
|
||||
const int64 rank = ShapeUtil::Rank(reference_shape);
|
||||
const int64 concat_dim = concatenate->dimensions()[0];
|
||||
CHECK_GE(concat_dim, 0);
|
||||
CHECK_LT(concat_dim, rank);
|
||||
|
||||
DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
|
||||
reference_shape.dimensions().end());
|
||||
|
||||
for (int64 i = 1; i < operands.size(); ++i) {
|
||||
const Shape& operand_shape = operands[i]->shape();
|
||||
CHECK(!ShapeUtil::IsTuple(operand_shape));
|
||||
// Accumulate the concat dimension from all tensors taking part to the
|
||||
// operation.
|
||||
concat_dimensions[concat_dim] +=
|
||||
ShapeUtil::GetDimension(operand_shape, concat_dim);
|
||||
}
|
||||
|
||||
auto result_literal = LiteralUtil::CreateFromDimensions(
|
||||
reference_shape.element_type(), concat_dimensions);
|
||||
DimensionVector source_indices(rank, 0);
|
||||
DimensionVector dest_indices(concat_dimensions.size(), 0);
|
||||
|
||||
for (auto operand : operands) {
|
||||
const Shape& operand_shape = operand->shape();
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
|
||||
GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(),
|
||||
dest_indices, AsInt64Slice(operand_shape.dimensions())));
|
||||
dest_indices[concat_dim] +=
|
||||
ShapeUtil::GetDimension(operand_shape, concat_dim);
|
||||
}
|
||||
|
||||
evaluated_[concatenate] = std::move(result_literal);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite,
|
||||
HloInstruction* operand) {
|
||||
if (!ShapeUtil::ElementIsFloating(operand->shape())) {
|
||||
return InvalidArgument(
|
||||
"expected element type in shape to be float for IsFinite op, got: %s",
|
||||
PrimitiveType_Name(operand->shape().element_type()).c_str());
|
||||
}
|
||||
|
||||
switch (operand->shape().element_type()) {
|
||||
case F16:
|
||||
return Unimplemented("unhandled primitive type: F16.");
|
||||
case F32: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<bool, float>(
|
||||
is_finite,
|
||||
[](float elem_operand) { return std::isfinite(elem_operand); },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
case F64: {
|
||||
auto result_or = ElementWiseUnaryOpImpl<bool, double>(
|
||||
is_finite,
|
||||
[](double elem_operand) { return std::isfinite(elem_operand); },
|
||||
GetEvaluatedLiteralFor(operand));
|
||||
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "unknown/unhandled primitive type.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||
HloInstruction* lhs, HloInstruction* rhs) {
|
||||
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
|
||||
// removed.
|
||||
if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
|
||||
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
|
||||
return Unimplemented(
|
||||
"Implicit broadcasting is currently unsupported in HLO evaluator "
|
||||
"Shape Mismatch: %s vs %s vs %s",
|
||||
ShapeUtil::HumanString(compare->shape()).c_str(),
|
||||
ShapeUtil::HumanString(lhs->shape()).c_str(),
|
||||
ShapeUtil::HumanString(rhs->shape()).c_str());
|
||||
}
|
||||
|
||||
TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
|
||||
|
||||
const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
|
||||
const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
|
||||
|
||||
// Note here we switch on the operand's type.
|
||||
switch (lhs->shape().element_type()) {
|
||||
case PRED: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case U8: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case U16:
|
||||
return Unimplemented("unhandled primitive type: U16.");
|
||||
case U32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case U64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case S8: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case S16:
|
||||
return Unimplemented("unhandled primitive type: S16.");
|
||||
case S32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case S64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case F16:
|
||||
return Unimplemented("unhandled primitive type: F16.");
|
||||
case F32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
case F64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
evaluated_[compare],
|
||||
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
default:
|
||||
LOG(FATAL) << "unknown primitive type.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleSlice(HloInstruction* slice,
|
||||
HloInstruction* operand) {
|
||||
const Shape& shape = slice->shape();
|
||||
auto literal = LiteralUtil::CreateFromDimensions(
|
||||
shape.element_type(), AsInt64Slice(shape.dimensions()));
|
||||
|
||||
DimensionVector dest_indices(slice->slice_starts().size(), 0);
|
||||
|
||||
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
|
||||
GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(),
|
||||
dest_indices, AsInt64Slice(shape.dimensions())));
|
||||
|
||||
evaluated_[slice] = std::move(literal);
|
||||
evaluated_[constant] = MakeUnique<Literal>(literal);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -57,32 +57,21 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
// Evaluates a single HLO instruction and an array of pointers to literals.
|
||||
// Return the evaluated result as literal if successful.
|
||||
// Precondition:
|
||||
// 1. argument literals correspond to the input instruction's parameters in
|
||||
// their post-ordering.
|
||||
// 1. argument literals are corresponds to the input instruction's
|
||||
// parameters in their post-orderring.
|
||||
// 2. the instruction's operands must be of either Parameter or Constant type.
|
||||
// TODO(b/35950897): implement more ops other than element-wise ops.
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(
|
||||
HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
|
||||
|
||||
// Evaluates a single HLO instruction with constant operands.
|
||||
// Returns the evaluated result as literal if successful.
|
||||
// Precondition:
|
||||
// 1. all operands of the input instruction are constants.
|
||||
// 2. the instruction is not a Parameter operation.
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
|
||||
|
||||
// Same as Evaluate, except returning nullptr on error.
|
||||
std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
|
||||
|
||||
protected:
|
||||
// Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
|
||||
// literal type of each evaluated Handle* method of a TypedVisitor.
|
||||
// There are however a few notable exceptions to this is rule, notably:
|
||||
// - HandleCompare and HandleIsFinite: where the resulting literal type is
|
||||
// literal type of each evaluated Handle* method of a TypedVisitor. One
|
||||
// exception to this is HandleCompare, where the resulting literal type is
|
||||
// always boolean.
|
||||
// These operations are handled outside of the parent HloEvaluator handlers
|
||||
// instead of from within TypedVisitor.
|
||||
// Note the forward declaration here is necessary to enable TypedVisitor to
|
||||
// access parent members.
|
||||
template <typename ReturnT>
|
||||
class TypedVisitor;
|
||||
|
||||
@ -92,38 +81,15 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
|
||||
}
|
||||
|
||||
// Operations that are type-agnostic.
|
||||
//
|
||||
Status HandleParameter(HloInstruction* parameter) override;
|
||||
|
||||
Status HandleConstant(HloInstruction* constant,
|
||||
const Literal& literal) override;
|
||||
|
||||
Status HandleConcatenate(
|
||||
HloInstruction* concatenate,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
|
||||
|
||||
Status HandleReshape(HloInstruction* reshape) override;
|
||||
|
||||
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
|
||||
|
||||
Status HandleTranspose(HloInstruction* transpose) override;
|
||||
|
||||
Status HandleIsFinite(HloInstruction* is_finite,
|
||||
HloInstruction* operand) override;
|
||||
|
||||
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||
HloInstruction* lhs, HloInstruction* rhs) override;
|
||||
|
||||
private:
|
||||
// Returns the already-evaluated literal result for the instruction.
|
||||
// A Constant instruction is considered evaluated and its literal will be
|
||||
// returned directly without looking up the cache.
|
||||
// Crash with log if the given instruction has not been evaluated previously.
|
||||
const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
|
||||
if (hlo->IsConstant()) {
|
||||
return hlo->literal();
|
||||
}
|
||||
auto it = evaluated_.find(hlo);
|
||||
CHECK(it != evaluated_.end())
|
||||
<< "could not find evaluated value for: " << hlo->ToString();
|
||||
|
@ -23,10 +23,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -145,7 +143,7 @@ TEST_F(HloEvaluatorTest, DoesDivide) {
|
||||
// element-wise abs op with 1 operand.
|
||||
TEST_F(HloEvaluatorTest, DoesAbs) {
|
||||
auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
|
||||
const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2});
|
||||
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
|
||||
auto c1 = HloInstruction::CreateConstant(std::move(operand));
|
||||
auto instruction =
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get());
|
||||
@ -156,29 +154,7 @@ TEST_F(HloEvaluatorTest, DoesAbs) {
|
||||
auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
|
||||
|
||||
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
|
||||
|
||||
// For R0 literal.
|
||||
const Shape& r0 = ShapeUtil::MakeShape(F32, {});
|
||||
operand = LiteralUtil::CreateR0<float>(-1.0f);
|
||||
c1 = HloInstruction::CreateConstant(std::move(operand));
|
||||
instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get());
|
||||
result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
|
||||
expected = LiteralUtil::CreateR0<float>(1.0f);
|
||||
|
||||
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
|
||||
|
||||
// For R1 literal with dimension of size 0.
|
||||
Shape empty_r1 = ShapeUtil::MakeShape(F32, {0});
|
||||
operand = LiteralUtil::CreateR1<float>({});
|
||||
c1 = HloInstruction::CreateConstant(std::move(operand));
|
||||
instruction =
|
||||
HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get());
|
||||
|
||||
result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
|
||||
expected = LiteralUtil::CreateR1<float>({});
|
||||
|
||||
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
|
||||
} // namespace
|
||||
}
|
||||
|
||||
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
|
||||
// constant operands.
|
||||
@ -211,35 +187,5 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) {
|
||||
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
|
||||
}
|
||||
|
||||
// Verifies Reshape operation is correctly evaluated.
|
||||
TEST_F(HloEvaluatorTest, DoesReshape) {
|
||||
HloComputation::Builder builder(
|
||||
::testing::UnitTest::GetInstance()->current_test_info()->name());
|
||||
|
||||
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
||||
TF_ASSIGN_OR_ASSERT_OK(auto literal,
|
||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||
auto literal_clone = LiteralUtil::CloneToUnique(*literal);
|
||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
|
||||
const int64 permutation[] = {1, 2, 0, 4, 3};
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
|
||||
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
|
||||
LiteralUtil::EachCell<NativeT>(
|
||||
*result, [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
|
||||
std::vector<int64> rindexes = Permute(permutation, indices);
|
||||
EXPECT_TRUE(value ==
|
||||
LiteralUtil::Get<NativeT>(*literal_clone, rindexes));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -756,28 +756,27 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
|
||||
// and unmodified_dim_pair have size >1. Otherwise, returns true and appends
|
||||
// the degerenate input/output dimensions in the gap to
|
||||
// deleted_indices/inserted_indices respectively.
|
||||
auto check_modified_dims =
|
||||
[&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
|
||||
std::pair<int64, int64> prior_unmodified_dim_pair,
|
||||
std::pair<int64, int64> unmodified_dim_pair) {
|
||||
for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
|
||||
modified_input_dim < unmodified_dim_pair.first;
|
||||
++modified_input_dim) {
|
||||
if (shape_pre.dimensions(modified_input_dim) > 1) {
|
||||
return false;
|
||||
}
|
||||
deleted_indices.push_back(modified_input_dim);
|
||||
}
|
||||
for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
|
||||
modified_output_dim < unmodified_dim_pair.second;
|
||||
++modified_output_dim) {
|
||||
if (shape_post.dimensions(modified_output_dim) > 1) {
|
||||
return false;
|
||||
}
|
||||
inserted_indices.push_back(modified_output_dim);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices,
|
||||
&inserted_indices](
|
||||
std::pair<int64, int64> prior_unmodified_dim_pair,
|
||||
std::pair<int64, int64> unmodified_dim_pair) {
|
||||
for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
|
||||
modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) {
|
||||
if (shape_pre.dimensions(modified_input_dim) > 1) {
|
||||
return false;
|
||||
}
|
||||
deleted_indices.push_back(modified_input_dim);
|
||||
}
|
||||
for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
|
||||
modified_output_dim < unmodified_dim_pair.second;
|
||||
++modified_output_dim) {
|
||||
if (shape_post.dimensions(modified_output_dim) > 1) {
|
||||
return false;
|
||||
}
|
||||
inserted_indices.push_back(modified_output_dim);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
DimensionsUnmodifiedByReshape(shape_pre, shape_post);
|
||||
@ -1190,9 +1189,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
tensorflow::gtl::ArraySlice<int64> count,
|
||||
tensorflow::gtl::ArraySlice<int64> incr,
|
||||
const IndexVisitorFunction& visitor_function) {
|
||||
if (ShapeUtil::HasZeroElements(shape)) {
|
||||
return;
|
||||
}
|
||||
DCHECK_EQ(Rank(shape), base.size());
|
||||
DCHECK_EQ(incr.size(), base.size());
|
||||
DCHECK_EQ(count.size(), base.size());
|
||||
|
@ -446,34 +446,6 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
|
||||
ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ForEachIndex) {
|
||||
struct ShapeDimensionAndNumberInvocations {
|
||||
std::vector<int64> dimensions;
|
||||
int invocations;
|
||||
} test_data[] = {
|
||||
{{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0},
|
||||
{{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610},
|
||||
};
|
||||
|
||||
for (const auto& data : test_data) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
|
||||
// Increments at every invocation.
|
||||
int invocations = 0;
|
||||
auto increment_func = [&invocations](const std::vector<int64>& indexes) {
|
||||
invocations++;
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<int64> zero_base(data.dimensions.size(), 0);
|
||||
std::vector<int64> step(data.dimensions.size(), 1);
|
||||
|
||||
ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step,
|
||||
increment_func);
|
||||
|
||||
EXPECT_EQ(invocations, data.invocations);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
|
||||
// All output dimensions should be unmodified. One of the input dimensions is
|
||||
// modified because the input rank is larger by one.
|
||||
|
@ -46,16 +46,12 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
|
||||
|
||||
ClientLibraryTestBase::ClientLibraryTestBase(
|
||||
se::Platform* platform,
|
||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names,
|
||||
bool disable_constant_folding)
|
||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names)
|
||||
: client_(GetOrCreateLocalClientOrDie(platform)) {
|
||||
legacy_flags::HloPassPipelineFlags* flags =
|
||||
legacy_flags::GetHloPassPipelineFlags();
|
||||
flags->xla_disable_hlo_passes =
|
||||
tensorflow::str_util::Join(disabled_pass_names, ",");
|
||||
if (disable_constant_folding) {
|
||||
flags->xla_disable_hlo_passes += ",constant_folding";
|
||||
}
|
||||
}
|
||||
|
||||
string ClientLibraryTestBase::TestName() const {
|
||||
|
@ -47,14 +47,7 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
protected:
|
||||
explicit ClientLibraryTestBase(
|
||||
perftools::gputools::Platform* platform = nullptr,
|
||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {},
|
||||
// Note: here we are disabling constant_folding paths so that the tests
|
||||
// (usually written using Constants) will exercise the intended code
|
||||
// paths, instead of being constant folded.
|
||||
// TODO(b/38354253): Constant folding is currently disabled here. Change
|
||||
// tests to Parameters instead of Constants, and re-enable constant
|
||||
// folding by default.
|
||||
bool disable_constant_folding = true);
|
||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
|
||||
|
||||
// Returns the name of the test currently being run.
|
||||
string TestName() const;
|
||||
|
@ -29,7 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
Loading…
Reference in New Issue
Block a user