Populate partition graphs in run metadata from partitioned call via eager.

Caveats: only populated the first time the function is called; not populated
if runmetadata is requested from session.run.
PiperOrigin-RevId: 217773737
This commit is contained in:
Alexandre Passos 2018-10-18 14:52:29 -07:00 committed by TensorFlower Gardener
parent 8898af469c
commit 0136e95df8
10 changed files with 66 additions and 20 deletions

View File

@ -131,6 +131,8 @@ class EagerContext {
Device* HostCPU() { return devices_[0]; }
GraphCollector* GetGraphCollector() { return &graph_collector_; }
uint64 NextId() { return executor_.NextId(); }
void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
@ -249,6 +251,7 @@ class EagerContext {
std::atomic<bool> should_store_metadata_{false};
mutex metadata_mu_;
RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
GraphCollector graph_collector_;
const bool log_device_placement_;
// EagerExecutor for async execution.
EagerExecutor executor_;

View File

@ -325,7 +325,9 @@ Status EagerLocalExecute(EagerOperation* op,
if (!status.ok()) return status;
std::unique_ptr<NodeExecStats> maybe_stats;
StepStats* maybe_step_stats = nullptr;
GraphCollector* graph_collector = nullptr;
if (ctx->ShouldStoreMetadata()) {
graph_collector = ctx->GetGraphCollector();
maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats();
int64 now_nanos = Env::Default()->NowNanos();
maybe_stats.reset(new NodeExecStats);
@ -349,14 +351,14 @@ Status EagerLocalExecute(EagerOperation* op,
}
EagerNode* node = new ExecuteNode(
id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(),
maybe_step_stats, output_dtypes, *retvals);
maybe_step_stats, graph_collector, output_dtypes, *retvals);
ctx->ExecutorAdd(node);
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
status =
EagerExecute(ctx, op->Device(), op->Inputs(), kernel, maybe_stats.get(),
maybe_step_stats, retvals->data(), *num_retvals);
status = EagerExecute(ctx, op->Device(), op->Inputs(), kernel,
maybe_stats.get(), maybe_step_stats, graph_collector,
retvals->data(), *num_retvals);
}
return status;
@ -710,7 +712,8 @@ Status EagerExecute(EagerOperation* op,
Status EagerExecute(EagerContext* ctx, Device* device,
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
StepStats* maybe_step_stats, TensorHandle** retvals,
StepStats* maybe_step_stats,
GraphCollector* graph_collector, TensorHandle** retvals,
int num_retvals) {
if (device == nullptr) {
// TODO(apassos) debug how the assignment below might return a different
@ -732,11 +735,11 @@ Status EagerExecute(EagerContext* ctx, Device* device,
// TODO(agarwal): change Run to take vector of handles ?
ScopedStepContainer* container = ctx->StepContainer();
if (container == nullptr) {
TF_RETURN_IF_ERROR(
kernel->Run(&inputs, &outputs, maybe_stats, maybe_step_stats));
TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats,
maybe_step_stats, graph_collector));
} else {
TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats,
maybe_step_stats));
maybe_step_stats, graph_collector));
}
if (maybe_stats != nullptr) {
int64 nanos = Env::Default()->NowNanos();
@ -748,6 +751,14 @@ Status EagerExecute(EagerContext* ctx, Device* device,
maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
mutex_lock ml(*ctx->MetadataMu());
if (ctx->ShouldStoreMetadata()) {
{
GraphCollector* collector = ctx->GetGraphCollector();
mutex_lock mll(collector->mu);
for (const auto& graph : collector->graphs) {
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
}
collector->graphs.clear();
}
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
// Lazily initialize the RunMetadata with information about all devices if
// this is the first call.

View File

@ -46,7 +46,8 @@ Status EagerExecute(
Status EagerExecute(EagerContext* ctx, Device* device,
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
StepStats* maybe_step_stats, TensorHandle** retvals,
StepStats* maybe_step_stats,
GraphCollector* graph_collector, TensorHandle** retvals,
int num_retvals);
// Low-level utility to copy a tensor handle from one device to another.

View File

@ -34,7 +34,8 @@ class ExecuteNode : public EagerNode {
ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device,
const tensorflow::gtl::InlinedVector<TensorHandle*, 4>& inputs,
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
StepStats* maybe_step_stats, const DataTypeVector& output_dtypes,
StepStats* maybe_step_stats, GraphCollector* graph_collector,
const DataTypeVector& output_dtypes,
const tensorflow::gtl::InlinedVector<TensorHandle*, 2>& retvals)
: EagerNode(id),
ctx_(ctx),
@ -43,6 +44,7 @@ class ExecuteNode : public EagerNode {
kernel_(kernel),
maybe_stats_(maybe_stats),
maybe_step_stats_(maybe_step_stats),
graph_collector_(graph_collector),
retvals_(retvals) {
for (auto handle : inputs_) {
handle->Ref();
@ -62,9 +64,9 @@ class ExecuteNode : public EagerNode {
}
tensorflow::Status Run() override {
const Status status =
EagerExecute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
maybe_step_stats_, retvals_.begin(), retvals_.size());
const Status status = EagerExecute(
ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
maybe_step_stats_, graph_collector_, retvals_.begin(), retvals_.size());
if (status.ok()) {
return status;
} else {
@ -82,6 +84,7 @@ class ExecuteNode : public EagerNode {
tensorflow::KernelAndDevice* kernel_;
std::unique_ptr<NodeExecStats> maybe_stats_;
StepStats* maybe_step_stats_;
tensorflow::GraphCollector* graph_collector_;
tensorflow::gtl::InlinedVector<TensorHandle*, 2> retvals_;
};

View File

@ -48,17 +48,20 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats) {
StepStats* step_stats,
GraphCollector* graph_collector) {
ScopedStepContainer step_container(0, [this](const string& name) {
device_->resource_manager()->Cleanup(name).IgnoreError();
});
return this->Run(&step_container, inputs, outputs, stats, step_stats);
return this->Run(&step_container, inputs, outputs, stats, step_stats,
graph_collector);
}
Status KernelAndDevice::Run(ScopedStepContainer* step_container,
std::vector<Tensor>* inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats) {
StepStats* step_stats,
GraphCollector* graph_collector) {
gtl::InlinedVector<TensorValue, 4> input_vector;
for (Tensor& t : *inputs) {
input_vector.push_back(TensorValue(&t));
@ -87,6 +90,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
step_stats_collector.reset(new StepStatsCollector(step_stats));
params.track_allocations = true;
params.stats_collector = step_stats_collector.get();
params.graph_collector = graph_collector;
}
if (runner_ == nullptr) {
params.runner = &default_runner_;

View File

@ -62,11 +62,12 @@ class KernelAndDevice {
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
NodeExecStats* stats, StepStats* step_stats);
NodeExecStats* stats, StepStats* step_stats,
GraphCollector* graph_collector);
Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
std::vector<Tensor>* outputs, NodeExecStats* stats,
StepStats* step_stats);
StepStats* step_stats, GraphCollector* graph_collector);
const OpKernel* kernel() const { return kernel_.get(); }

View File

@ -132,7 +132,7 @@ void BM_KernelAndDeviceRun(int iters) {
nullptr, &kernel));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr));
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr, nullptr));
}
}
BENCHMARK(BM_KernelAndDeviceRun);

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -489,6 +490,17 @@ struct TensorValue {
Tensor* tensor;
};
// Used to store partitioned graphs from function-calling ops.
struct GraphCollector {
mutex mu;
std::vector<GraphDef> graphs GUARDED_BY(mu);
void CollectGraph(const GraphDef& graph) {
mutex_lock ml(mu);
graphs.push_back(graph);
}
};
class OpKernelContext {
public:
// The first element of a WrappedAllocator is a "base" Allocator and
@ -589,6 +601,7 @@ class OpKernelContext {
FunctionLibraryRuntime* function_library = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
StepStatsCollectorInterface* stats_collector = nullptr;
GraphCollector* graph_collector = nullptr;
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
@ -711,6 +724,9 @@ class OpKernelContext {
// Usage: if (!context->ValidateInputsAreSameShape(this)) return;
bool ValidateInputsAreSameShape(OpKernel* op);
// If non-null, kernels should populate with any partition subgraphs created.
GraphCollector* graph_collector() { return params_->graph_collector; }
// Input to output forwarding.
// Set the output Ref Tensor at output_index to be an alias of the

View File

@ -183,6 +183,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
done);
if (ctx->graph_collector() != nullptr) {
for (const auto& pair : subgraphs) {
GraphDef def;
pair.second->ToGraphDef(&def);
ctx->graph_collector()->CollectGraph(def);
}
}
optimization_options.graph = nullptr;
optimization_options.device_set = nullptr;
optimization_options.partition_graphs = &subgraphs;

View File

@ -630,7 +630,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return x * x
with ops.device('cpu:0'):
f(constant_op.constant(1.0)) # pre-build the defun
context.enable_run_metadata()
f(constant_op.constant(1.0))
run_metadata = context.export_run_metadata()
@ -645,6 +644,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
# arbitrarily many (placeholders, return identities, etc, might be included
# or not in the future, so shouldn't be tested for exactly.
self.assertGreaterEqual(len(cpu_stats.node_stats), 2)
self.assertEqual(len(run_metadata.partition_graphs), 1)
def testGraphModeCaptureVariable(self):
with context.graph_mode(), self.cached_session() as sess: