Update API to traffic in unique_ptrs rather than owning raw pointers
PiperOrigin-RevId: 163414320
This commit is contained in:
parent
31a77bc775
commit
569a00e681
@ -36,7 +36,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Checks that arguments `args` match types `types`.
|
||||
@ -90,14 +89,14 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
|
||||
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
|
||||
FunctionDefLibrary{}));
|
||||
local_flib_runtime_.reset(NewFunctionLibraryRuntime(
|
||||
local_flib_runtime_ = NewFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
||||
local_flib_def_.get(), OptimizerOptions(),
|
||||
nullptr /* custom_kernel_creator */));
|
||||
flib_runtime_.reset(NewFunctionLibraryRuntime(
|
||||
nullptr /* custom_kernel_creator */);
|
||||
flib_runtime_ = NewFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
||||
options.flib_def, OptimizerOptions(),
|
||||
nullptr /* custom_kernel_creator */));
|
||||
nullptr /* custom_kernel_creator */);
|
||||
}
|
||||
|
||||
XlaCompiler::~XlaCompiler() = default;
|
||||
|
@ -1157,9 +1157,9 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
|
||||
ek->items.resize(ek->items.size() + 1);
|
||||
auto* item = &(ek->items.back());
|
||||
item->flib.reset(NewFunctionLibraryRuntime(
|
||||
device_mgr_.get(), options_.env, device, graph_def_version,
|
||||
ek->flib_def.get(), optimizer_opts));
|
||||
item->flib = NewFunctionLibraryRuntime(device_mgr_.get(), options_.env,
|
||||
device, graph_def_version,
|
||||
ek->flib_def.get(), optimizer_opts);
|
||||
|
||||
LocalExecutorParams params;
|
||||
params.device = device;
|
||||
|
@ -606,27 +606,27 @@ CustomCreatorSingleton* GetCustomCreatorSingleton() {
|
||||
return ccs;
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
} // namespace
|
||||
|
||||
void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
|
||||
GetCustomCreatorSingleton()->Set(std::move(cb));
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
CustomKernelCreator custom_kernel_creator) {
|
||||
return new FunctionLibraryRuntimeImpl(dmgr, env, device, graph_def_version,
|
||||
lib_def, optimizer_options,
|
||||
std::move(custom_kernel_creator));
|
||||
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
|
||||
device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
|
||||
std::move(custom_kernel_creator)));
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options) {
|
||||
return NewFunctionLibraryRuntime(dmgr, env, device, graph_def_version,
|
||||
return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
|
||||
lib_def, optimizer_options,
|
||||
GetCustomCreatorSingleton()->Get());
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
|
||||
// The returned object does not take ownerships of "device" or
|
||||
// "lib_def". The caller must ensure "device" and "lib_def" outlives
|
||||
// the returned object.
|
||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
@ -59,7 +59,7 @@ FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
// Same as above except that the returned runtime consults with the
|
||||
// global default custom kernel creator registered by
|
||||
// RegisterDefaultCustomKernelCreator.
|
||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options);
|
||||
|
@ -43,7 +43,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
typedef FunctionDefHelper FDH;
|
||||
using FDH = FunctionDefHelper;
|
||||
|
||||
Status GetOpSig(const string& op, const OpDef** sig) {
|
||||
return OpRegistry::Global()->LookUpOpDef(op, sig);
|
||||
@ -163,9 +163,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||
OptimizerOptions opts;
|
||||
lib_.reset(NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(),
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts));
|
||||
lib_ =
|
||||
NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(),
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts);
|
||||
fdef_lib_ = lib_def_->ToProto();
|
||||
}
|
||||
|
||||
@ -1264,5 +1264,5 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
||||
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
} // end namespace tensorflow
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -214,9 +214,10 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
||||
|
||||
// Function library runtime.
|
||||
unit->lib = NewFunctionLibraryRuntime(
|
||||
device_mgr_, worker_env_->env, unit->device,
|
||||
subgraph->versions().producer(), item->lib_def,
|
||||
graph_options.optimizer_options());
|
||||
device_mgr_, worker_env_->env, unit->device,
|
||||
subgraph->versions().producer(), item->lib_def,
|
||||
graph_options.optimizer_options())
|
||||
.release();
|
||||
|
||||
// Construct the root executor for the subgraph.
|
||||
params.device = unit->device;
|
||||
|
Loading…
x
Reference in New Issue
Block a user