Open source hlo_module_loader.
This is in preparation of open sourcing another tool. PiperOrigin-RevId: 283927480 Change-Id: I0f38a0e6a1fcdded1b0e1c28ff62d07e51bb1cc9
This commit is contained in:
parent
2c2f30c7d4
commit
d364d465d7
@ -252,3 +252,30 @@ sh_test(
|
||||
srcs = ["interactive_graphviz_test.sh"],
|
||||
data = [":interactive_graphviz"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_module_loader",
|
||||
srcs = ["hlo_module_loader.cc"],
|
||||
hdrs = ["hlo_module_loader.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_module_loader_test",
|
||||
srcs = ["hlo_module_loader_test.cc"],
|
||||
deps = [
|
||||
":hlo_module_loader",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
125
tensorflow/compiler/xla/tools/hlo_module_loader.cc
Normal file
125
tensorflow/compiler/xla/tools/hlo_module_loader.cc
Normal file
@ -0,0 +1,125 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Emits an HLO module in a text form suitable for diffing.
|
||||
|
||||
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
Status OverrideConfig(const hlo_module_loader_details::Config& ovr_config,
|
||||
HloModuleConfig* config) {
|
||||
config->set_replica_count(ovr_config.num_replicas);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
string StripLogHeaders(const string& hlo_string) {
|
||||
// I0521 12:04:45.883483 1509 service.cc:186] ...
|
||||
static RE2* matcher = new RE2(
|
||||
"[IWEF]\\d{4} "
|
||||
"\\d{2}:\\d{2}:\\d{2}\\.\\d+\\s+\\d+\\s+[^:]+:\\d+\\]\\s?(.*)");
|
||||
absl::string_view matches[4];
|
||||
std::vector<string> lines = absl::StrSplit(hlo_string, '\n');
|
||||
for (auto& line : lines) {
|
||||
if (matcher->Match(line, 0, line.size(), RE2::ANCHOR_START, matches, 4)) {
|
||||
line = string(matches[1]);
|
||||
}
|
||||
}
|
||||
return absl::StrJoin(lines, "\n", [](string* out, const string& line) {
|
||||
absl::StrAppend(out, line);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
|
||||
const string& data, const string& format,
|
||||
hlo_module_loader_details::Config ovr_config,
|
||||
const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
|
||||
DebugOptions debug_options = GetDebugOptionsFromFlags();
|
||||
std::unique_ptr<HloModule> module;
|
||||
if (format == "hlo" || format == "txt") {
|
||||
string hlo_string = StripLogHeaders(data);
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
|
||||
if (config_modifier_hook) {
|
||||
config_modifier_hook(&config);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string, config));
|
||||
} else {
|
||||
HloSnapshot proto;
|
||||
if (format == "pb") {
|
||||
if (!proto.ParseFromString(data) &&
|
||||
!proto.mutable_hlo()->ParseFromString(data)) {
|
||||
return InvalidArgument("Failed to parse input as HLO protobuf binary");
|
||||
}
|
||||
} else if (format == "pbtxt") {
|
||||
if (!proto2::TextFormat::ParseFromString(data, &proto) &&
|
||||
!proto2::TextFormat::ParseFromString(data, proto.mutable_hlo())) {
|
||||
return InvalidArgument("Failed to parse input as HLO protobuf text");
|
||||
}
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
"Invalid format from file extension: '%s'. Expected: hlo, txt, pb, "
|
||||
"or pbtxt",
|
||||
format);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
|
||||
HloModule::CreateModuleConfigFromProto(
|
||||
proto.hlo().hlo_module(), debug_options));
|
||||
TF_RETURN_IF_ERROR(OverrideConfig(ovr_config, &config));
|
||||
if (config_modifier_hook) {
|
||||
config_modifier_hook(&config);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, HloModule::CreateFromProto(proto.hlo().hlo_module(), config));
|
||||
}
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
|
||||
const string& path, hlo_module_loader_details::Config ovr_config,
|
||||
string format,
|
||||
const std::function<void(HloModuleConfig*)>& config_modifier_hook) {
|
||||
string data;
|
||||
if (format.empty()) {
|
||||
format = string(tensorflow::io::Extension(path));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &data));
|
||||
return LoadModuleFromData(data, format, ovr_config, config_modifier_hook);
|
||||
}
|
||||
|
||||
} // namespace xla
|
79
tensorflow/compiler/xla/tools/hlo_module_loader.h
Normal file
79
tensorflow/compiler/xla/tools/hlo_module_loader.h
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_MODULE_LOADER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_MODULE_LOADER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace hlo_module_loader_details {
|
||||
|
||||
struct Config {
|
||||
Config() {}
|
||||
int64 num_replicas = 1;
|
||||
};
|
||||
|
||||
} // namespace hlo_module_loader_details
|
||||
|
||||
// Given a string composed by multiple lines, strip the log headers, if present
|
||||
// at the beginning of each line.
|
||||
string StripLogHeaders(const string& hlo_string);
|
||||
|
||||
// Loads an HLO module from a string.
|
||||
// The data can have the followings formats:
|
||||
// 1) A binary of text proto file, the proto should be in xla.HloProto type. It
|
||||
// can be a binary proto (format must be "pb"), or a text proto (format must
|
||||
// be "pbtxt").
|
||||
// 2) A hlo text dump, the string should be in HloModule::ToString() format
|
||||
// (format must be "txt" or "hlo"). The input data can also contain log
|
||||
// headers, which will be stripped.
|
||||
// The ovr_config data can be used to override certain fields of the
|
||||
// HloModuleConfig.
|
||||
// The HloModuleConfig is passed to config_modifier_hook for custom
|
||||
// modifications before use.
|
||||
StatusOr<std::unique_ptr<HloModule>> LoadModuleFromData(
|
||||
const string& data, const string& format,
|
||||
hlo_module_loader_details::Config ovr_config =
|
||||
hlo_module_loader_details::Config(),
|
||||
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {});
|
||||
|
||||
// Loads an HLO module from file.
|
||||
// The file can be one of the followings:
|
||||
// 1) A binary of text proto file, the proto should be in xla.HloProto type. It
|
||||
// can be a binary proto (with .pb extension), or a text proto (with a .pbtxt
|
||||
// extension).
|
||||
// 2) A hlo text dump, the string should be in HloModule::ToString() format
|
||||
// (with a .hlo or .txt extension). A text file can also contain log headers,
|
||||
// which will be stripped.
|
||||
// If the format is specified (not empty), it overrides the one guessed from the
|
||||
// file extension. The ovr_config data can be used to override certain fields of
|
||||
// the HloModuleConfig.
|
||||
// The HloModuleConfig is passed to config_modifier_hook for custom
|
||||
// modifications before use.
|
||||
StatusOr<std::unique_ptr<HloModule>> LoadModuleFromFile(
|
||||
const string& path,
|
||||
hlo_module_loader_details::Config ovr_config =
|
||||
hlo_module_loader_details::Config(),
|
||||
string format = "",
|
||||
const std::function<void(HloModuleConfig*)>& config_modifier_hook = {});
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_MODULE_LOADER_H_
|
48
tensorflow/compiler/xla/tools/hlo_module_loader_test.cc
Normal file
48
tensorflow/compiler/xla/tools/hlo_module_loader_test.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class HloModuleLoaderTest : public HloTestBase {};
|
||||
|
||||
TEST_F(HloModuleLoaderTest, StripsLogHeaders) {
|
||||
const string& hlo_string = R"(
|
||||
I0521 12:04:45.883483 1509 service.cc:186] HloModule test_log_stripping
|
||||
I0521 12:04:45.883483 1509 service.cc:186]
|
||||
I0521 12:04:45.883483 1509 service.cc:186] ENTRY entry {
|
||||
I0521 12:04:45.883483 1509 service.cc:186] p0 = f32[4]{0} parameter(0)
|
||||
I0521 12:04:45.883483 1509 service.cc:186] p1 = f32[4]{0} parameter(1)
|
||||
I0521 12:04:45.883483 1509 service.cc:186] add = f32[4]{0} add(p0, p1)
|
||||
I0521 12:04:45.883483 1509 service.cc:186] ROOT rooty = (f32[4]{0}, f32[4]{0}) tuple(p1, add)
|
||||
I0521 12:04:45.883483 1509 service.cc:186] }
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
|
||||
LoadModuleFromData(hlo_string, "txt"));
|
||||
EXPECT_NE(FindInstruction(hlo_module.get(), "p0"), nullptr);
|
||||
EXPECT_NE(FindInstruction(hlo_module.get(), "p1"), nullptr);
|
||||
EXPECT_NE(FindInstruction(hlo_module.get(), "add"), nullptr);
|
||||
EXPECT_NE(FindInstruction(hlo_module.get(), "rooty"), nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user