Move ParseHloString to VerifiedHloModule.
To make this possible, add HloParser interface which only allows VerifiedHloModule to instantiate a parser and make the private HloParser class a HloParserImpl child class. Also use std::string instead of string in hlo_parser.cc. Finally add the verification at the end of the parsing instead of calling Verify() at various call sites. PiperOrigin-RevId: 276055152 Change-Id: I647e7a14e1ff9ae0aa1ba764af8718c753226e6b
This commit is contained in:
parent
c474877e2b
commit
07aae20777
File diff suppressed because it is too large
Load Diff
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -44,11 +47,6 @@ StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
|
||||
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
|
||||
absl::string_view str);
|
||||
|
||||
// Given a string in the HloModule::ToString() format, parses the string and
|
||||
// builds the HloModule in place at the given module pointer. 'module' must
|
||||
// point to an empty module (no computations).
|
||||
Status ParseHloString(absl::string_view str, HloModule* module);
|
||||
|
||||
// Parses sharding from str. str is supposed to contain the body of the
|
||||
// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g.,
|
||||
// "{replicated}".
|
||||
@ -85,6 +83,19 @@ StatusOr<Shape> ParseShape(absl::string_view str);
|
||||
StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
|
||||
absl::string_view str);
|
||||
|
||||
class HloParser {
|
||||
public:
|
||||
// Runs the parser and constructs the resulting HLO in the given (empty)
|
||||
// HloModule. Returns the error status in case an error occurred.
|
||||
virtual Status Run(HloModule* module) = 0;
|
||||
virtual ~HloParser() {}
|
||||
|
||||
private:
|
||||
static std::unique_ptr<HloParser> CreateHloParserForTests(
|
||||
absl::string_view str);
|
||||
friend class VerifiedHloModule;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
|
||||
|
@ -1707,8 +1707,7 @@ class HloParameterizedParserTest
|
||||
/*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||
ShapeUtil::ByteSizeOfElements);
|
||||
TF_ASSERT_OK(ParseHloString(original, verified_module.get()));
|
||||
TF_ASSERT_OK(verified_module->Verify());
|
||||
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
|
||||
module = std::move(verified_module);
|
||||
} else {
|
||||
TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
|
||||
@ -1768,8 +1767,7 @@ class HloParserTest : public ::testing::Test {
|
||||
/*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||
ShapeUtil::ByteSizeOfElements);
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
|
||||
return std::move(module);
|
||||
}
|
||||
};
|
||||
|
@ -67,8 +67,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) {
|
||||
"Conv", hlo_config, /*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||
/*shape_size_function=*/ShapeUtil::ByteSizeOfElements);
|
||||
TF_CHECK_OK(xla::ParseHloString(hlo_text, &hlo_module));
|
||||
TF_CHECK_OK(hlo_module.Verify());
|
||||
TF_CHECK_OK(hlo_module.ParseHloStringAndVerifyModule(hlo_text));
|
||||
xla::HloInstruction* conv =
|
||||
hlo_module.entry_computation()->root_instruction();
|
||||
|
||||
|
@ -46,8 +46,7 @@ ENTRY entry {
|
||||
"TupleUtilTest", HloModuleConfig(), /*verifier_layout_sensitive=*/true,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/false,
|
||||
ShapeUtil::ByteSizeOfElements);
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_string, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_string));
|
||||
|
||||
*entry_computation = module->entry_computation();
|
||||
*param0 = (*entry_computation)->parameter_instruction(0);
|
||||
|
@ -114,10 +114,12 @@ cc_library(
|
||||
hdrs = ["verified_hlo_module.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -136,8 +136,7 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
|
||||
TestName(), config, verifier_layout_sensitive_,
|
||||
allow_mixed_precision_in_hlo_verifier_,
|
||||
backend().compiler()->ShapeSizeBytesFunction());
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
|
@ -223,8 +223,7 @@ LocalClientTestBase::ParseAndReturnVerifiedModule(
|
||||
TestName(), config, /*verifier_layout_sensitive=*/false,
|
||||
/*allow_mixed_precision_in_hlo_verifier=*/true,
|
||||
local_client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
|
||||
TF_RETURN_IF_ERROR(module->Verify());
|
||||
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
|
@ -14,25 +14,26 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status VerifiedHloModule::Verify() {
|
||||
if (computation_count() == 0) {
|
||||
// The computation was never built. Nothing to verify.
|
||||
return Status::OK();
|
||||
}
|
||||
return verifier_.Run(this).status();
|
||||
Status VerifiedHloModule::ParseHloStringAndVerifyModule(absl::string_view str) {
|
||||
TF_RET_CHECK(computation_count() == 0);
|
||||
auto parser = HloParser::CreateHloParserForTests(str);
|
||||
TF_RETURN_IF_ERROR(parser->Run(this));
|
||||
return Verify();
|
||||
}
|
||||
|
||||
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
||||
void VerifiedHloModule::VerifyOrAddFailure(absl::string_view message) {
|
||||
Status status = Verify();
|
||||
if (!status.ok()) {
|
||||
ADD_FAILURE() << "HloVerifier failed on module " << name()
|
||||
@ -43,4 +44,12 @@ void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
|
||||
}
|
||||
}
|
||||
|
||||
Status VerifiedHloModule::Verify() {
|
||||
if (computation_count() == 0) {
|
||||
// The computation was never built. Nothing to verify.
|
||||
return Status::OK();
|
||||
}
|
||||
return verifier_.Run(this).status();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
@ -43,14 +43,20 @@ class VerifiedHloModule : public HloModule {
|
||||
|
||||
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
|
||||
|
||||
// Verifies the module using HloVerifier and returns the status.
|
||||
Status Verify();
|
||||
// Given a string in the HloModule::ToString() format, parses the string and
|
||||
// builds the VerifiedHloModule in place. Before calling this method, the
|
||||
// module must be empty (no computations). Finally verifies the module using
|
||||
// HloVerifier and returns the status.
|
||||
Status ParseHloStringAndVerifyModule(absl::string_view str);
|
||||
|
||||
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
|
||||
// included in the failure message.
|
||||
void VerifyOrAddFailure(const string& message);
|
||||
void VerifyOrAddFailure(absl::string_view message);
|
||||
|
||||
private:
|
||||
// Verifies the module using HloVerifier and returns the status.
|
||||
Status Verify();
|
||||
|
||||
HloVerifier verifier_;
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user