Populate partition graphs in run metadata from partitioned call via eager.
Caveats: only populated the first time the function is called; not populated if runmetadata is requested from session.run. PiperOrigin-RevId: 217773737
This commit is contained in:
parent
8898af469c
commit
0136e95df8
@ -131,6 +131,8 @@ class EagerContext {
|
||||
|
||||
Device* HostCPU() { return devices_[0]; }
|
||||
|
||||
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
||||
|
||||
uint64 NextId() { return executor_.NextId(); }
|
||||
|
||||
void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
|
||||
@ -249,6 +251,7 @@ class EagerContext {
|
||||
std::atomic<bool> should_store_metadata_{false};
|
||||
mutex metadata_mu_;
|
||||
RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
|
||||
GraphCollector graph_collector_;
|
||||
const bool log_device_placement_;
|
||||
// EagerExecutor for async execution.
|
||||
EagerExecutor executor_;
|
||||
|
@ -325,7 +325,9 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
if (!status.ok()) return status;
|
||||
std::unique_ptr<NodeExecStats> maybe_stats;
|
||||
StepStats* maybe_step_stats = nullptr;
|
||||
GraphCollector* graph_collector = nullptr;
|
||||
if (ctx->ShouldStoreMetadata()) {
|
||||
graph_collector = ctx->GetGraphCollector();
|
||||
maybe_step_stats = ctx->RunMetadataProto()->mutable_step_stats();
|
||||
int64 now_nanos = Env::Default()->NowNanos();
|
||||
maybe_stats.reset(new NodeExecStats);
|
||||
@ -349,14 +351,14 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
}
|
||||
EagerNode* node = new ExecuteNode(
|
||||
id, ctx, op->Device(), op->Inputs(), kernel, maybe_stats.release(),
|
||||
maybe_step_stats, output_dtypes, *retvals);
|
||||
maybe_step_stats, graph_collector, output_dtypes, *retvals);
|
||||
ctx->ExecutorAdd(node);
|
||||
} else {
|
||||
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
|
||||
// allocate it.
|
||||
status =
|
||||
EagerExecute(ctx, op->Device(), op->Inputs(), kernel, maybe_stats.get(),
|
||||
maybe_step_stats, retvals->data(), *num_retvals);
|
||||
status = EagerExecute(ctx, op->Device(), op->Inputs(), kernel,
|
||||
maybe_stats.get(), maybe_step_stats, graph_collector,
|
||||
retvals->data(), *num_retvals);
|
||||
}
|
||||
|
||||
return status;
|
||||
@ -710,7 +712,8 @@ Status EagerExecute(EagerOperation* op,
|
||||
Status EagerExecute(EagerContext* ctx, Device* device,
|
||||
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
||||
StepStats* maybe_step_stats, TensorHandle** retvals,
|
||||
StepStats* maybe_step_stats,
|
||||
GraphCollector* graph_collector, TensorHandle** retvals,
|
||||
int num_retvals) {
|
||||
if (device == nullptr) {
|
||||
// TODO(apassos) debug how the assignment below might return a different
|
||||
@ -732,11 +735,11 @@ Status EagerExecute(EagerContext* ctx, Device* device,
|
||||
// TODO(agarwal): change Run to take vector of handles ?
|
||||
ScopedStepContainer* container = ctx->StepContainer();
|
||||
if (container == nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
kernel->Run(&inputs, &outputs, maybe_stats, maybe_step_stats));
|
||||
TF_RETURN_IF_ERROR(kernel->Run(&inputs, &outputs, maybe_stats,
|
||||
maybe_step_stats, graph_collector));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(kernel->Run(container, &inputs, &outputs, maybe_stats,
|
||||
maybe_step_stats));
|
||||
maybe_step_stats, graph_collector));
|
||||
}
|
||||
if (maybe_stats != nullptr) {
|
||||
int64 nanos = Env::Default()->NowNanos();
|
||||
@ -748,6 +751,14 @@ Status EagerExecute(EagerContext* ctx, Device* device,
|
||||
maybe_stats->set_all_end_rel_nanos(nanos - maybe_stats->all_start_nanos());
|
||||
mutex_lock ml(*ctx->MetadataMu());
|
||||
if (ctx->ShouldStoreMetadata()) {
|
||||
{
|
||||
GraphCollector* collector = ctx->GetGraphCollector();
|
||||
mutex_lock mll(collector->mu);
|
||||
for (const auto& graph : collector->graphs) {
|
||||
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
|
||||
}
|
||||
collector->graphs.clear();
|
||||
}
|
||||
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
|
||||
// Lazily initialize the RunMetadata with information about all devices if
|
||||
// this is the first call.
|
||||
|
@ -46,7 +46,8 @@ Status EagerExecute(
|
||||
Status EagerExecute(EagerContext* ctx, Device* device,
|
||||
const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
|
||||
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
||||
StepStats* maybe_step_stats, TensorHandle** retvals,
|
||||
StepStats* maybe_step_stats,
|
||||
GraphCollector* graph_collector, TensorHandle** retvals,
|
||||
int num_retvals);
|
||||
|
||||
// Low-level utility to copy a tensor handle from one device to another.
|
||||
|
@ -34,7 +34,8 @@ class ExecuteNode : public EagerNode {
|
||||
ExecuteNode(uint64 id, EagerContext* ctx, Device* op_device,
|
||||
const tensorflow::gtl::InlinedVector<TensorHandle*, 4>& inputs,
|
||||
KernelAndDevice* kernel, NodeExecStats* maybe_stats,
|
||||
StepStats* maybe_step_stats, const DataTypeVector& output_dtypes,
|
||||
StepStats* maybe_step_stats, GraphCollector* graph_collector,
|
||||
const DataTypeVector& output_dtypes,
|
||||
const tensorflow::gtl::InlinedVector<TensorHandle*, 2>& retvals)
|
||||
: EagerNode(id),
|
||||
ctx_(ctx),
|
||||
@ -43,6 +44,7 @@ class ExecuteNode : public EagerNode {
|
||||
kernel_(kernel),
|
||||
maybe_stats_(maybe_stats),
|
||||
maybe_step_stats_(maybe_step_stats),
|
||||
graph_collector_(graph_collector),
|
||||
retvals_(retvals) {
|
||||
for (auto handle : inputs_) {
|
||||
handle->Ref();
|
||||
@ -62,9 +64,9 @@ class ExecuteNode : public EagerNode {
|
||||
}
|
||||
|
||||
tensorflow::Status Run() override {
|
||||
const Status status =
|
||||
EagerExecute(ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
|
||||
maybe_step_stats_, retvals_.begin(), retvals_.size());
|
||||
const Status status = EagerExecute(
|
||||
ctx_, op_device_, inputs_, kernel_, maybe_stats_.get(),
|
||||
maybe_step_stats_, graph_collector_, retvals_.begin(), retvals_.size());
|
||||
if (status.ok()) {
|
||||
return status;
|
||||
} else {
|
||||
@ -82,6 +84,7 @@ class ExecuteNode : public EagerNode {
|
||||
tensorflow::KernelAndDevice* kernel_;
|
||||
std::unique_ptr<NodeExecStats> maybe_stats_;
|
||||
StepStats* maybe_step_stats_;
|
||||
tensorflow::GraphCollector* graph_collector_;
|
||||
tensorflow::gtl::InlinedVector<TensorHandle*, 2> retvals_;
|
||||
};
|
||||
|
||||
|
@ -48,17 +48,20 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||
|
||||
Status KernelAndDevice::Run(std::vector<Tensor>* inputs,
|
||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||
StepStats* step_stats) {
|
||||
StepStats* step_stats,
|
||||
GraphCollector* graph_collector) {
|
||||
ScopedStepContainer step_container(0, [this](const string& name) {
|
||||
device_->resource_manager()->Cleanup(name).IgnoreError();
|
||||
});
|
||||
return this->Run(&step_container, inputs, outputs, stats, step_stats);
|
||||
return this->Run(&step_container, inputs, outputs, stats, step_stats,
|
||||
graph_collector);
|
||||
}
|
||||
|
||||
Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||
std::vector<Tensor>* inputs,
|
||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||
StepStats* step_stats) {
|
||||
StepStats* step_stats,
|
||||
GraphCollector* graph_collector) {
|
||||
gtl::InlinedVector<TensorValue, 4> input_vector;
|
||||
for (Tensor& t : *inputs) {
|
||||
input_vector.push_back(TensorValue(&t));
|
||||
@ -87,6 +90,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
|
||||
step_stats_collector.reset(new StepStatsCollector(step_stats));
|
||||
params.track_allocations = true;
|
||||
params.stats_collector = step_stats_collector.get();
|
||||
params.graph_collector = graph_collector;
|
||||
}
|
||||
if (runner_ == nullptr) {
|
||||
params.runner = &default_runner_;
|
||||
|
@ -62,11 +62,12 @@ class KernelAndDevice {
|
||||
|
||||
// TODO(ashankar): Handle list-valued inputs.
|
||||
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
|
||||
NodeExecStats* stats, StepStats* step_stats);
|
||||
NodeExecStats* stats, StepStats* step_stats,
|
||||
GraphCollector* graph_collector);
|
||||
|
||||
Status Run(ScopedStepContainer* step_container, std::vector<Tensor>* inputs,
|
||||
std::vector<Tensor>* outputs, NodeExecStats* stats,
|
||||
StepStats* step_stats);
|
||||
StepStats* step_stats, GraphCollector* graph_collector);
|
||||
|
||||
const OpKernel* kernel() const { return kernel_.get(); }
|
||||
|
||||
|
@ -132,7 +132,7 @@ void BM_KernelAndDeviceRun(int iters) {
|
||||
nullptr, &kernel));
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr));
|
||||
TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr, nullptr, nullptr));
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_KernelAndDeviceRun);
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -489,6 +490,17 @@ struct TensorValue {
|
||||
Tensor* tensor;
|
||||
};
|
||||
|
||||
// Used to store partitioned graphs from function-calling ops.
|
||||
struct GraphCollector {
|
||||
mutex mu;
|
||||
std::vector<GraphDef> graphs GUARDED_BY(mu);
|
||||
|
||||
void CollectGraph(const GraphDef& graph) {
|
||||
mutex_lock ml(mu);
|
||||
graphs.push_back(graph);
|
||||
}
|
||||
};
|
||||
|
||||
class OpKernelContext {
|
||||
public:
|
||||
// The first element of a WrappedAllocator is a "base" Allocator and
|
||||
@ -589,6 +601,7 @@ class OpKernelContext {
|
||||
FunctionLibraryRuntime* function_library = nullptr;
|
||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||
StepStatsCollectorInterface* stats_collector = nullptr;
|
||||
GraphCollector* graph_collector = nullptr;
|
||||
|
||||
// TensorSliceReaderCache support.
|
||||
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
|
||||
@ -711,6 +724,9 @@ class OpKernelContext {
|
||||
// Usage: if (!context->ValidateInputsAreSameShape(this)) return;
|
||||
bool ValidateInputsAreSameShape(OpKernel* op);
|
||||
|
||||
// If non-null, kernels should populate with any partition subgraphs created.
|
||||
GraphCollector* graph_collector() { return params_->graph_collector; }
|
||||
|
||||
// Input to output forwarding.
|
||||
|
||||
// Set the output Ref Tensor at output_index to be an alias of the
|
||||
|
@ -183,6 +183,13 @@ class PartitionedCallOp : public AsyncOpKernel {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
|
||||
done);
|
||||
if (ctx->graph_collector() != nullptr) {
|
||||
for (const auto& pair : subgraphs) {
|
||||
GraphDef def;
|
||||
pair.second->ToGraphDef(&def);
|
||||
ctx->graph_collector()->CollectGraph(def);
|
||||
}
|
||||
}
|
||||
optimization_options.graph = nullptr;
|
||||
optimization_options.device_set = nullptr;
|
||||
optimization_options.partition_graphs = &subgraphs;
|
||||
|
@ -630,7 +630,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
return x * x
|
||||
|
||||
with ops.device('cpu:0'):
|
||||
f(constant_op.constant(1.0)) # pre-build the defun
|
||||
context.enable_run_metadata()
|
||||
f(constant_op.constant(1.0))
|
||||
run_metadata = context.export_run_metadata()
|
||||
@ -645,6 +644,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
# arbitrarily many (placeholders, return identities, etc, might be included
|
||||
# or not in the future, so shouldn't be tested for exactly.
|
||||
self.assertGreaterEqual(len(cpu_stats.node_stats), 2)
|
||||
self.assertEqual(len(run_metadata.partition_graphs), 1)
|
||||
|
||||
def testGraphModeCaptureVariable(self):
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user