STT-tensorflow/tensorflow/compiler/aot/compile.cc
Feng Liu 03c15fee32 Remove experimental xla hlo quantization from open source (I)
PiperOrigin-RevId: 305818950
Change-Id: Ie09f6a81bafb100ca3dd3601fd4f9a8a732677ca
2020-04-09 21:09:45 -07:00

268 lines
10 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/compile.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/quantize.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tfcompile {
static llvm::ManagedStatic<QuantizeXlaFn> quantize_xla;
bool RegisterQuantizeFn(const QuantizeXlaFn& fn) {
if (*quantize_xla) return false;
*quantize_xla = fn;
return true;
}
namespace {
// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
const xla::XlaComputation& computation,
const xla::cpu::CpuAotCompilationOptions& aot_opts,
CompileResult* compile_result) {
// Retrieves arg and result layouts from the computation.
// TODO(toddw): Should we let the user choose the major/minor ordering?
xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
client->GetComputationShape(computation);
if (!pshape_or.ok()) {
return errors::Unknown("Couldn't get XLA program shape: ",
pshape_or.status().error_message());
}
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
// AotXlaComputationInstance::argument_layouts is a vector of Shape
// pointers. Accumulate the Shape objects themselves in a separate vector
// while building the vector of pointers.
std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
arg_layout_ptrs[i] = &arg_layouts[i];
}
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layout_ptrs);
xla::Shape result_shape(pshape->result());
instance.result_layout = &result_shape;
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
if (!aot_or.ok()) {
return errors::Unknown("XLA compilation failed: ",
aot_or.status().error_message());
}
compile_result->aot =
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->pointer_size =
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK();
}
} // namespace
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
se::Platform* cpu_platform =
se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::XlaComputation computation;
if (flags.mlir_components == "Bridge") {
TF_RETURN_IF_ERROR(ConvertGraphDefToXlaViaMlir(
graph_def, config, &computation, flags.debug_info,
flags.debug_info_path_begin_marker));
} else if (flags.mlir_components.empty() || flags.mlir_components == "None") {
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
} else {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
if (flags.experimental_quantize && *quantize_xla) {
TF_RETURN_IF_ERROR((*quantize_xla)(config, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
computation.Snapshot());
// Serialize the HloSnapshot deterministically so that all the outputs of a
// tf_library genrule are deterministic.
const size_t size = module->ByteSizeLong();
auto serialized = absl::make_unique<char[]>(size);
TF_RET_CHECK(
SerializeToBufferDeterministic(*module, serialized.get(), size));
TF_RETURN_IF_ERROR(
WriteStringToFile(Env::Default(), flags.out_session_module,
absl::string_view(serialized.get(), size)));
}
xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
return CompileXla(client, computation, aot_opts, compile_result);
}
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
static absl::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
#if TF_LLVM_AARCH64_AVAILABLE
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
#endif
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
}
// Replaces {{tag.type tag.name}} in the error message with tag_name.
// TODO(bixia): We currently only handlge tag.type == "node".
//
// In the error message, a graph node is represented as {{tag.type, tag.name}},
// to allow a Python debugger to insert source information about the graph node.
// For example, a Python add expression may be represented as
// {{node, x_y_sum}} = Add(x, y) in the error message. See routine interpolate
// in tensorflow/python/framework/error_interpolation.py for more detail.
static std::string InterpolateErrorMessage(std::string message) {
// See _NAME_REGEX in tensorflow/python/framework/error_interpolation.py
// Change "prefix {{node tag.name}} suffix" to "prefix tag.name suffix".
static LazyRE2 pattern{"(.*){{node (.*)}}(.*)"};
RE2::GlobalReplace(&message, *pattern, "\\1\\2\\3");
return message;
}
Status Main(const MainFlags& flags) {
absl::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
Status status =
CompileGraph(std::move(graph_def), config, flags, &compile_result);
if (!status.ok()) {
return Status(status.code(),
InterpolateErrorMessage(status.error_message()));
}
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // namespace tfcompile
} // namespace tensorflow