[XLA] Enable HloEvaluator for constant folding, also merged a few operations

from hlo_constant_folding to hlo_evaluator.

Additionally:
- In ShapeUtil::ForEachIndex:
    * fix a bug where visitor is called when the shape has zero elements (e.g., F32{1,0})
    * added test case for ForEachIndex.

- In HloEvaluator:
    * Instead of copying and caching a Constant instruction, return the literal directly if the instruction is constant.
    * Fix an issue where TUPLE and OPAQUE primitives are not keyed in the templated typed_visitor.
    * Use (fixed) LiteralUtil::Populate to populate resulting literal, fixes the preexisting bug in the evaluator where R0 and shape with zero size dimensions are not handled.
    * Refactor ElementWiseUnaryOp and HandleCompare to be templatized on the operand's type.
    * Refactor IsFinite to be top level since it is only applicable to floats and the return type is always boolean.
    * Change from std::remainder to std::fmod for kRemainder to be compliant with existing XLA behavior.
    * Change from std::max and std::min to std::fmax and std::fmin to handle NaNs.
    * Minor comments fix.

- Disables constant_folding and reshape-motion for ClientLibraryTestBase so that constant folding would not affect the intended code paths to be execercised by the test. In the longer term we plan change all Constants to Parameter and re-enable constant_folding in tests.

PiperOrigin-RevId: 157174708
This commit is contained in:
Kay Zhu 2017-05-25 17:32:49 -07:00 committed by TensorFlower Gardener
parent 2b546d85db
commit 405f70c6de
12 changed files with 571 additions and 371 deletions

View File

@ -1071,6 +1071,7 @@ template <typename NativeT>
ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions,
stride_config.step, init_function);
} else {
// For scalars.
data.at(0) = generator({});
}
return Status::OK();

View File

@ -807,7 +807,9 @@ TEST_F(LiteralUtilTest, Populate) {
std::vector<int64> layout;
} populate_data[] = {
{{}, {}},
{{0}, {0}},
{{16}, {0}},
{{2, 0}, {1, 0}},
{{4, 16}, {1, 0}},
{{21, 12}, {0, 1}},
{{6, 11, 17}, {2, 0, 1}},

View File

@ -98,10 +98,12 @@ cc_test(
":hlo_evaluator",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
@ -1447,7 +1449,9 @@ cc_library(
hdrs = ["hlo_constant_folding.h"],
deps = [
":hlo",
":hlo_evaluator",
":hlo_pass",
":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",

View File

@ -24,230 +24,57 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace {
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
static std::unique_ptr<Literal> ConvertIfTypesMatch(
const Literal& src_literal) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
return LiteralUtil::Convert<
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
typename primitive_util::PrimitiveTypeToNative<
primitive_dest_type>::type>(src_literal);
}
template <PrimitiveType primitive_src_type>
static std::unique_ptr<Literal> ConvertIfDestTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
case (type): \
return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
static std::unique_ptr<Literal> ConvertIfSrcTypeMatches(
const Literal& src_literal, PrimitiveType primitive_dest_type) {
switch (src_literal.shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
CONVERT_IF_DEST_TYPE_MATCHES(S64)
CONVERT_IF_DEST_TYPE_MATCHES(U8)
CONVERT_IF_DEST_TYPE_MATCHES(U32)
CONVERT_IF_DEST_TYPE_MATCHES(U64)
CONVERT_IF_DEST_TYPE_MATCHES(F32)
CONVERT_IF_DEST_TYPE_MATCHES(F64)
#undef CONVERT_IF_DEST_TYPE_MATCHES
// Other types are not yet supported.
default:
LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
<< PrimitiveType_Name(src_literal.shape().element_type());
}
}
} // namespace
// ConstantFolderVisitor traverses the HLO computation and reduces certain
// constant graph sections, to literals.
class ConstantFolderVisitor : public DfsHloVisitorWithDefault {
public:
// Default visitor action is to do nothing and return OK.
Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
return Status::OK();
}
Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleConvert(HloInstruction* convert,
HloInstruction* operand) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
Status HandleTranspose(HloInstruction* transpose) override;
// Returns whether a constant folding operation has occurred.
const bool changed() const { return changed_; }
// Runs the visitor on a computation and returns whether any changes were
// performed.
static StatusOr<bool> Run(HloComputation* computation);
private:
ConstantFolderVisitor() = default;
// Replaces the existing HLO instruction old_instruction, with a literal,
// and marks the optimizer status as changed.
// Returns the Status representing the result of the replace operation.
Status ReplaceWithConstant(HloInstruction* old_instruction,
std::unique_ptr<Literal> literal) {
TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
old_instruction, HloInstruction::CreateConstant(std::move(literal))));
changed_ = true;
return Status::OK();
}
// Whether any constant folding operations have occurred.
bool changed_ = false;
};
StatusOr<bool> ConstantFolderVisitor::Run(HloComputation* computation) {
ConstantFolderVisitor visitor;
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}
StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
auto evaluator = MakeUnique<HloEvaluator>();
XLA_VLOG_LINES(2,
"HloConstantFolding::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto& comp : module->computations()) {
TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get()));
changed = changed || result;
for (auto& computation : module->computations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
// Skip dead code.
if (instruction->user_count() == 0 &&
computation->root_instruction() != instruction) {
continue;
}
// Skip Constant and Parameter operation.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant) {
continue;
}
// Skip instructions with non-constant operands.
if (!hlo_query::AllOperandsAreConstants(*instruction)) {
continue;
}
std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
// Currently we skip unimplemented operations.
// TODO(b/35975797): Fold constant computations for more operations.
if (result == nullptr) {
VLOG(2) << "Constant folding failed for instruction: "
<< instruction->ToString();
continue;
}
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(std::move(result))));
changed = true;
}
}
XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
return changed;
}
Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) {
if (reshape->operand(0)->opcode() == HloOpcode::kConstant) {
TF_ASSIGN_OR_RETURN(
auto reshaped_literal,
LiteralUtil::Reshape(reshape->operand(0)->literal(),
AsInt64Slice(reshape->shape().dimensions())));
return ReplaceWithConstant(reshape, std::move(reshaped_literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) {
if (transpose->operand(0)->opcode() == HloOpcode::kConstant) {
auto transposed_literal = LiteralUtil::Transpose(
transpose->operand(0)->literal(), transpose->dimensions());
return ReplaceWithConstant(transpose, std::move(transposed_literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
if (operands[0]->opcode() == HloOpcode::kConstant) {
// If all the operands of a concatenate are constant, fold them into a
// single constant tensor.
// The result concatenate dimension is going to be the sum of all the
// concatenate dimensions of the arrays taking part of the operation.
int64 concat_dim = concatenate->dimensions()[0];
const Shape& reference_shape = operands[0]->shape();
CHECK(!ShapeUtil::IsTuple(reference_shape));
int64 rank = ShapeUtil::Rank(reference_shape);
std::vector<int64> concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
if (concat_dim < 0) {
concat_dim += rank;
}
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
CHECK(!ShapeUtil::IsTuple(operand_shape));
if (operands[i]->opcode() != HloOpcode::kConstant) {
return Status::OK();
}
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
auto literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
std::vector<int64> source_indices(rank, 0);
std::vector<int64> dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), source_indices, literal.get(), dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
return ReplaceWithConstant(concatenate, std::move(literal));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Shape& shape = slice->shape();
auto literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
std::vector<int64> dest_indices(slice->slice_starts().size(), 0);
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
operand->literal(), slice->slice_starts(), literal.get(), dest_indices,
AsInt64Slice(shape.dimensions())));
TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal)));
}
return Status::OK();
}
Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert,
HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kConstant) {
const Literal& src_literal = operand->literal();
std::unique_ptr<Literal> new_constant =
ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type());
return ReplaceWithConstant(convert, std::move(new_constant));
}
return Status::OK();
}
} // namespace xla

View File

@ -46,6 +46,89 @@ limitations under the License.
namespace xla {
namespace {
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
const Literal& lhs_literal,
const Literal& rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el != rhs_el;
};
break;
case HloOpcode::kGe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el >= rhs_el;
};
break;
case HloOpcode::kGt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el > rhs_el;
};
break;
case HloOpcode::kLe:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el <= rhs_el;
};
break;
case HloOpcode::kLt:
compare_op = [](OperandT lhs_el, OperandT rhs_el) {
return lhs_el < rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
}
auto result = LiteralUtil::CreateFromShape(shape);
TF_RETURN_IF_ERROR(LiteralUtil::Populate<bool>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return compare_op(LiteralUtil::Get<OperandT>(lhs_literal, multi_index),
LiteralUtil::Get<OperandT>(rhs_literal, multi_index));
}));
return std::move(result);
}
template <typename ReturnT, typename NativeT>
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
HloInstruction* instruction,
const std::function<ReturnT(NativeT)>& unary_op,
const Literal& operand_literal) {
const auto shape = instruction->shape();
const auto* operand = instruction->operand(0);
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s",
ShapeUtil::HumanString(shape).c_str(),
ShapeUtil::HumanString(operand->shape()).c_str());
}
auto result = LiteralUtil::CreateFromShape(shape);
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return unary_op(
LiteralUtil::Get<NativeT>(operand_literal, multi_index));
}));
return std::move(result);
}
} // namespace
template <typename ReturnT>
class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
public:
@ -68,7 +151,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return elem_operand;
}));
return Status::OK();
};
}
template <
typename NativeT,
@ -79,7 +162,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return std::abs(elem_operand);
}));
return Status::OK();
};
}
Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override {
return HandleAbs<ReturnT>(abs, operand);
@ -101,6 +184,45 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
template <PrimitiveType src_type, PrimitiveType dest_type>
std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
DCHECK_EQ(src_type, src_literal.shape().element_type());
return LiteralUtil::Convert<
typename primitive_util::PrimitiveTypeToNative<src_type>::type,
typename primitive_util::PrimitiveTypeToNative<dest_type>::type>(
src_literal);
}
Status HandleConvert(HloInstruction* convert,
HloInstruction* operand) override {
auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
switch (operand->shape().element_type()) {
#define CONVERT_IF_TYPES_MATCH(src_type) \
case (src_type): \
parent_->evaluated_[convert] = LiteralUtil::Convert< \
typename primitive_util::PrimitiveTypeToNative<src_type>::type, \
ReturnT>(operand_literal); \
break;
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
CONVERT_IF_TYPES_MATCH(S64)
CONVERT_IF_TYPES_MATCH(U8)
CONVERT_IF_TYPES_MATCH(U32)
CONVERT_IF_TYPES_MATCH(U64)
CONVERT_IF_TYPES_MATCH(F32)
CONVERT_IF_TYPES_MATCH(F64)
#undef CONVERT_IF_TYPES_MATCH
// Other types are not yet supported.
default:
LOG(FATAL) << "unimplemented operand type for HandleCovert: "
<< PrimitiveType_Name(operand->shape().element_type());
}
return Status::OK();
}
Status HandleExp(HloInstruction* exp, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
ElementWiseUnaryOp(exp, [](ReturnT elem_operand) {
@ -117,15 +239,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
Status HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite],
ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) {
return std::isfinite(elem_operand);
}));
return Status::OK();
};
Status HandleLog(HloInstruction* log, HloInstruction* operand) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ReturnT elem_operand) {
@ -209,77 +322,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) override {
std::function<bool(ReturnT, ReturnT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el == rhs_el;
};
break;
case HloOpcode::kNe:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el != rhs_el;
};
break;
case HloOpcode::kGe:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el >= rhs_el;
};
break;
case HloOpcode::kGt:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el > rhs_el;
};
break;
case HloOpcode::kLe:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el <= rhs_el;
};
break;
case HloOpcode::kLt:
compare_op = [](ReturnT lhs_el, ReturnT rhs_el) {
return lhs_el < rhs_el;
};
break;
default:
LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: "
<< HloOpcodeString(opcode);
}
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
return Unimplemented(
"Compare operation with mismatched dimensions, likely due to "
"broadcasting is unsupported.");
}
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
auto result = LiteralUtil::CreateFromShape(compare->shape());
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<bool>(
result.get(), multi_index,
compare_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
parent_->evaluated_[compare] = std::move(result);
return Status::OK();
};
Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
HloInstruction* rhs) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[maximum],
ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) {
return std::max(lhs, rhs);
return std::fmax(lhs, rhs);
}));
return Status::OK();
};
@ -289,7 +337,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[minimum],
ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) {
return std::min(lhs_el, rhs_el);
return std::fmin(lhs_el, rhs_el);
}));
return Status::OK();
};
@ -309,7 +357,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[remainder],
ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) {
return std::remainder(lhs_el, rhs_el);
return std::fmod(lhs_el, rhs_el);
}));
return Status::OK();
};
@ -338,7 +386,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
HloInstruction* arg, HloInstruction* max) override {
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
[](ReturnT low, ReturnT high, ReturnT value) {
return std::max(low, std::min(value, high));
return std::fmax(low, std::fmin(value, high));
};
TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp],
ElementWiseTernaryOp(clamp, std::move(clamp_op)));
@ -370,32 +418,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ReturnT(ReturnT)>& unary_op) {
const auto shape = instruction->shape();
const auto* operand = instruction->operand(0);
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!ShapeUtil::SameDimensions(shape, operand->shape())) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s",
ShapeUtil::HumanString(shape).c_str(),
ShapeUtil::HumanString(operand->shape()).c_str());
}
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto result = LiteralUtil::CreateFromShape(shape);
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<ReturnT>(
result.get(), multi_index,
unary_op(LiteralUtil::Get<ReturnT>(operand_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
return std::move(result);
};
const Literal& operand_literal =
parent_->GetEvaluatedLiteralFor(instruction->operand(0));
return ElementWiseUnaryOpImpl<ReturnT, ReturnT>(instruction, unary_op,
operand_literal);
}
StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
HloInstruction* instruction,
@ -420,16 +447,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
auto result = LiteralUtil::CreateFromShape(shape);
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<ReturnT>(
result.get(), multi_index,
binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return binary_op(LiteralUtil::Get<ReturnT>(lhs_literal, multi_index),
LiteralUtil::Get<ReturnT>(rhs_literal, multi_index));
}));
return std::move(result);
};
}
template <typename LhsType, typename RhsType, typename EhsType>
StatusOr<std::unique_ptr<Literal>> ElementWiseTernaryOp(
@ -459,17 +484,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
auto result = LiteralUtil::CreateFromShape(shape);
std::vector<int64> multi_index(ShapeUtil::Rank(result->shape()), 0);
do {
LiteralUtil::Set<ReturnT>(
result.get(), multi_index,
ternary_op(LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
LiteralUtil::Get<EhsType>(ehs_literal, multi_index)));
} while (IndexUtil::BumpIndices(result->shape(), &multi_index));
TF_RETURN_IF_ERROR(LiteralUtil::Populate<ReturnT>(
result.get(), [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
return ternary_op(
LiteralUtil::Get<LhsType>(lhs_literal, multi_index),
LiteralUtil::Get<RhsType>(rhs_literal, multi_index),
LiteralUtil::Get<EhsType>(ehs_literal, multi_index));
}));
return std::move(result);
};
}
HloEvaluator* parent_;
};
@ -493,6 +518,12 @@ HloEvaluator::HloEvaluator() {
});
typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: TUPLE.");
});
typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
return Unimplemented("unhandled primitive type: OPAQUE.");
});
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
@ -502,15 +533,15 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
evaluated_.clear();
TF_RETURN_IF_ERROR(computation->Accept(this));
return std::move(FindOrDie(evaluated_, computation->root_instruction()));
return MakeUnique<Literal>(
GetEvaluatedLiteralFor(computation->root_instruction()));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> operands) {
DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
Shape shape = instruction->shape();
TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_ = operands;
evaluated_.clear();
@ -525,13 +556,34 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
evaluated_[operand] = MakeUnique<Literal>(*input_literal);
} else if (operand->opcode() == HloOpcode::kConstant) {
evaluated_[operand] = MakeUnique<Literal>(operand->literal());
}
}
TF_RETURN_IF_ERROR(instruction->Visit(this));
return std::move(FindOrDie(evaluated_, instruction));
return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction) {
TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction));
TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter);
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_.clear();
evaluated_.clear();
TF_RETURN_IF_ERROR(instruction->Visit(this));
return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
}
std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
HloInstruction* instruction) {
auto result_or = Evaluate(instruction);
if (!result_or.ok()) {
LOG(ERROR) << "TryEvaluate failed:" << result_or.status();
return nullptr;
}
return result_or.ConsumeValueOrDie();
}
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
@ -548,9 +600,191 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
Status HloEvaluator::HandleConstant(HloInstruction* constant,
const Literal& literal) {
VLOG(2) << "HandleConstant: " << constant->ToString();
DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape()));
return Status::OK();
}
evaluated_[constant] = MakeUnique<Literal>(literal);
Status HloEvaluator::HandleReshape(HloInstruction* reshape) {
TF_ASSIGN_OR_RETURN(
evaluated_[reshape],
LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)),
AsInt64Slice(reshape->shape().dimensions())));
return Status::OK();
}
Status HloEvaluator::HandleTranspose(HloInstruction* transpose) {
evaluated_[transpose] = LiteralUtil::Transpose(
GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions());
return Status::OK();
}
Status HloEvaluator::HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
// The result concatenate dimension is going to be the sum of all concatenate
// dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
CHECK(!ShapeUtil::IsTuple(reference_shape));
const int64 rank = ShapeUtil::Rank(reference_shape);
const int64 concat_dim = concatenate->dimensions()[0];
CHECK_GE(concat_dim, 0);
CHECK_LT(concat_dim, rank);
DimensionVector concat_dimensions(reference_shape.dimensions().begin(),
reference_shape.dimensions().end());
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
CHECK(!ShapeUtil::IsTuple(operand_shape));
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
auto result_literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
DimensionVector source_indices(rank, 0);
DimensionVector dest_indices(concat_dimensions.size(), 0);
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(),
dest_indices, AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
evaluated_[concatenate] = std::move(result_literal);
return Status::OK();
}
Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) {
if (!ShapeUtil::ElementIsFloating(operand->shape())) {
return InvalidArgument(
"expected element type in shape to be float for IsFinite op, got: %s",
PrimitiveType_Name(operand->shape().element_type()).c_str());
}
switch (operand->shape().element_type()) {
case F16:
return Unimplemented("unhandled primitive type: F16.");
case F32: {
auto result_or = ElementWiseUnaryOpImpl<bool, float>(
is_finite,
[](float elem_operand) { return std::isfinite(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
case F64: {
auto result_or = ElementWiseUnaryOpImpl<bool, double>(
is_finite,
[](double elem_operand) { return std::isfinite(elem_operand); },
GetEvaluatedLiteralFor(operand));
TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or));
break;
}
default:
LOG(FATAL) << "unknown/unhandled primitive type.";
}
return Status::OK();
}
Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) {
// TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is
// removed.
if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) &&
ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) {
return Unimplemented(
"Implicit broadcasting is currently unsupported in HLO evaluator "
"Shape Mismatch: %s vs %s vs %s",
ShapeUtil::HumanString(compare->shape()).c_str(),
ShapeUtil::HumanString(lhs->shape()).c_str(),
ShapeUtil::HumanString(rhs->shape()).c_str());
}
TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type());
const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs);
// Note here we switch on the operand's type.
switch (lhs->shape().element_type()) {
case PRED: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<bool>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case U8: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint8>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case U16:
return Unimplemented("unhandled primitive type: U16.");
case U32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint32>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case U64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<uint64>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case S8: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int8>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case S16:
return Unimplemented("unhandled primitive type: S16.");
case S32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int32>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case S64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<int64>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case F16:
return Unimplemented("unhandled primitive type: F16.");
case F32: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<float>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
case F64: {
TF_ASSIGN_OR_RETURN(
evaluated_[compare],
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
default:
LOG(FATAL) << "unknown primitive type.";
}
return Status::OK();
}
Status HloEvaluator::HandleSlice(HloInstruction* slice,
HloInstruction* operand) {
const Shape& shape = slice->shape();
auto literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
DimensionVector dest_indices(slice->slice_starts().size(), 0);
TF_RETURN_IF_ERROR(LiteralUtil::Copy(
GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(),
dest_indices, AsInt64Slice(shape.dimensions())));
evaluated_[slice] = std::move(literal);
return Status::OK();
}

View File

@ -57,21 +57,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Evaluates a single HLO instruction and an array of pointers to literals.
// Return the evaluated result as literal if successful.
// Precondition:
// 1. argument literals are corresponds to the input instruction's
// parameters in their post-orderring.
// 1. argument literals correspond to the input instruction's parameters in
// their post-ordering.
// 2. the instruction's operands must be of either Parameter or Constant type.
// TODO(b/35950897): implement more ops other than element-wise ops.
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
// Evaluates a single HLO instruction with constant operands.
// Returns the evaluated result as literal if successful.
// Precondition:
// 1. all operands of the input instruction are constants.
// 2. the instruction is not a Parameter operation.
StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
// Same as Evaluate, except returning nullptr on error.
std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
protected:
// Templated DfsHloVisitor. Typically ReturnT here indicates the resulting
// literal type of each evaluated Handle* method of a TypedVisitor. One
// exception to this is HandleCompare, where the resulting literal type is
// literal type of each evaluated Handle* method of a TypedVisitor.
// There are however a few notable exceptions to this is rule, notably:
// - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
// Note the forward declaration here is necessary to enable TypedVisitor to
// access parent members.
// These operations are handled outside of the parent HloEvaluator handlers
// instead of from within TypedVisitor.
template <typename ReturnT>
class TypedVisitor;
@ -81,15 +92,38 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
}
// Operations that are type-agnostic.
//
Status HandleParameter(HloInstruction* parameter) override;
Status HandleConstant(HloInstruction* constant,
const Literal& literal) override;
Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
Status HandleTranspose(HloInstruction* transpose) override;
Status HandleIsFinite(HloInstruction* is_finite,
HloInstruction* operand) override;
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) override;
private:
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
// Crash with log if the given instruction has not been evaluated previously.
const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) {
if (hlo->IsConstant()) {
return hlo->literal();
}
auto it = evaluated_.find(hlo);
CHECK(it != evaluated_.end())
<< "could not find evaluated value for: " << hlo->ToString();

View File

@ -23,8 +23,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
@ -143,7 +145,7 @@ TEST_F(HloEvaluatorTest, DoesDivide) {
// element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest, DoesAbs) {
auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2});
auto c1 = HloInstruction::CreateConstant(std::move(operand));
auto instruction =
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get());
@ -154,7 +156,29 @@ TEST_F(HloEvaluatorTest, DoesAbs) {
auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
}
// For R0 literal.
const Shape& r0 = ShapeUtil::MakeShape(F32, {});
operand = LiteralUtil::CreateR0<float>(-1.0f);
c1 = HloInstruction::CreateConstant(std::move(operand));
instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get());
result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
expected = LiteralUtil::CreateR0<float>(1.0f);
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
// For R1 literal with dimension of size 0.
Shape empty_r1 = ShapeUtil::MakeShape(F32, {0});
operand = LiteralUtil::CreateR1<float>({});
c1 = HloInstruction::CreateConstant(std::move(operand));
instruction =
HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get());
result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie();
expected = LiteralUtil::CreateR1<float>({});
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
} // namespace
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
@ -187,5 +211,35 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) {
EXPECT_TRUE(LiteralUtil::Equal(*result, *expected));
}
// Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest, DoesReshape) {
HloComputation::Builder builder(
::testing::UnitTest::GetInstance()->current_test_info()->name());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSIGN_OR_ASSERT_OK(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = LiteralUtil::CloneToUnique(*literal);
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
const int64 permutation[] = {1, 2, 0, 4, 3};
builder.AddInstruction(
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
LiteralUtil::EachCell<NativeT>(
*result, [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT value) {
std::vector<int64> rindexes = Permute(permutation, indices);
EXPECT_TRUE(value ==
LiteralUtil::Get<NativeT>(*literal_clone, rindexes));
});
}
} // namespace
} // namespace xla

View File

@ -756,27 +756,28 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
// and unmodified_dim_pair have size >1. Otherwise, returns true and appends
// the degerenate input/output dimensions in the gap to
// deleted_indices/inserted_indices respectively.
auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices,
&inserted_indices](
std::pair<int64, int64> prior_unmodified_dim_pair,
std::pair<int64, int64> unmodified_dim_pair) {
for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) {
if (shape_pre.dimensions(modified_input_dim) > 1) {
return false;
}
deleted_indices.push_back(modified_input_dim);
}
for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
modified_output_dim < unmodified_dim_pair.second;
++modified_output_dim) {
if (shape_post.dimensions(modified_output_dim) > 1) {
return false;
}
inserted_indices.push_back(modified_output_dim);
}
return true;
};
auto check_modified_dims =
[&shape_pre, &shape_post, &deleted_indices, &inserted_indices](
std::pair<int64, int64> prior_unmodified_dim_pair,
std::pair<int64, int64> unmodified_dim_pair) {
for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
modified_input_dim < unmodified_dim_pair.first;
++modified_input_dim) {
if (shape_pre.dimensions(modified_input_dim) > 1) {
return false;
}
deleted_indices.push_back(modified_input_dim);
}
for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
modified_output_dim < unmodified_dim_pair.second;
++modified_output_dim) {
if (shape_post.dimensions(modified_output_dim) > 1) {
return false;
}
inserted_indices.push_back(modified_output_dim);
}
return true;
};
std::vector<std::pair<int64, int64>> unmodified_dims =
DimensionsUnmodifiedByReshape(shape_pre, shape_post);
@ -1189,6 +1190,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const IndexVisitorFunction& visitor_function) {
if (ShapeUtil::HasZeroElements(shape)) {
return;
}
DCHECK_EQ(Rank(shape), base.size());
DCHECK_EQ(incr.size(), base.size());
DCHECK_EQ(count.size(), base.size());

View File

@ -446,6 +446,34 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
}
TEST(ShapeUtilTest, ForEachIndex) {
struct ShapeDimensionAndNumberInvocations {
std::vector<int64> dimensions;
int invocations;
} test_data[] = {
{{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0},
{{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610},
};
for (const auto& data : test_data) {
Shape shape = ShapeUtil::MakeShape(F32, data.dimensions);
// Increments at every invocation.
int invocations = 0;
auto increment_func = [&invocations](const std::vector<int64>& indexes) {
invocations++;
return true;
};
std::vector<int64> zero_base(data.dimensions.size(), 0);
std::vector<int64> step(data.dimensions.size(), 1);
ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step,
increment_func);
EXPECT_EQ(invocations, data.invocations);
}
}
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
// All output dimensions should be unmodified. One of the input dimensions is
// modified because the input rank is larger by one.

View File

@ -46,12 +46,16 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
ClientLibraryTestBase::ClientLibraryTestBase(
se::Platform* platform,
tensorflow::gtl::ArraySlice<string> disabled_pass_names)
tensorflow::gtl::ArraySlice<string> disabled_pass_names,
bool disable_constant_folding)
: client_(GetOrCreateLocalClientOrDie(platform)) {
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
flags->xla_disable_hlo_passes =
tensorflow::str_util::Join(disabled_pass_names, ",");
if (disable_constant_folding) {
flags->xla_disable_hlo_passes += ",constant_folding";
}
}
string ClientLibraryTestBase::TestName() const {

View File

@ -47,7 +47,14 @@ class ClientLibraryTestBase : public ::testing::Test {
protected:
explicit ClientLibraryTestBase(
perftools::gputools::Platform* platform = nullptr,
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {},
// Note: here we are disabling constant_folding paths so that the tests
// (usually written using Constants) will exercise the intended code
// paths, instead of being constant folded.
// TODO(b/38354253): Constant folding is currently disabled here. Change
// tests to Parameters instead of Constants, and re-enable constant
// folding by default.
bool disable_constant_folding = true);
// Returns the name of the test currently being run.
string TestName() const;

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"