From 405f70c6de296faf7561d6739e947d0fb9b26f14 Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Thu, 25 May 2017 17:32:49 -0700 Subject: [PATCH] [XLA] Enable HloEvaluator for constant folding, also merged a few operations from hlo_constant_folding to hlo_evaluator. Additionally: - In ShapeUtil::ForEachIndex: * fix a bug where visitor is called when the shape has zero elements (e.g., F32{1,0}) * added test case for ForEachIndex. - In HloEvaluator: * Instead of copying and caching a Constant instruction, return the literal directly if the instruction is constant. * Fix an issue where TUPLE and OPAQUE primitives are not keyed in the templated typed_visitor. * Use (fixed) LiteralUtil::Populate to populate resulting literal, fixes the preexisting bug in the evaluator where R0 and shape with zero size dimensions are not handled. * Refactor ElementWiseUnaryOp and HandleCompare to be templatized on the operand's type. * Refactor IsFinite to be top level since it is only applicable to floats and the return type is always boolean. * Change from std::remainder to std::fmod for kRemainder to be compliant with existing XLA behavior. * Change from std::max and std::min to std::fmax and std::fmin to handle NaNs. * Minor comments fix. - Disables constant_folding and reshape-motion for ClientLibraryTestBase so that constant folding would not affect the intended code paths to be execercised by the test. In the longer term we plan change all Constants to Parameter and re-enable constant_folding in tests. PiperOrigin-RevId: 157174708 --- tensorflow/compiler/xla/literal_util.h | 1 + tensorflow/compiler/xla/literal_util_test.cc | 2 + tensorflow/compiler/xla/service/BUILD | 4 + .../xla/service/hlo_constant_folding.cc | 243 ++------- .../compiler/xla/service/hlo_evaluator.cc | 498 +++++++++++++----- .../compiler/xla/service/hlo_evaluator.h | 46 +- .../xla/service/hlo_evaluator_test.cc | 58 +- tensorflow/compiler/xla/shape_util.cc | 46 +- tensorflow/compiler/xla/shape_util_test.cc | 28 + .../xla/tests/client_library_test_base.cc | 6 +- .../xla/tests/client_library_test_base.h | 9 +- .../compiler/xla/tests/literal_test_util.h | 1 + 12 files changed, 571 insertions(+), 371 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index a05dc968ee5..0f4b077a300 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -1071,6 +1071,7 @@ template ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions, stride_config.step, init_function); } else { + // For scalars. data.at(0) = generator({}); } return Status::OK(); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 9a09822174d..f8f22e2c398 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -807,7 +807,9 @@ TEST_F(LiteralUtilTest, Populate) { std::vector layout; } populate_data[] = { {{}, {}}, + {{0}, {0}}, {{16}, {0}}, + {{2, 0}, {1, 0}}, {{4, 16}, {1, 0}}, {{21, 12}, {0, 1}}, {{6, 11, 17}, {2, 0, 1}}, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0c1bab0463a..dca39285a3b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -98,10 +98,12 @@ cc_test( ":hlo_evaluator", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], @@ -1447,7 +1449,9 @@ cc_library( hdrs = ["hlo_constant_folding.h"], deps = [ ":hlo", + ":hlo_evaluator", ":hlo_pass", + ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index cb0a99d773c..762ceebf39d 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -24,230 +24,57 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { -namespace { - -template -static std::unique_ptr ConvertIfTypesMatch( - const Literal& src_literal) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); -} - -template -static std::unique_ptr ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -static std::unique_ptr ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -} // namespace - -// ConstantFolderVisitor traverses the HLO computation and reduces certain -// constant graph sections, to literals. -class ConstantFolderVisitor : public DfsHloVisitorWithDefault { - public: - // Default visitor action is to do nothing and return OK. - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override; - - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - - Status HandleReshape(HloInstruction* reshape) override; - - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - - Status HandleTranspose(HloInstruction* transpose) override; - - // Returns whether a constant folding operation has occurred. - const bool changed() const { return changed_; } - - // Runs the visitor on a computation and returns whether any changes were - // performed. - static StatusOr Run(HloComputation* computation); - - private: - ConstantFolderVisitor() = default; - - // Replaces the existing HLO instruction old_instruction, with a literal, - // and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceWithConstant(HloInstruction* old_instruction, - std::unique_ptr literal) { - TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( - old_instruction, HloInstruction::CreateConstant(std::move(literal)))); - changed_ = true; - return Status::OK(); - } - - // Whether any constant folding operations have occurred. - bool changed_ = false; -}; - -StatusOr ConstantFolderVisitor::Run(HloComputation* computation) { - ConstantFolderVisitor visitor; - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - return visitor.changed(); -} StatusOr HloConstantFolding::Run(HloModule* module) { + auto evaluator = MakeUnique(); + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); bool changed = false; - for (auto& comp : module->computations()) { - TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get())); - changed = changed || result; + + for (auto& computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + // Skip dead code. + if (instruction->user_count() == 0 && + computation->root_instruction() != instruction) { + continue; + } + // Skip Constant and Parameter operation. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant) { + continue; + } + // Skip instructions with non-constant operands. + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + continue; + } + + std::unique_ptr result = evaluator->TryEvaluate(instruction); + // Currently we skip unimplemented operations. + // TODO(b/35975797): Fold constant computations for more operations. + if (result == nullptr) { + VLOG(2) << "Constant folding failed for instruction: " + << instruction->ToString(); + continue; + } + + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateConstant(std::move(result)))); + changed = true; + } } XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); return changed; } -Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) { - if (reshape->operand(0)->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN( - auto reshaped_literal, - LiteralUtil::Reshape(reshape->operand(0)->literal(), - AsInt64Slice(reshape->shape().dimensions()))); - return ReplaceWithConstant(reshape, std::move(reshaped_literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) { - if (transpose->operand(0)->opcode() == HloOpcode::kConstant) { - auto transposed_literal = LiteralUtil::Transpose( - transpose->operand(0)->literal(), transpose->dimensions()); - return ReplaceWithConstant(transpose, std::move(transposed_literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) { - if (operands[0]->opcode() == HloOpcode::kConstant) { - // If all the operands of a concatenate are constant, fold them into a - // single constant tensor. - // The result concatenate dimension is going to be the sum of all the - // concatenate dimensions of the arrays taking part of the operation. - int64 concat_dim = concatenate->dimensions()[0]; - const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); - int64 rank = ShapeUtil::Rank(reference_shape); - std::vector concat_dimensions(reference_shape.dimensions().begin(), - reference_shape.dimensions().end()); - if (concat_dim < 0) { - concat_dim += rank; - } - for (int64 i = 1; i < operands.size(); ++i) { - const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); - if (operands[i]->opcode() != HloOpcode::kConstant) { - return Status::OK(); - } - // Accumulate the concat dimension from all tensors taking part to the - // operation. - concat_dimensions[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); - } - - auto literal = LiteralUtil::CreateFromDimensions( - reference_shape.element_type(), concat_dimensions); - std::vector source_indices(rank, 0); - std::vector dest_indices(concat_dimensions.size(), 0); - for (auto operand : operands) { - const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), source_indices, literal.get(), dest_indices, - AsInt64Slice(operand_shape.dimensions()))); - dest_indices[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); - } - return ReplaceWithConstant(concatenate, std::move(literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Shape& shape = slice->shape(); - auto literal = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - std::vector dest_indices(slice->slice_starts().size(), 0); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), slice->slice_starts(), literal.get(), dest_indices, - AsInt64Slice(shape.dimensions()))); - TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal))); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Literal& src_literal = operand->literal(); - std::unique_ptr new_constant = - ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type()); - return ReplaceWithConstant(convert, std::move(new_constant)); - } - return Status::OK(); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e0447d69aa2..17f9416197d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -46,6 +46,89 @@ limitations under the License. namespace xla { +namespace { + +template +StatusOr> Compare(const Shape& shape, HloOpcode opcode, + const Literal& lhs_literal, + const Literal& rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el != rhs_el; + }; + break; + case HloOpcode::kGe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el >= rhs_el; + }; + break; + case HloOpcode::kGt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el > rhs_el; + }; + break; + case HloOpcode::kLe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el <= rhs_el; + }; + break; + case HloOpcode::kLt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el < rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + auto result = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); + + return std::move(result); +} + +template +StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = LiteralUtil::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op( + LiteralUtil::Get(operand_literal, multi_index)); + })); + return std::move(result); +} + +} // namespace + template class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { public: @@ -68,7 +151,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return elem_operand; })); return Status::OK(); - }; + } template < typename NativeT, @@ -79,7 +162,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::abs(elem_operand); })); return Status::OK(); - }; + } Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override { return HandleAbs(abs, operand); @@ -101,6 +184,45 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + template + std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { + DCHECK_EQ(src_type, src_literal.shape().element_type()); + return LiteralUtil::Convert< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative::type>( + src_literal); + } + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override { + auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + + switch (operand->shape().element_type()) { +#define CONVERT_IF_TYPES_MATCH(src_type) \ + case (src_type): \ + parent_->evaluated_[convert] = LiteralUtil::Convert< \ + typename primitive_util::PrimitiveTypeToNative::type, \ + ReturnT>(operand_literal); \ + break; + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "unimplemented operand type for HandleCovert: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); + } + Status HandleExp(HloInstruction* exp, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { @@ -117,15 +239,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleIsFinite(HloInstruction* is_finite, - HloInstruction* operand) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite], - ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) { - return std::isfinite(elem_operand); - })); - return Status::OK(); - }; - Status HandleLog(HloInstruction* log, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ReturnT elem_operand) { @@ -209,77 +322,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, HloInstruction* rhs) override { - std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el == rhs_el; - }; - break; - case HloOpcode::kNe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el != rhs_el; - }; - break; - case HloOpcode::kGe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el >= rhs_el; - }; - break; - case HloOpcode::kGt: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el > rhs_el; - }; - break; - case HloOpcode::kLe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el <= rhs_el; - }; - break; - case HloOpcode::kLt: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el < rhs_el; - }; - break; - default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); - } - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Compare operation with mismatched dimensions, likely due to " - "broadcasting is unsupported."); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = LiteralUtil::CreateFromShape(compare->shape()); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - compare_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - parent_->evaluated_[compare] = std::move(result); - - return Status::OK(); - }; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, HloInstruction* rhs) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { - return std::max(lhs, rhs); + return std::fmax(lhs, rhs); })); return Status::OK(); }; @@ -289,7 +337,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::min(lhs_el, rhs_el); + return std::fmin(lhs_el, rhs_el); })); return Status::OK(); }; @@ -309,7 +357,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::remainder(lhs_el, rhs_el); + return std::fmod(lhs_el, rhs_el); })); return Status::OK(); }; @@ -338,7 +386,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { HloInstruction* arg, HloInstruction* max) override { std::function clamp_op = [](ReturnT low, ReturnT high, ReturnT value) { - return std::max(low, std::min(value, high)); + return std::fmax(low, std::fmin(value, high)); }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], ElementWiseTernaryOp(clamp, std::move(clamp_op))); @@ -370,32 +418,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - - auto result = LiteralUtil::CreateFromShape(shape); - - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - unary_op(LiteralUtil::Get(operand_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - return std::move(result); - }; + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + return ElementWiseUnaryOpImpl(instruction, unary_op, + operand_literal); + } StatusOr> ElementWiseBinaryOp( HloInstruction* instruction, @@ -420,16 +447,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); auto result = LiteralUtil::CreateFromShape(shape); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - binary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return binary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); return std::move(result); - }; + } template StatusOr> ElementWiseTernaryOp( @@ -459,17 +484,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); auto result = LiteralUtil::CreateFromShape(shape); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - ternary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index), - LiteralUtil::Get(ehs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op( + LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index), + LiteralUtil::Get(ehs_literal, multi_index)); + })); return std::move(result); - }; + } HloEvaluator* parent_; }; @@ -493,6 +518,12 @@ HloEvaluator::HloEvaluator() { }); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: OPAQUE."); + }); } StatusOr> HloEvaluator::Evaluate( @@ -502,15 +533,15 @@ StatusOr> HloEvaluator::Evaluate( evaluated_.clear(); TF_RETURN_IF_ERROR(computation->Accept(this)); - return std::move(FindOrDie(evaluated_, computation->root_instruction())); + return MakeUnique( + GetEvaluatedLiteralFor(computation->root_instruction())); } StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction, tensorflow::gtl::ArraySlice operands) { - DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); - Shape shape = instruction->shape(); - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_ = operands; evaluated_.clear(); @@ -525,13 +556,34 @@ StatusOr> HloEvaluator::Evaluate( TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); evaluated_[operand] = MakeUnique(*input_literal); - } else if (operand->opcode() == HloOpcode::kConstant) { - evaluated_[operand] = MakeUnique(operand->literal()); } } TF_RETURN_IF_ERROR(instruction->Visit(this)); - return std::move(FindOrDie(evaluated_, instruction)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +StatusOr> HloEvaluator::Evaluate( + HloInstruction* instruction) { + TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction)); + TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); + + arg_literals_.clear(); + evaluated_.clear(); + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +std::unique_ptr HloEvaluator::TryEvaluate( + HloInstruction* instruction) { + auto result_or = Evaluate(instruction); + if (!result_or.ok()) { + LOG(ERROR) << "TryEvaluate failed:" << result_or.status(); + return nullptr; + } + + return result_or.ConsumeValueOrDie(); } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { @@ -548,9 +600,191 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { Status HloEvaluator::HandleConstant(HloInstruction* constant, const Literal& literal) { VLOG(2) << "HandleConstant: " << constant->ToString(); - DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape())); + return Status::OK(); +} - evaluated_[constant] = MakeUnique(literal); +Status HloEvaluator::HandleReshape(HloInstruction* reshape) { + TF_ASSIGN_OR_RETURN( + evaluated_[reshape], + LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)), + AsInt64Slice(reshape->shape().dimensions()))); + return Status::OK(); +} + +Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { + evaluated_[transpose] = LiteralUtil::Transpose( + GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions()); + return Status::OK(); +} + +Status HloEvaluator::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + // The result concatenate dimension is going to be the sum of all concatenate + // dimensions of the operands taking part of the operation. + const Shape& reference_shape = operands[0]->shape(); + CHECK(!ShapeUtil::IsTuple(reference_shape)); + const int64 rank = ShapeUtil::Rank(reference_shape); + const int64 concat_dim = concatenate->dimensions()[0]; + CHECK_GE(concat_dim, 0); + CHECK_LT(concat_dim, rank); + + DimensionVector concat_dimensions(reference_shape.dimensions().begin(), + reference_shape.dimensions().end()); + + for (int64 i = 1; i < operands.size(); ++i) { + const Shape& operand_shape = operands[i]->shape(); + CHECK(!ShapeUtil::IsTuple(operand_shape)); + // Accumulate the concat dimension from all tensors taking part to the + // operation. + concat_dimensions[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + auto result_literal = LiteralUtil::CreateFromDimensions( + reference_shape.element_type(), concat_dimensions); + DimensionVector source_indices(rank, 0); + DimensionVector dest_indices(concat_dimensions.size(), 0); + + for (auto operand : operands) { + const Shape& operand_shape = operand->shape(); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(), + dest_indices, AsInt64Slice(operand_shape.dimensions()))); + dest_indices[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + evaluated_[concatenate] = std::move(result_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) { + if (!ShapeUtil::ElementIsFloating(operand->shape())) { + return InvalidArgument( + "expected element type in shape to be float for IsFinite op, got: %s", + PrimitiveType_Name(operand->shape().element_type()).c_str()); + } + + switch (operand->shape().element_type()) { + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](float elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](double elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "unknown/unhandled primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s", + ShapeUtil::HumanString(compare->shape()).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); + + const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs); + + // Note here we switch on the operand's type. + switch (lhs->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U16: + return Unimplemented("unhandled primitive type: U16."); + case U32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S16: + return Unimplemented("unhandled primitive type: S16."); + case S32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + default: + LOG(FATAL) << "unknown primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + const Shape& shape = slice->shape(); + auto literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + + DimensionVector dest_indices(slice->slice_starts().size(), 0); + + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(), + dest_indices, AsInt64Slice(shape.dimensions()))); + + evaluated_[slice] = std::move(literal); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 50cb32eb85c..e6798a35a01 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -57,21 +57,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. // Precondition: - // 1. argument literals are corresponds to the input instruction's - // parameters in their post-orderring. + // 1. argument literals correspond to the input instruction's parameters in + // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. // TODO(b/35950897): implement more ops other than element-wise ops. StatusOr> Evaluate( HloInstruction* instruction, tensorflow::gtl::ArraySlice arg_literals); + // Evaluates a single HLO instruction with constant operands. + // Returns the evaluated result as literal if successful. + // Precondition: + // 1. all operands of the input instruction are constants. + // 2. the instruction is not a Parameter operation. + StatusOr> Evaluate(HloInstruction* instruction); + + // Same as Evaluate, except returning nullptr on error. + std::unique_ptr TryEvaluate(HloInstruction* instruction); + protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. One - // exception to this is HandleCompare, where the resulting literal type is + // literal type of each evaluated Handle* method of a TypedVisitor. + // There are however a few notable exceptions to this is rule, notably: + // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. - // Note the forward declaration here is necessary to enable TypedVisitor to - // access parent members. + // These operations are handled outside of the parent HloEvaluator handlers + // instead of from within TypedVisitor. template class TypedVisitor; @@ -81,15 +92,38 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); } + // Operations that are type-agnostic. + // Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) override; + + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override; + private: // Returns the already-evaluated literal result for the instruction. + // A Constant instruction is considered evaluated and its literal will be + // returned directly without looking up the cache. // Crash with log if the given instruction has not been evaluated previously. const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + if (hlo->IsConstant()) { + return hlo->literal(); + } auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 443e5ad4f42..b26ece28b75 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -23,8 +23,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -143,7 +145,7 @@ TEST_F(HloEvaluatorTest, DoesDivide) { // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbs) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto instruction = HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get()); @@ -154,7 +156,29 @@ TEST_F(HloEvaluatorTest, DoesAbs) { auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); -} + + // For R0 literal. + const Shape& r0 = ShapeUtil::MakeShape(F32, {}); + operand = LiteralUtil::CreateR0(-1.0f); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get()); + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR0(1.0f); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + + // For R1 literal with dimension of size 0. + Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); + operand = LiteralUtil::CreateR1({}); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = + HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get()); + + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR1({}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} // namespace // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. @@ -187,5 +211,35 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); } +// Verifies Reshape operation is correctly evaluated. +TEST_F(HloEvaluatorTest, DoesReshape) { + HloComputation::Builder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; + builder.AddInstruction( + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + LiteralUtil::EachCell( + *result, [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_TRUE(value == + LiteralUtil::Get(*literal_clone, rindexes)); + }); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ccc1dc63e78..d7c8881da8d 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -756,27 +756,28 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, // and unmodified_dim_pair have size >1. Otherwise, returns true and appends // the degerenate input/output dimensions in the gap to // deleted_indices/inserted_indices respectively. - auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices, - &inserted_indices]( - std::pair prior_unmodified_dim_pair, - std::pair unmodified_dim_pair) { - for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; - modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) { - if (shape_pre.dimensions(modified_input_dim) > 1) { - return false; - } - deleted_indices.push_back(modified_input_dim); - } - for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; - modified_output_dim < unmodified_dim_pair.second; - ++modified_output_dim) { - if (shape_post.dimensions(modified_output_dim) > 1) { - return false; - } - inserted_indices.push_back(modified_output_dim); - } - return true; - }; + auto check_modified_dims = + [&shape_pre, &shape_post, &deleted_indices, &inserted_indices]( + std::pair prior_unmodified_dim_pair, + std::pair unmodified_dim_pair) { + for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; + modified_input_dim < unmodified_dim_pair.first; + ++modified_input_dim) { + if (shape_pre.dimensions(modified_input_dim) > 1) { + return false; + } + deleted_indices.push_back(modified_input_dim); + } + for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; + modified_output_dim < unmodified_dim_pair.second; + ++modified_output_dim) { + if (shape_post.dimensions(modified_output_dim) > 1) { + return false; + } + inserted_indices.push_back(modified_output_dim); + } + return true; + }; std::vector> unmodified_dims = DimensionsUnmodifiedByReshape(shape_pre, shape_post); @@ -1189,6 +1190,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, tensorflow::gtl::ArraySlice count, tensorflow::gtl::ArraySlice incr, const IndexVisitorFunction& visitor_function) { + if (ShapeUtil::HasZeroElements(shape)) { + return; + } DCHECK_EQ(Rank(shape), base.size()); DCHECK_EQ(incr.size(), base.size()); DCHECK_EQ(count.size(), base.size()); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 73538b8b88e..eb5467f5e54 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -446,6 +446,34 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } +TEST(ShapeUtilTest, ForEachIndex) { + struct ShapeDimensionAndNumberInvocations { + std::vector dimensions; + int invocations; + } test_data[] = { + {{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0}, + {{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610}, + }; + + for (const auto& data : test_data) { + Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); + // Increments at every invocation. + int invocations = 0; + auto increment_func = [&invocations](const std::vector& indexes) { + invocations++; + return true; + }; + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + + ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step, + increment_func); + + EXPECT_EQ(invocations, data.invocations); + } +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 7bf1168dc39..cdf53179ca8 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -46,12 +46,16 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { ClientLibraryTestBase::ClientLibraryTestBase( se::Platform* platform, - tensorflow::gtl::ArraySlice disabled_pass_names) + tensorflow::gtl::ArraySlice disabled_pass_names, + bool disable_constant_folding) : client_(GetOrCreateLocalClientOrDie(platform)) { legacy_flags::HloPassPipelineFlags* flags = legacy_flags::GetHloPassPipelineFlags(); flags->xla_disable_hlo_passes = tensorflow::str_util::Join(disabled_pass_names, ","); + if (disable_constant_folding) { + flags->xla_disable_hlo_passes += ",constant_folding"; + } } string ClientLibraryTestBase::TestName() const { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 34f82603e89..5bff8ddd835 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -47,7 +47,14 @@ class ClientLibraryTestBase : public ::testing::Test { protected: explicit ClientLibraryTestBase( perftools::gputools::Platform* platform = nullptr, - tensorflow::gtl::ArraySlice disabled_pass_names = {}); + tensorflow::gtl::ArraySlice disabled_pass_names = {}, + // Note: here we are disabling constant_folding paths so that the tests + // (usually written using Constants) will exercise the intended code + // paths, instead of being constant folded. + // TODO(b/38354253): Constant folding is currently disabled here. Change + // tests to Parameters instead of Constants, and re-enable constant + // folding by default. + bool disable_constant_folding = true); // Returns the name of the test currently being run. string TestName() const; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 4f980830333..a8b07a2c5d1 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h"