Allow collecting graph via RunMetadata in TFRT.
PiperOrigin-RevId: 339094080 Change-Id: Ic938770b9dc2f655af7f6e4e54dea74f8e75a18b
This commit is contained in:
parent
ae6c9dae43
commit
217f018519
@ -1447,13 +1447,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->Executor().WaitForAllPendingNodes();
|
||||
auto* context = tensorflow::unwrap(ctx);
|
||||
status->status = context->AsyncWait();
|
||||
if (!status->status.ok()) return;
|
||||
tensorflow::mutex_lock ml(*context->MetadataMu());
|
||||
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
||||
context->ClearRunMetadata();
|
||||
auto run_metadata = context->ExportRunMetadata();
|
||||
status->status = MessageToBuffer(*run_metadata, buf);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -124,6 +125,13 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// Returns the device placement policy for the current thread.
|
||||
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||
|
||||
// Configure graph collection in RunMetadata.
|
||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||
|
||||
// Return the collected RunMetadata. This method will transfer the ownership
|
||||
// to the caller.
|
||||
virtual std::unique_ptr<RunMetadata> ExportRunMetadata() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
@ -149,9 +157,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
// Configure graph collection in RunMetadata.
|
||||
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
@ -108,6 +109,8 @@ EagerContext::EagerContext(
|
||||
this->thread_pool_->Schedule(std::move(closure));
|
||||
};
|
||||
|
||||
run_metadata_ = std::make_unique<RunMetadata>();
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
context_id_ = kInvalidContextId;
|
||||
context_view_id_ = 0;
|
||||
@ -577,7 +580,12 @@ const FunctionDef* EagerContext::FindFunctionDef(const string& name) const {
|
||||
return func_lib_def_.Find(name);
|
||||
}
|
||||
|
||||
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
||||
std::unique_ptr<RunMetadata> EagerContext::ExportRunMetadata() {
|
||||
mutex_lock ml(metadata_mu_);
|
||||
auto result = std::make_unique<RunMetadata>();
|
||||
run_metadata_.swap(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool EagerContext::UsesTFRT() { return false; }
|
||||
|
||||
@ -858,7 +866,7 @@ void EagerContext::SetShouldStoreGraphs(bool value) {
|
||||
mutex_lock ml(metadata_mu_);
|
||||
should_store_graphs_.store(value);
|
||||
if (!value) {
|
||||
run_metadata_.Clear();
|
||||
run_metadata_.reset(new RunMetadata);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -79,12 +79,6 @@ namespace eager {
|
||||
class RemoteMgr;
|
||||
} // namespace eager
|
||||
|
||||
class RunMetadataListener {
|
||||
public:
|
||||
virtual ~RunMetadataListener() {}
|
||||
virtual void BeforeClearRunMetadata() = 0;
|
||||
};
|
||||
|
||||
class TensorHandle;
|
||||
class EagerOperation;
|
||||
|
||||
@ -310,8 +304,11 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
|
||||
bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||
void SetShouldStoreGraphs(bool value) override;
|
||||
RunMetadata* RunMetadataProto() { return &run_metadata_; }
|
||||
void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
|
||||
RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) {
|
||||
return run_metadata_.get();
|
||||
}
|
||||
std::unique_ptr<RunMetadata> ExportRunMetadata() override
|
||||
TF_LOCKS_EXCLUDED(metadata_mu_);
|
||||
|
||||
void StartStep() override;
|
||||
void EndStep() override;
|
||||
@ -587,7 +584,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
// Whether we should compute RunMetadata.
|
||||
std::atomic<bool> should_store_graphs_{false};
|
||||
mutex metadata_mu_;
|
||||
RunMetadata run_metadata_ TF_GUARDED_BY(metadata_mu_);
|
||||
std::unique_ptr<RunMetadata> run_metadata_ TF_GUARDED_BY(metadata_mu_);
|
||||
GraphCollector graph_collector_;
|
||||
std::atomic<bool> log_device_placement_;
|
||||
std::atomic<bool> allow_soft_placement_;
|
||||
|
@ -75,9 +75,24 @@ class ContextTest(test.TestCase):
|
||||
del tensor2
|
||||
self.assertIs(weak_c(), None)
|
||||
|
||||
@test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet')
|
||||
def testSimpleGraphCollection(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
with ops.device('CPU:0'):
|
||||
return x + constant_op.constant(1.)
|
||||
|
||||
with context.collect_graphs() as graphs:
|
||||
f(constant_op.constant(1.))
|
||||
|
||||
self.assertLen(graphs, 1)
|
||||
graph, = graphs
|
||||
self.assertIn('CPU:0', graph.node[1].device)
|
||||
|
||||
@test_util.disable_tfrt(
|
||||
'b/171600738: tfrt does not support exporting post-optimization graph')
|
||||
def testGraphCollectionAfterDevicePlacement(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return x + constant_op.constant(1.)
|
||||
|
@ -1322,7 +1322,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsInstance(
|
||||
self.v, resource_variable_ops.ResourceVariable)
|
||||
|
||||
@test_util.disable_tfrt('b/169294215')
|
||||
def testRunMetadata(self):
|
||||
|
||||
@def_function.function
|
||||
|
Loading…
Reference in New Issue
Block a user