[TF:XLA] Remove locking from XlaContext, which is only ever used from a single thread.
Change: 147038627
This commit is contained in:
parent
f439016532
commit
e65fff66d4
tensorflow/compiler/tf2xla
@ -256,24 +256,12 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
std::unique_ptr<Executor> exec(exec_ptr);
|
||||
// At this point ownership of the graph has been transferred to exec.
|
||||
|
||||
auto runner = [](Executor::Args::Closure c) {
|
||||
// TODO(misard) Temporarily just schedule c eagerly while we
|
||||
// decide what to do about the fact that the ComputationBuilder is
|
||||
// thread-compatible, but we don't really want Op writers to have
|
||||
// to remember to acquire a lock around every call to
|
||||
// ComputationBuilder. One possibility is to add the (generally
|
||||
// useful) ability to run a single-threaded Executor based on an
|
||||
// option in LocalExecutorParams. Another is to automagically
|
||||
// acquire a lock around ComputationBuilder calls using some
|
||||
// wrapper or RAII funny business.
|
||||
c();
|
||||
};
|
||||
|
||||
// Run the graph symbolically, turning the graph into an XLA computation.
|
||||
Executor::Args exec_args;
|
||||
exec_args.step_id = step_id;
|
||||
exec_args.step_container = step_container.get();
|
||||
exec_args.runner = runner;
|
||||
// Run all compilation kernels on the main thread.
|
||||
exec_args.runner = [](Executor::Args::Closure c) { c(); };
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
exec->Run(exec_args),
|
||||
"Conversion from TensorFlow graph to XLA computation failed.");
|
||||
|
@ -131,8 +131,6 @@ Status XlaContext::CollectResults(
|
||||
xla::Computation* computation, bool* requires_runtime_context,
|
||||
std::vector<ConstRetVal>* compile_time_constants,
|
||||
int* num_nonconst_outputs) {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
xla::ComputationDataHandle handle;
|
||||
if (retval_.empty() && has_side_effects_) {
|
||||
// Build a empty tuple return value for computations that have side effects
|
||||
@ -200,7 +198,6 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::Client* client,
|
||||
|
||||
const xla::ComputationDataHandle&
|
||||
XlaContext::GetOrCreateRuntimeContextParameter() {
|
||||
mutex_lock lock(mu_);
|
||||
CHECK(allow_cpu_custom_calls_);
|
||||
CHECK(!use_tuple_arg_);
|
||||
if (has_context_parameter_) return context_parameter_;
|
||||
@ -220,7 +217,6 @@ void XlaContext::AddRetval(int retval_index,
|
||||
// Add the return value to the list being built up. The executor
|
||||
// is multi-threaded so this has to happen under the
|
||||
// lock.
|
||||
mutex_lock l(mu_);
|
||||
retval_.emplace_back(retval_index, handle);
|
||||
}
|
||||
|
||||
@ -232,17 +228,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
ConstRetVal value;
|
||||
value.index = retval_index;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value.value));
|
||||
mutex_lock l(mu_);
|
||||
compile_time_constant_.push_back(std::move(value));
|
||||
} else {
|
||||
mutex_lock l(mu_);
|
||||
retval_.emplace_back(retval_index, xla_builder_.ConstantLiteral(literal));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaContext::AddSideEffects() {
|
||||
mutex_lock lock(mu_);
|
||||
has_side_effects_ = true;
|
||||
}
|
||||
|
||||
@ -323,7 +316,6 @@ const xla::Computation* XlaContext::LookupOrCreate(
|
||||
DataType type, ComputationMap* out,
|
||||
const std::function<xla::Computation()>& create) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
const auto& entry = (*out)[type];
|
||||
if (!entry.IsNull()) {
|
||||
return &entry;
|
||||
@ -331,7 +323,6 @@ const xla::Computation* XlaContext::LookupOrCreate(
|
||||
}
|
||||
auto new_entry = create();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Somebody else might have made one concurrently.
|
||||
auto& entry = (*out)[type];
|
||||
if (entry.IsNull()) {
|
||||
|
@ -29,8 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -223,11 +221,9 @@ class XlaContext : public ResourceBase {
|
||||
|
||||
XlaCompiler* const compiler_;
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
// The ComputationBuilder used to construct the subgraph's compiled
|
||||
// representation.
|
||||
xla::ComputationBuilder xla_builder_ GUARDED_BY(mu_);
|
||||
xla::ComputationBuilder xla_builder_;
|
||||
|
||||
// Number of XLA Parameters, not counting the context parameter, if any.
|
||||
int num_parameters_;
|
||||
@ -252,18 +248,17 @@ class XlaContext : public ResourceBase {
|
||||
// for an additional final parameter to the computation, through which will be
|
||||
// passed a XlaLocalRuntimeContext* at runtime. Created on demand by
|
||||
// GetOrCreateRuntimeContextParameter().
|
||||
bool has_context_parameter_ GUARDED_BY(mu_) = false;
|
||||
xla::ComputationDataHandle context_parameter_ GUARDED_BY(mu_);
|
||||
bool has_context_parameter_ = false;
|
||||
xla::ComputationDataHandle context_parameter_;
|
||||
|
||||
// The data-dependent return values of the computation.
|
||||
std::vector<std::pair<int, xla::ComputationDataHandle>> retval_
|
||||
GUARDED_BY(mu_);
|
||||
std::vector<std::pair<int, xla::ComputationDataHandle>> retval_;
|
||||
|
||||
// The non-data-dependent return values of the computation.
|
||||
std::vector<ConstRetVal> compile_time_constant_ GUARDED_BY(mu_);
|
||||
std::vector<ConstRetVal> compile_time_constant_;
|
||||
|
||||
// Does the computation have side effects, i.e., Send() calls?
|
||||
bool has_side_effects_ GUARDED_BY(mu_) = false;
|
||||
bool has_side_effects_ = false;
|
||||
|
||||
// Cache of prebuilt computations indexed by their type.
|
||||
using ComputationMap = std::map<DataType, xla::Computation>;
|
||||
@ -273,16 +268,16 @@ class XlaContext : public ResourceBase {
|
||||
// map. The returned value != nullptr and is owned by the map.
|
||||
const xla::Computation* LookupOrCreate(
|
||||
DataType type, ComputationMap* out,
|
||||
const std::function<xla::Computation()>& create) LOCKS_EXCLUDED(mu_);
|
||||
const std::function<xla::Computation()>& create);
|
||||
|
||||
// Cached computation to compute Max of two elements, specialized by type.
|
||||
ComputationMap max_func_ GUARDED_BY(mu_);
|
||||
ComputationMap max_func_;
|
||||
|
||||
// Cached computation to compute Sum of two elements, specialized by type.
|
||||
ComputationMap add_func_ GUARDED_BY(mu_);
|
||||
ComputationMap add_func_;
|
||||
|
||||
// Cached computation to compute Sigmoid of an element, specialized by type.
|
||||
ComputationMap sigmoid_func_ GUARDED_BY(mu_);
|
||||
ComputationMap sigmoid_func_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaContext);
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user