Allow collecting graph via RunMetadata in TFRT.

PiperOrigin-RevId: 339094080
Change-Id: Ic938770b9dc2f655af7f6e4e54dea74f8e75a18b
This commit is contained in:
Xiao Yu 2020-10-26 12:32:39 -07:00 committed by TensorFlower Gardener
parent ae6c9dae43
commit 217f018519
6 changed files with 44 additions and 22 deletions

View File

@ -1447,13 +1447,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->Executor().WaitForAllPendingNodes();
auto* context = tensorflow::unwrap(ctx);
status->status = context->AsyncWait();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*context->MetadataMu());
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
context->ClearRunMetadata();
auto run_metadata = context->ExportRunMetadata();
status->status = MessageToBuffer(*run_metadata, buf);
}
namespace {

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -124,6 +125,13 @@ class ImmediateExecutionContext : public AbstractContext {
// Returns the device placement policy for the current thread.
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
// Return the collected RunMetadata. This method will transfer the ownership
// to the caller.
virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
@ -149,9 +157,6 @@ class ImmediateExecutionContext : public AbstractContext {
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
protected:
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#if !defined(IS_MOBILE_PLATFORM)
@ -108,6 +109,8 @@ EagerContext::EagerContext(
this->thread_pool_->Schedule(std::move(closure));
};
run_metadata_ = std::make_unique<RunMetadata>();
#if !defined(IS_MOBILE_PLATFORM)
context_id_ = kInvalidContextId;
context_view_id_ = 0;
@ -577,7 +580,12 @@ const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
return func_lib_def_.Find(name);
}
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
mutex_lock ml(metadata_mu_);
auto result = std::make_unique<RunMetadata>();
run_metadata_.swap(result);
return result;
}
bool EagerContext::UsesTFRT() { return false; }
@ -858,7 +866,7 @@ void EagerContext::SetShouldStoreGraphs(bool value) {
mutex_lock ml(metadata_mu_);
should_store_graphs_.store(value);
if (!value) {
run_metadata_.Clear();
run_metadata_.reset(new RunMetadata);
}
}

View File

@ -79,12 +79,6 @@ namespace eager {
class RemoteMgr;
} // namespace eager
class RunMetadataListener {
public:
virtual ~RunMetadataListener() {}
virtual void BeforeClearRunMetadata() = 0;
};
class TensorHandle;
class EagerOperation;
@ -310,8 +304,11 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
void SetShouldStoreGraphs(bool value) override;
RunMetadata* RunMetadataProto() { return &run_metadata_; }
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
return run_metadata_.get();
}
std::unique_ptr<RunMetadata> ExportRunMetadata() override
TF_LOCKS_EXCLUDED(metadata_mu_);
void StartStep() override;
void EndStep() override;
@ -587,7 +584,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
// Whether we should compute RunMetadata.
std::atomic<bool> should_store_graphs_{false};
mutex metadata_mu_;
RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_);
std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
GraphCollector graph_collector_;
std::atomic<bool> log_device_placement_;
std::atomic<bool> allow_soft_placement_;

View File

@ -75,9 +75,24 @@ class ContextTest(test.TestCase):
del tensor2
self.assertIs(weak_c(), None)
@test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet')
def testSimpleGraphCollection(self):
@def_function.function
def f(x):
with ops.device('CPU:0'):
return x + constant_op.constant(1.)
with context.collect_graphs() as graphs:
f(constant_op.constant(1.))
self.assertLen(graphs, 1)
graph, = graphs
self.assertIn('CPU:0', graph.node[1].device)
@test_util.disable_tfrt(
'b/171600738: tfrt does not support exporting post-optimization graph')
def testGraphCollectionAfterDevicePlacement(self):
@def_function.function
def f(x):
return x + constant_op.constant(1.)

View File

@ -1322,7 +1322,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(
self.v, resource_variable_ops.ResourceVariable)
@test_util.disable_tfrt('b/169294215')
def testRunMetadata(self):
@def_function.function