Update API to traffic in unique_ptrs rather than owning raw pointers
PiperOrigin-RevId: 163414320
This commit is contained in:
parent
31a77bc775
commit
569a00e681
@ -36,7 +36,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Checks that arguments `args` match types `types`.
|
// Checks that arguments `args` match types `types`.
|
||||||
@ -90,14 +89,14 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
|||||||
|
|
||||||
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
|
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
|
||||||
FunctionDefLibrary{}));
|
FunctionDefLibrary{}));
|
||||||
local_flib_runtime_.reset(NewFunctionLibraryRuntime(
|
local_flib_runtime_ = NewFunctionLibraryRuntime(
|
||||||
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
||||||
local_flib_def_.get(), OptimizerOptions(),
|
local_flib_def_.get(), OptimizerOptions(),
|
||||||
nullptr /* custom_kernel_creator */));
|
nullptr /* custom_kernel_creator */);
|
||||||
flib_runtime_.reset(NewFunctionLibraryRuntime(
|
flib_runtime_ = NewFunctionLibraryRuntime(
|
||||||
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
&device_mgr_, Env::Default(), device_, options.graph_def_version,
|
||||||
options.flib_def, OptimizerOptions(),
|
options.flib_def, OptimizerOptions(),
|
||||||
nullptr /* custom_kernel_creator */));
|
nullptr /* custom_kernel_creator */);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaCompiler::~XlaCompiler() = default;
|
XlaCompiler::~XlaCompiler() = default;
|
||||||
|
@ -1157,9 +1157,9 @@ Status DirectSession::GetOrCreateExecutors(
|
|||||||
|
|
||||||
ek->items.resize(ek->items.size() + 1);
|
ek->items.resize(ek->items.size() + 1);
|
||||||
auto* item = &(ek->items.back());
|
auto* item = &(ek->items.back());
|
||||||
item->flib.reset(NewFunctionLibraryRuntime(
|
item->flib = NewFunctionLibraryRuntime(device_mgr_.get(), options_.env,
|
||||||
device_mgr_.get(), options_.env, device, graph_def_version,
|
device, graph_def_version,
|
||||||
ek->flib_def.get(), optimizer_opts));
|
ek->flib_def.get(), optimizer_opts);
|
||||||
|
|
||||||
LocalExecutorParams params;
|
LocalExecutorParams params;
|
||||||
params.device = device;
|
params.device = device;
|
||||||
|
@ -606,27 +606,27 @@ CustomCreatorSingleton* GetCustomCreatorSingleton() {
|
|||||||
return ccs;
|
return ccs;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace
|
} // namespace
|
||||||
|
|
||||||
void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
|
void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
|
||||||
GetCustomCreatorSingleton()->Set(std::move(cb));
|
GetCustomCreatorSingleton()->Set(std::move(cb));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||||
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
|
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||||
const FunctionLibraryDefinition* lib_def,
|
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||||
const OptimizerOptions& optimizer_options,
|
const OptimizerOptions& optimizer_options,
|
||||||
CustomKernelCreator custom_kernel_creator) {
|
CustomKernelCreator custom_kernel_creator) {
|
||||||
return new FunctionLibraryRuntimeImpl(dmgr, env, device, graph_def_version,
|
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
|
||||||
lib_def, optimizer_options,
|
device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
|
||||||
std::move(custom_kernel_creator));
|
std::move(custom_kernel_creator)));
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||||
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
|
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||||
const FunctionLibraryDefinition* lib_def,
|
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||||
const OptimizerOptions& optimizer_options) {
|
const OptimizerOptions& optimizer_options) {
|
||||||
return NewFunctionLibraryRuntime(dmgr, env, device, graph_def_version,
|
return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
|
||||||
lib_def, optimizer_options,
|
lib_def, optimizer_options,
|
||||||
GetCustomCreatorSingleton()->Get());
|
GetCustomCreatorSingleton()->Get());
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
|
|||||||
// The returned object does not take ownerships of "device" or
|
// The returned object does not take ownerships of "device" or
|
||||||
// "lib_def". The caller must ensure "device" and "lib_def" outlives
|
// "lib_def". The caller must ensure "device" and "lib_def" outlives
|
||||||
// the returned object.
|
// the returned object.
|
||||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||||
const OptimizerOptions& optimizer_options,
|
const OptimizerOptions& optimizer_options,
|
||||||
@ -59,7 +59,7 @@ FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
|||||||
// Same as above except that the returned runtime consults with the
|
// Same as above except that the returned runtime consults with the
|
||||||
// global default custom kernel creator registered by
|
// global default custom kernel creator registered by
|
||||||
// RegisterDefaultCustomKernelCreator.
|
// RegisterDefaultCustomKernelCreator.
|
||||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||||
const OptimizerOptions& optimizer_options);
|
const OptimizerOptions& optimizer_options);
|
||||||
|
@ -43,7 +43,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
typedef FunctionDefHelper FDH;
|
using FDH = FunctionDefHelper;
|
||||||
|
|
||||||
Status GetOpSig(const string& op, const OpDef** sig) {
|
Status GetOpSig(const string& op, const OpDef** sig) {
|
||||||
return OpRegistry::Global()->LookUpOpDef(op, sig);
|
return OpRegistry::Global()->LookUpOpDef(op, sig);
|
||||||
@ -163,9 +163,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
|
||||||
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
lib_.reset(NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(),
|
lib_ =
|
||||||
TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
NewFunctionLibraryRuntime(nullptr, Env::Default(), device_.get(),
|
||||||
opts));
|
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts);
|
||||||
fdef_lib_ = lib_def_->ToProto();
|
fdef_lib_ = lib_def_->ToProto();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1264,5 +1264,5 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
|||||||
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
|
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace
|
} // namespace
|
||||||
} // end namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -214,9 +214,10 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
|
|||||||
|
|
||||||
// Function library runtime.
|
// Function library runtime.
|
||||||
unit->lib = NewFunctionLibraryRuntime(
|
unit->lib = NewFunctionLibraryRuntime(
|
||||||
device_mgr_, worker_env_->env, unit->device,
|
device_mgr_, worker_env_->env, unit->device,
|
||||||
subgraph->versions().producer(), item->lib_def,
|
subgraph->versions().producer(), item->lib_def,
|
||||||
graph_options.optimizer_options());
|
graph_options.optimizer_options())
|
||||||
|
.release();
|
||||||
|
|
||||||
// Construct the root executor for the subgraph.
|
// Construct the root executor for the subgraph.
|
||||||
params.device = unit->device;
|
params.device = unit->device;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user