[XLA] More module proto verification added.
PiperOrigin-RevId: 216745236
This commit is contained in:
parent
55cf8c0db7
commit
2b279edb86
@ -57,10 +57,13 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
sparse_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewModule();
|
||||
// Since verifier is reporting sparse layouts as errors, we should
|
||||
// use a regular HloModule instead of VerifiedHloModule to avoid
|
||||
// verifier errors being triggered in the destructor.
|
||||
auto module = HloTestBase::CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module).status();
|
||||
Status status = checker().Run(module.get()).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("CPU backend does not support"));
|
||||
|
@ -57,10 +57,13 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
|
||||
HloInstruction::CreateParameter(1, sparse_shape, "param1"));
|
||||
builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
sparse_shape, HloOpcode::kAdd, param0, param1));
|
||||
auto module = CreateNewModule();
|
||||
// Since verifier is reporting sparse layouts as errors, we should
|
||||
// use a regular HloModule instead of VerifiedHloModule to avoid
|
||||
// verifier errors being triggered in the destructor.
|
||||
auto module = HloTestBase::CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
Status status = checker().Run(module).status();
|
||||
Status status = checker().Run(module.get()).status();
|
||||
ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("GPU backend does not support"));
|
||||
|
@ -2677,7 +2677,6 @@ Status HloInstruction::AcceptOrdered(
|
||||
}
|
||||
|
||||
const Shape& HloInstruction::shape() const {
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
|
||||
return shape_;
|
||||
}
|
||||
|
||||
|
@ -42,7 +42,7 @@ StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(proto, module_config));
|
||||
TF_RETURN_IF_ERROR(
|
||||
HloVerifier(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false)
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
return std::move(module);
|
||||
|
@ -469,6 +469,9 @@ absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const {
|
||||
if (!IsTuple()) {
|
||||
return *this;
|
||||
}
|
||||
if (tuple_elements_.empty()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
for (int64 i = 1; i < tuple_elements_.size(); ++i) {
|
||||
if (tuple_elements_[0] != tuple_elements_[i]) {
|
||||
return absl::nullopt;
|
||||
|
@ -27,6 +27,14 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
|
||||
if (LayoutUtil::IsSparseArray(hlo->shape())) {
|
||||
return InternalError("Sparse arrays are not yet fully supported: %s",
|
||||
hlo->ToString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
|
||||
if (hlo->operand_count() != expected) {
|
||||
return InternalError("Expected %d operands for %s instruction: %s",
|
||||
@ -286,6 +294,10 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
|
||||
|
||||
Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0));
|
||||
if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
|
||||
return InternalError("Constant is required to have a valid literal: %s",
|
||||
constant->ToString());
|
||||
}
|
||||
return CheckShape(constant, constant->literal().shape());
|
||||
}
|
||||
|
||||
@ -877,14 +889,21 @@ Status VerifyEntryAndExitShapes(const HloModule& module) {
|
||||
Status CheckEntryComputationLayout(const HloModule& module) {
|
||||
const HloComputation* computation = module.entry_computation();
|
||||
const auto& layout = module.entry_computation_layout();
|
||||
const ShapeLayout& result_layout = layout.result_layout();
|
||||
|
||||
if (LayoutUtil::IsSparseArray(result_layout.shape())) {
|
||||
return Unimplemented(
|
||||
"Sparse arrays are not yet fully supported in program result shape: %s",
|
||||
ShapeUtil::HumanStringWithLayout(result_layout.shape()));
|
||||
}
|
||||
|
||||
if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
|
||||
layout.result_layout().shape())) {
|
||||
result_layout.shape())) {
|
||||
return InternalError(
|
||||
"Shape of the root instruction of entry computation (%s) should be "
|
||||
"compatible to one specified in module's entry computation layout (%s)",
|
||||
ShapeUtil::HumanString(computation->root_instruction()->shape()),
|
||||
ShapeUtil::HumanString(layout.result_layout().shape()));
|
||||
ShapeUtil::HumanString(result_layout.shape()));
|
||||
}
|
||||
|
||||
if (computation->num_parameters() != layout.parameter_count()) {
|
||||
@ -895,15 +914,19 @@ Status CheckEntryComputationLayout(const HloModule& module) {
|
||||
}
|
||||
|
||||
for (int i = 0; i < computation->num_parameters(); ++i) {
|
||||
if (!ShapeUtil::Compatible(computation->parameter_instruction(i)->shape(),
|
||||
layout.parameter_shape(i))) {
|
||||
const HloInstruction* parameter = computation->parameter_instruction(i);
|
||||
if (LayoutUtil::IsSparseArray(layout.parameter_shape(i))) {
|
||||
return Unimplemented(
|
||||
"Sparse arrays are not yet fully supported "
|
||||
"in program parameter shape: %s",
|
||||
ShapeUtil::HumanStringWithLayout(layout.parameter_shape(i)));
|
||||
}
|
||||
if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
|
||||
return InternalError(
|
||||
"Shape of the entry computation parameter %d is %s should be "
|
||||
"compatible to the one specified in module's entry computation "
|
||||
"layout %s",
|
||||
i,
|
||||
ShapeUtil::HumanString(
|
||||
computation->parameter_instruction(i)->shape()),
|
||||
i, ShapeUtil::HumanString(parameter->shape()),
|
||||
ShapeUtil::HumanString(layout.parameter_shape(i)));
|
||||
}
|
||||
}
|
||||
|
@ -32,6 +32,8 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
: layout_sensitive_(layout_sensitive),
|
||||
allow_mixed_precision_(allow_mixed_precision) {}
|
||||
|
||||
Status Preprocess(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override;
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override;
|
||||
Status HandleClamp(HloInstruction* clamp) override;
|
||||
|
@ -1523,6 +1523,10 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
|
||||
// Execute extra verification step once the layout has been finalized.
|
||||
TF_RETURN_IF_ERROR(Verify(instruction));
|
||||
|
||||
// Shape must be valid.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
|
||||
|
||||
// Verify all layouts in the shape have been set.
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
|
||||
}
|
||||
|
@ -345,8 +345,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
for (int64 i = 0; i < module_protos.size(); ++i) {
|
||||
const HloModuleProto* proto = module_protos[i];
|
||||
const HloModuleConfig& config = *module_configs[i];
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
HloModule::CreateFromProto(*proto, config));
|
||||
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
|
||||
modules.push_back(std::move(module));
|
||||
}
|
||||
|
||||
@ -810,7 +809,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(module_proto, *module_config));
|
||||
CreateModuleFromProto(module_proto, *module_config));
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
|
||||
|
||||
@ -1081,7 +1080,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
|
||||
HloModuleConfig config(program_shape);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(arg->computation(), config));
|
||||
CreateModuleFromProto(arg->computation(), config));
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
|
||||
@ -1118,7 +1117,7 @@ Status Service::GetComputationGraphStats(
|
||||
HloModuleConfig config(arg->computation().program_shape());
|
||||
config.set_debug_options(arg->debug_options());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
HloModule::CreateFromProto(arg->computation(), config));
|
||||
CreateModuleFromProto(arg->computation(), config));
|
||||
|
||||
hlo_graph_dumper::MaybeDumpHloModule(*module,
|
||||
"computation statistics subject");
|
||||
|
Loading…
Reference in New Issue
Block a user