[XLA] Propagate invalid shape errors through reduce folding and turn it on
HloEvaluator should be stable enough for reduce folding, but it shouldn't crash when it encounters an instruction without a layout. Verify the layout on every instruction that gets evaluated and return an error on failure. PiperOrigin-RevId: 209641401
This commit is contained in:
parent
c61a49ec31
commit
4f41091f88
@ -2456,6 +2456,7 @@ tf_cc_test(
|
||||
":hlo",
|
||||
":hlo_constant_folding",
|
||||
":hlo_matchers",
|
||||
":hlo_parser",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -52,9 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
||||
computation->root_instruction() != instruction) {
|
||||
continue;
|
||||
}
|
||||
// Skip Constant, Parameter, Reduce, and AfterAll operation.
|
||||
// TODO(b/35975797): Enable Reduce operation once arbitrary computation
|
||||
// are supported by the evaluator.
|
||||
// Skip Constant, Parameter, and AfterAll operation.
|
||||
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
|
||||
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
|
||||
// operand in which case constant folding will be impossible and this
|
||||
@ -62,7 +60,6 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
||||
if (instruction->opcode() == HloOpcode::kParameter ||
|
||||
instruction->opcode() == HloOpcode::kConstant ||
|
||||
instruction->opcode() == HloOpcode::kTuple ||
|
||||
instruction->opcode() == HloOpcode::kReduce ||
|
||||
instruction->opcode() == HloOpcode::kAfterAll) {
|
||||
continue;
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
@ -202,5 +203,45 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
||||
EXPECT_TRUE(matched);
|
||||
}
|
||||
|
||||
const char* const kConstantFoldReduce = R"(
|
||||
HloModule ConstantFoldReduce
|
||||
|
||||
add {
|
||||
a = s32[] parameter(0)
|
||||
b = s32[] parameter(1)
|
||||
ROOT add = s32[] add(a, b)
|
||||
}
|
||||
|
||||
ENTRY r {
|
||||
x = s32[3] constant({1, 2, 3})
|
||||
init = s32[] constant(0)
|
||||
ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
|
||||
})";
|
||||
|
||||
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(kConstantFoldReduce));
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
|
||||
EXPECT_EQ(6, module->entry_computation()
|
||||
->root_instruction()
|
||||
->literal()
|
||||
.GetFirstElement<int32>());
|
||||
}
|
||||
|
||||
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(kConstantFoldReduce));
|
||||
HloInstruction* add = module->computations().begin()->root_instruction();
|
||||
LayoutUtil::ClearLayout(add->mutable_shape());
|
||||
HloConstantFolding const_folder;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
|
||||
EXPECT_FALSE(result);
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -230,7 +230,6 @@ template <typename LiteralPtr>
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
|
||||
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
|
||||
|
||||
evaluated_.clear();
|
||||
arg_literals_.clear();
|
||||
@ -267,7 +266,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Not all operands are constants.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
|
||||
|
||||
arg_literals_.clear();
|
||||
evaluated_.clear();
|
||||
@ -1285,7 +1283,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
|
||||
|
||||
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
|
||||
VLOG(2) << "About to visit HLO: " << hlo->ToString();
|
||||
return Status::OK();
|
||||
return ShapeUtil::ValidateShape(hlo->shape());
|
||||
}
|
||||
|
||||
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
|
||||
|
@ -1544,10 +1544,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||
auto result = absl::make_unique<Literal>(reduce->shape());
|
||||
Status eval_status;
|
||||
// For each resulting dimension, calculate and assign computed value.
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
ReturnT result_val = init_scalar;
|
||||
if (!eval_status.ok()) {
|
||||
return result_val;
|
||||
}
|
||||
|
||||
std::vector<int64> base(arg_dimensions.size());
|
||||
for (int64 i = 0; i < multi_index.size(); ++i) {
|
||||
@ -1568,7 +1572,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
arg_dim_steps, func);
|
||||
return static_cast<ReturnT>(computed_result);
|
||||
}
|
||||
auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
|
||||
auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
|
||||
-> StatusOr<bool> {
|
||||
auto curr_val = arg_literal.Get<ReturnT>(input_index);
|
||||
|
||||
// Evaluate computation with specified literal operands.
|
||||
@ -1576,12 +1581,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
auto result_val_literal =
|
||||
LiteralUtil::CreateR0<ReturnT>(result_val);
|
||||
|
||||
std::unique_ptr<Literal> computed_result =
|
||||
embedded_evaluator
|
||||
.Evaluate<const Literal*>(
|
||||
*function,
|
||||
{result_val_literal.get(), curr_val_literal.get()})
|
||||
.ConsumeValueOrDie();
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
|
||||
embedded_evaluator.Evaluate<const Literal*>(
|
||||
*function, {result_val_literal.get(),
|
||||
curr_val_literal.get()}));
|
||||
// Clear visit states so that we can use the evaluator again on
|
||||
// the same computation.
|
||||
embedded_evaluator.ResetVisitStates();
|
||||
@ -1591,13 +1594,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
};
|
||||
// Computes one element of the result, reducing all dimensions that
|
||||
// contribute to that element.
|
||||
ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
|
||||
arg_dim_steps, func);
|
||||
eval_status = ShapeUtil::ForEachIndexWithStatus(
|
||||
arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func);
|
||||
return result_val;
|
||||
}));
|
||||
|
||||
parent_->evaluated_[reduce] = std::move(result);
|
||||
return Status::OK();
|
||||
return eval_status;
|
||||
}
|
||||
|
||||
bool IsScalarAdd(HloComputation* computation) {
|
||||
|
Loading…
Reference in New Issue
Block a user