[XLA] Propagate invalid shape errors through reduce folding and turn it on

HloEvaluator should be stable enough for reduce folding, but it shouldn't
crash when it encounters an instruction without a layout. Verify the layout on
every instruction that gets evaluated and return an error on failure.

PiperOrigin-RevId: 209641401
This commit is contained in:
Benjamin Kramer 2018-08-21 12:35:33 -07:00 committed by TensorFlower Gardener
parent c61a49ec31
commit 4f41091f88
5 changed files with 57 additions and 17 deletions

View File

@ -2456,6 +2456,7 @@ tf_cc_test(
":hlo", ":hlo",
":hlo_constant_folding", ":hlo_constant_folding",
":hlo_matchers", ":hlo_matchers",
":hlo_parser",
":hlo_pass", ":hlo_pass",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",

View File

@ -52,9 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) { computation->root_instruction() != instruction) {
continue; continue;
} }
// Skip Constant, Parameter, Reduce, and AfterAll operation. // Skip Constant, Parameter, and AfterAll operation.
// TODO(b/35975797): Enable Reduce operation once arbitrary computation
// are supported by the evaluator.
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
// operand in which case constant folding will be impossible and this // operand in which case constant folding will be impossible and this
@ -62,7 +60,6 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
if (instruction->opcode() == HloOpcode::kParameter || if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kTuple ||
instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kAfterAll) { instruction->opcode() == HloOpcode::kAfterAll) {
continue; continue;
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test.h"
@ -202,5 +203,45 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
EXPECT_TRUE(matched); EXPECT_TRUE(matched);
} }
const char* const kConstantFoldReduce = R"(
HloModule ConstantFoldReduce
add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = s32[] add(a, b)
}
ENTRY r {
x = s32[3] constant({1, 2, 3})
init = s32[] constant(0)
ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
})";
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(kConstantFoldReduce));
HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
EXPECT_TRUE(result);
EXPECT_EQ(6, module->entry_computation()
->root_instruction()
->literal()
.GetFirstElement<int32>());
}
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(kConstantFoldReduce));
HloInstruction* add = module->computations().begin()->root_instruction();
LayoutUtil::ClearLayout(add->mutable_shape());
HloConstantFolding const_folder;
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
EXPECT_FALSE(result);
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -230,7 +230,6 @@ template <typename LiteralPtr>
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) { HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
evaluated_.clear(); evaluated_.clear();
arg_literals_.clear(); arg_literals_.clear();
@ -267,7 +266,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
return tensorflow::errors::FailedPrecondition( return tensorflow::errors::FailedPrecondition(
"Not all operands are constants."); "Not all operands are constants.");
} }
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
arg_literals_.clear(); arg_literals_.clear();
evaluated_.clear(); evaluated_.clear();
@ -1285,7 +1283,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString(); VLOG(2) << "About to visit HLO: " << hlo->ToString();
return Status::OK(); return ShapeUtil::ValidateShape(hlo->shape());
} }
Status HloEvaluator::Postprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) {

View File

@ -1544,10 +1544,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
auto result = absl::make_unique<Literal>(reduce->shape()); auto result = absl::make_unique<Literal>(reduce->shape());
Status eval_status;
// For each resulting dimension, calculate and assign computed value. // For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>( TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) { [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
ReturnT result_val = init_scalar; ReturnT result_val = init_scalar;
if (!eval_status.ok()) {
return result_val;
}
std::vector<int64> base(arg_dimensions.size()); std::vector<int64> base(arg_dimensions.size());
for (int64 i = 0; i < multi_index.size(); ++i) { for (int64 i = 0; i < multi_index.size(); ++i) {
@ -1568,7 +1572,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
arg_dim_steps, func); arg_dim_steps, func);
return static_cast<ReturnT>(computed_result); return static_cast<ReturnT>(computed_result);
} }
auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) { auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
-> StatusOr<bool> {
auto curr_val = arg_literal.Get<ReturnT>(input_index); auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands. // Evaluate computation with specified literal operands.
@ -1576,12 +1581,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto result_val_literal = auto result_val_literal =
LiteralUtil::CreateR0<ReturnT>(result_val); LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result = TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
embedded_evaluator embedded_evaluator.Evaluate<const Literal*>(
.Evaluate<const Literal*>( *function, {result_val_literal.get(),
*function, curr_val_literal.get()}));
{result_val_literal.get(), curr_val_literal.get()})
.ConsumeValueOrDie();
// Clear visit states so that we can use the evaluator again on // Clear visit states so that we can use the evaluator again on
// the same computation. // the same computation.
embedded_evaluator.ResetVisitStates(); embedded_evaluator.ResetVisitStates();
@ -1591,13 +1594,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}; };
// Computes one element of the result, reducing all dimensions that // Computes one element of the result, reducing all dimensions that
// contribute to that element. // contribute to that element.
ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, eval_status = ShapeUtil::ForEachIndexWithStatus(
arg_dim_steps, func); arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func);
return result_val; return result_val;
})); }));
parent_->evaluated_[reduce] = std::move(result); parent_->evaluated_[reduce] = std::move(result);
return Status::OK(); return eval_status;
} }
bool IsScalarAdd(HloComputation* computation) { bool IsScalarAdd(HloComputation* computation) {