diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 15eeaa8066a..4de807bde31 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -131,6 +131,8 @@ class EagerContext { Device* HostCPU() { return devices_[0]; } + GraphCollector* GetGraphCollector() { return &graph_collector_; } + uint64 NextId() { return executor_.NextId(); } void ExecutorAdd(EagerNode* node) { executor_.Add(node); } @@ -249,6 +251,7 @@ class EagerContext { std::atomic should_store_metadata_{false}; mutex metadata_mu_; RunMetadata run_metadata_ GUARDED_BY(metadata_mu_); + GraphCollector graph_collector_; const bool log_device_placement_; // EagerExecutor for async execution. EagerExecutor executor_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 81e0cd1b71a..c5f1d52e43d 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -325,7 +325,9 @@ Status EagerLocalExecute(EagerOperation* op, if (!status.ok()) return status; std::unique_ptr maybe_stats; StepStats* maybe_step_stats = nullptr; + GraphCollector* graph_collector = nullptr; if (ctx->ShouldStoreMetadata()) { + graph_collector = ctx->GetGraphCollector(); maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats(); int64 now_nanos = Env::Default()->NowNanos(); maybe_stats.reset(new NodeExecStats); @@ -349,14 +351,14 @@ Status EagerLocalExecute(EagerOperation* op, } EagerNode* node = new ExecuteNode( id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(), - maybe_step_stats, output_dtypes, *retvals); + maybe_step_stats, graph_collector, output_dtypes, *retvals); ctx->ExecutorAdd(node); } else { // Execute checks if retvals[i] is nullptr or not to figure if it needs to // allocate it. - status = - EagerExecute(ctx, op->Device(), op->Inputs(), kernel, maybe_stats.get(), - maybe_step_stats, retvals->data(), *num_retvals); + status = EagerExecute(ctx, op->Device(), op->Inputs(), kernel, + maybe_stats.get(), maybe_step_stats, graph_collector, + retvals->data(), *num_retvals); } return status; @@ -710,7 +712,8 @@ Status EagerExecute(EagerOperation* op, Status EagerExecute(EagerContext* ctx, Device* device, const gtl::InlinedVector& op_inputs, KernelAndDevice* kernel, NodeExecStats* maybe_stats, - StepStats* maybe_step_stats, TensorHandle** retvals, + StepStats* maybe_step_stats, + GraphCollector* graph_collector, TensorHandle** retvals, int num_retvals) { if (device == nullptr) { // TODO(apassos) debug how the assignment below might return a different @@ -732,11 +735,11 @@ Status EagerExecute(EagerContext* ctx, Device* device, // TODO(agarwal): change Run to take vector of handles ? ScopedStepContainer* container = ctx->StepContainer(); if (container == nullptr) { - TF_RETURN_IF_ERROR( - kernel->Run(&inputs, &outputs, maybe_stats, maybe_step_stats)); + TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats, + maybe_step_stats, graph_collector)); } else { TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats, - maybe_step_stats)); + maybe_step_stats, graph_collector)); } if (maybe_stats != nullptr) { int64 nanos = Env::Default()->NowNanos(); @@ -748,6 +751,14 @@ Status EagerExecute(EagerContext* ctx, Device* device, maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos()); mutex_lock ml(*ctx->MetadataMu()); if (ctx->ShouldStoreMetadata()) { + { + GraphCollector* collector = ctx->GetGraphCollector(); + mutex_lock mll(collector->mu); + for (const auto& graph : collector->graphs) { + *ctx->RunMetadataProto()->add_partition_graphs() = graph; + } + collector->graphs.clear(); + } auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats(); // Lazily initialize the RunMetadata with information about all devices if // this is the first call. diff --git a/tensorflow/core/common_runtime/eager/execute.h b/tensorflow/core/common_runtime/eager/execute.h index 0e997bdfa9d..6143a52d4b9 100644 --- a/tensorflow/core/common_runtime/eager/execute.h +++ b/tensorflow/core/common_runtime/eager/execute.h @@ -46,7 +46,8 @@ Status EagerExecute( Status EagerExecute(EagerContext* ctx, Device* device, const gtl::InlinedVector& op_inputs, KernelAndDevice* kernel, NodeExecStats* maybe_stats, - StepStats* maybe_step_stats, TensorHandle** retvals, + StepStats* maybe_step_stats, + GraphCollector* graph_collector, TensorHandle** retvals, int num_retvals); // Low-level utility to copy a tensor handle from one device to another. diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h index 18b1892f5d9..a99d509dd60 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.h +++ b/tensorflow/core/common_runtime/eager/execute_node.h @@ -34,7 +34,8 @@ class ExecuteNode : public EagerNode { ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device, const tensorflow::gtl::InlinedVector& inputs, KernelAndDevice* kernel, NodeExecStats* maybe_stats, - StepStats* maybe_step_stats, const DataTypeVector& output_dtypes, + StepStats* maybe_step_stats, GraphCollector* graph_collector, + const DataTypeVector& output_dtypes, const tensorflow::gtl::InlinedVector& retvals) : EagerNode(id), ctx_(ctx), @@ -43,6 +44,7 @@ class ExecuteNode : public EagerNode { kernel_(kernel), maybe_stats_(maybe_stats), maybe_step_stats_(maybe_step_stats), + graph_collector_(graph_collector), retvals_(retvals) { for (auto handle : inputs_) { handle->Ref(); @@ -62,9 +64,9 @@ class ExecuteNode : public EagerNode { } tensorflow::Status Run() override { - const Status status = - EagerExecute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(), - maybe_step_stats_, retvals_.begin(), retvals_.size()); + const Status status = EagerExecute( + ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(), + maybe_step_stats_, graph_collector_, retvals_.begin(), retvals_.size()); if (status.ok()) { return status; } else { @@ -82,6 +84,7 @@ class ExecuteNode : public EagerNode { tensorflow::KernelAndDevice* kernel_; std::unique_ptr maybe_stats_; StepStats* maybe_step_stats_; + tensorflow::GraphCollector* graph_collector_; tensorflow::gtl::InlinedVector retvals_; }; diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 0adfcd7697d..b63257907f6 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -48,17 +48,20 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, Status KernelAndDevice::Run(std::vector* inputs, std::vector* outputs, NodeExecStats* stats, - StepStats* step_stats) { + StepStats* step_stats, + GraphCollector* graph_collector) { ScopedStepContainer step_container(0, [this](const string& name) { device_->resource_manager()->Cleanup(name).IgnoreError(); }); - return this->Run(&step_container, inputs, outputs, stats, step_stats); + return this->Run(&step_container, inputs, outputs, stats, step_stats, + graph_collector); } Status KernelAndDevice::Run(ScopedStepContainer* step_container, std::vector* inputs, std::vector* outputs, NodeExecStats* stats, - StepStats* step_stats) { + StepStats* step_stats, + GraphCollector* graph_collector) { gtl::InlinedVector input_vector; for (Tensor& t : *inputs) { input_vector.push_back(TensorValue(&t)); @@ -87,6 +90,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, step_stats_collector.reset(new StepStatsCollector(step_stats)); params.track_allocations = true; params.stats_collector = step_stats_collector.get(); + params.graph_collector = graph_collector; } if (runner_ == nullptr) { params.runner = &default_runner_; diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index cbfa0af5074..ac9143b253a 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -62,11 +62,12 @@ class KernelAndDevice { // TODO(ashankar): Handle list-valued inputs. Status Run(std::vector* inputs, std::vector* outputs, - NodeExecStats* stats, StepStats* step_stats); + NodeExecStats* stats, StepStats* step_stats, + GraphCollector* graph_collector); Status Run(ScopedStepContainer* step_container, std::vector* inputs, std::vector* outputs, NodeExecStats* stats, - StepStats* step_stats); + StepStats* step_stats, GraphCollector* graph_collector); const OpKernel* kernel() const { return kernel_.get(); } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index fbe0f46f8a5..948bdbcaf53 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -132,7 +132,7 @@ void BM_KernelAndDeviceRun(int iters) { nullptr, &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr)); + TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr, nullptr)); } } BENCHMARK(BM_KernelAndDeviceRun); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 3b1f57a4571..aae94e50e9e 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -489,6 +490,17 @@ struct TensorValue { Tensor* tensor; }; +// Used to store partitioned graphs from function-calling ops. +struct GraphCollector { + mutex mu; + std::vector graphs GUARDED_BY(mu); + + void CollectGraph(const GraphDef& graph) { + mutex_lock ml(mu); + graphs.push_back(graph); + } +}; + class OpKernelContext { public: // The first element of a WrappedAllocator is a "base" Allocator and @@ -589,6 +601,7 @@ class OpKernelContext { FunctionLibraryRuntime* function_library = nullptr; std::function)>* runner = nullptr; StepStatsCollectorInterface* stats_collector = nullptr; + GraphCollector* graph_collector = nullptr; // TensorSliceReaderCache support. checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; @@ -711,6 +724,9 @@ class OpKernelContext { // Usage: if (!context->ValidateInputsAreSameShape(this)) return; bool ValidateInputsAreSameShape(OpKernel* op); + // If non-null, kernels should populate with any partition subgraphs created. + GraphCollector* graph_collector() { return params_->graph_collector; } + // Input to output forwarding. // Set the output Ref Tensor at output_index to be an alias of the diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 0ec14f7a2a1..37739e0f921 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -183,6 +183,13 @@ class PartitionedCallOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), done); + if (ctx->graph_collector() != nullptr) { + for (const auto& pair : subgraphs) { + GraphDef def; + pair.second->ToGraphDef(&def); + ctx->graph_collector()->CollectGraph(def); + } + } optimization_options.graph = nullptr; optimization_options.device_set = nullptr; optimization_options.partition_graphs = &subgraphs; diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index a3731cf81e8..73d787b8be7 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -630,7 +630,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): return x * x with ops.device('cpu:0'): - f(constant_op.constant(1.0)) # pre-build the defun context.enable_run_metadata() f(constant_op.constant(1.0)) run_metadata = context.export_run_metadata() @@ -645,6 +644,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): # arbitrarily many (placeholders, return identities, etc, might be included # or not in the future, so shouldn't be tested for exactly. self.assertGreaterEqual(len(cpu_stats.node_stats), 2) + self.assertEqual(len(run_metadata.partition_graphs), 1) def testGraphModeCaptureVariable(self): with context.graph_mode(), self.cached_session() as sess: