Move Graph creation from NodeDef logic to XlaCompilationCache from XlaCompiler

XlaCompilationCache is the only user of single op compilation so we can move single op handling to the cache. This will allow MLIR based on demand compilation to reuse this logic in a follow-up change.

PiperOrigin-RevId: 300799049
Change-Id: I50d3f258e815cbc2caa6315eff0d902695146537
This commit is contained in:
Smit Hinsu 2020-03-13 11:54:05 -07:00 committed by TensorFlower Gardener
parent 0ad8a52c4d
commit 499b528806
3 changed files with 53 additions and 56 deletions

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/util.h"
@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
@ -202,6 +204,52 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
execution_count < kMinExecutionsPerCompile * compile_count;
}
// Creates a simple graph using the specified op as the only op apart from the
// arg and retval nodes.
static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
absl::Span<const DataType> result_types) {
// TODO(b/74182462): We implement this by creating a new dummy Graph including
// _Arg nodes, and let CompileGraph walk it. This could be optimized.
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Status status;
// First create the actual node we care about computing.
Node* main_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status);
// Create dummy _Arg nodes. Link these to `node` and also via a control
// dependency edge to the _SOURCE node.
for (int64 i = 0; i < args.size(); ++i) {
Node* node;
string arg_name = absl::StrCat("_arg", i);
Status status =
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
.ControlInput(graph->source_node())
.Attr("T", args[i].kind == XlaCompiler::Argument::kResource
? DT_RESOURCE
: args[i].type)
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(node, 0, main_node, i);
}
// Similarly with return values, create dummy _Retval nodes fed by `node`.
for (int64 i = 0; i < result_types.size(); ++i) {
Node* node;
string retval_name = absl::StrCat("_retval", i);
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
.Input(main_node, i)
.Attr("T", result_types[i])
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
}
FixupSourceAndSinkEdges(graph.get());
return graph;
}
Status XlaCompilationCache::CompileSingleOp(
const XlaCompiler::Options& options,
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
@ -222,8 +270,11 @@ Status XlaCompilationCache::CompileSingleOp(
for (int i = 0; i < result_dtypes.size(); ++i) {
result_dtypes[i] = ctx->expected_output_dtype(i);
}
return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
args, result_dtypes, result);
const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -44,7 +44,6 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
@ -1174,51 +1173,6 @@ Status XlaCompiler::BuildArguments(
return Status::OK();
}
Status XlaCompiler::CompileSingleOp(
const XlaCompiler::CompileOptions& options, const NodeDef& node_def,
absl::Span<const XlaCompiler::Argument> args,
absl::Span<const DataType> result_types, CompilationResult* result) {
// TODO(b/74182462): We implement this by creating a new dummy Graph including
// _Arg nodes, and let CompileGraph walk it. This could be optimized.
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Status status;
// First create the actual node we care about computing.
Node* main_node = graph->AddNode(node_def, &status);
TF_RETURN_IF_ERROR(status);
// Create dummy _Arg nodes. Link these to `node` and also via a control
// dependency edge to the _SOURCE node.
for (int64 i = 0; i < args.size(); ++i) {
Node* node;
string arg_name = absl::StrCat("_arg", i);
Status status =
NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
.ControlInput(graph->source_node())
.Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
: args[i].type)
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(node, 0, main_node, i);
}
// Similarly with return values, create dummy _Retval nodes fed by `node`.
for (int64 i = 0; i < result_types.size(); ++i) {
Node* node;
string retval_name = absl::StrCat("_retval", i);
Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
.Input(main_node, i)
.Attr("T", result_types[i])
.Attr("index", i)
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
}
FixupSourceAndSinkEdges(graph.get());
return CompileGraph(options, node_def.name(), std::move(graph), args, result);
}
namespace {
// Check that the ops of all non-functional nodes have been registered.

View File

@ -375,14 +375,6 @@ class XlaCompiler {
std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
CompilationResult* result);
// Compiles a single Op, given by `node_def`, into an
// xla::XlaComputation. Similar to CompileFunction but takes a single Op as
// input.
Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def,
absl::Span<const Argument> args,
absl::Span<const DataType> result_types,
CompilationResult* result);
// Returns the shape of the XLA parameter for an argument 'arg'.
// See the class comment for more details about the argument passing
// convention.