Move ParseHloString to VerifiedHloModule.

To make this possible, add HloParser interface which only allows
VerifiedHloModule to instantiate a parser and make the private HloParser
class a HloParserImpl child class.
Also use std::string instead of string in hlo_parser.cc.
Finally add the verification at the end of the parsing instead of calling
Verify() at various call sites.

PiperOrigin-RevId: 276055152
Change-Id: I647e7a14e1ff9ae0aa1ba764af8718c753226e6b
This commit is contained in:
Adrian Kuegel 2019-10-22 06:50:02 -07:00 committed by TensorFlower Gardener
parent c474877e2b
commit 07aae20777
10 changed files with 299 additions and 259 deletions

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -44,11 +47,6 @@ StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str);
// Given a string in the HloModule::ToString() format, parses the string and
// builds the HloModule in place at the given module pointer. 'module' must
// point to an empty module (no computations).
Status ParseHloString(absl::string_view str, HloModule* module);
// Parses sharding from str. str is supposed to contain the body of the
// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g.,
// "{replicated}".
@ -85,6 +83,19 @@ StatusOr<Shape> ParseShape(absl::string_view str);
StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
absl::string_view str);
class HloParser {
public:
// Runs the parser and constructs the resulting HLO in the given (empty)
// HloModule. Returns the error status in case an error occurred.
virtual Status Run(HloModule* module) = 0;
virtual ~HloParser() {}
private:
static std::unique_ptr<HloParser> CreateHloParserForTests(
absl::string_view str);
friend class VerifiedHloModule;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_

View File

@ -1707,8 +1707,7 @@ class HloParameterizedParserTest
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(ParseHloString(original, verified_module.get()));
TF_ASSERT_OK(verified_module->Verify());
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
module = std::move(verified_module);
} else {
TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
@ -1768,8 +1767,7 @@ class HloParserTest : public ::testing::Test {
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
return std::move(module);
}
};

View File

@ -67,8 +67,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) {
"Conv", hlo_config, /*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
/*shape_size_function=*/ShapeUtil::ByteSizeOfElements);
TF_CHECK_OK(xla::ParseHloString(hlo_text, &hlo_module));
TF_CHECK_OK(hlo_module.Verify());
TF_CHECK_OK(hlo_module.ParseHloStringAndVerifyModule(hlo_text));
xla::HloInstruction* conv =
hlo_module.entry_computation()->root_instruction();

View File

@ -46,8 +46,7 @@ ENTRY entry {
"TupleUtilTest", HloModuleConfig(), /*verifier_layout_sensitive=*/true,
/*allow_mixed_precision_in_hlo_verifier=*/false,
ShapeUtil::ByteSizeOfElements);
TF_RETURN_IF_ERROR(ParseHloString(hlo_string, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);

View File

@ -114,10 +114,12 @@ cc_library(
hdrs = ["verified_hlo_module.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
"//tensorflow/core:test",

View File

@ -136,8 +136,7 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text,
TestName(), config, verifier_layout_sensitive_,
allow_mixed_precision_in_hlo_verifier_,
backend().compiler()->ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
return std::move(module);
}

View File

@ -223,8 +223,7 @@ LocalClientTestBase::ParseAndReturnVerifiedModule(
TestName(), config, /*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
local_client_->backend().compiler()->ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
TF_RETURN_IF_ERROR(module->Verify());
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
return std::move(module);
}

View File

@ -14,25 +14,26 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
Status VerifiedHloModule::ParseHloStringAndVerifyModule(absl::string_view str) {
TF_RET_CHECK(computation_count() == 0);
auto parser = HloParser::CreateHloParserForTests(str);
TF_RETURN_IF_ERROR(parser->Run(this));
return Verify();
}
void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
void VerifiedHloModule::VerifyOrAddFailure(absl::string_view message) {
Status status = Verify();
if (!status.ok()) {
ADD_FAILURE() << "HloVerifier failed on module " << name()
@ -43,4 +44,12 @@ void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
}
}
Status VerifiedHloModule::Verify() {
if (computation_count() == 0) {
// The computation was never built. Nothing to verify.
return Status::OK();
}
return verifier_.Run(this).status();
}
} // namespace xla

View File

@ -16,8 +16,8 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_
#include <functional>
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
@ -43,14 +43,20 @@ class VerifiedHloModule : public HloModule {
~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
// Verifies the module using HloVerifier and returns the status.
Status Verify();
// Given a string in the HloModule::ToString() format, parses the string and
// builds the VerifiedHloModule in place. Before calling this method, the
// module must be empty (no computations). Finally verifies the module using
// HloVerifier and returns the status.
Status ParseHloStringAndVerifyModule(absl::string_view str);
// Verifies the module and flags any error with ADD_FAILURE. 'message' is
// included in the failure message.
void VerifyOrAddFailure(const string& message);
void VerifyOrAddFailure(absl::string_view message);
private:
// Verifies the module using HloVerifier and returns the status.
Status Verify();
HloVerifier verifier_;
};