Remove ExecutorImpl::graph_
.
An executor no longer needs to keep a copy of the graph it is executing, after it has been initialized. This change also modifies all Executor factory methods to take a `const Graph&` instead of an `std::unique_ptr<const Graph>`. PiperOrigin-RevId: 272099412
This commit is contained in:
parent
6f9c242cd0
commit
a0e1a4d1bc
@ -736,8 +736,8 @@ Status DirectSession::RunInternal(
|
||||
std::unordered_map<string, const Graph*> device_to_graph;
|
||||
for (const PerPartitionExecutorsAndLib& partition :
|
||||
executors_and_keys->items) {
|
||||
const Graph* graph = partition.graph;
|
||||
const string device = partition.flib->device()->name();
|
||||
const Graph* graph = partition.graph.get();
|
||||
const string& device = partition.flib->device()->name();
|
||||
device_to_graph[device] = graph;
|
||||
}
|
||||
|
||||
@ -748,7 +748,7 @@ Status DirectSession::RunInternal(
|
||||
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
|
||||
for (const auto& item : executors_and_keys->items) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
|
||||
cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1353,13 +1353,12 @@ Status DirectSession::CreateExecutors(
|
||||
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
|
||||
device->name(),
|
||||
partition_graph.get()));
|
||||
// NewLocalExecutor takes ownership of partition_graph.
|
||||
item->graph = partition_graph.get();
|
||||
item->graph = std::move(partition_graph);
|
||||
item->executor = nullptr;
|
||||
item->device = device;
|
||||
auto executor_type = options_.config.experimental().executor_type();
|
||||
TF_RETURN_IF_ERROR(NewExecutor(
|
||||
executor_type, params, std::move(partition_graph), &item->executor));
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewExecutor(executor_type, params, *item->graph, &item->executor));
|
||||
}
|
||||
|
||||
// Cache the mapping from input/output names to graph elements to
|
||||
|
@ -134,7 +134,7 @@ class DirectSession : public Session {
|
||||
// We create one executor and its dependent library runtime for
|
||||
// every partition.
|
||||
struct PerPartitionExecutorsAndLib {
|
||||
Graph* graph = nullptr; // not owned.
|
||||
std::unique_ptr<Graph> graph = nullptr;
|
||||
Device* device = nullptr; // not owned.
|
||||
FunctionLibraryRuntime* flib = nullptr; // not owned.
|
||||
std::unique_ptr<Executor> executor;
|
||||
|
@ -300,12 +300,14 @@ class GraphView {
|
||||
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
|
||||
}
|
||||
|
||||
int32 num_nodes() const { return num_nodes_; }
|
||||
|
||||
private:
|
||||
char* InitializeNode(char* ptr, const Node* n);
|
||||
size_t NodeItemBytes(const Node* n);
|
||||
|
||||
int32 num_nodes_ = 0;
|
||||
uint32* node_offsets_ = nullptr; // array of size "graph_.num_node_ids()"
|
||||
uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
|
||||
// node_offsets_[id] holds the byte offset for node w/ "id" in space_
|
||||
|
||||
char* space_; // NodeItem objects are allocated here
|
||||
@ -315,14 +317,13 @@ class GraphView {
|
||||
|
||||
class ExecutorImpl : public Executor {
|
||||
public:
|
||||
ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g)
|
||||
: params_(p), graph_(std::move(g)), gview_() {
|
||||
explicit ExecutorImpl(const LocalExecutorParams& p) : params_(p), gview_() {
|
||||
CHECK(p.create_kernel != nullptr);
|
||||
CHECK(p.delete_kernel != nullptr);
|
||||
}
|
||||
|
||||
~ExecutorImpl() override {
|
||||
for (int i = 0; i < graph_->num_node_ids(); i++) {
|
||||
for (int32 i = 0; i < gview_.num_nodes(); i++) {
|
||||
NodeItem* item = gview_.node(i);
|
||||
if (item != nullptr) {
|
||||
params_.delete_kernel(item->kernel);
|
||||
@ -333,7 +334,7 @@ class ExecutorImpl : public Executor {
|
||||
}
|
||||
}
|
||||
|
||||
Status Initialize();
|
||||
Status Initialize(const Graph& graph);
|
||||
|
||||
// Process all Nodes in the current graph, attempting to infer the
|
||||
// memory allocation attributes to be used wherever they may allocate
|
||||
@ -394,7 +395,6 @@ class ExecutorImpl : public Executor {
|
||||
|
||||
// Owned.
|
||||
LocalExecutorParams params_;
|
||||
std::unique_ptr<const Graph> graph_;
|
||||
GraphView gview_;
|
||||
|
||||
// A cached value of params_
|
||||
@ -623,12 +623,12 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending,
|
||||
*max_dead_count = num_in_edges;
|
||||
}
|
||||
|
||||
Status ExecutorImpl::Initialize() {
|
||||
gview_.Initialize(graph_.get());
|
||||
Status ExecutorImpl::Initialize(const Graph& graph) {
|
||||
gview_.Initialize(&graph);
|
||||
|
||||
// Build the information about frames in this subgraph.
|
||||
ControlFlowInfo cf_info;
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info));
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
|
||||
|
||||
// Cache this value so we make this virtual function call once, rather
|
||||
// that O(# steps * # nodes per step) times.
|
||||
@ -641,7 +641,7 @@ Status ExecutorImpl::Initialize() {
|
||||
|
||||
// Preprocess every node in the graph to create an instance of op
|
||||
// kernel for each node.
|
||||
for (const Node* n : graph_->nodes()) {
|
||||
for (const Node* n : graph.nodes()) {
|
||||
const int id = n->id();
|
||||
const string& frame_name = cf_info.frame_names[id];
|
||||
FrameInfo* frame_info = EnsureFrameInfo(frame_name);
|
||||
@ -707,9 +707,9 @@ Status ExecutorImpl::Initialize() {
|
||||
|
||||
// Initialize PendingCounts only after item->pending_id is initialized for
|
||||
// all nodes.
|
||||
InitializePending(graph_.get(), cf_info);
|
||||
InitializePending(&graph, cf_info);
|
||||
|
||||
return gview_.SetAllocAttrs(graph_.get(), params_.device);
|
||||
return gview_.SetAllocAttrs(&graph, params_.device);
|
||||
}
|
||||
|
||||
// If a Node has been marked to use a ScopedAllocator x for output i, then
|
||||
@ -2914,11 +2914,10 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Status NewLocalExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
|
||||
Executor** executor) {
|
||||
ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
|
||||
const Status s = impl->Initialize();
|
||||
ExecutorImpl* impl = new ExecutorImpl(params);
|
||||
const Status s = impl->Initialize(graph);
|
||||
if (s.ok()) {
|
||||
*executor = impl;
|
||||
} else {
|
||||
@ -2950,8 +2949,7 @@ class DefaultExecutorRegistrar {
|
||||
|
||||
private:
|
||||
class Factory : public ExecutorFactory {
|
||||
Status NewExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
|
||||
std::unique_ptr<Executor>* out_executor) override {
|
||||
Executor* ret = nullptr;
|
||||
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
|
||||
|
@ -147,8 +147,7 @@ struct LocalExecutorParams {
|
||||
Executor::RendezvousFactory rendezvous_factory;
|
||||
};
|
||||
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
Executor** executor);
|
||||
const Graph& graph, Executor** executor);
|
||||
|
||||
// A class to help run multiple executors in parallel and wait until
|
||||
// all of them are complete.
|
||||
|
@ -74,8 +74,7 @@ Status ExecutorFactory::GetFactory(const string& executor_type,
|
||||
}
|
||||
|
||||
Status NewExecutor(const string& executor_type,
|
||||
const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
const LocalExecutorParams& params, const Graph& graph,
|
||||
std::unique_ptr<Executor>* out_executor) {
|
||||
ExecutorFactory* factory = nullptr;
|
||||
TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));
|
||||
|
@ -32,7 +32,7 @@ struct LocalExecutorParams;
|
||||
class ExecutorFactory {
|
||||
public:
|
||||
virtual Status NewExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
const Graph& graph,
|
||||
std::unique_ptr<Executor>* out_executor) = 0;
|
||||
virtual ~ExecutorFactory() {}
|
||||
|
||||
@ -42,8 +42,7 @@ class ExecutorFactory {
|
||||
};
|
||||
|
||||
Status NewExecutor(const string& executor_type,
|
||||
const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
const LocalExecutorParams& params, const Graph& graph,
|
||||
std::unique_ptr<Executor>* out_executor);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -77,7 +77,7 @@ class ExecutorTest : public ::testing::Test {
|
||||
return Status::OK();
|
||||
};
|
||||
delete exec_;
|
||||
TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_));
|
||||
TF_CHECK_OK(NewLocalExecutor(params, *graph, &exec_));
|
||||
runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
|
||||
}
|
||||
|
||||
|
@ -395,7 +395,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
// object, and an executor is created for the graph.
|
||||
struct Item {
|
||||
uint64 instantiation_counter = 0;
|
||||
const Graph* graph = nullptr; // Owned by exec.
|
||||
std::unique_ptr<const Graph> graph = nullptr;
|
||||
const FunctionLibraryDefinition* lib_def = nullptr; // Not owned.
|
||||
FunctionBody* func_graph = nullptr;
|
||||
Executor* exec = nullptr;
|
||||
@ -952,14 +952,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) {
|
||||
};
|
||||
params.rendezvous_factory = (*item)->rendezvous_factory;
|
||||
params.session_metadata = session_metadata_;
|
||||
Graph* graph = g.get();
|
||||
std::unique_ptr<Executor> exec;
|
||||
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
|
||||
TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, *g, &exec));
|
||||
{
|
||||
// Guard item since it is already inserted in items_.
|
||||
mutex_lock l(mu_);
|
||||
if ((*item)->exec == nullptr) {
|
||||
(*item)->graph = graph;
|
||||
(*item)->graph = std::move(g);
|
||||
(*item)->exec = exec.release();
|
||||
}
|
||||
}
|
||||
@ -1230,7 +1229,7 @@ string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
Status s = GetOrCreateItem(local_handle, &item);
|
||||
if (s.ok()) {
|
||||
return tensorflow::DebugString(item->graph);
|
||||
return tensorflow::DebugString(item->graph.get());
|
||||
} else {
|
||||
return s.ToString();
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ class FunctionTest : public ::testing::Test {
|
||||
return Status::OK();
|
||||
};
|
||||
Executor* exec;
|
||||
TF_CHECK_OK(NewLocalExecutor(params, std::move(g), &exec));
|
||||
TF_CHECK_OK(NewLocalExecutor(params, *g, &exec));
|
||||
exec_.reset(exec);
|
||||
}
|
||||
|
||||
@ -603,8 +603,7 @@ class DummyExecutorRegistrar {
|
||||
|
||||
private:
|
||||
class Factory : public ExecutorFactory {
|
||||
Status NewExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
|
||||
std::unique_ptr<Executor>* out_executor) override {
|
||||
return errors::Internal("This is a dummy.");
|
||||
}
|
||||
|
@ -171,8 +171,7 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
};
|
||||
|
||||
Executor* executor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewLocalExecutor(params, std::move(graph_to_run), &executor));
|
||||
TF_RETURN_IF_ERROR(NewLocalExecutor(params, *graph_to_run, &executor));
|
||||
std::unique_ptr<Executor> executor_unref(executor);
|
||||
|
||||
Executor::Args args;
|
||||
|
@ -88,16 +88,14 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
||||
|
||||
if (init) {
|
||||
std::unique_ptr<Executor> init_exec;
|
||||
TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init),
|
||||
&init_exec));
|
||||
TF_CHECK_OK(NewExecutor(executor_type, params, *init, &init_exec));
|
||||
Executor::Args args;
|
||||
args.rendezvous = rendez_;
|
||||
args.runner = runner;
|
||||
TF_CHECK_OK(init_exec->Run(args));
|
||||
}
|
||||
|
||||
TF_CHECK_OK(
|
||||
NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_));
|
||||
TF_CHECK_OK(NewExecutor(executor_type, params, *g, &exec_));
|
||||
}
|
||||
|
||||
Benchmark::~Benchmark() {
|
||||
|
@ -74,7 +74,7 @@ GraphMgr::Item::~Item() {
|
||||
for (const auto& unit : this->units) {
|
||||
CHECK_NOTNULL(unit.device);
|
||||
if (!graph_mgr->skip_cost_models_) {
|
||||
graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph);
|
||||
graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph.get());
|
||||
}
|
||||
delete unit.root;
|
||||
unit.device->op_segment()->RemoveHold(this->session);
|
||||
@ -277,13 +277,12 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
|
||||
TF_RETURN_IF_ERROR(
|
||||
EnsureMemoryTypes(DeviceType(unit->device->device_type()),
|
||||
unit->device->name(), subgraph.get()));
|
||||
unit->graph = subgraph.get();
|
||||
unit->graph = std::move(subgraph);
|
||||
unit->build_cost_model = graph_options.build_cost_model();
|
||||
if (unit->build_cost_model > 0) {
|
||||
skip_cost_models_ = false;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewLocalExecutor(params, std::move(subgraph), &unit->root));
|
||||
TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -552,14 +551,14 @@ void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
|
||||
std::unordered_map<string, const Graph*> device_to_graph;
|
||||
for (const auto& unit : item->units) {
|
||||
if (unit.build_cost_model > 0) {
|
||||
device_to_graph[unit.device->name()] = unit.graph;
|
||||
device_to_graph[unit.device->name()] = unit.graph.get();
|
||||
}
|
||||
}
|
||||
collector->BuildCostModel(&cost_model_manager_, device_to_graph);
|
||||
|
||||
if (cost_graph != nullptr) {
|
||||
for (const auto& unit : item->units) {
|
||||
cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph)
|
||||
cost_model_manager_.AddToCostGraphDef(unit.graph.get(), cost_graph)
|
||||
.IgnoreError();
|
||||
}
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ class GraphMgr {
|
||||
typedef GraphMgr ME;
|
||||
|
||||
struct ExecutionUnit {
|
||||
Graph* graph = nullptr; // not owned.
|
||||
std::unique_ptr<Graph> graph = nullptr;
|
||||
Device* device = nullptr; // not owned.
|
||||
Executor* root = nullptr; // not owned.
|
||||
FunctionLibraryRuntime* lib = nullptr; // not owned.
|
||||
|
@ -475,7 +475,7 @@ Status DatasetOpsTestBase::RunFunction(
|
||||
};
|
||||
|
||||
Executor* cur_exec;
|
||||
TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &cur_exec));
|
||||
TF_RETURN_IF_ERROR(NewLocalExecutor(params, *g, &cur_exec));
|
||||
exec.reset(cur_exec);
|
||||
FunctionCallFrame frame(arg_types, ret_types);
|
||||
TF_RETURN_IF_ERROR(frame.SetArgs(args));
|
||||
|
@ -361,12 +361,10 @@ class SingleThreadedExecutorRegistrar {
|
||||
|
||||
private:
|
||||
class Factory : public ExecutorFactory {
|
||||
Status NewExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
|
||||
std::unique_ptr<Executor>* out_executor) override {
|
||||
Executor* ret;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NewSingleThreadedExecutor(params, std::move(graph), &ret));
|
||||
TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret));
|
||||
out_executor->reset(ret);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -377,11 +375,9 @@ static SingleThreadedExecutorRegistrar registrar;
|
||||
} // namespace
|
||||
|
||||
Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
Executor** executor) {
|
||||
std::unique_ptr<SingleThreadedExecutorImpl> impl =
|
||||
absl::make_unique<SingleThreadedExecutorImpl>(params);
|
||||
TF_RETURN_IF_ERROR(impl->Initialize(*graph));
|
||||
const Graph& graph, Executor** executor) {
|
||||
auto impl = absl::make_unique<SingleThreadedExecutorImpl>(params);
|
||||
TF_RETURN_IF_ERROR(impl->Initialize(graph));
|
||||
*executor = impl.release();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -53,8 +53,7 @@ namespace data {
|
||||
// The single-threaded executor is primarily suitable for executing simple
|
||||
// TensorFlow functions, such as one might find in a `tf.data` pipeline.
|
||||
Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
|
||||
std::unique_ptr<const Graph> graph,
|
||||
Executor** executor);
|
||||
const Graph& graph, Executor** executor);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -67,7 +67,7 @@ class ExecutorTest : public ::testing::Test {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
};
|
||||
delete exec_;
|
||||
TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_));
|
||||
TF_CHECK_OK(NewSingleThreadedExecutor(params, *graph, &exec_));
|
||||
runner_ = [](std::function<void()> fn) { fn(); };
|
||||
rendez_ = NewLocalRendezvous();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user