Remove ExecutorImpl::graph_.

An executor no longer needs to keep a copy of the graph it is executing, after it has been initialized. This change also modifies all Executor factory methods to take a `const Graph&` instead of an `std::unique_ptr<const Graph>`.

PiperOrigin-RevId: 272099412
This commit is contained in:
Derek Murray 2019-09-30 17:08:00 -07:00 committed by TensorFlower Gardener
parent 6f9c242cd0
commit a0e1a4d1bc
17 changed files with 51 additions and 68 deletions

View File

@ -736,8 +736,8 @@ Status DirectSession::RunInternal(
std::unordered_map<string, const Graph*> device_to_graph;
for (const PerPartitionExecutorsAndLib& partition :
executors_and_keys->items) {
const Graph* graph = partition.graph;
const string device = partition.flib->device()->name();
const Graph* graph = partition.graph.get();
const string& device = partition.flib->device()->name();
device_to_graph[device] = graph;
}
@ -748,7 +748,7 @@ Status DirectSession::RunInternal(
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
for (const auto& item : executors_and_keys->items) {
TF_RETURN_IF_ERROR(
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
}
}
@ -1353,13 +1353,12 @@ Status DirectSession::CreateExecutors(
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
device->name(),
partition_graph.get()));
// NewLocalExecutor takes ownership of partition_graph.
item->graph = partition_graph.get();
item->graph = std::move(partition_graph);
item->executor = nullptr;
item->device = device;
auto executor_type = options_.config.experimental().executor_type();
TF_RETURN_IF_ERROR(NewExecutor(
executor_type, params, std::move(partition_graph), &item->executor));
TF_RETURN_IF_ERROR(
NewExecutor(executor_type, params, *item->graph, &item->executor));
}
// Cache the mapping from input/output names to graph elements to

View File

@ -134,7 +134,7 @@ class DirectSession : public Session {
// We create one executor and its dependent library runtime for
// every partition.
struct PerPartitionExecutorsAndLib {
Graph* graph = nullptr; // not owned.
std::unique_ptr<Graph> graph = nullptr;
Device* device = nullptr; // not owned.
FunctionLibraryRuntime* flib = nullptr; // not owned.
std::unique_ptr<Executor> executor;

View File

@ -300,12 +300,14 @@ class GraphView {
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
}
int32 num_nodes() const { return num_nodes_; }
private:
char* InitializeNode(char* ptr, const Node* n);
size_t NodeItemBytes(const Node* n);
int32 num_nodes_ = 0;
uint32* node_offsets_ = nullptr; // array of size "graph_.num_node_ids()"
uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
// node_offsets_[id] holds the byte offset for node w/ "id" in space_
char* space_; // NodeItem objects are allocated here
@ -315,14 +317,13 @@ class GraphView {
class ExecutorImpl : public Executor {
public:
ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g)
: params_(p), graph_(std::move(g)), gview_() {
explicit ExecutorImpl(const LocalExecutorParams& p) : params_(p), gview_() {
CHECK(p.create_kernel != nullptr);
CHECK(p.delete_kernel != nullptr);
}
~ExecutorImpl() override {
for (int i = 0; i < graph_->num_node_ids(); i++) {
for (int32 i = 0; i < gview_.num_nodes(); i++) {
NodeItem* item = gview_.node(i);
if (item != nullptr) {
params_.delete_kernel(item->kernel);
@ -333,7 +334,7 @@ class ExecutorImpl : public Executor {
}
}
Status Initialize();
Status Initialize(const Graph& graph);
// Process all Nodes in the current graph, attempting to infer the
// memory allocation attributes to be used wherever they may allocate
@ -394,7 +395,6 @@ class ExecutorImpl : public Executor {
// Owned.
LocalExecutorParams params_;
std::unique_ptr<const Graph> graph_;
GraphView gview_;
// A cached value of params_
@ -623,12 +623,12 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending,
*max_dead_count = num_in_edges;
}
Status ExecutorImpl::Initialize() {
gview_.Initialize(graph_.get());
Status ExecutorImpl::Initialize(const Graph& graph) {
gview_.Initialize(&graph);
// Build the information about frames in this subgraph.
ControlFlowInfo cf_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info));
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
// Cache this value so we make this virtual function call once, rather
// that O(# steps * # nodes per step) times.
@ -641,7 +641,7 @@ Status ExecutorImpl::Initialize() {
// Preprocess every node in the graph to create an instance of op
// kernel for each node.
for (const Node* n : graph_->nodes()) {
for (const Node* n : graph.nodes()) {
const int id = n->id();
const string& frame_name = cf_info.frame_names[id];
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
@ -707,9 +707,9 @@ Status ExecutorImpl::Initialize() {
// Initialize PendingCounts only after item->pending_id is initialized for
// all nodes.
InitializePending(graph_.get(), cf_info);
InitializePending(&graph, cf_info);
return gview_.SetAllocAttrs(graph_.get(), params_.device);
return gview_.SetAllocAttrs(&graph, params_.device);
}
// If a Node has been marked to use a ScopedAllocator x for output i, then
@ -2914,11 +2914,10 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
} // namespace
Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
Executor** executor) {
ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
const Status s = impl->Initialize();
ExecutorImpl* impl = new ExecutorImpl(params);
const Status s = impl->Initialize(graph);
if (s.ok()) {
*executor = impl;
} else {
@ -2950,8 +2949,7 @@ class DefaultExecutorRegistrar {
private:
class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) override {
Executor* ret = nullptr;
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));

View File

@ -147,8 +147,7 @@ struct LocalExecutorParams {
Executor::RendezvousFactory rendezvous_factory;
};
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Executor** executor);
const Graph& graph, Executor** executor);
// A class to help run multiple executors in parallel and wait until
// all of them are complete.

View File

@ -74,8 +74,7 @@ Status ExecutorFactory::GetFactory(const string& executor_type,
}
Status NewExecutor(const string& executor_type,
const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) {
ExecutorFactory* factory = nullptr;
TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));

View File

@ -32,7 +32,7 @@ struct LocalExecutorParams;
class ExecutorFactory {
public:
virtual Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
const Graph& graph,
std::unique_ptr<Executor>* out_executor) = 0;
virtual ~ExecutorFactory() {}
@ -42,8 +42,7 @@ class ExecutorFactory {
};
Status NewExecutor(const string& executor_type,
const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor);
} // namespace tensorflow

View File

@ -77,7 +77,7 @@ class ExecutorTest : public ::testing::Test {
return Status::OK();
};
delete exec_;
TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_));
TF_CHECK_OK(NewLocalExecutor(params, *graph, &exec_));
runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
}

View File

@ -395,7 +395,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
// object, and an executor is created for the graph.
struct Item {
uint64 instantiation_counter = 0;
const Graph* graph = nullptr; // Owned by exec.
std::unique_ptr<const Graph> graph = nullptr;
const FunctionLibraryDefinition* lib_def = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
@ -952,14 +952,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
};
params.rendezvous_factory = (*item)->rendezvous_factory;
params.session_metadata = session_metadata_;
Graph* graph = g.get();
std::unique_ptr<Executor> exec;
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
{
// Guard item since it is already inserted in items_.
mutex_lock l(mu_);
if ((*item)->exec == nullptr) {
(*item)->graph = graph;
(*item)->graph = std::move(g);
(*item)->exec = exec.release();
}
}
@ -1230,7 +1229,7 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
Status s = GetOrCreateItem(local_handle, &item);
if (s.ok()) {
return tensorflow::DebugString(item->graph);
return tensorflow::DebugString(item->graph.get());
} else {
return s.ToString();
}

View File

@ -104,7 +104,7 @@ class FunctionTest : public ::testing::Test {
return Status::OK();
};
Executor* exec;
TF_CHECK_OK(NewLocalExecutor(params, std::move(g), &exec));
TF_CHECK_OK(NewLocalExecutor(params, *g, &exec));
exec_.reset(exec);
}
@ -603,8 +603,7 @@ class DummyExecutorRegistrar {
private:
class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) override {
return errors::Internal("This is a dummy.");
}

View File

@ -171,8 +171,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
};
Executor* executor;
TF_RETURN_IF_ERROR(
NewLocalExecutor(params, std::move(graph_to_run), &executor));
TF_RETURN_IF_ERROR(NewLocalExecutor(params, *graph_to_run, &executor));
std::unique_ptr<Executor> executor_unref(executor);
Executor::Args args;

View File

@ -88,16 +88,14 @@ Benchmark::Benchmark(const string& device, Graph* g,
if (init) {
std::unique_ptr<Executor> init_exec;
TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init),
&init_exec));
TF_CHECK_OK(NewExecutor(executor_type, params, *init, &init_exec));
Executor::Args args;
args.rendezvous = rendez_;
args.runner = runner;
TF_CHECK_OK(init_exec->Run(args));
}
TF_CHECK_OK(
NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_));
TF_CHECK_OK(NewExecutor(executor_type, params, *g, &exec_));
}
Benchmark::~Benchmark() {

View File

@ -74,7 +74,7 @@ GraphMgr::Item::~Item() {
for (const auto& unit : this->units) {
CHECK_NOTNULL(unit.device);
if (!graph_mgr->skip_cost_models_) {
graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph);
graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
}
delete unit.root;
unit.device->op_segment()->RemoveHold(this->session);
@ -277,13 +277,12 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
TF_RETURN_IF_ERROR(
EnsureMemoryTypes(DeviceType(unit->device->device_type()),
unit->device->name(), subgraph.get()));
unit->graph = subgraph.get();
unit->graph = std::move(subgraph);
unit->build_cost_model = graph_options.build_cost_model();
if (unit->build_cost_model > 0) {
skip_cost_models_ = false;
}
TF_RETURN_IF_ERROR(
NewLocalExecutor(params, std::move(subgraph), &unit->root));
TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
}
return Status::OK();
}
@ -552,14 +551,14 @@ void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
std::unordered_map<string, const Graph*> device_to_graph;
for (const auto& unit : item->units) {
if (unit.build_cost_model > 0) {
device_to_graph[unit.device->name()] = unit.graph;
device_to_graph[unit.device->name()] = unit.graph.get();
}
}
collector->BuildCostModel(&cost_model_manager_, device_to_graph);
if (cost_graph != nullptr) {
for (const auto& unit : item->units) {
cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph)
cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
.IgnoreError();
}
}

View File

@ -108,7 +108,7 @@ class GraphMgr {
typedef GraphMgr ME;
struct ExecutionUnit {
Graph* graph = nullptr; // not owned.
std::unique_ptr<Graph> graph = nullptr;
Device* device = nullptr; // not owned.
Executor* root = nullptr; // not owned.
FunctionLibraryRuntime* lib = nullptr; // not owned.

View File

@ -475,7 +475,7 @@ Status DatasetOpsTestBase::RunFunction(
};
Executor* cur_exec;
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &cur_exec));
TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
exec.reset(cur_exec);
FunctionCallFrame frame(arg_types, ret_types);
TF_RETURN_IF_ERROR(frame.SetArgs(args));

View File

@ -361,12 +361,10 @@ class SingleThreadedExecutorRegistrar {
private:
class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) override {
Executor* ret;
TF_RETURN_IF_ERROR(
NewSingleThreadedExecutor(params, std::move(graph), &ret));
TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret));
out_executor->reset(ret);
return Status::OK();
}
@ -377,11 +375,9 @@ static SingleThreadedExecutorRegistrar registrar;
} // namespace
Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Executor** executor) {
std::unique_ptr<SingleThreadedExecutorImpl> impl =
absl::make_unique<SingleThreadedExecutorImpl>(params);
TF_RETURN_IF_ERROR(impl->Initialize(*graph));
const Graph& graph, Executor** executor) {
auto impl = absl::make_unique<SingleThreadedExecutorImpl>(params);
TF_RETURN_IF_ERROR(impl->Initialize(graph));
*executor = impl.release();
return Status::OK();
}

View File

@ -53,8 +53,7 @@ namespace data {
// The single-threaded executor is primarily suitable for executing simple
// TensorFlow functions, such as one might find in a `tf.data` pipeline.
Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
std::unique_ptr<const Graph> graph,
Executor** executor);
const Graph& graph, Executor** executor);
} // namespace data
} // namespace tensorflow

View File

@ -67,7 +67,7 @@ class ExecutorTest : public ::testing::Test {
DeleteNonCachedKernel(kernel);
};
delete exec_;
TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_));
TF_CHECK_OK(NewSingleThreadedExecutor(params, *graph, &exec_));
runner_ = [](std::function<void()> fn) { fn(); };
rendez_ = NewLocalRendezvous();
}