Fix race in C API.

RecordMutation could race with ExtendSessionGraphHelper, which would
release the graph lock and only keep the session lock when extending
the session.

Also makes sure thread annotations are on declarations, not definitions
(otherwise they have no effect).

PiperOrigin-RevId: 188747158
This commit is contained in:
Skye Wanderman-Milne 2018-03-12 11:02:29 -07:00 committed by TensorFlower Gardener
parent 89177f289e
commit 1d6a57edc0
3 changed files with 24 additions and 29 deletions

View File

@ -63,6 +63,7 @@ limitations under the License.
// brain namespace because we are defining 'extern "C"' functions. // brain namespace because we are defining 'extern "C"' functions.
using tensorflow::AllocationDescription; using tensorflow::AllocationDescription;
using tensorflow::DataType; using tensorflow::DataType;
using tensorflow::ExtendSessionGraphHelper;
using tensorflow::Graph; using tensorflow::Graph;
using tensorflow::GraphDef; using tensorflow::GraphDef;
using tensorflow::mutex_lock; using tensorflow::mutex_lock;
@ -640,11 +641,11 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
} }
void RecordMutation(TF_Graph* graph, const TF_Operation& op, void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type) const char* mutation_type) {
EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
// If any session has already run this node_id, mark this session as // If any session has already run this node_id, mark this session as
// unrunnable. // unrunnable.
for (auto it : graph->sessions) { for (auto it : graph->sessions) {
mutex_lock session_lock(it.first->mu);
if (it.first->last_num_graph_nodes > op.node.id()) { if (it.first->last_num_graph_nodes > op.node.id()) {
it.second = FailedPrecondition( it.second = FailedPrecondition(
"Operation '", op.node.DebugString(), "' was changed by ", "Operation '", op.node.DebugString(), "' was changed by ",
@ -713,10 +714,12 @@ Status LoadLibrary(const char* library_filename, void** result,
// TODO(josh11b,mrry): Change Session to be able to use a Graph* // TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and // directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend(). // call Session::Extend().
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
EXCLUSIVE_LOCKS_REQUIRED(session->mu) {
if (session->graph != nullptr) { if (session->graph != nullptr) {
// Take the graph lock before the session lock to avoid deadlock. This is
// safe since session->graph does not change.
session->graph->mu.lock(); session->graph->mu.lock();
mutex_lock session_lock(session->mu);
const Graph& graph = session->graph->graph; const Graph& graph = session->graph->graph;
status->status = session->graph->sessions[session]; status->status = session->graph->sessions[session];
@ -2571,13 +2574,10 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
// TODO(josh11b,mrry): Change Session to be able to use a Graph* // TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and // directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend(). // call Session::Extend().
{
mutex_lock l(session->mu);
if (session->extend_before_run && if (session->extend_before_run &&
!tensorflow::ExtendSessionGraphHelper(session, status)) { !ExtendSessionGraphHelper(session, status)) {
return; return;
} }
}
TF_Run_Setup(noutputs, output_values, status); TF_Run_Setup(noutputs, output_values, status);
@ -2612,13 +2612,10 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
const char** handle, TF_Status* status) { const char** handle, TF_Status* status) {
*handle = nullptr; *handle = nullptr;
{
mutex_lock l(session->mu);
if (session->extend_before_run && if (session->extend_before_run &&
!tensorflow::ExtendSessionGraphHelper(session, status)) { !ExtendSessionGraphHelper(session, status)) {
return; return;
} }
}
std::vector<string> input_names(ninputs); std::vector<string> input_names(ninputs);
for (int i = 0; i < ninputs; ++i) { for (int i = 0; i < ninputs; ++i) {
@ -2659,13 +2656,10 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
// TODO(josh11b,mrry): Change Session to be able to use a Graph* // TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and // directly, instead of requiring us to serialize to a GraphDef and
// call Session::Extend(). // call Session::Extend().
{
mutex_lock l(session->mu);
if (session->extend_before_run && if (session->extend_before_run &&
!tensorflow::ExtendSessionGraphHelper(session, status)) { !ExtendSessionGraphHelper(session, status)) {
return; return;
} }
}
TF_Run_Setup(noutputs, output_values, status); TF_Run_Setup(noutputs, output_values, status);

View File

@ -124,16 +124,16 @@ struct TF_Session {
TF_Session(tensorflow::Session* s, TF_Graph* g); TF_Session(tensorflow::Session* s, TF_Graph* g);
tensorflow::Session* session; tensorflow::Session* session;
TF_Graph* graph; TF_Graph* const graph;
tensorflow::mutex mu; tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu);
int last_num_graph_nodes; int last_num_graph_nodes;
// If true, TF_SessionRun and similar methods will call // If true, TF_SessionRun and similar methods will call
// ExtendSessionGraphHelper before running the graph (this is the default // ExtendSessionGraphHelper before running the graph (this is the default
// public behavior). Can be set to false if the caller needs to call // public behavior). Can be set to false if the caller needs to call
// ExtendSessionGraphHelper manually. // ExtendSessionGraphHelper manually.
bool extend_before_run GUARDED_BY(mu); std::atomic<bool> extend_before_run;
}; };
struct TF_ImportGraphDefOptions { struct TF_ImportGraphDefOptions {
@ -211,9 +211,11 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
TF_Status* status); TF_Status* status);
void RecordMutation(TF_Graph* graph, const TF_Operation& op, void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type); const char* mutation_type)
EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status); bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
LOCKS_EXCLUDED(session->graph->mu, session->mu);
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -105,9 +105,8 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
} }
void ExtendSession(TF_Session* session, TF_Status* status) { void ExtendSession(TF_Session* session, TF_Status* status) {
mutex_lock l(session->mu);
session->extend_before_run = false;
ExtendSessionGraphHelper(session, status); ExtendSessionGraphHelper(session, status);
session->extend_before_run = false;
} }
} // namespace tensorflow } // namespace tensorflow