From 217f018519a3121ba6bfcb38cb9aff59e275e308 Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Mon, 26 Oct 2020 12:32:39 -0700 Subject: [PATCH] Allow collecting graph via RunMetadata in TFRT. PiperOrigin-RevId: 339094080 Change-Id: Ic938770b9dc2f655af7f6e4e54dea74f8e75a18b --- tensorflow/c/eager/c_api.cc | 10 ++++------ .../c/eager/immediate_execution_context.h | 11 ++++++++--- tensorflow/core/common_runtime/eager/context.cc | 12 ++++++++++-- tensorflow/core/common_runtime/eager/context.h | 15 ++++++--------- tensorflow/python/eager/context_test.py | 17 ++++++++++++++++- tensorflow/python/eager/function_test.py | 1 - 6 files changed, 44 insertions(+), 22 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3418bccf050..9c73d1aba8c 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1447,13 +1447,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->Executor().WaitForAllPendingNodes(); + auto* context = tensorflow::unwrap(ctx); + status->status = context->AsyncWait(); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(*context->MetadataMu()); - status->status = MessageToBuffer(*context->RunMetadataProto(), buf); - context->ClearRunMetadata(); + auto run_metadata = context->ExportRunMetadata(); + status->status = MessageToBuffer(*run_metadata, buf); } namespace { diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index a3e3857b34b..27fa17127b8 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -124,6 +125,13 @@ class ImmediateExecutionContext : public AbstractContext { // Returns the device placement policy for the current thread. virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; + // Configure graph collection in RunMetadata. + virtual void SetShouldStoreGraphs(bool value) = 0; + + // Return the collected RunMetadata. This method will transfer the ownership + // to the caller. + virtual std::unique_ptr ExportRunMetadata() = 0; + // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; @@ -149,9 +157,6 @@ class ImmediateExecutionContext : public AbstractContext { // Update the Eager Executor for current thread. virtual void SetExecutorForThread(EagerExecutor* executor) = 0; - // Configure graph collection in RunMetadata. - virtual void SetShouldStoreGraphs(bool value) = 0; - protected: explicit ImmediateExecutionContext(AbstractContextKind kind) : AbstractContext(kind) {} diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 757ac1f7783..d7c1359a3e1 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" #if !defined(IS_MOBILE_PLATFORM) @@ -108,6 +109,8 @@ EagerContext::EagerContext( this->thread_pool_->Schedule(std::move(closure)); }; + run_metadata_ = std::make_unique(); + #if !defined(IS_MOBILE_PLATFORM) context_id_ = kInvalidContextId; context_view_id_ = 0; @@ -577,7 +580,12 @@ const FunctionDef* EagerContext::FindFunctionDef(const string& name) const { return func_lib_def_.Find(name); } -void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); } +std::unique_ptr EagerContext::ExportRunMetadata() { + mutex_lock ml(metadata_mu_); + auto result = std::make_unique(); + run_metadata_.swap(result); + return result; +} bool EagerContext::UsesTFRT() { return false; } @@ -858,7 +866,7 @@ void EagerContext::SetShouldStoreGraphs(bool value) { mutex_lock ml(metadata_mu_); should_store_graphs_.store(value); if (!value) { - run_metadata_.Clear(); + run_metadata_.reset(new RunMetadata); } } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index f48da696d48..62093dcc1d0 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -79,12 +79,6 @@ namespace eager { class RemoteMgr; } // namespace eager -class RunMetadataListener { - public: - virtual ~RunMetadataListener() {} - virtual void BeforeClearRunMetadata() = 0; -}; - class TensorHandle; class EagerOperation; @@ -310,8 +304,11 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_); void SetShouldStoreGraphs(bool value) override; - RunMetadata* RunMetadataProto() { return &run_metadata_; } - void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_); + RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) { + return run_metadata_.get(); + } + std::unique_ptr ExportRunMetadata() override + TF_LOCKS_EXCLUDED(metadata_mu_); void StartStep() override; void EndStep() override; @@ -587,7 +584,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // Whether we should compute RunMetadata. std::atomic should_store_graphs_{false}; mutex metadata_mu_; - RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_); + std::unique_ptr run_metadata_ TF_GUARDED_BY(metadata_mu_); GraphCollector graph_collector_; std::atomic log_device_placement_; std::atomic allow_soft_placement_; diff --git a/tensorflow/python/eager/context_test.py b/tensorflow/python/eager/context_test.py index fe86104cc0b..4ee59ff484a 100644 --- a/tensorflow/python/eager/context_test.py +++ b/tensorflow/python/eager/context_test.py @@ -75,9 +75,24 @@ class ContextTest(test.TestCase): del tensor2 self.assertIs(weak_c(), None) - @test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet') def testSimpleGraphCollection(self): + @def_function.function + def f(x): + with ops.device('CPU:0'): + return x + constant_op.constant(1.) + + with context.collect_graphs() as graphs: + f(constant_op.constant(1.)) + + self.assertLen(graphs, 1) + graph, = graphs + self.assertIn('CPU:0', graph.node[1].device) + + @test_util.disable_tfrt( + 'b/171600738: tfrt does not support exporting post-optimization graph') + def testGraphCollectionAfterDevicePlacement(self): + @def_function.function def f(x): return x + constant_op.constant(1.) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index de7557e848f..a584a6cc4fa 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1322,7 +1322,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertIsInstance( self.v, resource_variable_ops.ResourceVariable) - @test_util.disable_tfrt('b/169294215') def testRunMetadata(self): @def_function.function