Automated g4 rollback of changelist 157174708

PiperOrigin-RevId: 157253080
This commit is contained in:
Kay Zhu 2017-05-26 13:00:03 -07:00 committed by TensorFlower Gardener
parent 70313342b6
commit 2ff1d7bf04
12 changed files with 371 additions and 571 deletions

View File

@ -1071,7 +1071,6 @@ template <typename NativeT>
ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions,
stride_config.step, init_function);
} else {
// For scalars.
data.at(0) = generator({});
}
return Status::OK();

View File

@ -807,9 +807,7 @@ TEST_F(LiteralUtilTest, Populate) {
std::vector<int64> layout;
} populate_data[] = {
{{}, {}},
{{0}, {0}},
{{16}, {0}},
{{2, 0}, {1, 0}},
{{4, 16}, {1, 0}},
{{21, 12}, {0, 1}},
{{6, 11, 17}, {2, 0, 1}},

View File

@ -106,12 +106,10 @@ cc_test(
":hlo_evaluator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
@ -1461,9 +1459,7 @@ cc_library(
hdrs = ["hlo_constant_folding.h"],
deps = [
":hlo",
":hlo_evaluator",
":hlo_pass",
":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",

View File

@ -24,57 +24,230 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace {
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
static std::unique_ptr<Literal> ConvertIfTypesMatch(
const Literal& src_literal) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
return LiteralUtil::Convert<
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal);
}
template <PrimitiveType primitive_src_type>
static std::unique_ptr<Literal> ConvertIfDestTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
static std::unique_ptr<Literal> ConvertIfSrcTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (src_literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
CONVERT_IF_DEST_TYPE_MATCHES(S64)
CONVERT_IF_DEST_TYPE_MATCHES(U8)
CONVERT_IF_DEST_TYPE_MATCHES(U32)
CONVERT_IF_DEST_TYPE_MATCHES(U64)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
} // namespace
// ConstantFolderVisitor traverses the HLO computation and reduces certain
// constant graph sections, to literals.
class ConstantFolderVisitor : public DfsHloVisitorWithDefault {
public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleConvert(HloInstruction* convert,
HloInstruction* operand) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
Status HandleTranspose(HloInstruction* transpose) override;
// Returns whether a constant folding operation has occurred.
const bool changed() const { return changed_; }
// Runs the visitor on a computation and returns whether any changes were
// performed.
static StatusOr<bool> Run(HloComputation* computation);
private:
ConstantFolderVisitor() = default;
// Replaces the existing HLO instruction old_instruction, with a literal,
// and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithConstant(HloInstruction* old_instruction,
std::unique_ptr<Literal> literal) {
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
old_instruction, HloInstruction::CreateConstant(std::move(literal))));
changed_ = true;
return Status::OK();
}
// Whether any constant folding operations have occurred.
bool changed_ = false;
};
StatusOr<bool> ConstantFolderVisitor::Run(HloComputation* computation) {
ConstantFolderVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}
StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
auto evaluator = MakeUnique<HloEvaluator>();
XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto& computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
// Skip dead code.
if (instruction->user_count() == 0 &&
computation->root_instruction() != instruction) {
continue;
}
// Skip Constant and Parameter operation.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant) {
continue;
}
// Skip instructions with non-constant operands.
if (!hlo_query::AllOperandsAreConstants(*instruction)) {
continue;
}
std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
if (result == nullptr) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
}
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(std::move(result))));
changed = true;
}
for (auto& comp : module->computations()) {
TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get()));
changed = changed || result;
}
XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
return changed;
}
Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) {
if (reshape->operand(0)->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(
auto reshaped_literal,
LiteralUtil::Reshape(reshape->operand(0)->literal(),
AsInt64Slice(reshape->shape().dimensions())));
return ReplaceWithConstant(reshape, std::move(reshaped_literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) {
if (transpose->operand(0)->opcode() == HloOpcode::kConstant) {
auto transposed_literal = LiteralUtil::Transpose(
transpose->operand(0)->literal(), transpose->dimensions());
return ReplaceWithConstant(transpose, std::move(transposed_literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
if (operands[0]->opcode() == HloOpcode::kConstant) {
// If all the operands of a concatenate are constant, fold them into a
// single constant tensor.
// The result concatenate dimension is going to be the sum of all the
// concatenate dimensions of the arrays taking part of the operation.
int64 concat_dim = concatenate->dimensions()[0];
const Shape& reference_shape = operands[0]->shape();
CHECK(!ShapeUtil::IsTuple(reference_shape));
int64 rank = ShapeUtil::Rank(reference_shape);
std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
if (concat_dim < 0) {
concat_dim += rank;
}
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
CHECK(!ShapeUtil::IsTuple(operand_shape));
if (operands[i]->opcode() != HloOpcode::kConstant) {
return Status::OK();
}
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
auto literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
std::vector<int64> source_indices(rank, 0);
std::vector<int64> dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), source_indices, literal.get(), dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
return ReplaceWithConstant(concatenate, std::move(literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Shape& shape = slice->shape();
auto literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
std::vector<int64> dest_indices(slice->slice_starts().size(), 0);
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), slice->slice_starts(), literal.get(), dest_indices,
AsInt64Slice(shape.dimensions())));
TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal)));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Literal& src_literal = operand->literal();
std::unique_ptr<Literal> new_constant =
ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type());
return ReplaceWithConstant(convert, std::move(new_constant));
}
return Status::OK();
}
} // namespace xla

View File

@ -46,89 +46,6 @@ limitations under the License.
namespace xla {
namespace {
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
const Literal& lhs_literal,
const Literal& rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el != rhs_el;
};
break;
case HloOpcode::kGe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el >= rhs_el;
};
break;
case HloOpcode::kGt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el > rhs_el;
};
break;
case HloOpcode::kLe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el <= rhs_el;
};
break;
case HloOpcode::kLt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el < rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
}
auto result = LiteralUtil::CreateFromShape(shape);
TF_RETURN_IF_ERROR(LiteralUtil::Populate<bool>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return compare_op(LiteralUtil::Get<OperandT>(lhs_literal, multi_index),
LiteralUtil::Get<OperandT>(rhs_literal, multi_index));
}));
return std::move(result);
}
template <typename ReturnT, typename NativeT>
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
const auto shape = instruction->shape();
const auto* operand = instruction->operand(0);
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s",
ShapeUtil::HumanString(shape).c_str(),
ShapeUtil::HumanString(operand->shape()).c_str());
}
auto result = LiteralUtil::CreateFromShape(shape);
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return unary_op(
LiteralUtil::Get<NativeT>(operand_literal, multi_index));
}));
return std::move(result);
}
} // namespace
template <typename ReturnT>
class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
public:
@ -151,7 +68,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return elem_operand;
}));
return Status::OK();
}
};
template <
typename NativeT,
@ -162,7 +79,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return std::abs(elem_operand);
}));
return Status::OK();
}
};
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override {
return HandleAbs<ReturnT>(abs, operand);
@ -184,45 +101,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
template <PrimitiveType src_type, PrimitiveType dest_type>
std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
DCHECK_EQ(src_type, src_literal.shape().element_type());
return LiteralUtil::Convert<
typename primitive_util::PrimitiveTypeToNative<src_type>::type,
typename primitive_util::PrimitiveTypeToNative<dest_type>::type>(
src_literal);
}
Status HandleConvert(HloInstruction* convert,
HloInstruction* operand) override {
auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
switch (operand->shape().element_type()) {
#define CONVERT_IF_TYPES_MATCH(src_type) \
case (src_type): \
parent_->evaluated_[convert] = LiteralUtil::Convert< \
typename primitive_util::PrimitiveTypeToNative<src_type>::type, \
ReturnT>(operand_literal); \
break;
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
LOG(FATAL) << "unimplemented operand type for HandleCovert: "
<< PrimitiveType_Name(operand->shape().element_type());
}
return Status::OK();
}
Status HandleExp(HloInstruction* exp, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
ElementWiseUnaryOp(exp, [](ReturnT elem_operand) {
@ -239,6 +117,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
Status HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite],
ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) {
return std::isfinite(elem_operand);
}));
return Status::OK();
};
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ReturnT elem_operand) {
@ -322,12 +209,77 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) override {
std::function<bool(ReturnT, ReturnT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el != rhs_el;
};
break;
case HloOpcode::kGe:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el >= rhs_el;
};
break;
case HloOpcode::kGt:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el > rhs_el;
};
break;
case HloOpcode::kLe:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el <= rhs_el;
};
break;
case HloOpcode::kLt:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el < rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
}
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
return Unimplemented(
"Compare operation with mismatched dimensions, likely due to "
"broadcasting is unsupported.");
}
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
auto result = LiteralUtil::CreateFromShape(compare->shape());
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<bool>(
result.get(), multi_index,
compare_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
parent_->evaluated_[compare] = std::move(result);
return Status::OK();
};
Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
HloInstruction* rhs) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[maximum],
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
return std::fmax(lhs, rhs);
return std::max(lhs, rhs);
}));
return Status::OK();
};
@ -337,7 +289,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[minimum],
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
return std::fmin(lhs_el, rhs_el);
return std::min(lhs_el, rhs_el);
}));
return Status::OK();
};
@ -357,7 +309,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[remainder],
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
return std::fmod(lhs_el, rhs_el);
return std::remainder(lhs_el, rhs_el);
}));
return Status::OK();
};
@ -386,7 +338,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
HloInstruction* arg, HloInstruction* max) override {
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
[](ReturnT low, ReturnT high, ReturnT value) {
return std::fmax(low, std::fmin(value, high));
return std::max(low, std::min(value, high));
};
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
@ -418,11 +370,32 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ReturnT(ReturnT)>& unary_op) {
const Literal& operand_literal =
parent_->GetEvaluatedLiteralFor(instruction->operand(0));
return ElementWiseUnaryOpImpl<ReturnT, ReturnT>(instruction, unary_op,
operand_literal);
}
const auto shape = instruction->shape();
const auto* operand = instruction->operand(0);
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s",
ShapeUtil::HumanString(shape).c_str(),
ShapeUtil::HumanString(operand->shape()).c_str());
}
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto result = LiteralUtil::CreateFromShape(shape);
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<ReturnT>(
result.get(), multi_index,
unary_op(LiteralUtil::Get<ReturnT>(operand_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
return std::move(result);
};
StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
HloInstruction* instruction,
@ -447,14 +420,16 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
auto result = LiteralUtil::CreateFromShape(shape);
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<ReturnT>(
result.get(), multi_index,
binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index));
}));
return std::move(result);
}
};
template <typename LhsType, typename RhsType, typename EhsType>
StatusOr<std::unique_ptr<Literal>> ElementWiseTernaryOp(
@ -484,17 +459,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
auto result = LiteralUtil::CreateFromShape(shape);
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return ternary_op(
LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
LiteralUtil::Get<EhsType>(ehs_literal, multi_index));
}));
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<ReturnT>(
result.get(), multi_index,
ternary_op(LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
LiteralUtil::Get<EhsType>(ehs_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
return std::move(result);
}
};
HloEvaluator* parent_;
};
@ -518,12 +493,6 @@ HloEvaluator::HloEvaluator() {
});
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: TUPLE.");
});
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: OPAQUE.");
});
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
@ -533,15 +502,15 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
evaluated_.clear();
TF_RETURN_IF_ERROR(computation->Accept(this));
return MakeUnique<Literal>(
GetEvaluatedLiteralFor(computation->root_instruction()));
return std::move(FindOrDie(evaluated_, computation->root_instruction()));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> operands) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
Shape shape = instruction->shape();
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
arg_literals_ = operands;
evaluated_.clear();
@ -556,34 +525,13 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
evaluated_[operand] = MakeUnique<Literal>(*input_literal);
} else if (operand->opcode() == HloOpcode::kConstant) {
evaluated_[operand] = MakeUnique<Literal>(operand->literal());
}
}
TF_RETURN_IF_ERROR(instruction->Visit(this));
return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction) {
TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction));
TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter);
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_.clear();
evaluated_.clear();
TF_RETURN_IF_ERROR(instruction->Visit(this));
return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
}
std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
HloInstruction* instruction) {
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
LOG(ERROR) << "TryEvaluate failed:" << result_or.status();
return nullptr;
}
return result_or.ConsumeValueOrDie();
return std::move(FindOrDie(evaluated_, instruction));
}
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
@ -600,191 +548,9 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
Status HloEvaluator::HandleConstant(HloInstruction* constant,
const Literal& literal) {
VLOG(2) << "HandleConstant: " << constant->ToString();
return Status::OK();
}
DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape()));
Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
TF_ASSIGN_OR_RETURN(
evaluated_[reshape],
LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)),
AsInt64Slice(reshape->shape().dimensions())));
return Status::OK();
}
Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
evaluated_[transpose] = LiteralUtil::Transpose(
GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions());
return Status::OK();
}
Status HloEvaluator::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
// The result concatenate dimension is going to be the sum of all concatenate
// dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
CHECK(!ShapeUtil::IsTuple(reference_shape));
const int64 rank = ShapeUtil::Rank(reference_shape);
const int64 concat_dim = concatenate->dimensions()[0];
CHECK_GE(concat_dim, 0);
CHECK_LT(concat_dim, rank);
DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
CHECK(!ShapeUtil::IsTuple(operand_shape));
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
auto result_literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
DimensionVector source_indices(rank, 0);
DimensionVector dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(),
dest_indices, AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
evaluated_[concatenate] = std::move(result_literal);
return Status::OK();
}
Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) {
if (!ShapeUtil::ElementIsFloating(operand->shape())) {
return InvalidArgument(
"expected element type in shape to be float for IsFinite op, got: %s",
PrimitiveType_Name(operand->shape().element_type()).c_str());
}
switch (operand->shape().element_type()) {
case F16:
return Unimplemented("unhandled primitive type: F16.");
case F32: {
auto result_or = ElementWiseUnaryOpImpl<bool, float>(
is_finite,
[](float elem_operand) { return std::isfinite(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
case F64: {
auto result_or = ElementWiseUnaryOpImpl<bool, double>(
is_finite,
[](double elem_operand) { return std::isfinite(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
default:
LOG(FATAL) << "unknown/unhandled primitive type.";
}
return Status::OK();
}
Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) {
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s",
ShapeUtil::HumanString(compare->shape()).c_str(),
ShapeUtil::HumanString(lhs->shape()).c_str(),
ShapeUtil::HumanString(rhs->shape()).c_str());
}
TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
// Note here we switch on the operand's type.
switch (lhs->shape().element_type()) {
case PRED: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case U8: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case U16:
return Unimplemented("unhandled primitive type: U16.");
case U32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case U64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case S8: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case S16:
return Unimplemented("unhandled primitive type: S16.");
case S32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case S64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case F16:
return Unimplemented("unhandled primitive type: F16.");
case F32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case F64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
default:
LOG(FATAL) << "unknown primitive type.";
}
return Status::OK();
}
Status HloEvaluator::HandleSlice(HloInstruction* slice,
HloInstruction* operand) {
const Shape& shape = slice->shape();
auto literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
DimensionVector dest_indices(slice->slice_starts().size(), 0);
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(),
dest_indices, AsInt64Slice(shape.dimensions())));
evaluated_[slice] = std::move(literal);
evaluated_[constant] = MakeUnique<Literal>(literal);
return Status::OK();
}

View File

@ -57,32 +57,21 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
// Precondition:
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 1. argument literals are corresponds to the input instruction's
// parameters in their post-orderring.
// 2. the instruction's operands must be of either Parameter or Constant type.
// TODO(b/35950897): implement more ops other than element-wise ops.
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
// Same as Evaluate, except returning nullptr on error.
std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
protected:
// Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
// literal type of each evaluated Handle* method of a TypedVisitor.
// There are however a few notable exceptions to this is rule, notably:
// - HandleCompare and HandleIsFinite: where the resulting literal type is
// literal type of each evaluated Handle* method of a TypedVisitor. One
// exception to this is HandleCompare, where the resulting literal type is
// always boolean.
// These operations are handled outside of the parent HloEvaluator handlers
// instead of from within TypedVisitor.
// Note the forward declaration here is necessary to enable TypedVisitor to
// access parent members.
template <typename ReturnT>
class TypedVisitor;
@ -92,38 +81,15 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
}
// Operations that are type-agnostic.
//
Status HandleParameter(HloInstruction* parameter) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
Status HandleTranspose(HloInstruction* transpose) override;
Status HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) override;
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) override;
private:
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
// Crash with log if the given instruction has not been evaluated previously.
const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
if (hlo->IsConstant()) {
return hlo->literal();
}
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();

View File

@ -23,10 +23,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
@ -145,7 +143,7 @@ TEST_F(HloEvaluatorTest, DoesDivide) {
// element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest, DoesAbs) {
auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2});
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
auto c1 = HloInstruction::CreateConstant(std::move(operand));
auto instruction =
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get());
@ -156,29 +154,7 @@ TEST_F(HloEvaluatorTest, DoesAbs) {
auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
// For R0 literal.
const Shape& r0 = ShapeUtil::MakeShape(F32, {});
operand = LiteralUtil::CreateR0<float>(-1.0f);
c1 = HloInstruction::CreateConstant(std::move(operand));
instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get());
result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
expected = LiteralUtil::CreateR0<float>(1.0f);
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
// For R1 literal with dimension of size 0.
Shape empty_r1 = ShapeUtil::MakeShape(F32, {0});
operand = LiteralUtil::CreateR1<float>({});
c1 = HloInstruction::CreateConstant(std::move(operand));
instruction =
HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get());
result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
expected = LiteralUtil::CreateR1<float>({});
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
} // namespace
}
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
@ -211,35 +187,5 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) {
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
}
// Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest, DoesReshape) {
HloComputation::Builder builder(
::testing::UnitTest::GetInstance()->current_test_info()->name());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSIGN_OR_ASSERT_OK(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = LiteralUtil::CloneToUnique(*literal);
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
const int64 permutation[] = {1, 2, 0, 4, 3};
builder.AddInstruction(
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
LiteralUtil::EachCell<NativeT>(
*result, [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
EXPECT_TRUE(value ==
LiteralUtil::Get<NativeT>(*literal_clone, rindexes));
});
}
} // namespace
} // namespace xla

View File

@ -756,28 +756,27 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
// and unmodified_dim_pair have size >1. Otherwise, returns true and appends
// the degerenate input/output dimensions in the gap to
// deleted_indices/inserted_indices respectively.
auto check_modified_dims =
[&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
std::pair<int64, int64> prior_unmodified_dim_pair,
std::pair<int64, int64> unmodified_dim_pair) {
for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
modified_input_dim < unmodified_dim_pair.first;
++modified_input_dim) {
if (shape_pre.dimensions(modified_input_dim) > 1) {
return false;
}
deleted_indices.push_back(modified_input_dim);
}
for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
modified_output_dim < unmodified_dim_pair.second;
++modified_output_dim) {
if (shape_post.dimensions(modified_output_dim) > 1) {
return false;
}
inserted_indices.push_back(modified_output_dim);
}
return true;
};
auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices,
&inserted_indices](
std::pair<int64, int64> prior_unmodified_dim_pair,
std::pair<int64, int64> unmodified_dim_pair) {
for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) {
if (shape_pre.dimensions(modified_input_dim) > 1) {
return false;
}
deleted_indices.push_back(modified_input_dim);
}
for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
modified_output_dim < unmodified_dim_pair.second;
++modified_output_dim) {
if (shape_post.dimensions(modified_output_dim) > 1) {
return false;
}
inserted_indices.push_back(modified_output_dim);
}
return true;
};
std::vector<std::pair<int64, int64>> unmodified_dims =
DimensionsUnmodifiedByReshape(shape_pre, shape_post);
@ -1190,9 +1189,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const IndexVisitorFunction& visitor_function) {
if (ShapeUtil::HasZeroElements(shape)) {
return;
}
DCHECK_EQ(Rank(shape), base.size());
DCHECK_EQ(incr.size(), base.size());
DCHECK_EQ(count.size(), base.size());

View File

@ -446,34 +446,6 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
}
TEST(ShapeUtilTest, ForEachIndex) {
struct ShapeDimensionAndNumberInvocations {
std::vector<int64> dimensions;
int invocations;
} test_data[] = {
{{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0},
{{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610},
};
for (const auto& data : test_data) {
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
// Increments at every invocation.
int invocations = 0;
auto increment_func = [&invocations](const std::vector<int64>& indexes) {
invocations++;
return true;
};
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step,
increment_func);
EXPECT_EQ(invocations, data.invocations);
}
}
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
// All output dimensions should be unmodified. One of the input dimensions is
// modified because the input rank is larger by one.

View File

@ -46,16 +46,12 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
ClientLibraryTestBase::ClientLibraryTestBase(
se::Platform* platform,
tensorflow::gtl::ArraySlice<string> disabled_pass_names,
bool disable_constant_folding)
tensorflow::gtl::ArraySlice<string> disabled_pass_names)
: client_(GetOrCreateLocalClientOrDie(platform)) {
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
flags->xla_disable_hlo_passes =
tensorflow::str_util::Join(disabled_pass_names, ",");
if (disable_constant_folding) {
flags->xla_disable_hlo_passes += ",constant_folding";
}
}
string ClientLibraryTestBase::TestName() const {

View File

@ -47,14 +47,7 @@ class ClientLibraryTestBase : public ::testing::Test {
protected:
explicit ClientLibraryTestBase(
perftools::gputools::Platform* platform = nullptr,
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {},
// Note: here we are disabling constant_folding paths so that the tests
// (usually written using Constants) will exercise the intended code
// paths, instead of being constant folded.
// TODO(b/38354253): Constant folding is currently disabled here. Change
// tests to Parameters instead of Constants, and re-enable constant
// folding by default.
bool disable_constant_folding = true);
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
// Returns the name of the test currently being run.
string TestName() const;

View File

@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"