Allow collecting graph via RunMetadata in TFRT.
PiperOrigin-RevId: 339094080 Change-Id: Ic938770b9dc2f655af7f6e4e54dea74f8e75a18b
This commit is contained in:
parent
ae6c9dae43
commit
217f018519
@ -1447,13 +1447,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
|||||||
|
|
||||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
auto* context = tensorflow::unwrap(ctx);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
status->status = context->AsyncWait();
|
||||||
status->status = context->Executor().WaitForAllPendingNodes();
|
|
||||||
if (!status->status.ok()) return;
|
if (!status->status.ok()) return;
|
||||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
auto run_metadata = context->ExportRunMetadata();
|
||||||
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
status->status = MessageToBuffer(*run_metadata, buf);
|
||||||
context->ClearRunMetadata();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/tstring.h"
|
#include "tensorflow/core/platform/tstring.h"
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -124,6 +125,13 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// Returns the device placement policy for the current thread.
|
// Returns the device placement policy for the current thread.
|
||||||
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||||
|
|
||||||
|
// Configure graph collection in RunMetadata.
|
||||||
|
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||||
|
|
||||||
|
// Return the collected RunMetadata. This method will transfer the ownership
|
||||||
|
// to the caller.
|
||||||
|
virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
|
||||||
|
|
||||||
// For LLVM style RTTI.
|
// For LLVM style RTTI.
|
||||||
static bool classof(const AbstractContext* ptr) {
|
static bool classof(const AbstractContext* ptr) {
|
||||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||||
@ -149,9 +157,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// Update the Eager Executor for current thread.
|
// Update the Eager Executor for current thread.
|
||||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||||
|
|
||||||
// Configure graph collection in RunMetadata.
|
|
||||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||||
: AbstractContext(kind) {}
|
: AbstractContext(kind) {}
|
||||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/graph_def_util.h"
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
@ -108,6 +109,8 @@ EagerContext::EagerContext(
|
|||||||
this->thread_pool_->Schedule(std::move(closure));
|
this->thread_pool_->Schedule(std::move(closure));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
run_metadata_ = std::make_unique<RunMetadata>();
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
context_id_ = kInvalidContextId;
|
context_id_ = kInvalidContextId;
|
||||||
context_view_id_ = 0;
|
context_view_id_ = 0;
|
||||||
@ -577,7 +580,12 @@ const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
|
|||||||
return func_lib_def_.Find(name);
|
return func_lib_def_.Find(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
|
||||||
|
mutex_lock ml(metadata_mu_);
|
||||||
|
auto result = std::make_unique<RunMetadata>();
|
||||||
|
run_metadata_.swap(result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
bool EagerContext::UsesTFRT() { return false; }
|
bool EagerContext::UsesTFRT() { return false; }
|
||||||
|
|
||||||
@ -858,7 +866,7 @@ void EagerContext::SetShouldStoreGraphs(bool value) {
|
|||||||
mutex_lock ml(metadata_mu_);
|
mutex_lock ml(metadata_mu_);
|
||||||
should_store_graphs_.store(value);
|
should_store_graphs_.store(value);
|
||||||
if (!value) {
|
if (!value) {
|
||||||
run_metadata_.Clear();
|
run_metadata_.reset(new RunMetadata);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,12 +79,6 @@ namespace eager {
|
|||||||
class RemoteMgr;
|
class RemoteMgr;
|
||||||
} // namespace eager
|
} // namespace eager
|
||||||
|
|
||||||
class RunMetadataListener {
|
|
||||||
public:
|
|
||||||
virtual ~RunMetadataListener() {}
|
|
||||||
virtual void BeforeClearRunMetadata() = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class TensorHandle;
|
class TensorHandle;
|
||||||
class EagerOperation;
|
class EagerOperation;
|
||||||
|
|
||||||
@ -310,8 +304,11 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
||||||
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||||
void SetShouldStoreGraphs(bool value) override;
|
void SetShouldStoreGraphs(bool value) override;
|
||||||
RunMetadata* RunMetadataProto() { return &run_metadata_; }
|
RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
|
||||||
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
|
return run_metadata_.get();
|
||||||
|
}
|
||||||
|
std::unique_ptr<RunMetadata> ExportRunMetadata() override
|
||||||
|
TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||||
|
|
||||||
void StartStep() override;
|
void StartStep() override;
|
||||||
void EndStep() override;
|
void EndStep() override;
|
||||||
@ -587,7 +584,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
// Whether we should compute RunMetadata.
|
// Whether we should compute RunMetadata.
|
||||||
std::atomic<bool> should_store_graphs_{false};
|
std::atomic<bool> should_store_graphs_{false};
|
||||||
mutex metadata_mu_;
|
mutex metadata_mu_;
|
||||||
RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_);
|
std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
|
||||||
GraphCollector graph_collector_;
|
GraphCollector graph_collector_;
|
||||||
std::atomic<bool> log_device_placement_;
|
std::atomic<bool> log_device_placement_;
|
||||||
std::atomic<bool> allow_soft_placement_;
|
std::atomic<bool> allow_soft_placement_;
|
||||||
|
@ -75,9 +75,24 @@ class ContextTest(test.TestCase):
|
|||||||
del tensor2
|
del tensor2
|
||||||
self.assertIs(weak_c(), None)
|
self.assertIs(weak_c(), None)
|
||||||
|
|
||||||
@test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet')
|
|
||||||
def testSimpleGraphCollection(self):
|
def testSimpleGraphCollection(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x):
|
||||||
|
with ops.device('CPU:0'):
|
||||||
|
return x + constant_op.constant(1.)
|
||||||
|
|
||||||
|
with context.collect_graphs() as graphs:
|
||||||
|
f(constant_op.constant(1.))
|
||||||
|
|
||||||
|
self.assertLen(graphs, 1)
|
||||||
|
graph, = graphs
|
||||||
|
self.assertIn('CPU:0', graph.node[1].device)
|
||||||
|
|
||||||
|
@test_util.disable_tfrt(
|
||||||
|
'b/171600738: tfrt does not support exporting post-optimization graph')
|
||||||
|
def testGraphCollectionAfterDevicePlacement(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def f(x):
|
def f(x):
|
||||||
return x + constant_op.constant(1.)
|
return x + constant_op.constant(1.)
|
||||||
|
@ -1322,7 +1322,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
self.v, resource_variable_ops.ResourceVariable)
|
self.v, resource_variable_ops.ResourceVariable)
|
||||||
|
|
||||||
@test_util.disable_tfrt('b/169294215')
|
|
||||||
def testRunMetadata(self):
|
def testRunMetadata(self):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
|
Loading…
Reference in New Issue
Block a user