Remove ExecutorImpl::graph_.

An executor no longer needs to keep a copy of the graph it is executing, after it has been initialized. This change also modifies all Executor factory methods to take a `const Graph&` instead of an `std::unique_ptr<const Graph>`.

PiperOrigin-RevId: 272099412
This commit is contained in:
Derek Murray 2019-09-30 17:08:00 -07:00 committed by TensorFlower Gardener
parent 6f9c242cd0
commit a0e1a4d1bc
17 changed files with 51 additions and 68 deletions

View File

@ -736,8 +736,8 @@ Status DirectSession::RunInternal(
std::unordered_map<string, const Graph*> device_to_graph; std::unordered_map<string, const Graph*> device_to_graph;
for (const PerPartitionExecutorsAndLib& partition : for (const PerPartitionExecutorsAndLib& partition :
executors_and_keys->items) { executors_and_keys->items) {
const Graph* graph = partition.graph; const Graph* graph = partition.graph.get();
const string device = partition.flib->device()->name(); const string& device = partition.flib->device()->name();
device_to_graph[device] = graph; device_to_graph[device] = graph;
} }
@ -748,7 +748,7 @@ Status DirectSession::RunInternal(
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph(); CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
for (const auto& item : executors_and_keys->items) { for (const auto& item : executors_and_keys->items) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph)); cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
} }
} }
@ -1353,13 +1353,12 @@ Status DirectSession::CreateExecutors(
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
device->name(), device->name(),
partition_graph.get())); partition_graph.get()));
// NewLocalExecutor takes ownership of partition_graph. item->graph = std::move(partition_graph);
item->graph = partition_graph.get();
item->executor = nullptr; item->executor = nullptr;
item->device = device; item->device = device;
auto executor_type = options_.config.experimental().executor_type(); auto executor_type = options_.config.experimental().executor_type();
TF_RETURN_IF_ERROR(NewExecutor( TF_RETURN_IF_ERROR(
executor_type, params, std::move(partition_graph), &item->executor)); NewExecutor(executor_type, params, *item->graph, &item->executor));
} }
// Cache the mapping from input/output names to graph elements to // Cache the mapping from input/output names to graph elements to

View File

@ -134,7 +134,7 @@ class DirectSession : public Session {
// We create one executor and its dependent library runtime for // We create one executor and its dependent library runtime for
// every partition. // every partition.
struct PerPartitionExecutorsAndLib { struct PerPartitionExecutorsAndLib {
Graph* graph = nullptr; // not owned. std::unique_ptr<Graph> graph = nullptr;
Device* device = nullptr; // not owned. Device* device = nullptr; // not owned.
FunctionLibraryRuntime* flib = nullptr; // not owned. FunctionLibraryRuntime* flib = nullptr; // not owned.
std::unique_ptr<Executor> executor; std::unique_ptr<Executor> executor;

View File

@ -300,12 +300,14 @@ class GraphView {
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id])); : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
} }
int32 num_nodes() const { return num_nodes_; }
private: private:
char* InitializeNode(char* ptr, const Node* n); char* InitializeNode(char* ptr, const Node* n);
size_t NodeItemBytes(const Node* n); size_t NodeItemBytes(const Node* n);
int32 num_nodes_ = 0; int32 num_nodes_ = 0;
uint32* node_offsets_ = nullptr; // array of size "graph_.num_node_ids()" uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
// node_offsets_[id] holds the byte offset for node w/ "id" in space_ // node_offsets_[id] holds the byte offset for node w/ "id" in space_
char* space_; // NodeItem objects are allocated here char* space_; // NodeItem objects are allocated here
@ -315,14 +317,13 @@ class GraphView {
class ExecutorImpl : public Executor { class ExecutorImpl : public Executor {
public: public:
ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g) explicit ExecutorImpl(const LocalExecutorParams& p) : params_(p), gview_() {
: params_(p), graph_(std::move(g)), gview_() {
CHECK(p.create_kernel != nullptr); CHECK(p.create_kernel != nullptr);
CHECK(p.delete_kernel != nullptr); CHECK(p.delete_kernel != nullptr);
} }
~ExecutorImpl() override { ~ExecutorImpl() override {
for (int i = 0; i < graph_->num_node_ids(); i++) { for (int32 i = 0; i < gview_.num_nodes(); i++) {
NodeItem* item = gview_.node(i); NodeItem* item = gview_.node(i);
if (item != nullptr) { if (item != nullptr) {
params_.delete_kernel(item->kernel); params_.delete_kernel(item->kernel);
@ -333,7 +334,7 @@ class ExecutorImpl : public Executor {
} }
} }
Status Initialize(); Status Initialize(const Graph& graph);
// Process all Nodes in the current graph, attempting to infer the // Process all Nodes in the current graph, attempting to infer the
// memory allocation attributes to be used wherever they may allocate // memory allocation attributes to be used wherever they may allocate
@ -394,7 +395,6 @@ class ExecutorImpl : public Executor {
// Owned. // Owned.
LocalExecutorParams params_; LocalExecutorParams params_;
std::unique_ptr<const Graph> graph_;
GraphView gview_; GraphView gview_;
// A cached value of params_ // A cached value of params_
@ -623,12 +623,12 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending,
*max_dead_count = num_in_edges; *max_dead_count = num_in_edges;
} }
Status ExecutorImpl::Initialize() { Status ExecutorImpl::Initialize(const Graph& graph) {
gview_.Initialize(graph_.get()); gview_.Initialize(&graph);
// Build the information about frames in this subgraph. // Build the information about frames in this subgraph.
ControlFlowInfo cf_info; ControlFlowInfo cf_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info)); TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
// Cache this value so we make this virtual function call once, rather // Cache this value so we make this virtual function call once, rather
// that O(# steps * # nodes per step) times. // that O(# steps * # nodes per step) times.
@ -641,7 +641,7 @@ Status ExecutorImpl::Initialize() {
// Preprocess every node in the graph to create an instance of op // Preprocess every node in the graph to create an instance of op
// kernel for each node. // kernel for each node.
for (const Node* n : graph_->nodes()) { for (const Node* n : graph.nodes()) {
const int id = n->id(); const int id = n->id();
const string& frame_name = cf_info.frame_names[id]; const string& frame_name = cf_info.frame_names[id];
FrameInfo* frame_info = EnsureFrameInfo(frame_name); FrameInfo* frame_info = EnsureFrameInfo(frame_name);
@ -707,9 +707,9 @@ Status ExecutorImpl::Initialize() {
// Initialize PendingCounts only after item->pending_id is initialized for // Initialize PendingCounts only after item->pending_id is initialized for
// all nodes. // all nodes.
InitializePending(graph_.get(), cf_info); InitializePending(&graph, cf_info);
return gview_.SetAllocAttrs(graph_.get(), params_.device); return gview_.SetAllocAttrs(&graph, params_.device);
} }
// If a Node has been marked to use a ScopedAllocator x for output i, then // If a Node has been marked to use a ScopedAllocator x for output i, then
@ -2914,11 +2914,10 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
} // namespace } // namespace
Status NewLocalExecutor(const LocalExecutorParams& params, Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<const Graph> graph,
Executor** executor) { Executor** executor) {
ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph)); ExecutorImpl* impl = new ExecutorImpl(params);
const Status s = impl->Initialize(); const Status s = impl->Initialize(graph);
if (s.ok()) { if (s.ok()) {
*executor = impl; *executor = impl;
} else { } else {
@ -2950,8 +2949,7 @@ class DefaultExecutorRegistrar {
private: private:
class Factory : public ExecutorFactory { class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params, Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<const Graph> graph,
std::unique_ptr<Executor>* out_executor) override { std::unique_ptr<Executor>* out_executor) override {
Executor* ret = nullptr; Executor* ret = nullptr;
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret)); TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));

View File

@ -147,8 +147,7 @@ struct LocalExecutorParams {
Executor::RendezvousFactory rendezvous_factory; Executor::RendezvousFactory rendezvous_factory;
}; };
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph, const Graph& graph, Executor** executor);
Executor** executor);
// A class to help run multiple executors in parallel and wait until // A class to help run multiple executors in parallel and wait until
// all of them are complete. // all of them are complete.

View File

@ -74,8 +74,7 @@ Status ExecutorFactory::GetFactory(const string& executor_type,
} }
Status NewExecutor(const string& executor_type, Status NewExecutor(const string& executor_type,
const LocalExecutorParams& params, const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<const Graph> graph,
std::unique_ptr<Executor>* out_executor) { std::unique_ptr<Executor>* out_executor) {
ExecutorFactory* factory = nullptr; ExecutorFactory* factory = nullptr;
TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory)); TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));

View File

@ -32,7 +32,7 @@ struct LocalExecutorParams;
class ExecutorFactory { class ExecutorFactory {
public: public:
virtual Status NewExecutor(const LocalExecutorParams& params, virtual Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph, const Graph& graph,
std::unique_ptr<Executor>* out_executor) = 0; std::unique_ptr<Executor>* out_executor) = 0;
virtual ~ExecutorFactory() {} virtual ~ExecutorFactory() {}
@ -42,8 +42,7 @@ class ExecutorFactory {
}; };
Status NewExecutor(const string& executor_type, Status NewExecutor(const string& executor_type,
const LocalExecutorParams& params, const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<const Graph> graph,
std::unique_ptr<Executor>* out_executor); std::unique_ptr<Executor>* out_executor);
} // namespace tensorflow } // namespace tensorflow

View File

@ -77,7 +77,7 @@ class ExecutorTest : public ::testing::Test {
return Status::OK(); return Status::OK();
}; };
delete exec_; delete exec_;
TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_)); TF_CHECK_OK(NewLocalExecutor(params, *graph, &exec_));
runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); }; runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
} }

View File

@ -395,7 +395,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
// object, and an executor is created for the graph. // object, and an executor is created for the graph.
struct Item { struct Item {
uint64 instantiation_counter = 0; uint64 instantiation_counter = 0;
const Graph* graph = nullptr; // Owned by exec. std::unique_ptr<const Graph> graph = nullptr;
const FunctionLibraryDefinition* lib_def = nullptr; // Not owned. const FunctionLibraryDefinition* lib_def = nullptr; // Not owned.
FunctionBody* func_graph = nullptr; FunctionBody* func_graph = nullptr;
Executor* exec = nullptr; Executor* exec = nullptr;
@ -952,14 +952,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
}; };
params.rendezvous_factory = (*item)->rendezvous_factory; params.rendezvous_factory = (*item)->rendezvous_factory;
params.session_metadata = session_metadata_; params.session_metadata = session_metadata_;
Graph* graph = g.get();
std::unique_ptr<Executor> exec; std::unique_ptr<Executor> exec;
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec)); TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
{ {
// Guard item since it is already inserted in items_. // Guard item since it is already inserted in items_.
mutex_lock l(mu_); mutex_lock l(mu_);
if ((*item)->exec == nullptr) { if ((*item)->exec == nullptr) {
(*item)->graph = graph; (*item)->graph = std::move(g);
(*item)->exec = exec.release(); (*item)->exec = exec.release();
} }
} }
@ -1230,7 +1229,7 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
Status s = GetOrCreateItem(local_handle, &item); Status s = GetOrCreateItem(local_handle, &item);
if (s.ok()) { if (s.ok()) {
return tensorflow::DebugString(item->graph); return tensorflow::DebugString(item->graph.get());
} else { } else {
return s.ToString(); return s.ToString();
} }

View File

@ -104,7 +104,7 @@ class FunctionTest : public ::testing::Test {
return Status::OK(); return Status::OK();
}; };
Executor* exec; Executor* exec;
TF_CHECK_OK(NewLocalExecutor(params, std::move(g), &exec)); TF_CHECK_OK(NewLocalExecutor(params, *g, &exec));
exec_.reset(exec); exec_.reset(exec);
} }
@ -603,8 +603,7 @@ class DummyExecutorRegistrar {
private: private:
class Factory : public ExecutorFactory { class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params, Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<const Graph> graph,
std::unique_ptr<Executor>* out_executor) override { std::unique_ptr<Executor>* out_executor) override {
return errors::Internal("This is a dummy."); return errors::Internal("This is a dummy.");
} }

View File

@ -171,8 +171,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
}; };
Executor* executor; Executor* executor;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(NewLocalExecutor(params, *graph_to_run, &executor));
NewLocalExecutor(params, std::move(graph_to_run), &executor));
std::unique_ptr<Executor> executor_unref(executor); std::unique_ptr<Executor> executor_unref(executor);
Executor::Args args; Executor::Args args;

View File

@ -88,16 +88,14 @@ Benchmark::Benchmark(const string& device, Graph* g,
if (init) { if (init) {
std::unique_ptr<Executor> init_exec; std::unique_ptr<Executor> init_exec;
TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init), TF_CHECK_OK(NewExecutor(executor_type, params, *init, &init_exec));
&init_exec));
Executor::Args args; Executor::Args args;
args.rendezvous = rendez_; args.rendezvous = rendez_;
args.runner = runner; args.runner = runner;
TF_CHECK_OK(init_exec->Run(args)); TF_CHECK_OK(init_exec->Run(args));
} }
TF_CHECK_OK( TF_CHECK_OK(NewExecutor(executor_type, params, *g, &exec_));
NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_));
} }
Benchmark::~Benchmark() { Benchmark::~Benchmark() {

View File

@ -74,7 +74,7 @@ GraphMgr::Item::~Item() {
for (const auto& unit : this->units) { for (const auto& unit : this->units) {
CHECK_NOTNULL(unit.device); CHECK_NOTNULL(unit.device);
if (!graph_mgr->skip_cost_models_) { if (!graph_mgr->skip_cost_models_) {
graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph); graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
} }
delete unit.root; delete unit.root;
unit.device->op_segment()->RemoveHold(this->session); unit.device->op_segment()->RemoveHold(this->session);
@ -277,13 +277,12 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
EnsureMemoryTypes(DeviceType(unit->device->device_type()), EnsureMemoryTypes(DeviceType(unit->device->device_type()),
unit->device->name(), subgraph.get())); unit->device->name(), subgraph.get()));
unit->graph = subgraph.get(); unit->graph = std::move(subgraph);
unit->build_cost_model = graph_options.build_cost_model(); unit->build_cost_model = graph_options.build_cost_model();
if (unit->build_cost_model > 0) { if (unit->build_cost_model > 0) {
skip_cost_models_ = false; skip_cost_models_ = false;
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
NewLocalExecutor(params, std::move(subgraph), &unit->root));
} }
return Status::OK(); return Status::OK();
} }
@ -552,14 +551,14 @@ void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
std::unordered_map<string, const Graph*> device_to_graph; std::unordered_map<string, const Graph*> device_to_graph;
for (const auto& unit : item->units) { for (const auto& unit : item->units) {
if (unit.build_cost_model > 0) { if (unit.build_cost_model > 0) {
device_to_graph[unit.device->name()] = unit.graph; device_to_graph[unit.device->name()] = unit.graph.get();
} }
} }
collector->BuildCostModel(&cost_model_manager_, device_to_graph); collector->BuildCostModel(&cost_model_manager_, device_to_graph);
if (cost_graph != nullptr) { if (cost_graph != nullptr) {
for (const auto& unit : item->units) { for (const auto& unit : item->units) {
cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph) cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
.IgnoreError(); .IgnoreError();
} }
} }

View File

@ -108,7 +108,7 @@ class GraphMgr {
typedef GraphMgr ME; typedef GraphMgr ME;
struct ExecutionUnit { struct ExecutionUnit {
Graph* graph = nullptr; // not owned. std::unique_ptr<Graph> graph = nullptr;
Device* device = nullptr; // not owned. Device* device = nullptr; // not owned.
Executor* root = nullptr; // not owned. Executor* root = nullptr; // not owned.
FunctionLibraryRuntime* lib = nullptr; // not owned. FunctionLibraryRuntime* lib = nullptr; // not owned.

View File

@ -475,7 +475,7 @@ Status DatasetOpsTestBase::RunFunction(
}; };
Executor* cur_exec; Executor* cur_exec;
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &cur_exec)); TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
exec.reset(cur_exec); exec.reset(cur_exec);
FunctionCallFrame frame(arg_types, ret_types); FunctionCallFrame frame(arg_types, ret_types);
TF_RETURN_IF_ERROR(frame.SetArgs(args)); TF_RETURN_IF_ERROR(frame.SetArgs(args));

View File

@ -361,12 +361,10 @@ class SingleThreadedExecutorRegistrar {
private: private:
class Factory : public ExecutorFactory { class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params, Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<const Graph> graph,
std::unique_ptr<Executor>* out_executor) override { std::unique_ptr<Executor>* out_executor) override {
Executor* ret; Executor* ret;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret));
NewSingleThreadedExecutor(params, std::move(graph), &ret));
out_executor->reset(ret); out_executor->reset(ret);
return Status::OK(); return Status::OK();
} }
@ -377,11 +375,9 @@ static SingleThreadedExecutorRegistrar registrar;
} // namespace } // namespace
Status NewSingleThreadedExecutor(const LocalExecutorParams& params, Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph, const Graph& graph, Executor** executor) {
Executor** executor) { auto impl = absl::make_unique<SingleThreadedExecutorImpl>(params);
std::unique_ptr<SingleThreadedExecutorImpl> impl = TF_RETURN_IF_ERROR(impl->Initialize(graph));
absl::make_unique<SingleThreadedExecutorImpl>(params);
TF_RETURN_IF_ERROR(impl->Initialize(*graph));
*executor = impl.release(); *executor = impl.release();
return Status::OK(); return Status::OK();
} }

View File

@ -53,8 +53,7 @@ namespace data {
// The single-threaded executor is primarily suitable for executing simple // The single-threaded executor is primarily suitable for executing simple
// TensorFlow functions, such as one might find in a `tf.data` pipeline. // TensorFlow functions, such as one might find in a `tf.data` pipeline.
Status NewSingleThreadedExecutor(const LocalExecutorParams& params, Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph, const Graph& graph, Executor** executor);
Executor** executor);
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -67,7 +67,7 @@ class ExecutorTest : public ::testing::Test {
DeleteNonCachedKernel(kernel); DeleteNonCachedKernel(kernel);
}; };
delete exec_; delete exec_;
TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_)); TF_CHECK_OK(NewSingleThreadedExecutor(params, *graph, &exec_));
runner_ = [](std::function<void()> fn) { fn(); }; runner_ = [](std::function<void()> fn) { fn(); };
rendez_ = NewLocalRendezvous(); rendez_ = NewLocalRendezvous();
} }