diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index a05dc968ee5..0f4b077a300 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -1071,6 +1071,7 @@ template ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions, stride_config.step, init_function); } else { + // For scalars. data.at(0) = generator({}); } return Status::OK(); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 9a09822174d..f8f22e2c398 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -807,7 +807,9 @@ TEST_F(LiteralUtilTest, Populate) { std::vector layout; } populate_data[] = { {{}, {}}, + {{0}, {0}}, {{16}, {0}}, + {{2, 0}, {1, 0}}, {{4, 16}, {1, 0}}, {{21, 12}, {0, 1}}, {{6, 11, 17}, {2, 0, 1}}, diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0c1bab0463a..dca39285a3b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -98,10 +98,12 @@ cc_test( ":hlo_evaluator", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], @@ -1447,7 +1449,9 @@ cc_library( hdrs = ["hlo_constant_folding.h"], deps = [ ":hlo", + ":hlo_evaluator", ":hlo_pass", + ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index cb0a99d773c..762ceebf39d 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -24,230 +24,57 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { -namespace { - -template -static std::unique_ptr ConvertIfTypesMatch( - const Literal& src_literal) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); -} - -template -static std::unique_ptr ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -static std::unique_ptr ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -} // namespace - -// ConstantFolderVisitor traverses the HLO computation and reduces certain -// constant graph sections, to literals. -class ConstantFolderVisitor : public DfsHloVisitorWithDefault { - public: - // Default visitor action is to do nothing and return OK. - Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { - return Status::OK(); - } - - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override; - - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - - Status HandleReshape(HloInstruction* reshape) override; - - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - - Status HandleTranspose(HloInstruction* transpose) override; - - // Returns whether a constant folding operation has occurred. - const bool changed() const { return changed_; } - - // Runs the visitor on a computation and returns whether any changes were - // performed. - static StatusOr Run(HloComputation* computation); - - private: - ConstantFolderVisitor() = default; - - // Replaces the existing HLO instruction old_instruction, with a literal, - // and marks the optimizer status as changed. - // Returns the Status representing the result of the replace operation. - Status ReplaceWithConstant(HloInstruction* old_instruction, - std::unique_ptr literal) { - TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( - old_instruction, HloInstruction::CreateConstant(std::move(literal)))); - changed_ = true; - return Status::OK(); - } - - // Whether any constant folding operations have occurred. - bool changed_ = false; -}; - -StatusOr ConstantFolderVisitor::Run(HloComputation* computation) { - ConstantFolderVisitor visitor; - TF_RETURN_IF_ERROR(computation->Accept(&visitor)); - return visitor.changed(); -} StatusOr HloConstantFolding::Run(HloModule* module) { + auto evaluator = MakeUnique(); + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); bool changed = false; - for (auto& comp : module->computations()) { - TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get())); - changed = changed || result; + + for (auto& computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + // Skip dead code. + if (instruction->user_count() == 0 && + computation->root_instruction() != instruction) { + continue; + } + // Skip Constant and Parameter operation. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant) { + continue; + } + // Skip instructions with non-constant operands. + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + continue; + } + + std::unique_ptr result = evaluator->TryEvaluate(instruction); + // Currently we skip unimplemented operations. + // TODO(b/35975797): Fold constant computations for more operations. + if (result == nullptr) { + VLOG(2) << "Constant folding failed for instruction: " + << instruction->ToString(); + continue; + } + + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateConstant(std::move(result)))); + changed = true; + } } XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); return changed; } -Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) { - if (reshape->operand(0)->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN( - auto reshaped_literal, - LiteralUtil::Reshape(reshape->operand(0)->literal(), - AsInt64Slice(reshape->shape().dimensions()))); - return ReplaceWithConstant(reshape, std::move(reshaped_literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) { - if (transpose->operand(0)->opcode() == HloOpcode::kConstant) { - auto transposed_literal = LiteralUtil::Transpose( - transpose->operand(0)->literal(), transpose->dimensions()); - return ReplaceWithConstant(transpose, std::move(transposed_literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) { - if (operands[0]->opcode() == HloOpcode::kConstant) { - // If all the operands of a concatenate are constant, fold them into a - // single constant tensor. - // The result concatenate dimension is going to be the sum of all the - // concatenate dimensions of the arrays taking part of the operation. - int64 concat_dim = concatenate->dimensions()[0]; - const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); - int64 rank = ShapeUtil::Rank(reference_shape); - std::vector concat_dimensions(reference_shape.dimensions().begin(), - reference_shape.dimensions().end()); - if (concat_dim < 0) { - concat_dim += rank; - } - for (int64 i = 1; i < operands.size(); ++i) { - const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); - if (operands[i]->opcode() != HloOpcode::kConstant) { - return Status::OK(); - } - // Accumulate the concat dimension from all tensors taking part to the - // operation. - concat_dimensions[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); - } - - auto literal = LiteralUtil::CreateFromDimensions( - reference_shape.element_type(), concat_dimensions); - std::vector source_indices(rank, 0); - std::vector dest_indices(concat_dimensions.size(), 0); - for (auto operand : operands) { - const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), source_indices, literal.get(), dest_indices, - AsInt64Slice(operand_shape.dimensions()))); - dest_indices[concat_dim] += - ShapeUtil::GetDimension(operand_shape, concat_dim); - } - return ReplaceWithConstant(concatenate, std::move(literal)); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Shape& shape = slice->shape(); - auto literal = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - std::vector dest_indices(slice->slice_starts().size(), 0); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - operand->literal(), slice->slice_starts(), literal.get(), dest_indices, - AsInt64Slice(shape.dimensions()))); - TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal))); - } - return Status::OK(); -} - -Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kConstant) { - const Literal& src_literal = operand->literal(); - std::unique_ptr new_constant = - ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type()); - return ReplaceWithConstant(convert, std::move(new_constant)); - } - return Status::OK(); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e0447d69aa2..17f9416197d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -46,6 +46,89 @@ limitations under the License. namespace xla { +namespace { + +template +StatusOr> Compare(const Shape& shape, HloOpcode opcode, + const Literal& lhs_literal, + const Literal& rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el != rhs_el; + }; + break; + case HloOpcode::kGe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el >= rhs_el; + }; + break; + case HloOpcode::kGt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el > rhs_el; + }; + break; + case HloOpcode::kLe: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el <= rhs_el; + }; + break; + case HloOpcode::kLt: + compare_op = [](OperandT lhs_el, OperandT rhs_el) { + return lhs_el < rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + auto result = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); + + return std::move(result); +} + +template +StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = LiteralUtil::CreateFromShape(shape); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op( + LiteralUtil::Get(operand_literal, multi_index)); + })); + return std::move(result); +} + +} // namespace + template class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { public: @@ -68,7 +151,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return elem_operand; })); return Status::OK(); - }; + } template < typename NativeT, @@ -79,7 +162,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::abs(elem_operand); })); return Status::OK(); - }; + } Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override { return HandleAbs(abs, operand); @@ -101,6 +184,45 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + template + std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { + DCHECK_EQ(src_type, src_literal.shape().element_type()); + return LiteralUtil::Convert< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative::type>( + src_literal); + } + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override { + auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + + switch (operand->shape().element_type()) { +#define CONVERT_IF_TYPES_MATCH(src_type) \ + case (src_type): \ + parent_->evaluated_[convert] = LiteralUtil::Convert< \ + typename primitive_util::PrimitiveTypeToNative::type, \ + ReturnT>(operand_literal); \ + break; + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "unimplemented operand type for HandleCovert: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); + } + Status HandleExp(HloInstruction* exp, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { @@ -117,15 +239,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleIsFinite(HloInstruction* is_finite, - HloInstruction* operand) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite], - ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) { - return std::isfinite(elem_operand); - })); - return Status::OK(); - }; - Status HandleLog(HloInstruction* log, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ReturnT elem_operand) { @@ -209,77 +322,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, HloInstruction* rhs) override { - std::function compare_op; - switch (opcode) { - case HloOpcode::kEq: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el == rhs_el; - }; - break; - case HloOpcode::kNe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el != rhs_el; - }; - break; - case HloOpcode::kGe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el >= rhs_el; - }; - break; - case HloOpcode::kGt: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el > rhs_el; - }; - break; - case HloOpcode::kLe: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el <= rhs_el; - }; - break; - case HloOpcode::kLt: - compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { - return lhs_el < rhs_el; - }; - break; - default: - LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " - << HloOpcodeString(opcode); - } - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Compare operation with mismatched dimensions, likely due to " - "broadcasting is unsupported."); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = LiteralUtil::CreateFromShape(compare->shape()); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - compare_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - parent_->evaluated_[compare] = std::move(result); - - return Status::OK(); - }; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, HloInstruction* rhs) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { - return std::max(lhs, rhs); + return std::fmax(lhs, rhs); })); return Status::OK(); }; @@ -289,7 +337,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::min(lhs_el, rhs_el); + return std::fmin(lhs_el, rhs_el); })); return Status::OK(); }; @@ -309,7 +357,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::remainder(lhs_el, rhs_el); + return std::fmod(lhs_el, rhs_el); })); return Status::OK(); }; @@ -338,7 +386,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { HloInstruction* arg, HloInstruction* max) override { std::function clamp_op = [](ReturnT low, ReturnT high, ReturnT value) { - return std::max(low, std::min(value, high)); + return std::fmax(low, std::fmin(value, high)); }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], ElementWiseTernaryOp(clamp, std::move(clamp_op))); @@ -370,32 +418,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - - auto result = LiteralUtil::CreateFromShape(shape); - - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - unary_op(LiteralUtil::Get(operand_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); - - return std::move(result); - }; + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + return ElementWiseUnaryOpImpl(instruction, unary_op, + operand_literal); + } StatusOr> ElementWiseBinaryOp( HloInstruction* instruction, @@ -420,16 +447,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); auto result = LiteralUtil::CreateFromShape(shape); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - binary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return binary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index)); + })); return std::move(result); - }; + } template StatusOr> ElementWiseTernaryOp( @@ -459,17 +484,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); auto result = LiteralUtil::CreateFromShape(shape); - std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); - do { - LiteralUtil::Set( - result.get(), multi_index, - ternary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index), - LiteralUtil::Get(ehs_literal, multi_index))); - } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op( + LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index), + LiteralUtil::Get(ehs_literal, multi_index)); + })); return std::move(result); - }; + } HloEvaluator* parent_; }; @@ -493,6 +518,12 @@ HloEvaluator::HloEvaluator() { }); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: OPAQUE."); + }); } StatusOr> HloEvaluator::Evaluate( @@ -502,15 +533,15 @@ StatusOr> HloEvaluator::Evaluate( evaluated_.clear(); TF_RETURN_IF_ERROR(computation->Accept(this)); - return std::move(FindOrDie(evaluated_, computation->root_instruction())); + return MakeUnique( + GetEvaluatedLiteralFor(computation->root_instruction())); } StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction, tensorflow::gtl::ArraySlice operands) { - DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); - Shape shape = instruction->shape(); - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_ = operands; evaluated_.clear(); @@ -525,13 +556,34 @@ StatusOr> HloEvaluator::Evaluate( TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); evaluated_[operand] = MakeUnique(*input_literal); - } else if (operand->opcode() == HloOpcode::kConstant) { - evaluated_[operand] = MakeUnique(operand->literal()); } } TF_RETURN_IF_ERROR(instruction->Visit(this)); - return std::move(FindOrDie(evaluated_, instruction)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +StatusOr> HloEvaluator::Evaluate( + HloInstruction* instruction) { + TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction)); + TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); + + arg_literals_.clear(); + evaluated_.clear(); + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return MakeUnique(GetEvaluatedLiteralFor(instruction)); +} + +std::unique_ptr HloEvaluator::TryEvaluate( + HloInstruction* instruction) { + auto result_or = Evaluate(instruction); + if (!result_or.ok()) { + LOG(ERROR) << "TryEvaluate failed:" << result_or.status(); + return nullptr; + } + + return result_or.ConsumeValueOrDie(); } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { @@ -548,9 +600,191 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { Status HloEvaluator::HandleConstant(HloInstruction* constant, const Literal& literal) { VLOG(2) << "HandleConstant: " << constant->ToString(); - DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape())); + return Status::OK(); +} - evaluated_[constant] = MakeUnique(literal); +Status HloEvaluator::HandleReshape(HloInstruction* reshape) { + TF_ASSIGN_OR_RETURN( + evaluated_[reshape], + LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)), + AsInt64Slice(reshape->shape().dimensions()))); + return Status::OK(); +} + +Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { + evaluated_[transpose] = LiteralUtil::Transpose( + GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions()); + return Status::OK(); +} + +Status HloEvaluator::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + // The result concatenate dimension is going to be the sum of all concatenate + // dimensions of the operands taking part of the operation. + const Shape& reference_shape = operands[0]->shape(); + CHECK(!ShapeUtil::IsTuple(reference_shape)); + const int64 rank = ShapeUtil::Rank(reference_shape); + const int64 concat_dim = concatenate->dimensions()[0]; + CHECK_GE(concat_dim, 0); + CHECK_LT(concat_dim, rank); + + DimensionVector concat_dimensions(reference_shape.dimensions().begin(), + reference_shape.dimensions().end()); + + for (int64 i = 1; i < operands.size(); ++i) { + const Shape& operand_shape = operands[i]->shape(); + CHECK(!ShapeUtil::IsTuple(operand_shape)); + // Accumulate the concat dimension from all tensors taking part to the + // operation. + concat_dimensions[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + auto result_literal = LiteralUtil::CreateFromDimensions( + reference_shape.element_type(), concat_dimensions); + DimensionVector source_indices(rank, 0); + DimensionVector dest_indices(concat_dimensions.size(), 0); + + for (auto operand : operands) { + const Shape& operand_shape = operand->shape(); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(), + dest_indices, AsInt64Slice(operand_shape.dimensions()))); + dest_indices[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + evaluated_[concatenate] = std::move(result_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) { + if (!ShapeUtil::ElementIsFloating(operand->shape())) { + return InvalidArgument( + "expected element type in shape to be float for IsFinite op, got: %s", + PrimitiveType_Name(operand->shape().element_type()).c_str()); + } + + switch (operand->shape().element_type()) { + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](float elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl( + is_finite, + [](double elem_operand) { return std::isfinite(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[is_finite], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "unknown/unhandled primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s", + ShapeUtil::HumanString(compare->shape()).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); + + const Literal& lhs_literal = GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = GetEvaluatedLiteralFor(rhs); + + // Note here we switch on the operand's type. + switch (lhs->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U16: + return Unimplemented("unhandled primitive type: U16."); + case U32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S8: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S16: + return Unimplemented("unhandled primitive type: S16."); + case S32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F16: + return Unimplemented("unhandled primitive type: F16."); + case F32: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + case F64: { + TF_ASSIGN_OR_RETURN( + evaluated_[compare], + Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); + } break; + default: + LOG(FATAL) << "unknown primitive type."; + } + + return Status::OK(); +} + +Status HloEvaluator::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + const Shape& shape = slice->shape(); + auto literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + + DimensionVector dest_indices(slice->slice_starts().size(), 0); + + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(), + dest_indices, AsInt64Slice(shape.dimensions()))); + + evaluated_[slice] = std::move(literal); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 50cb32eb85c..e6798a35a01 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -57,21 +57,32 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. // Precondition: - // 1. argument literals are corresponds to the input instruction's - // parameters in their post-orderring. + // 1. argument literals correspond to the input instruction's parameters in + // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. // TODO(b/35950897): implement more ops other than element-wise ops. StatusOr> Evaluate( HloInstruction* instruction, tensorflow::gtl::ArraySlice arg_literals); + // Evaluates a single HLO instruction with constant operands. + // Returns the evaluated result as literal if successful. + // Precondition: + // 1. all operands of the input instruction are constants. + // 2. the instruction is not a Parameter operation. + StatusOr> Evaluate(HloInstruction* instruction); + + // Same as Evaluate, except returning nullptr on error. + std::unique_ptr TryEvaluate(HloInstruction* instruction); + protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. One - // exception to this is HandleCompare, where the resulting literal type is + // literal type of each evaluated Handle* method of a TypedVisitor. + // There are however a few notable exceptions to this is rule, notably: + // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. - // Note the forward declaration here is necessary to enable TypedVisitor to - // access parent members. + // These operations are handled outside of the parent HloEvaluator handlers + // instead of from within TypedVisitor. template class TypedVisitor; @@ -81,15 +92,38 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); } + // Operations that are type-agnostic. + // Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) override; + + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override; + private: // Returns the already-evaluated literal result for the instruction. + // A Constant instruction is considered evaluated and its literal will be + // returned directly without looking up the cache. // Crash with log if the given instruction has not been evaluated previously. const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + if (hlo->IsConstant()) { + return hlo->literal(); + } auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 443e5ad4f42..b26ece28b75 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -23,8 +23,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" @@ -143,7 +145,7 @@ TEST_F(HloEvaluatorTest, DoesDivide) { // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbs) { auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto instruction = HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get()); @@ -154,7 +156,29 @@ TEST_F(HloEvaluatorTest, DoesAbs) { auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); -} + + // For R0 literal. + const Shape& r0 = ShapeUtil::MakeShape(F32, {}); + operand = LiteralUtil::CreateR0(-1.0f); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get()); + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR0(1.0f); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + + // For R1 literal with dimension of size 0. + Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); + operand = LiteralUtil::CreateR1({}); + c1 = HloInstruction::CreateConstant(std::move(operand)); + instruction = + HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get()); + + result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); + expected = LiteralUtil::CreateR1({}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} // namespace // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. @@ -187,5 +211,35 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); } +// Verifies Reshape operation is correctly evaluated. +TEST_F(HloEvaluatorTest, DoesReshape) { + HloComputation::Builder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; + builder.AddInstruction( + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + LiteralUtil::EachCell( + *result, [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_TRUE(value == + LiteralUtil::Get(*literal_clone, rindexes)); + }); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ccc1dc63e78..d7c8881da8d 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -756,27 +756,28 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, // and unmodified_dim_pair have size >1. Otherwise, returns true and appends // the degerenate input/output dimensions in the gap to // deleted_indices/inserted_indices respectively. - auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices, - &inserted_indices]( - std::pair prior_unmodified_dim_pair, - std::pair unmodified_dim_pair) { - for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; - modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) { - if (shape_pre.dimensions(modified_input_dim) > 1) { - return false; - } - deleted_indices.push_back(modified_input_dim); - } - for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; - modified_output_dim < unmodified_dim_pair.second; - ++modified_output_dim) { - if (shape_post.dimensions(modified_output_dim) > 1) { - return false; - } - inserted_indices.push_back(modified_output_dim); - } - return true; - }; + auto check_modified_dims = + [&shape_pre, &shape_post, &deleted_indices, &inserted_indices]( + std::pair prior_unmodified_dim_pair, + std::pair unmodified_dim_pair) { + for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1; + modified_input_dim < unmodified_dim_pair.first; + ++modified_input_dim) { + if (shape_pre.dimensions(modified_input_dim) > 1) { + return false; + } + deleted_indices.push_back(modified_input_dim); + } + for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1; + modified_output_dim < unmodified_dim_pair.second; + ++modified_output_dim) { + if (shape_post.dimensions(modified_output_dim) > 1) { + return false; + } + inserted_indices.push_back(modified_output_dim); + } + return true; + }; std::vector> unmodified_dims = DimensionsUnmodifiedByReshape(shape_pre, shape_post); @@ -1189,6 +1190,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, tensorflow::gtl::ArraySlice count, tensorflow::gtl::ArraySlice incr, const IndexVisitorFunction& visitor_function) { + if (ShapeUtil::HasZeroElements(shape)) { + return; + } DCHECK_EQ(Rank(shape), base.size()); DCHECK_EQ(incr.size(), base.size()); DCHECK_EQ(count.size(), base.size()); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 73538b8b88e..eb5467f5e54 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -446,6 +446,34 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); } +TEST(ShapeUtilTest, ForEachIndex) { + struct ShapeDimensionAndNumberInvocations { + std::vector dimensions; + int invocations; + } test_data[] = { + {{}, 1}, {{0}, 0}, {{16}, 16}, {{3, 0}, 0}, + {{0, 2}, 0}, {{4, 16}, 64}, {{6, 11, 17}, 1122}, {{6, 11, 5, 17}, 5610}, + }; + + for (const auto& data : test_data) { + Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); + // Increments at every invocation. + int invocations = 0; + auto increment_func = [&invocations](const std::vector& indexes) { + invocations++; + return true; + }; + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + + ShapeUtil::ForEachIndex(shape, zero_base, data.dimensions, step, + increment_func); + + EXPECT_EQ(invocations, data.invocations); + } +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 7bf1168dc39..cdf53179ca8 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -46,12 +46,16 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { ClientLibraryTestBase::ClientLibraryTestBase( se::Platform* platform, - tensorflow::gtl::ArraySlice disabled_pass_names) + tensorflow::gtl::ArraySlice disabled_pass_names, + bool disable_constant_folding) : client_(GetOrCreateLocalClientOrDie(platform)) { legacy_flags::HloPassPipelineFlags* flags = legacy_flags::GetHloPassPipelineFlags(); flags->xla_disable_hlo_passes = tensorflow::str_util::Join(disabled_pass_names, ","); + if (disable_constant_folding) { + flags->xla_disable_hlo_passes += ",constant_folding"; + } } string ClientLibraryTestBase::TestName() const { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 34f82603e89..5bff8ddd835 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -47,7 +47,14 @@ class ClientLibraryTestBase : public ::testing::Test { protected: explicit ClientLibraryTestBase( perftools::gputools::Platform* platform = nullptr, - tensorflow::gtl::ArraySlice disabled_pass_names = {}); + tensorflow::gtl::ArraySlice disabled_pass_names = {}, + // Note: here we are disabling constant_folding paths so that the tests + // (usually written using Constants) will exercise the intended code + // paths, instead of being constant folded. + // TODO(b/38354253): Constant folding is currently disabled here. Change + // tests to Parameters instead of Constants, and re-enable constant + // folding by default. + bool disable_constant_folding = true); // Returns the name of the test currently being run. string TestName() const; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 4f980830333..a8b07a2c5d1 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h"