[XLA] Propagate invalid shape errors through reduce folding and turn it on
HloEvaluator should be stable enough for reduce folding, but it shouldn't crash when it encounters an instruction without a layout. Verify the layout on every instruction that gets evaluated and return an error on failure. PiperOrigin-RevId: 209641401
This commit is contained in:
parent
c61a49ec31
commit
4f41091f88
@ -2456,6 +2456,7 @@ tf_cc_test(
|
|||||||
":hlo",
|
":hlo",
|
||||||
":hlo_constant_folding",
|
":hlo_constant_folding",
|
||||||
":hlo_matchers",
|
":hlo_matchers",
|
||||||
|
":hlo_parser",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
@ -52,9 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
|||||||
computation->root_instruction() != instruction) {
|
computation->root_instruction() != instruction) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Skip Constant, Parameter, Reduce, and AfterAll operation.
|
// Skip Constant, Parameter, and AfterAll operation.
|
||||||
// TODO(b/35975797): Enable Reduce operation once arbitrary computation
|
|
||||||
// are supported by the evaluator.
|
|
||||||
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
|
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
|
||||||
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
|
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
|
||||||
// operand in which case constant folding will be impossible and this
|
// operand in which case constant folding will be impossible and this
|
||||||
@ -62,7 +60,6 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
|||||||
if (instruction->opcode() == HloOpcode::kParameter ||
|
if (instruction->opcode() == HloOpcode::kParameter ||
|
||||||
instruction->opcode() == HloOpcode::kConstant ||
|
instruction->opcode() == HloOpcode::kConstant ||
|
||||||
instruction->opcode() == HloOpcode::kTuple ||
|
instruction->opcode() == HloOpcode::kTuple ||
|
||||||
instruction->opcode() == HloOpcode::kReduce ||
|
|
||||||
instruction->opcode() == HloOpcode::kAfterAll) {
|
instruction->opcode() == HloOpcode::kAfterAll) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
@ -202,5 +203,45 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
|
|||||||
EXPECT_TRUE(matched);
|
EXPECT_TRUE(matched);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* const kConstantFoldReduce = R"(
|
||||||
|
HloModule ConstantFoldReduce
|
||||||
|
|
||||||
|
add {
|
||||||
|
a = s32[] parameter(0)
|
||||||
|
b = s32[] parameter(1)
|
||||||
|
ROOT add = s32[] add(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY r {
|
||||||
|
x = s32[3] constant({1, 2, 3})
|
||||||
|
init = s32[] constant(0)
|
||||||
|
ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
|
||||||
|
})";
|
||||||
|
|
||||||
|
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(kConstantFoldReduce));
|
||||||
|
HloConstantFolding const_folder;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
|
||||||
|
EXPECT_TRUE(result);
|
||||||
|
|
||||||
|
EXPECT_EQ(6, module->entry_computation()
|
||||||
|
->root_instruction()
|
||||||
|
->literal()
|
||||||
|
.GetFirstElement<int32>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(kConstantFoldReduce));
|
||||||
|
HloInstruction* add = module->computations().begin()->root_instruction();
|
||||||
|
LayoutUtil::ClearLayout(add->mutable_shape());
|
||||||
|
HloConstantFolding const_folder;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
|
||||||
|
EXPECT_FALSE(result);
|
||||||
|
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -230,7 +230,6 @@ template <typename LiteralPtr>
|
|||||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||||
HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
|
HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) {
|
||||||
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
|
TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
|
||||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
|
|
||||||
|
|
||||||
evaluated_.clear();
|
evaluated_.clear();
|
||||||
arg_literals_.clear();
|
arg_literals_.clear();
|
||||||
@ -267,7 +266,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
|||||||
return tensorflow::errors::FailedPrecondition(
|
return tensorflow::errors::FailedPrecondition(
|
||||||
"Not all operands are constants.");
|
"Not all operands are constants.");
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
|
|
||||||
|
|
||||||
arg_literals_.clear();
|
arg_literals_.clear();
|
||||||
evaluated_.clear();
|
evaluated_.clear();
|
||||||
@ -1285,7 +1283,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) {
|
|||||||
|
|
||||||
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
|
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
|
||||||
VLOG(2) << "About to visit HLO: " << hlo->ToString();
|
VLOG(2) << "About to visit HLO: " << hlo->ToString();
|
||||||
return Status::OK();
|
return ShapeUtil::ValidateShape(hlo->shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
|
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
|
||||||
|
@ -1544,10 +1544,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
|
||||||
auto result = absl::make_unique<Literal>(reduce->shape());
|
auto result = absl::make_unique<Literal>(reduce->shape());
|
||||||
|
Status eval_status;
|
||||||
// For each resulting dimension, calculate and assign computed value.
|
// For each resulting dimension, calculate and assign computed value.
|
||||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||||
ReturnT result_val = init_scalar;
|
ReturnT result_val = init_scalar;
|
||||||
|
if (!eval_status.ok()) {
|
||||||
|
return result_val;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64> base(arg_dimensions.size());
|
std::vector<int64> base(arg_dimensions.size());
|
||||||
for (int64 i = 0; i < multi_index.size(); ++i) {
|
for (int64 i = 0; i < multi_index.size(); ++i) {
|
||||||
@ -1568,7 +1572,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
arg_dim_steps, func);
|
arg_dim_steps, func);
|
||||||
return static_cast<ReturnT>(computed_result);
|
return static_cast<ReturnT>(computed_result);
|
||||||
}
|
}
|
||||||
auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) {
|
auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index)
|
||||||
|
-> StatusOr<bool> {
|
||||||
auto curr_val = arg_literal.Get<ReturnT>(input_index);
|
auto curr_val = arg_literal.Get<ReturnT>(input_index);
|
||||||
|
|
||||||
// Evaluate computation with specified literal operands.
|
// Evaluate computation with specified literal operands.
|
||||||
@ -1576,12 +1581,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
auto result_val_literal =
|
auto result_val_literal =
|
||||||
LiteralUtil::CreateR0<ReturnT>(result_val);
|
LiteralUtil::CreateR0<ReturnT>(result_val);
|
||||||
|
|
||||||
std::unique_ptr<Literal> computed_result =
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
|
||||||
embedded_evaluator
|
embedded_evaluator.Evaluate<const Literal*>(
|
||||||
.Evaluate<const Literal*>(
|
*function, {result_val_literal.get(),
|
||||||
*function,
|
curr_val_literal.get()}));
|
||||||
{result_val_literal.get(), curr_val_literal.get()})
|
|
||||||
.ConsumeValueOrDie();
|
|
||||||
// Clear visit states so that we can use the evaluator again on
|
// Clear visit states so that we can use the evaluator again on
|
||||||
// the same computation.
|
// the same computation.
|
||||||
embedded_evaluator.ResetVisitStates();
|
embedded_evaluator.ResetVisitStates();
|
||||||
@ -1591,13 +1594,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
};
|
};
|
||||||
// Computes one element of the result, reducing all dimensions that
|
// Computes one element of the result, reducing all dimensions that
|
||||||
// contribute to that element.
|
// contribute to that element.
|
||||||
ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
|
eval_status = ShapeUtil::ForEachIndexWithStatus(
|
||||||
arg_dim_steps, func);
|
arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func);
|
||||||
return result_val;
|
return result_val;
|
||||||
}));
|
}));
|
||||||
|
|
||||||
parent_->evaluated_[reduce] = std::move(result);
|
parent_->evaluated_[reduce] = std::move(result);
|
||||||
return Status::OK();
|
return eval_status;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsScalarAdd(HloComputation* computation) {
|
bool IsScalarAdd(HloComputation* computation) {
|
||||||
|
Loading…
Reference in New Issue
Block a user