STT-tensorflow/tensorflow/compiler/tf2xla/xla_compiler.cc
Peter Hawkins 8eb161e391 Add support for passes that run post-partitioning to OptimizationRegistry.
To avoid another GraphDef -> Graph -> GraphDef conversion, change Device::MaybeRewriteGraph to take a Graph instead of a GraphDef.

Use std::unique_ptr<> in more places to avoid some awkward .release() magic.
Change: 144532446
2017-01-14 07:45:34 -08:00

415 lines
16 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include <numeric>
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
bool HasRetval(const Graph& graph) {
for (const Node* n : graph.nodes()) {
if (n->type_string() == "_Retval") return true;
}
return false;
}
Status CheckSignature(const DataTypeVector& tf_types,
const xla::Shape& xla_shape) {
if (xla::ShapeUtil::IsTuple(xla_shape)) {
if (xla::ShapeUtil::TupleElementCount(xla_shape) != tf_types.size()) {
return errors::Internal("XLA shape has ",
xla::ShapeUtil::TupleElementCount(xla_shape),
" elements while function has ", tf_types.size());
}
for (int i = 0; i < tf_types.size(); ++i) {
xla::PrimitiveType type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[i], &type));
if (type !=
xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type()) {
return errors::Internal(
"element ", i, " has XLA type ",
xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type(),
" and TensorFlow type ", DataTypeString(tf_types[i]));
}
}
} else {
if (tf_types.size() != 1) {
return errors::Internal("Expected singleton type, got ", tf_types.size(),
" types");
}
xla::PrimitiveType type;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[0], &type));
if (type != xla_shape.element_type()) {
return errors::Internal("singleton element has XLA type ",
xla_shape.element_type(), " and TensorFlow type ",
DataTypeString(tf_types[0]));
}
}
return Status::OK();
}
} // namespace
XlaCompiler::XlaCompiler(const XlaCompiler::Options& options)
: client_(options.client),
allow_cpu_custom_calls_(options.allow_cpu_custom_calls),
local_executable_has_hybrid_result_(
options.local_executable_has_hybrid_result),
next_step_id_(1),
device_(new XlaCompilationDevice(SessionOptions(), options.device_type)),
device_mgr_({device_}) {}
XlaCompiler::~XlaCompiler() = default;
int64 XlaCompiler::NextStepId() {
mutex_lock l(mu_);
return next_step_id_++;
}
Status XlaCompiler::CompileFunction(
FunctionLibraryRuntime* flr, const NameAttrList& function,
const std::vector<XlaCompiler::Argument>& args,
XlaCompiler::CompilationResult* result) {
const string function_id = Canonicalize(function.name(), function.attr());
VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(
flr->Instantiate(function.name(), function.attr(), &handle));
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK(fbody);
return CompileFunctionBody(flr, *fbody, function_id, args,
/*use_tuple_arg=*/false, result);
}
Status XlaCompiler::CompileSubComputation(FunctionLibraryRuntime* flr,
const NameAttrList& function,
const xla::Shape& input_shape,
const xla::Shape& output_shape,
xla::Computation* computation) {
const string function_id = Canonicalize(function.name(), function.attr());
VLOG(1) << "XlaCompiler::CompileSubComputation " << function_id;
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(
flr->Instantiate(function.name(), function.attr(), &handle));
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK(fbody);
TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, input_shape));
TF_RETURN_IF_ERROR(CheckSignature(fbody->ret_types, output_shape));
const bool use_tuple_arg = xla::ShapeUtil::IsTuple(input_shape);
std::vector<XlaCompiler::Argument> args(fbody->arg_types.size());
if (use_tuple_arg) {
for (int i = 0; i < args.size(); ++i) {
xla::Shape xla_shape =
xla::ShapeUtil::GetTupleElementShape(input_shape, i);
args[i].type = fbody->arg_types[i];
args[i].shape = XLAShapeToTensorShape(xla_shape);
args[i].parameter = i;
}
} else {
args[0].type = fbody->arg_types[0];
args[0].shape = XLAShapeToTensorShape(input_shape);
args[0].parameter = 0;
}
CompilationResult result;
TF_RETURN_IF_ERROR(CompileFunctionBody(flr, *fbody, function_id, args,
use_tuple_arg, &result));
if (!xla::ShapeUtil::Compatible(result.xla_output_shape, output_shape)) {
return errors::Internal("output shape mismatch from compilation");
}
*computation = std::move(result.computation);
return Status::OK();
}
Status XlaCompiler::CompileFunctionBody(
FunctionLibraryRuntime* flr, const FunctionBody& fbody,
const string& function_id, const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg, XlaCompiler::CompilationResult* result) {
VLOG(1) << "XlaCompiler::CompileFunctionBody " << function_id;
std::unique_ptr<Graph> graph(new Graph(flr->GetFunctionLibraryDefinition()));
CopyGraph(*fbody.graph, graph.get());
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile(
strings::StrCat("xla_jit_raw_input_", function_id), *graph);
}
if (!HasRetval(*graph)) {
VLOG(1) << "Graph has no retvals. Skipping compilation.";
return Status::OK();
}
// Optimize the graph to before running throught the translator.
// TODO(pbar) The constant folder currently does not simplify int32 operations
// for devices other than CPU.
OptimizerOptions opts;
GraphOptimizer optimizer(opts);
OptimizeGraph(flr, &graph);
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile(
strings::StrCat("xla_jit_final_graph_", function_id), *graph);
}
VLOG(1) << "====================================================";
TF_RETURN_IF_ERROR(CompileGraph(function_id, std::move(graph), flr, args,
use_tuple_arg, result));
VLOG(1) << "====================================================";
return Status::OK();
}
Status XlaCompiler::BuildExecutable(
const XlaCompiler::CompilationResult& result,
std::unique_ptr<xla::LocalExecutable>* executable) {
VLOG(2) << "Compiling to local executable";
xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
std::vector<const xla::Shape*> argument_layouts(
result.xla_input_shapes.size());
for (int i = 0; i < result.xla_input_shapes.size(); ++i) {
argument_layouts[i] = &result.xla_input_shapes[i].second;
}
if (result.requires_runtime_context) {
// The final arg is the XlaLocalRuntimeContext*.
argument_layouts.push_back(&opaque_shape);
}
xla::LocalClient* local_client = static_cast<xla::LocalClient*>(client());
xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(local_client->default_device_ordinal());
build_options.set_platform(local_client->platform());
build_options.set_result_layout(result.xla_output_shape);
build_options.set_has_hybrid_result(local_executable_has_hybrid_result_);
auto compile_result = local_client->Compile(result.computation,
argument_layouts, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*executable = std::move(compile_result.ValueOrDie());
return Status::OK();
}
namespace {
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
int64 step_id) {
// Resource cleanup is a bit messy. XlaContext is a ref-counted resource; the
// resource manager takes ownership via Create, and unrefs via Cleanup. We
// explicitly add a reference to ensure the refcount at entry is maintained at
// all exit points; Create and Cleanup are always called in this function.
//
// The Executor requires us to use ScopedStepContainer. We wrap it in a
// unique_ptr so we can capture the cleanup status in the end.
xla_context->Ref();
Status cleanup_status;
auto step_container = xla::MakeUnique<ScopedStepContainer>(
step_id, [&cleanup_status, device](const string& name) {
cleanup_status = device->resource_manager()->Cleanup(name);
});
TF_RETURN_IF_ERROR(device->resource_manager()->Create(
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
// Create a LocalExecutor that will own and run the graph.
LocalExecutorParams exec_params;
exec_params.device = device;
exec_params.function_library = flib;
exec_params.create_kernel = [flib](const NodeDef& ndef, OpKernel** kernel) {
return flib->CreateKernel(ndef, kernel);
};
exec_params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
Executor* exec_ptr = nullptr;
TF_RETURN_IF_ERROR(NewLocalExecutor(exec_params, graph.release(), &exec_ptr));
std::unique_ptr<Executor> exec(exec_ptr);
// At this point ownership of the graph has been transferred to exec.
auto runner = [](Executor::Args::Closure c) {
// TODO(misard) Temporarily just schedule c eagerly while we
// decide what to do about the fact that the ComputationBuilder is
// thread-compatible, but we don't really want Op writers to have
// to remember to acquire a lock around every call to
// ComputationBuilder. One possibility is to add the (generally
// useful) ability to run a single-threaded Executor based on an
// option in LocalExecutorParams. Another is to automagically
// acquire a lock around ComputationBuilder calls using some
// wrapper or RAII funny business.
c();
};
// Run the graph symbolically, turning the graph into an XLA computation.
Executor::Args exec_args;
exec_args.step_id = step_id;
exec_args.step_container = step_container.get();
exec_args.runner = runner;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
exec->Run(exec_args),
"Conversion from TensorFlow graph to XLA computation failed.");
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
return cleanup_status;
}
} // namespace
Status XlaCompiler::CompileGraph(string const& name,
std::unique_ptr<Graph> graph,
FunctionLibraryRuntime* flib,
const std::vector<XlaCompiler::Argument>& args,
bool use_tuple_arg,
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
// Converts the input shapes into xla::Shape instances.
result->xla_input_shapes.reserve(args.size());
for (int i = 0; i < args.size(); ++i) {
if (args[i].parameter < 0) {
continue;
}
result->xla_input_shapes.push_back(std::make_pair(i, xla::Shape()));
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
args[i].type, args[i].shape, &result->xla_input_shapes.back().second));
}
XlaContext* xla_context =
new XlaContext(this, client(), name, allow_cpu_custom_calls_);
core::ScopedUnref xla_context_unref(xla_context);
TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg));
TF_RETURN_IF_ERROR(
ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId()));
std::vector<XlaContext::ConstRetVal> compile_time_constants;
int num_nonconst_outputs;
TF_RETURN_IF_ERROR(xla_context->CollectResults(
&result->computation, &result->requires_runtime_context,
&compile_time_constants, &num_nonconst_outputs));
result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs);
for (const auto& c : compile_time_constants) {
if (!c.status.ok()) {
Status constant_status = c.status;
errors::AppendToMessage(&constant_status,
"Failed evaluating constant XLA return "
"value ",
c.index);
return constant_status;
}
if (c.index >= result->outputs.size()) {
return errors::InvalidArgument("Invalid argument index ", c.index);
}
OutputDescription& output = result->outputs[c.index];
output.shape = c.value.shape();
output.is_constant = true;
output.constant_value = c.value;
}
if (result->computation.IsNull()) {
return Status::OK();
}
// Compute the output shapes, if there is a computation with non-constant
// outputs.
auto computation_shape = client()->GetComputationShape(result->computation);
if (!computation_shape.ok()) {
return computation_shape.status();
}
result->xla_output_shape.Swap(
computation_shape.ValueOrDie()->mutable_result());
auto num_non_constant_outputs =
(xla::ShapeUtil::IsTuple(result->xla_output_shape))
? xla::ShapeUtil::TupleElementCount(result->xla_output_shape)
: 1;
// Tensorflow expects a major-to-minor order of results.
if (1 == num_non_constant_outputs) {
xla::Shape& s = result->xla_output_shape;
auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major();
minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0);
std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
} else {
for (xla::Shape& s : *result->xla_output_shape.mutable_tuple_shapes()) {
auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major();
minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0);
std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
}
}
// Converts the output shapes to TensorShapes.
int computation_output = 0;
for (int i = 0; i < result->outputs.size(); ++i) {
if (!result->outputs[i].is_constant) {
CHECK_LT(computation_output, num_non_constant_outputs);
if (num_non_constant_outputs > 1) {
result->outputs[i].shape =
XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
result->xla_output_shape, computation_output));
} else {
result->outputs[i].shape =
XLAShapeToTensorShape(result->xla_output_shape);
}
++computation_output;
}
}
return Status::OK();
}
Status XlaCompiler::GetChannelHandle(const string& key,
xla::ChannelHandle* channel) {
mutex_lock lock(mu_);
auto result = channels_.emplace(key, xla::ChannelHandle());
if (result.second) {
TF_ASSIGN_OR_RETURN(result.first->second, client_->CreateChannelHandle());
}
*channel = result.first->second;
return Status::OK();
}
} // namespace tensorflow