diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7fdffe85c0e..73964733e8a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2456,6 +2456,7 @@ tf_cc_test( ":hlo", ":hlo_constant_folding", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 6dddda1ca89..2ed645c3aed 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -52,9 +52,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce, and AfterAll operation. - // TODO(b/35975797): Enable Reduce operation once arbitrary computation - // are supported by the evaluator. + // Skip Constant, Parameter, and AfterAll operation. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this @@ -62,7 +60,6 @@ StatusOr HloConstantFolding::Run(HloModule* module) { if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce || instruction->opcode() == HloOpcode::kAfterAll) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 64a42c1efc0..7cd1481a8ad 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -202,5 +203,45 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { EXPECT_TRUE(matched); } +const char* const kConstantFoldReduce = R"( + HloModule ConstantFoldReduce + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + + ENTRY r { + x = s32[3] constant({1, 2, 3}) + init = s32[] constant(0) + ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add + })"; + +TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kConstantFoldReduce)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_EQ(6, module->entry_computation() + ->root_instruction() + ->literal() + .GetFirstElement()); +} + +TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kConstantFoldReduce)); + HloInstruction* add = module->computations().begin()->root_instruction(); + LayoutUtil::ClearLayout(add->mutable_shape()); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 35d9e799df6..fb900494919 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -230,7 +230,6 @@ template StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction, ArraySlice arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); evaluated_.clear(); arg_literals_.clear(); @@ -267,7 +266,6 @@ StatusOr> HloEvaluator::Evaluate( return tensorflow::errors::FailedPrecondition( "Not all operands are constants."); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_.clear(); evaluated_.clear(); @@ -1285,7 +1283,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); - return Status::OK(); + return ShapeUtil::ValidateShape(hlo->shape()); } Status HloEvaluator::Postprocess(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 83d7b404f0b..aafba8afe8a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1544,10 +1544,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); auto result = absl::make_unique(reduce->shape()); + Status eval_status; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { ReturnT result_val = init_scalar; + if (!eval_status.ok()) { + return result_val; + } std::vector base(arg_dimensions.size()); for (int64 i = 0; i < multi_index.size(); ++i) { @@ -1568,7 +1572,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_dim_steps, func); return static_cast(computed_result); } - auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto func = [&](tensorflow::gtl::ArraySlice input_index) + -> StatusOr { auto curr_val = arg_literal.Get(input_index); // Evaluate computation with specified literal operands. @@ -1576,12 +1581,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result_val_literal = LiteralUtil::CreateR0(result_val); - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(std::unique_ptr computed_result, + embedded_evaluator.Evaluate( + *function, {result_val_literal.get(), + curr_val_literal.get()})); // Clear visit states so that we can use the evaluator again on // the same computation. embedded_evaluator.ResetVisitStates(); @@ -1591,13 +1594,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { }; // Computes one element of the result, reducing all dimensions that // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); + eval_status = ShapeUtil::ForEachIndexWithStatus( + arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func); return result_val; })); parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); + return eval_status; } bool IsScalarAdd(HloComputation* computation) {