Add Reset implementation for DirectSession.

- Reset clears and closes the specified containers for ALL DirectSession objects.
- Add closed bit to DirectSession to ensure that operations that occur after Close is called fail.
Change: 131889161
This commit is contained in:
Suharsh Sivakumar 2016-08-31 15:43:07 -08:00 committed by TensorFlower Gardener
parent e11b99749d
commit 62c159ffe8
3 changed files with 240 additions and 36 deletions

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
@ -113,6 +112,77 @@ string GetRendezvousKey(const string& tensor_name,
} // namespace
class DirectSessionFactory : public SessionFactory {
public:
DirectSessionFactory() {}
bool AcceptsOptions(const SessionOptions& options) override {
return options.target.empty();
}
Session* NewSession(const SessionOptions& options) override {
// Must do this before the CPU allocator is created.
if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true);
}
std::vector<Device*> devices;
Status s = DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices);
if (!s.ok()) {
LOG(ERROR) << s;
return nullptr;
}
DirectSession* session =
new DirectSession(options, new DeviceMgr(devices), this);
{
mutex_lock l(sessions_lock_);
sessions_.push_back(session);
}
return session;
}
Status Reset(const SessionOptions& options,
const std::vector<string>& containers) override {
std::vector<DirectSession*> sessions_to_reset;
{
mutex_lock l(sessions_lock_);
// We create a copy to ensure that we don't have a deadlock when
// session->Close calls the DirectSessionFactory.Deregister, which
// acquires sessions_lock_.
std::swap(sessions_to_reset, sessions_);
}
Status s;
for (auto session : sessions_to_reset) {
s.Update(session->Reset(containers));
}
// TODO(suharshs): Change the Reset behavior of all SessionFactories so that
// it doesn't close the sessions?
for (auto session : sessions_to_reset) {
s.Update(session->Close());
}
return s;
}
void Deregister(const DirectSession* session) {
mutex_lock l(sessions_lock_);
sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
sessions_.end());
}
private:
mutex sessions_lock_;
std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
};
class DirectSessionRegistrar {
public:
DirectSessionRegistrar() {
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
}
};
static DirectSessionRegistrar registrar;
std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
// NOTE: On Android with a single device, there is never
@ -146,10 +216,13 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
}
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr)
const DeviceMgr* device_mgr,
DirectSessionFactory* const factory)
: options_(options),
device_mgr_(device_mgr),
factory_(factory),
cancellation_manager_(new CancellationManager()),
closed_(false),
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
if (options_.config.session_inter_op_thread_pool_size() > 0) {
for (int i = 0; i < options_.config.session_inter_op_thread_pool_size();
@ -194,6 +267,7 @@ DirectSession::DirectSession(const SessionOptions& options,
}
DirectSession::~DirectSession() {
if (!closed_) Close();
for (auto& it : partial_runs_) {
it.second.reset(nullptr);
}
@ -237,6 +311,7 @@ Status DirectSession::Create(const GraphDef& graph) {
}
Status DirectSession::Extend(const GraphDef& graph) {
TF_RETURN_IF_ERROR(CheckNotClosed());
mutex_lock l(graph_def_lock_);
return ExtendLocked(graph);
}
@ -267,6 +342,7 @@ Status DirectSession::Run(const RunOptions& run_options,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
TF_RETURN_IF_ERROR(CheckNotClosed());
direct_session_runs->GetCell()->IncrementBy(1);
{
mutex_lock l(graph_def_lock_);
@ -412,6 +488,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) {
TF_RETURN_IF_ERROR(CheckNotClosed());
{
mutex_lock l(graph_def_lock_);
if (!graph_created_) {
@ -487,6 +564,7 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) {
TF_RETURN_IF_ERROR(CheckNotClosed());
std::vector<string> parts = str_util::Split(handle, ';');
const string& key = parts[0];
// Get the executors for this partial run.
@ -1002,8 +1080,20 @@ Status DirectSession::CreateGraphs(
return s;
}
::tensorflow::Status DirectSession::Reset(
const std::vector<string>& containers) {
device_mgr_->ClearContainers(containers);
return ::tensorflow::Status::OK();
}
::tensorflow::Status DirectSession::Close() {
cancellation_manager_->StartCancel();
{
mutex_lock l(mu_);
if (closed_) return ::tensorflow::Status::OK();
closed_ = true;
}
if (factory_ != nullptr) factory_->Deregister(this);
return ::tensorflow::Status::OK();
}
@ -1051,37 +1141,4 @@ void DirectSession::WaitForNotification(RunState* run_state,
}
}
class DirectSessionFactory : public SessionFactory {
public:
DirectSessionFactory() {}
bool AcceptsOptions(const SessionOptions& options) override {
return options.target.empty();
}
Session* NewSession(const SessionOptions& options) override {
// Must do this before the CPU allocator is created.
if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true);
}
std::vector<Device*> devices;
Status s = DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices);
if (!s.ok()) {
LOG(ERROR) << s;
return nullptr;
}
return new DirectSession(options, new DeviceMgr(devices));
}
};
class DirectSessionRegistrar {
public:
DirectSessionRegistrar() {
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
}
};
static DirectSessionRegistrar registrar;
} // namespace tensorflow

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/framework/cancellation.h"
@ -47,11 +48,18 @@ namespace tensorflow {
class CostModel;
class DebugGateway;
class Device;
class DirectSessionFactory;
class DirectSession : public Session {
public:
typedef std::function<void(Session*)> CloseCallback;
// Takes ownership of 'device_mgr'.
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr);
// 'factory' is used to unregister the DirectSession with 'factory' when its
// closed. This ensures that Reset requests from the 'factory' don't get sent
// to sessions that are already closed.
DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
DirectSessionFactory* factory);
~DirectSession() override;
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
@ -83,6 +91,10 @@ class DirectSession : public Session {
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) override;
// Reset clears 'containers' from the device_mgr of the DirectSession.
// If 'containers' is empty, then Reset clears the default container.
::tensorflow::Status Reset(const std::vector<string>& containers);
::tensorflow::Status Close() override;
void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
@ -198,6 +210,12 @@ class DirectSession : public Session {
// operation_timeout_in_ms is greater than 0.
void WaitForNotification(RunState* run_state, int64 timeout_in_ms);
::tensorflow::Status CheckNotClosed() {
mutex_lock l(mu_);
if (closed_) return errors::Cancelled("Session has been closed.");
return ::tensorflow::Status::OK();
}
const SessionOptions options_;
// Device structures.
@ -232,10 +250,12 @@ class DirectSession : public Session {
// This holds all the tensors that are currently alive in the session.
SessionState session_state_;
DirectSessionFactory* const factory_; // not owned
CancellationManager* cancellation_manager_;
// Saves and restores device placements for stateful nodes.
mutex mu_;
// Map of placed stateful nodes, i.e. nodes for which is_stateful()
// is true, such as "params" and "queue" nodes. Once placed these
// nodes can not be moved to a different device. Maps node names to
@ -251,6 +271,9 @@ class DirectSession : public Session {
// library; it copies and modifies the function library.
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
// true if the Session has been Closed.
bool closed_ GUARDED_BY(mu_);
// For generating unique names.
int64 name_counter_ GUARDED_BY(mu_) = 0;

View File

@ -970,5 +970,129 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) {
}
}
TEST(DirectSessionTest, TestDirectSessionRunClose) {
// Construct a graph with a variable and a single assign.
Graph g(OpRegistry::Global());
Tensor t(DT_FLOAT, TensorShape({}));
t.scalar<float>()() = {1.2};
Node* var_val = test::graph::Constant(&g, t);
Node* var = test::graph::Var(&g, DT_FLOAT, {});
Node* var_assign = test::graph::Assign(&g, var, var_val);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
// Assign a value to the var.
TF_ASSERT_OK(session->Run({} /* inputs */, {},
{var_assign->name()} /* target_nodes */, nullptr));
// Run a read on the variable to ensure that it works.
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run(
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
outputs.clear();
// Close the session.
session->Close();
// Run the read on the variable to get an error.
Status s = session->Run({} /* inputs */, {},
{var_assign->name()} /* target_nodes */, nullptr);
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
}
TEST(DirectSessionTest, TestDirectSessionPRunClose) {
GraphDef def;
Graph g(OpRegistry::Global());
Tensor first_value(DT_FLOAT, TensorShape({}));
first_value.scalar<float>()() = 1.0;
Node* first_const = test::graph::Constant(&g, first_value);
Node* first_identity = test::graph::Identity(&g, first_const);
Tensor second_value(DT_FLOAT, TensorShape({}));
second_value.scalar<float>()() = 2.0;
Node* second_const = test::graph::Constant(&g, second_value);
Node* second_identity = test::graph::Identity(&g, second_const);
Node* third = test::graph::Add(&g, first_identity, second_identity);
Node* third_identity = test::graph::Identity(&g, third);
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
std::vector<Tensor> outputs;
string handle;
Status s = session->PRunSetup(
{first_const->name(), second_const->name()},
{first_identity->name() + ":0", second_identity->name() + ":0",
third_identity->name() + ":0"},
{}, &handle);
TF_ASSERT_OK(s);
Tensor value_11(DT_FLOAT, TensorShape({}));
value_11.scalar<float>()() = 11.0;
Tensor value_22(DT_FLOAT, TensorShape({}));
value_22.scalar<float>()() = 22.0;
// Close the session.
session->Close();
// Feed first_const, fetch first_identity
s = session->PRun(handle, {{first_const->name(), value_11}},
{first_identity->name() + ":0"}, &outputs);
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
}
TEST(DirectSessionTest, TestDirectSessionReset) {
// Construct a graph with a variable and a single assign.
Graph g(OpRegistry::Global());
Tensor t(DT_FLOAT, TensorShape({}));
t.scalar<float>()() = {1.2};
Node* var_val = test::graph::Constant(&g, t);
Node* var = test::graph::Var(&g, DT_FLOAT, {});
Node* var_assign = test::graph::Assign(&g, var, var_val);
GraphDef def;
test::graph::ToGraphDef(&g, &def);
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
// Assign a value to the var.
TF_ASSERT_OK(session->Run({} /* inputs */, {},
{var_assign->name()} /* target_nodes */, nullptr));
// Run a read on the variable to ensure that it works.
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run(
{} /* inputs */, {var->name() + ":0"} /* output_names */, {}, &outputs));
EXPECT_EQ(t.scalar<float>()(), outputs[0].scalar<float>()());
outputs.clear();
// Reset the containers.
Reset(options, {});
// Run the read on the variable to get an error.
// TODO(suharshs): This test only works because we close the Session in Reset.
// If we change the behavior of Reset to not close the Session, this test will
// fail, since the Variable buffer is cached by var.
Status s = session->Run({} /* inputs */, {},
{var_assign->name()} /* target_nodes */, nullptr);
EXPECT_EQ("Cancelled: Session has been closed.", s.ToString());
}
} // namespace
} // namespace tensorflow