Merge pull request #2695 from lissyx/tf_unique_ptr

Use std::unique_ptr<> for TensorFlow session
This commit is contained in:
lissyx 2020-01-29 10:54:37 +01:00 committed by GitHub
commit d74ab7dc1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -20,7 +20,6 @@ TFModelState::~TFModelState()
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
}
}
delete mmap_env_;
}
int
@ -35,7 +34,7 @@ TFModelState::init(const char* model_path,
Status status;
SessionOptions options;
mmap_env_ = new MemmappedEnv(Env::Default());
mmap_env_.reset(new MemmappedEnv(Env::Default()));
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
if (!is_mmap) {
@ -50,17 +49,19 @@ TFModelState::init(const char* model_path,
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(::OptimizerOptions::L0);
options.env = mmap_env_;
options.env = mmap_env_.get();
}
status = NewSession(options, &session_);
Session* session;
status = NewSession(options, &session);
if (!status.ok()) {
std::cerr << status << std::endl;
return DS_ERR_FAIL_INIT_SESS;
}
session_.reset(session);
if (is_mmap) {
status = ReadBinaryProto(mmap_env_,
status = ReadBinaryProto(mmap_env_.get(),
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&graph_def_);
} else {

View File

@ -11,8 +11,8 @@
struct TFModelState : public ModelState
{
tensorflow::MemmappedEnv* mmap_env_;
tensorflow::Session* session_;
std::unique_ptr<tensorflow::MemmappedEnv> mmap_env_;
std::unique_ptr<tensorflow::Session> session_;
tensorflow::GraphDef graph_def_;
TFModelState();