Add Reset implementation for DirectSession.
- Reset clears and closes the specified containers for ALL DirectSession objects. - Add closed bit to DirectSession to ensure that operations that occur after Close is called fail. Change: 131889161
This commit is contained in:
parent
e11b99749d
commit
62c159ffe8
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -113,6 +112,77 @@ string GetRendezvousKey(const string& tensor_name,
|
||||
|
||||
} // namespace
|
||||
|
||||
class DirectSessionFactory : public SessionFactory {
|
||||
public:
|
||||
DirectSessionFactory() {}
|
||||
|
||||
bool AcceptsOptions(const SessionOptions& options) override {
|
||||
return options.target.empty();
|
||||
}
|
||||
|
||||
Session* NewSession(const SessionOptions& options) override {
|
||||
// Must do this before the CPU allocator is created.
|
||||
if (options.config.graph_options().build_cost_model() > 0) {
|
||||
EnableCPUAllocatorFullStats(true);
|
||||
}
|
||||
std::vector<Device*> devices;
|
||||
Status s = DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DirectSession* session =
|
||||
new DirectSession(options, new DeviceMgr(devices), this);
|
||||
{
|
||||
mutex_lock l(sessions_lock_);
|
||||
sessions_.push_back(session);
|
||||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
Status Reset(const SessionOptions& options,
|
||||
const std::vector<string>& containers) override {
|
||||
std::vector<DirectSession*> sessions_to_reset;
|
||||
{
|
||||
mutex_lock l(sessions_lock_);
|
||||
// We create a copy to ensure that we don't have a deadlock when
|
||||
// session->Close calls the DirectSessionFactory.Deregister, which
|
||||
// acquires sessions_lock_.
|
||||
std::swap(sessions_to_reset, sessions_);
|
||||
}
|
||||
Status s;
|
||||
for (auto session : sessions_to_reset) {
|
||||
s.Update(session->Reset(containers));
|
||||
}
|
||||
// TODO(suharshs): Change the Reset behavior of all SessionFactories so that
|
||||
// it doesn't close the sessions?
|
||||
for (auto session : sessions_to_reset) {
|
||||
s.Update(session->Close());
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
void Deregister(const DirectSession* session) {
|
||||
mutex_lock l(sessions_lock_);
|
||||
sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
|
||||
sessions_.end());
|
||||
}
|
||||
|
||||
private:
|
||||
mutex sessions_lock_;
|
||||
std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
|
||||
};
|
||||
|
||||
class DirectSessionRegistrar {
|
||||
public:
|
||||
DirectSessionRegistrar() {
|
||||
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
|
||||
}
|
||||
};
|
||||
static DirectSessionRegistrar registrar;
|
||||
|
||||
std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
|
||||
|
||||
// NOTE: On Android with a single device, there is never
|
||||
@ -146,10 +216,13 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
|
||||
}
|
||||
|
||||
DirectSession::DirectSession(const SessionOptions& options,
|
||||
const DeviceMgr* device_mgr)
|
||||
const DeviceMgr* device_mgr,
|
||||
DirectSessionFactory* const factory)
|
||||
: options_(options),
|
||||
device_mgr_(device_mgr),
|
||||
factory_(factory),
|
||||
cancellation_manager_(new CancellationManager()),
|
||||
closed_(false),
|
||||
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
|
||||
if (options_.config.session_inter_op_thread_pool_size() > 0) {
|
||||
for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
|
||||
@ -194,6 +267,7 @@ DirectSession::DirectSession(const SessionOptions& options,
|
||||
}
|
||||
|
||||
DirectSession::~DirectSession() {
|
||||
if (!closed_) Close();
|
||||
for (auto& it : partial_runs_) {
|
||||
it.second.reset(nullptr);
|
||||
}
|
||||
@ -237,6 +311,7 @@ Status DirectSession::Create(const GraphDef& graph) {
|
||||
}
|
||||
|
||||
Status DirectSession::Extend(const GraphDef& graph) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
mutex_lock l(graph_def_lock_);
|
||||
return ExtendLocked(graph);
|
||||
}
|
||||
@ -267,6 +342,7 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
const std::vector<string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
direct_session_runs->GetCell()->IncrementBy(1);
|
||||
{
|
||||
mutex_lock l(graph_def_lock_);
|
||||
@ -412,6 +488,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
{
|
||||
mutex_lock l(graph_def_lock_);
|
||||
if (!graph_created_) {
|
||||
@ -487,6 +564,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
|
||||
Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) {
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
std::vector<string> parts = str_util::Split(handle, ';');
|
||||
const string& key = parts[0];
|
||||
// Get the executors for this partial run.
|
||||
@ -1002,8 +1080,20 @@ Status DirectSession::CreateGraphs(
|
||||
return s;
|
||||
}
|
||||
|
||||
::tensorflow::Status DirectSession::Reset(
|
||||
const std::vector<string>& containers) {
|
||||
device_mgr_->ClearContainers(containers);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
::tensorflow::Status DirectSession::Close() {
|
||||
cancellation_manager_->StartCancel();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (closed_) return ::tensorflow::Status::OK();
|
||||
closed_ = true;
|
||||
}
|
||||
if (factory_ != nullptr) factory_->Deregister(this);
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
@ -1051,37 +1141,4 @@ void DirectSession::WaitForNotification(RunState* run_state,
|
||||
}
|
||||
}
|
||||
|
||||
class DirectSessionFactory : public SessionFactory {
|
||||
public:
|
||||
DirectSessionFactory() {}
|
||||
|
||||
bool AcceptsOptions(const SessionOptions& options) override {
|
||||
return options.target.empty();
|
||||
}
|
||||
|
||||
Session* NewSession(const SessionOptions& options) override {
|
||||
// Must do this before the CPU allocator is created.
|
||||
if (options.config.graph_options().build_cost_model() > 0) {
|
||||
EnableCPUAllocatorFullStats(true);
|
||||
}
|
||||
std::vector<Device*> devices;
|
||||
Status s = DeviceFactory::AddDevices(
|
||||
options, "/job:localhost/replica:0/task:0", &devices);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new DirectSession(options, new DeviceMgr(devices));
|
||||
}
|
||||
};
|
||||
|
||||
class DirectSessionRegistrar {
|
||||
public:
|
||||
DirectSessionRegistrar() {
|
||||
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
|
||||
}
|
||||
};
|
||||
static DirectSessionRegistrar registrar;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
|
||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
@ -47,11 +48,18 @@ namespace tensorflow {
|
||||
class CostModel;
|
||||
class DebugGateway;
|
||||
class Device;
|
||||
class DirectSessionFactory;
|
||||
|
||||
class DirectSession : public Session {
|
||||
public:
|
||||
typedef std::function<void(Session*)> CloseCallback;
|
||||
|
||||
// Takes ownership of 'device_mgr'.
|
||||
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr);
|
||||
// 'factory' is used to unregister the DirectSession with 'factory' when its
|
||||
// closed. This ensures that Reset requests from the 'factory' don't get sent
|
||||
// to sessions that are already closed.
|
||||
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
|
||||
DirectSessionFactory* factory);
|
||||
~DirectSession() override;
|
||||
|
||||
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
||||
@ -83,6 +91,10 @@ class DirectSession : public Session {
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) override;
|
||||
|
||||
// Reset clears 'containers' from the device_mgr of the DirectSession.
|
||||
// If 'containers' is empty, then Reset clears the default container.
|
||||
::tensorflow::Status Reset(const std::vector<string>& containers);
|
||||
|
||||
::tensorflow::Status Close() override;
|
||||
|
||||
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
|
||||
@ -198,6 +210,12 @@ class DirectSession : public Session {
|
||||
// operation_timeout_in_ms is greater than 0.
|
||||
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
|
||||
|
||||
::tensorflow::Status CheckNotClosed() {
|
||||
mutex_lock l(mu_);
|
||||
if (closed_) return errors::Cancelled("Session has been closed.");
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
const SessionOptions options_;
|
||||
|
||||
// Device structures.
|
||||
@ -232,10 +250,12 @@ class DirectSession : public Session {
|
||||
// This holds all the tensors that are currently alive in the session.
|
||||
SessionState session_state_;
|
||||
|
||||
DirectSessionFactory* const factory_; // not owned
|
||||
CancellationManager* cancellation_manager_;
|
||||
|
||||
// Saves and restores device placements for stateful nodes.
|
||||
mutex mu_;
|
||||
|
||||
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
|
||||
// is true, such as "params" and "queue" nodes. Once placed these
|
||||
// nodes can not be moved to a different device. Maps node names to
|
||||
@ -251,6 +271,9 @@ class DirectSession : public Session {
|
||||
// library; it copies and modifies the function library.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
|
||||
// true if the Session has been Closed.
|
||||
bool closed_ GUARDED_BY(mu_);
|
||||
|
||||
// For generating unique names.
|
||||
int64 name_counter_ GUARDED_BY(mu_) = 0;
|
||||
|
||||
|
@ -970,5 +970,129 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, TestDirectSessionRunClose) {
|
||||
// Construct a graph with a variable and a single assign.
|
||||
Graph g(OpRegistry::Global());
|
||||
Tensor t(DT_FLOAT, TensorShape({}));
|
||||
t.scalar<float>()() = {1.2};
|
||||
Node* var_val = test::graph::Constant(&g, t);
|
||||
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
// Assign a value to the var.
|
||||
TF_ASSERT_OK(session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr));
|
||||
|
||||
// Run a read on the variable to ensure that it works.
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run(
|
||||
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
|
||||
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
|
||||
outputs.clear();
|
||||
|
||||
// Close the session.
|
||||
session->Close();
|
||||
|
||||
// Run the read on the variable to get an error.
|
||||
Status s = session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, TestDirectSessionPRunClose) {
|
||||
GraphDef def;
|
||||
Graph g(OpRegistry::Global());
|
||||
|
||||
Tensor first_value(DT_FLOAT, TensorShape({}));
|
||||
first_value.scalar<float>()() = 1.0;
|
||||
Node* first_const = test::graph::Constant(&g, first_value);
|
||||
Node* first_identity = test::graph::Identity(&g, first_const);
|
||||
|
||||
Tensor second_value(DT_FLOAT, TensorShape({}));
|
||||
second_value.scalar<float>()() = 2.0;
|
||||
Node* second_const = test::graph::Constant(&g, second_value);
|
||||
Node* second_identity = test::graph::Identity(&g, second_const);
|
||||
|
||||
Node* third = test::graph::Add(&g, first_identity, second_identity);
|
||||
Node* third_identity = test::graph::Identity(&g, third);
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
string handle;
|
||||
Status s = session->PRunSetup(
|
||||
{first_const->name(), second_const->name()},
|
||||
{first_identity->name() + ":0", second_identity->name() + ":0",
|
||||
third_identity->name() + ":0"},
|
||||
{}, &handle);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
Tensor value_11(DT_FLOAT, TensorShape({}));
|
||||
value_11.scalar<float>()() = 11.0;
|
||||
Tensor value_22(DT_FLOAT, TensorShape({}));
|
||||
value_22.scalar<float>()() = 22.0;
|
||||
|
||||
// Close the session.
|
||||
session->Close();
|
||||
|
||||
// Feed first_const, fetch first_identity
|
||||
s = session->PRun(handle, {{first_const->name(), value_11}},
|
||||
{first_identity->name() + ":0"}, &outputs);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
TEST(DirectSessionTest, TestDirectSessionReset) {
|
||||
// Construct a graph with a variable and a single assign.
|
||||
Graph g(OpRegistry::Global());
|
||||
Tensor t(DT_FLOAT, TensorShape({}));
|
||||
t.scalar<float>()() = {1.2};
|
||||
Node* var_val = test::graph::Constant(&g, t);
|
||||
Node* var = test::graph::Var(&g, DT_FLOAT, {});
|
||||
Node* var_assign = test::graph::Assign(&g, var, var_val);
|
||||
GraphDef def;
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
// Assign a value to the var.
|
||||
TF_ASSERT_OK(session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr));
|
||||
|
||||
// Run a read on the variable to ensure that it works.
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run(
|
||||
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
|
||||
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
|
||||
outputs.clear();
|
||||
|
||||
// Reset the containers.
|
||||
Reset(options, {});
|
||||
|
||||
// Run the read on the variable to get an error.
|
||||
// TODO(suharshs): This test only works because we close the Session in Reset.
|
||||
// If we change the behavior of Reset to not close the Session, this test will
|
||||
// fail, since the Variable buffer is cached by var.
|
||||
Status s = session->Run({} /* inputs */, {},
|
||||
{var_assign->name()} /* target_nodes */, nullptr);
|
||||
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user