[XLA] More module proto verification added.

PiperOrigin-RevId: 216745236
This commit is contained in:
A. Unique TensorFlower 2018-10-11 12:53:55 -07:00 committed by TensorFlower Gardener
parent 55cf8c0db7
commit 2b279edb86
9 changed files with 54 additions and 18 deletions

View File

@ -57,10 +57,13 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
sparse_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewModule();
// Since verifier is reporting sparse layouts as errors, we should
// use a regular HloModule instead of VerifiedHloModule to avoid
// verifier errors being triggered in the destructor.
auto module = HloTestBase::CreateNewModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module).status();
Status status = checker().Run(module.get()).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("CPU backend does not support"));

View File

@ -57,10 +57,13 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
builder.AddInstruction(HloInstruction::CreateBinary(
sparse_shape, HloOpcode::kAdd, param0, param1));
auto module = CreateNewModule();
// Since verifier is reporting sparse layouts as errors, we should
// use a regular HloModule instead of VerifiedHloModule to avoid
// verifier errors being triggered in the destructor.
auto module = HloTestBase::CreateNewModule();
module->AddEntryComputation(builder.Build());
Status status = checker().Run(module).status();
Status status = checker().Run(module.get()).status();
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
EXPECT_THAT(status.error_message(),
HasSubstr("GPU backend does not support"));

View File

@ -2677,7 +2677,6 @@ Status HloInstruction::AcceptOrdered(
}
const Shape& HloInstruction::shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
return shape_;
}

View File

@ -42,7 +42,7 @@ StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(proto, module_config));
TF_RETURN_IF_ERROR(
HloVerifier(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false)
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
.Run(module.get())
.status());
return std::move(module);

View File

@ -469,6 +469,9 @@ absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
if (!IsTuple()) {
return *this;
}
if (tuple_elements_.empty()) {
return absl::nullopt;
}
for (int64 i = 1; i < tuple_elements_.size(); ++i) {
if (tuple_elements_[0] != tuple_elements_[i]) {
return absl::nullopt;

View File

@ -27,6 +27,14 @@ limitations under the License.
namespace xla {
Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
if (LayoutUtil::IsSparseArray(hlo->shape())) {
return InternalError("Sparse arrays are not yet fully supported: %s",
hlo->ToString());
}
return Status::OK();
}
static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
if (hlo->operand_count() != expected) {
return InternalError("Expected %d operands for %s instruction: %s",
@ -286,6 +294,10 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0));
if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
return InternalError("Constant is required to have a valid literal: %s",
constant->ToString());
}
return CheckShape(constant, constant->literal().shape());
}
@ -877,14 +889,21 @@ Status VerifyEntryAndExitShapes(const HloModule& module) {
Status CheckEntryComputationLayout(const HloModule& module) {
const HloComputation* computation = module.entry_computation();
const auto& layout = module.entry_computation_layout();
const ShapeLayout& result_layout = layout.result_layout();
if (LayoutUtil::IsSparseArray(result_layout.shape())) {
return Unimplemented(
"Sparse arrays are not yet fully supported in program result shape: %s",
ShapeUtil::HumanStringWithLayout(result_layout.shape()));
}
if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
layout.result_layout().shape())) {
result_layout.shape())) {
return InternalError(
"Shape of the root instruction of entry computation (%s) should be "
"compatible to one specified in module's entry computation layout (%s)",
ShapeUtil::HumanString(computation->root_instruction()->shape()),
ShapeUtil::HumanString(layout.result_layout().shape()));
ShapeUtil::HumanString(result_layout.shape()));
}
if (computation->num_parameters() != layout.parameter_count()) {
@ -895,15 +914,19 @@ Status CheckEntryComputationLayout(const HloModule& module) {
}
for (int i = 0; i < computation->num_parameters(); ++i) {
if (!ShapeUtil::Compatible(computation->parameter_instruction(i)->shape(),
layout.parameter_shape(i))) {
const HloInstruction* parameter = computation->parameter_instruction(i);
if (LayoutUtil::IsSparseArray(layout.parameter_shape(i))) {
return Unimplemented(
"Sparse arrays are not yet fully supported "
"in program parameter shape: %s",
ShapeUtil::HumanStringWithLayout(layout.parameter_shape(i)));
}
if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
return InternalError(
"Shape of the entry computation parameter %d is %s should be "
"compatible to the one specified in module's entry computation "
"layout %s",
i,
ShapeUtil::HumanString(
computation->parameter_instruction(i)->shape()),
i, ShapeUtil::HumanString(parameter->shape()),
ShapeUtil::HumanString(layout.parameter_shape(i)));
}
}

View File

@ -32,6 +32,8 @@ class ShapeVerifier : public DfsHloVisitor {
: layout_sensitive_(layout_sensitive),
allow_mixed_precision_(allow_mixed_precision) {}
Status Preprocess(HloInstruction* hlo) override;
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
Status HandleClamp(HloInstruction* clamp) override;

View File

@ -1523,6 +1523,10 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
// Execute extra verification step once the layout has been finalized.
TF_RETURN_IF_ERROR(Verify(instruction));
// Shape must be valid.
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}

View File

@ -345,8 +345,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
for (int64 i = 0; i < module_protos.size(); ++i) {
const HloModuleProto* proto = module_protos[i];
const HloModuleConfig& config = *module_configs[i];
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(*proto, config));
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
modules.push_back(std::move(module));
}
@ -810,7 +809,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(module_proto, *module_config));
CreateModuleFromProto(module_proto, *module_config));
TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
@ -1081,7 +1080,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(arg->computation(), config));
CreateModuleFromProto(arg->computation(), config));
HloEvaluator evaluator;
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
@ -1118,7 +1117,7 @@ Status Service::GetComputationGraphStats(
HloModuleConfig config(arg->computation().program_shape());
config.set_debug_options(arg->debug_options());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(arg->computation(), config));
CreateModuleFromProto(arg->computation(), config));
hlo_graph_dumper::MaybeDumpHloModule(*module,
"computation statistics subject");