STT-tensorflow/tensorflow/compiler/xla/service/hlo_runner_interface.cc
A. Unique TensorFlower c0141706a4 Internal build rule change.
PiperOrigin-RevId: 345134366
Change-Id: Idfa2b6983b3a9aaeaaa2db4a8e62b73c2533bf0c
2020-12-01 17:46:15 -08:00

104 lines
3.9 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_runner_interface.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::CreateModuleFromString(const absl::string_view hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
namespace {
// Creates an HloModule from the given proto.
StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
const HloProto& proto, const DebugOptions& debug_options) {
TF_ASSIGN_OR_RETURN(HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto.hlo_module(),
debug_options));
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(proto.hlo_module(), config));
return std::move(module);
}
} // namespace
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromBinaryProtoFile(
const std::string& filename, const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromTextProtoFile(
const std::string& filename, const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(
tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
return HloProtoToModule(proto, debug_options);
}
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunnerInterface::ReadModuleFromHloTextFile(
const std::string& filename, const DebugOptions& debug_options) {
string hlo_string;
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
filename, &hlo_string));
HloModuleConfig config;
config.set_debug_options(debug_options);
return ParseAndReturnUnverifiedModule(hlo_string, config);
}
StatusOr<Literal> HloRunnerInterface::Execute(
std::unique_ptr<HloModule> module, absl::Span<const Literal> arguments,
bool run_hlo_passes, ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return Execute(
/*module=*/std::move(module),
/*arguments=*/argument_pointers,
/*run_hlo_passes=*/run_hlo_passes,
/*profile=*/profile);
}
StatusOr<Literal> HloRunnerInterface::ExecuteWithExecutable(
std::unique_ptr<Executable> executable, absl::Span<const Literal> arguments,
ExecutionProfile* profile) {
// Construct a vector of plain pointers for the arguments.
std::vector<const Literal*> argument_pointers;
argument_pointers.reserve(arguments.size());
for (const auto& argument : arguments) {
argument_pointers.push_back(&argument);
}
return ExecuteWithExecutable(std::move(executable), argument_pointers,
nullptr);
}
} // namespace xla