Populate partition graphs in run metadata from partitioned call via eager.

Caveats: only populated the first time the function is called; not populated
if runmetadata is requested from session.run.
PiperOrigin-RevId: 217773737
This commit is contained in:
Alexandre Passos 2018-10-18 14:52:29 -07:00 committed by TensorFlower Gardener
parent 8898af469c
commit 0136e95df8
10 changed files with 66 additions and 20 deletions

View File

@ -131,6 +131,8 @@ class EagerContext {
Device* HostCPU() { return devices_[0]; } Device* HostCPU() { return devices_[0]; }
GraphCollector* GetGraphCollector() { return &graph_collector_; }
uint64 NextId() { return executor_.NextId(); } uint64 NextId() { return executor_.NextId(); }
void ExecutorAdd(EagerNode* node) { executor_.Add(node); } void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
@ -249,6 +251,7 @@ class EagerContext {
std::atomic<bool> should_store_metadata_{false}; std::atomic<bool> should_store_metadata_{false};
mutex metadata_mu_; mutex metadata_mu_;
RunMetadata run_metadata_ GUARDED_BY(metadata_mu_); RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
GraphCollector graph_collector_;
const bool log_device_placement_; const bool log_device_placement_;
// EagerExecutor for async execution. // EagerExecutor for async execution.
EagerExecutor executor_; EagerExecutor executor_;

View File

@ -325,7 +325,9 @@ Status EagerLocalExecute(EagerOperation* op,
if (!status.ok()) return status; if (!status.ok()) return status;
std::unique_ptr<NodeExecStats> maybe_stats; std::unique_ptr<NodeExecStats> maybe_stats;
StepStats* maybe_step_stats = nullptr; StepStats* maybe_step_stats = nullptr;
GraphCollector* graph_collector = nullptr;
if (ctx->ShouldStoreMetadata()) { if (ctx->ShouldStoreMetadata()) {
graph_collector = ctx->GetGraphCollector();
maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats(); maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats();
int64 now_nanos = Env::Default()->NowNanos(); int64 now_nanos = Env::Default()->NowNanos();
maybe_stats.reset(new NodeExecStats); maybe_stats.reset(new NodeExecStats);
@ -349,14 +351,14 @@ Status EagerLocalExecute(EagerOperation* op,
} }
EagerNode* node = new ExecuteNode( EagerNode* node = new ExecuteNode(
id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(), id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(),
maybe_step_stats, output_dtypes, *retvals); maybe_step_stats, graph_collector, output_dtypes, *retvals);
ctx->ExecutorAdd(node); ctx->ExecutorAdd(node);
} else { } else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to // Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it. // allocate it.
status = status = EagerExecute(ctx, op->Device(), op->Inputs(), kernel,
EagerExecute(ctx, op->Device(), op->Inputs(), kernel, maybe_stats.get(), maybe_stats.get(), maybe_step_stats, graph_collector,
maybe_step_stats, retvals->data(), *num_retvals); retvals->data(), *num_retvals);
} }
return status; return status;
@ -710,7 +712,8 @@ Status EagerExecute(EagerOperation* op,
Status EagerExecute(EagerContext* ctx, Device* device, Status EagerExecute(EagerContext* ctx, Device* device,
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
KernelAndDevice* kernel, NodeExecStats* maybe_stats, KernelAndDevice* kernel, NodeExecStats* maybe_stats,
StepStats* maybe_step_stats, TensorHandle** retvals, StepStats* maybe_step_stats,
GraphCollector* graph_collector, TensorHandle** retvals,
int num_retvals) { int num_retvals) {
if (device == nullptr) { if (device == nullptr) {
// TODO(apassos) debug how the assignment below might return a different // TODO(apassos) debug how the assignment below might return a different
@ -732,11 +735,11 @@ Status EagerExecute(EagerContext* ctx, Device* device,
// TODO(agarwal): change Run to take vector of handles ? // TODO(agarwal): change Run to take vector of handles ?
ScopedStepContainer* container = ctx->StepContainer(); ScopedStepContainer* container = ctx->StepContainer();
if (container == nullptr) { if (container == nullptr) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats,
kernel->Run(&inputs, &outputs, maybe_stats, maybe_step_stats)); maybe_step_stats, graph_collector));
} else { } else {
TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats, TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats,
maybe_step_stats)); maybe_step_stats, graph_collector));
} }
if (maybe_stats != nullptr) { if (maybe_stats != nullptr) {
int64 nanos = Env::Default()->NowNanos(); int64 nanos = Env::Default()->NowNanos();
@ -748,6 +751,14 @@ Status EagerExecute(EagerContext* ctx, Device* device,
maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos()); maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
mutex_lock ml(*ctx->MetadataMu()); mutex_lock ml(*ctx->MetadataMu());
if (ctx->ShouldStoreMetadata()) { if (ctx->ShouldStoreMetadata()) {
{
GraphCollector* collector = ctx->GetGraphCollector();
mutex_lock mll(collector->mu);
for (const auto& graph : collector->graphs) {
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
}
collector->graphs.clear();
}
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats(); auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
// Lazily initialize the RunMetadata with information about all devices if // Lazily initialize the RunMetadata with information about all devices if
// this is the first call. // this is the first call.

View File

@ -46,7 +46,8 @@ Status EagerExecute(
Status EagerExecute(EagerContext* ctx, Device* device, Status EagerExecute(EagerContext* ctx, Device* device,
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
KernelAndDevice* kernel, NodeExecStats* maybe_stats, KernelAndDevice* kernel, NodeExecStats* maybe_stats,
StepStats* maybe_step_stats, TensorHandle** retvals, StepStats* maybe_step_stats,
GraphCollector* graph_collector, TensorHandle** retvals,
int num_retvals); int num_retvals);
// Low-level utility to copy a tensor handle from one device to another. // Low-level utility to copy a tensor handle from one device to another.

View File

@ -34,7 +34,8 @@ class ExecuteNode : public EagerNode {
ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device, ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device,
const tensorflow::gtl::InlinedVector<TensorHandle*, 4>& inputs, const tensorflow::gtl::InlinedVector<TensorHandle*, 4>& inputs,
KernelAndDevice* kernel, NodeExecStats* maybe_stats, KernelAndDevice* kernel, NodeExecStats* maybe_stats,
StepStats* maybe_step_stats, const DataTypeVector& output_dtypes, StepStats* maybe_step_stats, GraphCollector* graph_collector,
const DataTypeVector& output_dtypes,
const tensorflow::gtl::InlinedVector<TensorHandle*, 2>& retvals) const tensorflow::gtl::InlinedVector<TensorHandle*, 2>& retvals)
: EagerNode(id), : EagerNode(id),
ctx_(ctx), ctx_(ctx),
@ -43,6 +44,7 @@ class ExecuteNode : public EagerNode {
kernel_(kernel), kernel_(kernel),
maybe_stats_(maybe_stats), maybe_stats_(maybe_stats),
maybe_step_stats_(maybe_step_stats), maybe_step_stats_(maybe_step_stats),
graph_collector_(graph_collector),
retvals_(retvals) { retvals_(retvals) {
for (auto handle : inputs_) { for (auto handle : inputs_) {
handle->Ref(); handle->Ref();
@ -62,9 +64,9 @@ class ExecuteNode : public EagerNode {
} }
tensorflow::Status Run() override { tensorflow::Status Run() override {
const Status status = const Status status = EagerExecute(
EagerExecute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(), ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
maybe_step_stats_, retvals_.begin(), retvals_.size()); maybe_step_stats_, graph_collector_, retvals_.begin(), retvals_.size());
if (status.ok()) { if (status.ok()) {
return status; return status;
} else { } else {
@ -82,6 +84,7 @@ class ExecuteNode : public EagerNode {
tensorflow::KernelAndDevice* kernel_; tensorflow::KernelAndDevice* kernel_;
std::unique_ptr<NodeExecStats> maybe_stats_; std::unique_ptr<NodeExecStats> maybe_stats_;
StepStats* maybe_step_stats_; StepStats* maybe_step_stats_;
tensorflow::GraphCollector* graph_collector_;
tensorflow::gtl::InlinedVector<TensorHandle*, 2> retvals_; tensorflow::gtl::InlinedVector<TensorHandle*, 2> retvals_;
}; };

View File

@ -48,17 +48,20 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
Status KernelAndDevice::Run(std::vector<Tensor>* inputs, Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats, std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats) { StepStats* step_stats,
GraphCollector* graph_collector) {
ScopedStepContainer step_container(0, [this](const string& name) { ScopedStepContainer step_container(0, [this](const string& name) {
device_->resource_manager()->Cleanup(name).IgnoreError(); device_->resource_manager()->Cleanup(name).IgnoreError();
}); });
return this->Run(&step_container, inputs, outputs, stats, step_stats); return this->Run(&step_container, inputs, outputs, stats, step_stats,
graph_collector);
} }
Status KernelAndDevice::Run(ScopedStepContainer* step_container, Status KernelAndDevice::Run(ScopedStepContainer* step_container,
std::vector<Tensor>* inputs, std::vector<Tensor>* inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats, std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats) { StepStats* step_stats,
GraphCollector* graph_collector) {
gtl::InlinedVector<TensorValue, 4> input_vector; gtl::InlinedVector<TensorValue, 4> input_vector;
for (Tensor& t : *inputs) { for (Tensor& t : *inputs) {
input_vector.push_back(TensorValue(&t)); input_vector.push_back(TensorValue(&t));
@ -87,6 +90,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
step_stats_collector.reset(new StepStatsCollector(step_stats)); step_stats_collector.reset(new StepStatsCollector(step_stats));
params.track_allocations = true; params.track_allocations = true;
params.stats_collector = step_stats_collector.get(); params.stats_collector = step_stats_collector.get();
params.graph_collector = graph_collector;
} }
if (runner_ == nullptr) { if (runner_ == nullptr) {
params.runner = &default_runner_; params.runner = &default_runner_;

View File

@ -62,11 +62,12 @@ class KernelAndDevice {
// TODO(ashankar): Handle list-valued inputs. // TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs, Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
NodeExecStats* stats, StepStats* step_stats); NodeExecStats* stats, StepStats* step_stats,
GraphCollector* graph_collector);
Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs, Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats, std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats); StepStats* step_stats, GraphCollector* graph_collector);
const OpKernel* kernel() const { return kernel_.get(); } const OpKernel* kernel() const { return kernel_.get(); }

View File

@ -132,7 +132,7 @@ void BM_KernelAndDeviceRun(int iters) {
nullptr, &kernel)); nullptr, &kernel));
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) { for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr)); TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr, nullptr));
} }
} }
BENCHMARK(BM_KernelAndDeviceRun); BENCHMARK(BM_KernelAndDeviceRun);

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
@ -489,6 +490,17 @@ struct TensorValue {
Tensor* tensor; Tensor* tensor;
}; };
// Used to store partitioned graphs from function-calling ops.
struct GraphCollector {
mutex mu;
std::vector<GraphDef> graphs GUARDED_BY(mu);
void CollectGraph(const GraphDef& graph) {
mutex_lock ml(mu);
graphs.push_back(graph);
}
};
class OpKernelContext { class OpKernelContext {
public: public:
// The first element of a WrappedAllocator is a "base" Allocator and // The first element of a WrappedAllocator is a "base" Allocator and
@ -589,6 +601,7 @@ class OpKernelContext {
FunctionLibraryRuntime* function_library = nullptr; FunctionLibraryRuntime* function_library = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr; std::function<void(std::function<void()>)>* runner = nullptr;
StepStatsCollectorInterface* stats_collector = nullptr; StepStatsCollectorInterface* stats_collector = nullptr;
GraphCollector* graph_collector = nullptr;
// TensorSliceReaderCache support. // TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
@ -711,6 +724,9 @@ class OpKernelContext {
// Usage: if (!context->ValidateInputsAreSameShape(this)) return; // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
bool ValidateInputsAreSameShape(OpKernel* op); bool ValidateInputsAreSameShape(OpKernel* op);
// If non-null, kernels should populate with any partition subgraphs created.
GraphCollector* graph_collector() { return params_->graph_collector; }
// Input to output forwarding. // Input to output forwarding.
// Set the output Ref Tensor at output_index to be an alias of the // Set the output Ref Tensor at output_index to be an alias of the

View File

@ -183,6 +183,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC( OP_REQUIRES_OK_ASYNC(
ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
done); done);
if (ctx->graph_collector() != nullptr) {
for (const auto& pair : subgraphs) {
GraphDef def;
pair.second->ToGraphDef(&def);
ctx->graph_collector()->CollectGraph(def);
}
}
optimization_options.graph = nullptr; optimization_options.graph = nullptr;
optimization_options.device_set = nullptr; optimization_options.device_set = nullptr;
optimization_options.partition_graphs = &subgraphs; optimization_options.partition_graphs = &subgraphs;

View File

@ -630,7 +630,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return x * x return x * x
with ops.device('cpu:0'): with ops.device('cpu:0'):
f(constant_op.constant(1.0)) # pre-build the defun
context.enable_run_metadata() context.enable_run_metadata()
f(constant_op.constant(1.0)) f(constant_op.constant(1.0))
run_metadata = context.export_run_metadata() run_metadata = context.export_run_metadata()
@ -645,6 +644,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# arbitrarily many (placeholders, return identities, etc, might be included # arbitrarily many (placeholders, return identities, etc, might be included
# or not in the future, so shouldn't be tested for exactly. # or not in the future, so shouldn't be tested for exactly.
self.assertGreaterEqual(len(cpu_stats.node_stats), 2) self.assertGreaterEqual(len(cpu_stats.node_stats), 2)
self.assertEqual(len(run_metadata.partition_graphs), 1)
def testGraphModeCaptureVariable(self): def testGraphModeCaptureVariable(self):
with context.graph_mode(), self.cached_session() as sess: with context.graph_mode(), self.cached_session() as sess: