Populate partition graphs in run metadata from partitioned call via eager.
Caveats: only populated the first time the function is called; not populated if runmetadata is requested from session.run. PiperOrigin-RevId: 217773737
This commit is contained in:
parent
8898af469c
commit
0136e95df8
tensorflow
core
common_runtime/eager
context.hexecute.ccexecute.hexecute_node.hkernel_and_device.cckernel_and_device.hkernel_and_device_test.cc
framework
kernels
python/eager
@ -131,6 +131,8 @@ class EagerContext {
|
|||||||
|
|
||||||
Device* HostCPU() { return devices_[0]; }
|
Device* HostCPU() { return devices_[0]; }
|
||||||
|
|
||||||
|
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
||||||
|
|
||||||
uint64 NextId() { return executor_.NextId(); }
|
uint64 NextId() { return executor_.NextId(); }
|
||||||
|
|
||||||
void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
|
void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
|
||||||
@ -249,6 +251,7 @@ class EagerContext {
|
|||||||
std::atomic<bool> should_store_metadata_{false};
|
std::atomic<bool> should_store_metadata_{false};
|
||||||
mutex metadata_mu_;
|
mutex metadata_mu_;
|
||||||
RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
|
RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
|
||||||
|
GraphCollector graph_collector_;
|
||||||
const bool log_device_placement_;
|
const bool log_device_placement_;
|
||||||
// EagerExecutor for async execution.
|
// EagerExecutor for async execution.
|
||||||
EagerExecutor executor_;
|
EagerExecutor executor_;
|
||||||
|
@ -325,7 +325,9 @@ Status EagerLocalExecute(EagerOperation* op,
|
|||||||
if (!status.ok()) return status;
|
if (!status.ok()) return status;
|
||||||
std::unique_ptr<NodeExecStats> maybe_stats;
|
std::unique_ptr<NodeExecStats> maybe_stats;
|
||||||
StepStats* maybe_step_stats = nullptr;
|
StepStats* maybe_step_stats = nullptr;
|
||||||
|
GraphCollector* graph_collector = nullptr;
|
||||||
if (ctx->ShouldStoreMetadata()) {
|
if (ctx->ShouldStoreMetadata()) {
|
||||||
|
graph_collector = ctx->GetGraphCollector();
|
||||||
maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats();
|
maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats();
|
||||||
int64 now_nanos = Env::Default()->NowNanos();
|
int64 now_nanos = Env::Default()->NowNanos();
|
||||||
maybe_stats.reset(new NodeExecStats);
|
maybe_stats.reset(new NodeExecStats);
|
||||||
@ -349,14 +351,14 @@ Status EagerLocalExecute(EagerOperation* op,
|
|||||||
}
|
}
|
||||||
EagerNode* node = new ExecuteNode(
|
EagerNode* node = new ExecuteNode(
|
||||||
id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(),
|
id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(),
|
||||||
maybe_step_stats, output_dtypes, *retvals);
|
maybe_step_stats, graph_collector, output_dtypes, *retvals);
|
||||||
ctx->ExecutorAdd(node);
|
ctx->ExecutorAdd(node);
|
||||||
} else {
|
} else {
|
||||||
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
|
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
|
||||||
// allocate it.
|
// allocate it.
|
||||||
status =
|
status = EagerExecute(ctx, op->Device(), op->Inputs(), kernel,
|
||||||
EagerExecute(ctx, op->Device(), op->Inputs(), kernel, maybe_stats.get(),
|
maybe_stats.get(), maybe_step_stats, graph_collector,
|
||||||
maybe_step_stats, retvals->data(), *num_retvals);
|
retvals->data(), *num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
return status;
|
return status;
|
||||||
@ -710,7 +712,8 @@ Status EagerExecute(EagerOperation* op,
|
|||||||
Status EagerExecute(EagerContext* ctx, Device* device,
|
Status EagerExecute(EagerContext* ctx, Device* device,
|
||||||
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||||
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
||||||
StepStats* maybe_step_stats, TensorHandle** retvals,
|
StepStats* maybe_step_stats,
|
||||||
|
GraphCollector* graph_collector, TensorHandle** retvals,
|
||||||
int num_retvals) {
|
int num_retvals) {
|
||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
// TODO(apassos) debug how the assignment below might return a different
|
// TODO(apassos) debug how the assignment below might return a different
|
||||||
@ -732,11 +735,11 @@ Status EagerExecute(EagerContext* ctx, Device* device,
|
|||||||
// TODO(agarwal): change Run to take vector of handles ?
|
// TODO(agarwal): change Run to take vector of handles ?
|
||||||
ScopedStepContainer* container = ctx->StepContainer();
|
ScopedStepContainer* container = ctx->StepContainer();
|
||||||
if (container == nullptr) {
|
if (container == nullptr) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats,
|
||||||
kernel->Run(&inputs, &outputs, maybe_stats, maybe_step_stats));
|
maybe_step_stats, graph_collector));
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats,
|
TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats,
|
||||||
maybe_step_stats));
|
maybe_step_stats, graph_collector));
|
||||||
}
|
}
|
||||||
if (maybe_stats != nullptr) {
|
if (maybe_stats != nullptr) {
|
||||||
int64 nanos = Env::Default()->NowNanos();
|
int64 nanos = Env::Default()->NowNanos();
|
||||||
@ -748,6 +751,14 @@ Status EagerExecute(EagerContext* ctx, Device* device,
|
|||||||
maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
|
maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
|
||||||
mutex_lock ml(*ctx->MetadataMu());
|
mutex_lock ml(*ctx->MetadataMu());
|
||||||
if (ctx->ShouldStoreMetadata()) {
|
if (ctx->ShouldStoreMetadata()) {
|
||||||
|
{
|
||||||
|
GraphCollector* collector = ctx->GetGraphCollector();
|
||||||
|
mutex_lock mll(collector->mu);
|
||||||
|
for (const auto& graph : collector->graphs) {
|
||||||
|
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
|
||||||
|
}
|
||||||
|
collector->graphs.clear();
|
||||||
|
}
|
||||||
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
|
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
|
||||||
// Lazily initialize the RunMetadata with information about all devices if
|
// Lazily initialize the RunMetadata with information about all devices if
|
||||||
// this is the first call.
|
// this is the first call.
|
||||||
|
@ -46,7 +46,8 @@ Status EagerExecute(
|
|||||||
Status EagerExecute(EagerContext* ctx, Device* device,
|
Status EagerExecute(EagerContext* ctx, Device* device,
|
||||||
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||||
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
||||||
StepStats* maybe_step_stats, TensorHandle** retvals,
|
StepStats* maybe_step_stats,
|
||||||
|
GraphCollector* graph_collector, TensorHandle** retvals,
|
||||||
int num_retvals);
|
int num_retvals);
|
||||||
|
|
||||||
// Low-level utility to copy a tensor handle from one device to another.
|
// Low-level utility to copy a tensor handle from one device to another.
|
||||||
|
@ -34,7 +34,8 @@ class ExecuteNode : public EagerNode {
|
|||||||
ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device,
|
ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device,
|
||||||
const tensorflow::gtl::InlinedVector<TensorHandle*, 4>& inputs,
|
const tensorflow::gtl::InlinedVector<TensorHandle*, 4>& inputs,
|
||||||
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
||||||
StepStats* maybe_step_stats, const DataTypeVector& output_dtypes,
|
StepStats* maybe_step_stats, GraphCollector* graph_collector,
|
||||||
|
const DataTypeVector& output_dtypes,
|
||||||
const tensorflow::gtl::InlinedVector<TensorHandle*, 2>& retvals)
|
const tensorflow::gtl::InlinedVector<TensorHandle*, 2>& retvals)
|
||||||
: EagerNode(id),
|
: EagerNode(id),
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
@ -43,6 +44,7 @@ class ExecuteNode : public EagerNode {
|
|||||||
kernel_(kernel),
|
kernel_(kernel),
|
||||||
maybe_stats_(maybe_stats),
|
maybe_stats_(maybe_stats),
|
||||||
maybe_step_stats_(maybe_step_stats),
|
maybe_step_stats_(maybe_step_stats),
|
||||||
|
graph_collector_(graph_collector),
|
||||||
retvals_(retvals) {
|
retvals_(retvals) {
|
||||||
for (auto handle : inputs_) {
|
for (auto handle : inputs_) {
|
||||||
handle->Ref();
|
handle->Ref();
|
||||||
@ -62,9 +64,9 @@ class ExecuteNode : public EagerNode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status Run() override {
|
tensorflow::Status Run() override {
|
||||||
const Status status =
|
const Status status = EagerExecute(
|
||||||
EagerExecute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
|
ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
|
||||||
maybe_step_stats_, retvals_.begin(), retvals_.size());
|
maybe_step_stats_, graph_collector_, retvals_.begin(), retvals_.size());
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
return status;
|
return status;
|
||||||
} else {
|
} else {
|
||||||
@ -82,6 +84,7 @@ class ExecuteNode : public EagerNode {
|
|||||||
tensorflow::KernelAndDevice* kernel_;
|
tensorflow::KernelAndDevice* kernel_;
|
||||||
std::unique_ptr<NodeExecStats> maybe_stats_;
|
std::unique_ptr<NodeExecStats> maybe_stats_;
|
||||||
StepStats* maybe_step_stats_;
|
StepStats* maybe_step_stats_;
|
||||||
|
tensorflow::GraphCollector* graph_collector_;
|
||||||
tensorflow::gtl::InlinedVector<TensorHandle*, 2> retvals_;
|
tensorflow::gtl::InlinedVector<TensorHandle*, 2> retvals_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -48,17 +48,20 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
|||||||
|
|
||||||
Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
|
Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
|
||||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||||
StepStats* step_stats) {
|
StepStats* step_stats,
|
||||||
|
GraphCollector* graph_collector) {
|
||||||
ScopedStepContainer step_container(0, [this](const string& name) {
|
ScopedStepContainer step_container(0, [this](const string& name) {
|
||||||
device_->resource_manager()->Cleanup(name).IgnoreError();
|
device_->resource_manager()->Cleanup(name).IgnoreError();
|
||||||
});
|
});
|
||||||
return this->Run(&step_container, inputs, outputs, stats, step_stats);
|
return this->Run(&step_container, inputs, outputs, stats, step_stats,
|
||||||
|
graph_collector);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||||
std::vector<Tensor>* inputs,
|
std::vector<Tensor>* inputs,
|
||||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||||
StepStats* step_stats) {
|
StepStats* step_stats,
|
||||||
|
GraphCollector* graph_collector) {
|
||||||
gtl::InlinedVector<TensorValue, 4> input_vector;
|
gtl::InlinedVector<TensorValue, 4> input_vector;
|
||||||
for (Tensor& t : *inputs) {
|
for (Tensor& t : *inputs) {
|
||||||
input_vector.push_back(TensorValue(&t));
|
input_vector.push_back(TensorValue(&t));
|
||||||
@ -87,6 +90,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
|||||||
step_stats_collector.reset(new StepStatsCollector(step_stats));
|
step_stats_collector.reset(new StepStatsCollector(step_stats));
|
||||||
params.track_allocations = true;
|
params.track_allocations = true;
|
||||||
params.stats_collector = step_stats_collector.get();
|
params.stats_collector = step_stats_collector.get();
|
||||||
|
params.graph_collector = graph_collector;
|
||||||
}
|
}
|
||||||
if (runner_ == nullptr) {
|
if (runner_ == nullptr) {
|
||||||
params.runner = &default_runner_;
|
params.runner = &default_runner_;
|
||||||
|
@ -62,11 +62,12 @@ class KernelAndDevice {
|
|||||||
|
|
||||||
// TODO(ashankar): Handle list-valued inputs.
|
// TODO(ashankar): Handle list-valued inputs.
|
||||||
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
|
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
|
||||||
NodeExecStats* stats, StepStats* step_stats);
|
NodeExecStats* stats, StepStats* step_stats,
|
||||||
|
GraphCollector* graph_collector);
|
||||||
|
|
||||||
Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
|
Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
|
||||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||||
StepStats* step_stats);
|
StepStats* step_stats, GraphCollector* graph_collector);
|
||||||
|
|
||||||
const OpKernel* kernel() const { return kernel_.get(); }
|
const OpKernel* kernel() const { return kernel_.get(); }
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ void BM_KernelAndDeviceRun(int iters) {
|
|||||||
nullptr, &kernel));
|
nullptr, &kernel));
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr));
|
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr, nullptr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_KernelAndDeviceRun);
|
BENCHMARK(BM_KernelAndDeviceRun);
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
#include "tensorflow/core/framework/control_flow.h"
|
#include "tensorflow/core/framework/control_flow.h"
|
||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
@ -489,6 +490,17 @@ struct TensorValue {
|
|||||||
Tensor* tensor;
|
Tensor* tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Used to store partitioned graphs from function-calling ops.
|
||||||
|
struct GraphCollector {
|
||||||
|
mutex mu;
|
||||||
|
std::vector<GraphDef> graphs GUARDED_BY(mu);
|
||||||
|
|
||||||
|
void CollectGraph(const GraphDef& graph) {
|
||||||
|
mutex_lock ml(mu);
|
||||||
|
graphs.push_back(graph);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class OpKernelContext {
|
class OpKernelContext {
|
||||||
public:
|
public:
|
||||||
// The first element of a WrappedAllocator is a "base" Allocator and
|
// The first element of a WrappedAllocator is a "base" Allocator and
|
||||||
@ -589,6 +601,7 @@ class OpKernelContext {
|
|||||||
FunctionLibraryRuntime* function_library = nullptr;
|
FunctionLibraryRuntime* function_library = nullptr;
|
||||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||||
StepStatsCollectorInterface* stats_collector = nullptr;
|
StepStatsCollectorInterface* stats_collector = nullptr;
|
||||||
|
GraphCollector* graph_collector = nullptr;
|
||||||
|
|
||||||
// TensorSliceReaderCache support.
|
// TensorSliceReaderCache support.
|
||||||
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
|
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
|
||||||
@ -711,6 +724,9 @@ class OpKernelContext {
|
|||||||
// Usage: if (!context->ValidateInputsAreSameShape(this)) return;
|
// Usage: if (!context->ValidateInputsAreSameShape(this)) return;
|
||||||
bool ValidateInputsAreSameShape(OpKernel* op);
|
bool ValidateInputsAreSameShape(OpKernel* op);
|
||||||
|
|
||||||
|
// If non-null, kernels should populate with any partition subgraphs created.
|
||||||
|
GraphCollector* graph_collector() { return params_->graph_collector; }
|
||||||
|
|
||||||
// Input to output forwarding.
|
// Input to output forwarding.
|
||||||
|
|
||||||
// Set the output Ref Tensor at output_index to be an alias of the
|
// Set the output Ref Tensor at output_index to be an alias of the
|
||||||
|
@ -183,6 +183,13 @@ class PartitionedCallOp : public AsyncOpKernel {
|
|||||||
OP_REQUIRES_OK_ASYNC(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
|
ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
|
||||||
done);
|
done);
|
||||||
|
if (ctx->graph_collector() != nullptr) {
|
||||||
|
for (const auto& pair : subgraphs) {
|
||||||
|
GraphDef def;
|
||||||
|
pair.second->ToGraphDef(&def);
|
||||||
|
ctx->graph_collector()->CollectGraph(def);
|
||||||
|
}
|
||||||
|
}
|
||||||
optimization_options.graph = nullptr;
|
optimization_options.graph = nullptr;
|
||||||
optimization_options.device_set = nullptr;
|
optimization_options.device_set = nullptr;
|
||||||
optimization_options.partition_graphs = &subgraphs;
|
optimization_options.partition_graphs = &subgraphs;
|
||||||
|
@ -630,7 +630,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
return x * x
|
return x * x
|
||||||
|
|
||||||
with ops.device('cpu:0'):
|
with ops.device('cpu:0'):
|
||||||
f(constant_op.constant(1.0)) # pre-build the defun
|
|
||||||
context.enable_run_metadata()
|
context.enable_run_metadata()
|
||||||
f(constant_op.constant(1.0))
|
f(constant_op.constant(1.0))
|
||||||
run_metadata = context.export_run_metadata()
|
run_metadata = context.export_run_metadata()
|
||||||
@ -645,6 +644,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
# arbitrarily many (placeholders, return identities, etc, might be included
|
# arbitrarily many (placeholders, return identities, etc, might be included
|
||||||
# or not in the future, so shouldn't be tested for exactly.
|
# or not in the future, so shouldn't be tested for exactly.
|
||||||
self.assertGreaterEqual(len(cpu_stats.node_stats), 2)
|
self.assertGreaterEqual(len(cpu_stats.node_stats), 2)
|
||||||
|
self.assertEqual(len(run_metadata.partition_graphs), 1)
|
||||||
|
|
||||||
def testGraphModeCaptureVariable(self):
|
def testGraphModeCaptureVariable(self):
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session() as sess:
|
||||||
|
Loading…
Reference in New Issue
Block a user