Merge pull request #2695 from lissyx/tf_unique_ptr

Use std::unique_ptr<> for TensorFlow session
This commit is contained in:
lissyx 2020-01-29 10:54:37 +01:00 committed by GitHub
commit d74ab7dc1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -20,7 +20,6 @@ TFModelState::~TFModelState()
std::cerr << "Error closing TensorFlow session: " << status << std::endl; std::cerr << "Error closing TensorFlow session: " << status << std::endl;
} }
} }
delete mmap_env_;
} }
int int
@ -35,7 +34,7 @@ TFModelState::init(const char* model_path,
Status status; Status status;
SessionOptions options; SessionOptions options;
mmap_env_ = new MemmappedEnv(Env::Default()); mmap_env_.reset(new MemmappedEnv(Env::Default()));
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos; bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
if (!is_mmap) { if (!is_mmap) {
@ -50,17 +49,19 @@ TFModelState::init(const char* model_path,
options.config.mutable_graph_options() options.config.mutable_graph_options()
->mutable_optimizer_options() ->mutable_optimizer_options()
->set_opt_level(::OptimizerOptions::L0); ->set_opt_level(::OptimizerOptions::L0);
options.env = mmap_env_; options.env = mmap_env_.get();
} }
status = NewSession(options, &session_); Session* session;
status = NewSession(options, &session);
if (!status.ok()) { if (!status.ok()) {
std::cerr << status << std::endl; std::cerr << status << std::endl;
return DS_ERR_FAIL_INIT_SESS; return DS_ERR_FAIL_INIT_SESS;
} }
session_.reset(session);
if (is_mmap) { if (is_mmap) {
status = ReadBinaryProto(mmap_env_, status = ReadBinaryProto(mmap_env_.get(),
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&graph_def_); &graph_def_);
} else { } else {

View File

@ -11,8 +11,8 @@
struct TFModelState : public ModelState struct TFModelState : public ModelState
{ {
tensorflow::MemmappedEnv* mmap_env_; std::unique_ptr<tensorflow::MemmappedEnv> mmap_env_;
tensorflow::Session* session_; std::unique_ptr<tensorflow::Session> session_;
tensorflow::GraphDef graph_def_; tensorflow::GraphDef graph_def_;
TFModelState(); TFModelState();