From 499b5288065ee9b90d7253bd0d6e9780b5959e66 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 13 Mar 2020 11:54:05 -0700 Subject: [PATCH] Move Graph creation from NodeDef logic to XlaCompilationCache from XlaCompiler XlaCompilationCache is the only user of single op compilation so we can move single op handling to the cache. This will allow MLIR based on demand compilation to reuse this logic in a follow-up change. PiperOrigin-RevId: 300799049 Change-Id: I50d3f258e815cbc2caa6315eff0d902695146537 --- .../compiler/jit/xla_compilation_cache.cc | 55 ++++++++++++++++++- tensorflow/compiler/tf2xla/xla_compiler.cc | 46 ---------------- tensorflow/compiler/tf2xla/xla_compiler.h | 8 --- 3 files changed, 53 insertions(+), 56 deletions(-) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 03a9a3ad3a4..5540fee7276 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/util.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" @@ -202,6 +204,52 @@ static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) { execution_count < kMinExecutionsPerCompile * compile_count; } +// Creates a simple graph using the specified op as the only op apart from the +// arg and retval nodes. +static xla::StatusOr> CreateGraph( + const NodeDef& node_def, absl::Span args, + absl::Span result_types) { + // TODO(b/74182462): We implement this by creating a new dummy Graph including + // _Arg nodes, and let CompileGraph walk it. This could be optimized. + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Status status; + // First create the actual node we care about computing. + Node* main_node = graph->AddNode(node_def, &status); + TF_RETURN_IF_ERROR(status); + + // Create dummy _Arg nodes. Link these to `node` and also via a control + // dependency edge to the _SOURCE node. + for (int64 i = 0; i < args.size(); ++i) { + Node* node; + string arg_name = absl::StrCat("_arg", i); + Status status = + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) + .ControlInput(graph->source_node()) + .Attr("T", args[i].kind == XlaCompiler::Argument::kResource + ? DT_RESOURCE + : args[i].type) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(node, 0, main_node, i); + } + + // Similarly with return values, create dummy _Retval nodes fed by `node`. + for (int64 i = 0; i < result_types.size(); ++i) { + Node* node; + string retval_name = absl::StrCat("_retval", i); + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) + .Input(main_node, i) + .Attr("T", result_types[i]) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + } + FixupSourceAndSinkEdges(graph.get()); + return graph; +} + Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, absl::Span args, OpKernelContext* ctx, @@ -222,8 +270,11 @@ Status XlaCompilationCache::CompileSingleOp( for (int i = 0; i < result_dtypes.size(); ++i) { result_dtypes[i] = ctx->expected_output_dtype(i); } - return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(), - args, result_dtypes, result); + + const NodeDef& node_def = ctx->op_kernel().def(); + TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3ea62882dcb..c30b1c0e17d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -44,7 +44,6 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" @@ -1174,51 +1173,6 @@ Status XlaCompiler::BuildArguments( return Status::OK(); } -Status XlaCompiler::CompileSingleOp( - const XlaCompiler::CompileOptions& options, const NodeDef& node_def, - absl::Span args, - absl::Span result_types, CompilationResult* result) { - // TODO(b/74182462): We implement this by creating a new dummy Graph including - // _Arg nodes, and let CompileGraph walk it. This could be optimized. - std::unique_ptr graph(new Graph(OpRegistry::Global())); - - Status status; - // First create the actual node we care about computing. - Node* main_node = graph->AddNode(node_def, &status); - TF_RETURN_IF_ERROR(status); - - // Create dummy _Arg nodes. Link these to `node` and also via a control - // dependency edge to the _SOURCE node. - for (int64 i = 0; i < args.size(); ++i) { - Node* node; - string arg_name = absl::StrCat("_arg", i); - Status status = - NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) - .ControlInput(graph->source_node()) - .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE - : args[i].type) - .Attr("index", i) - .Finalize(graph.get(), &node); - TF_RETURN_IF_ERROR(status); - graph->AddEdge(node, 0, main_node, i); - } - - // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64 i = 0; i < result_types.size(); ++i) { - Node* node; - string retval_name = absl::StrCat("_retval", i); - Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) - .Input(main_node, i) - .Attr("T", result_types[i]) - .Attr("index", i) - .Finalize(graph.get(), &node); - TF_RETURN_IF_ERROR(status); - } - FixupSourceAndSinkEdges(graph.get()); - - return CompileGraph(options, node_def.name(), std::move(graph), args, result); -} - namespace { // Check that the ops of all non-functional nodes have been registered. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 5ec5866632b..6a56136a9f6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -375,14 +375,6 @@ class XlaCompiler { std::unique_ptr graph, absl::Span args, CompilationResult* result); - // Compiles a single Op, given by `node_def`, into an - // xla::XlaComputation. Similar to CompileFunction but takes a single Op as - // input. - Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def, - absl::Span args, - absl::Span result_types, - CompilationResult* result); - // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention.