STT-tensorflow/tensorflow/compiler/xla/service/hlo_runner_interface.cc
A. Unique TensorFlower d185bebfe7 Added support to visualise HloModule Protobufs
PiperOrigin-RevId: 350723449
Change-Id: I25bd48e5316ec1124523244cbad52d82d64f1568
2021-01-08 02:06:03 -08:00

118 lines
4.5 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
namespace {
// Creates an HloModule from the given proto.
StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
const HloProto& proto, const DebugOptions& debug_options) {
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
debug_options));
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(proto.hlo_module(), config));
return std::move(module);
}
} // namespace
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromBinaryProtoFile(
const std::string& filename, const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromTextProtoFile(
const std::string& filename, const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(
tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromHloTextFile(
const std::string& filename, const DebugOptions& debug_options) {
string hlo_string;
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
filename, &hlo_string));
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromModuleBinaryProtofile(
const std::string& filename, const DebugOptions& debug_options) {
HloModuleProto module_proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &module_proto));
TF_ASSIGN_OR_RETURN(
HloModuleConfig module_config,
HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
return HloModule::CreateFromProto(module_proto, module_config);
}
StatusOr<Literal> HloRunnerInterface::Execute(
std::unique_ptr<HloModule> module, absl::Span<const Literal> arguments,
bool run_hlo_passes, ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
/*arguments=*/argument_pointers,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile);
}
StatusOr<Literal> HloRunnerInterface::ExecuteWithExecutable(
std::unique_ptr<Executable> executable, absl::Span<const Literal> arguments,
ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return ExecuteWithExecutable(std::move(executable), argument_pointers,
nullptr);
}
} // namespace xla