158 lines
6.0 KiB
C++
158 lines
6.0 KiB
C++
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
|
|
#define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include "tensorflow/cc/framework/ops.h"
|
|
#include "tensorflow/cc/framework/scope.h"
|
|
#include "tensorflow/core/public/session_options.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
namespace thread {
|
|
|
|
struct ThreadPoolOptions;
|
|
|
|
}
|
|
|
|
/// @addtogroup core
|
|
/// @{
|
|
|
|
/// A `ClientSession` object lets the caller drive the evaluation of the
|
|
/// TensorFlow graph constructed with the C++ API.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// Scope root = Scope::NewRootScope();
|
|
/// auto a = Placeholder(root, DT_INT32);
|
|
/// auto c = Add(root, a, {41});
|
|
///
|
|
/// ClientSession session(root);
|
|
/// std::vector<Tensor> outputs;
|
|
///
|
|
/// Status s = session.Run({ {a, {1}} }, {c}, &outputs);
|
|
/// if (!s.ok()) { ... }
|
|
class ClientSession {
|
|
public:
|
|
/// A data type to represent feeds to a Run call.
|
|
///
|
|
/// This is a map of `Output` objects returned by op-constructors to the value
|
|
/// to feed them with. See `Input::Initializer` for details on what can be
|
|
/// used as feed values.
|
|
typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType;
|
|
|
|
/// Create a new session to evaluate the graph contained in `scope` by
|
|
/// connecting to the TensorFlow runtime specified by `target`.
|
|
ClientSession(const Scope& scope, const string& target);
|
|
|
|
/// Same as above, but use the empty string ("") as the target specification.
|
|
explicit ClientSession(const Scope& scope);
|
|
|
|
/// Create a new session, configuring it with `session_options`.
|
|
ClientSession(const Scope& scope, const SessionOptions& session_options);
|
|
|
|
~ClientSession();
|
|
|
|
/// Evaluate the tensors in `fetch_outputs`. The values are returned as
|
|
/// `Tensor` objects in `outputs`. The number and order of `outputs` will
|
|
/// match `fetch_outputs`.
|
|
Status Run(const std::vector<Output>& fetch_outputs,
|
|
std::vector<Tensor>* outputs) const;
|
|
|
|
/// Same as above, but use the mapping in `inputs` as feeds.
|
|
Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
|
|
std::vector<Tensor>* outputs) const;
|
|
|
|
/// Same as above. Additionally runs the operations ins `run_outputs`.
|
|
Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
|
|
const std::vector<Operation>& run_outputs,
|
|
std::vector<Tensor>* outputs) const;
|
|
|
|
/// Use `run_options` to turn on performance profiling. `run_metadata`, if not
|
|
/// null, is filled in with the profiling results.
|
|
Status Run(const RunOptions& run_options, const FeedType& inputs,
|
|
const std::vector<Output>& fetch_outputs,
|
|
const std::vector<Operation>& run_outputs,
|
|
std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
|
|
|
|
/// Same as above. Additionally allows user to provide custom threadpool
|
|
/// implementation via ThreadPoolOptions.
|
|
Status Run(const RunOptions& run_options, const FeedType& inputs,
|
|
const std::vector<Output>& fetch_outputs,
|
|
const std::vector<Operation>& run_outputs,
|
|
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
|
|
const thread::ThreadPoolOptions& threadpool_options) const;
|
|
|
|
/// \brief A handle to a subgraph, created with
|
|
/// `ClientSession::MakeCallable()`.
|
|
typedef int64 CallableHandle;
|
|
|
|
/// \brief Creates a `handle` for invoking the subgraph defined by
|
|
/// `callable_options`.
|
|
/// NOTE: This API is still experimental and may change.
|
|
Status MakeCallable(const CallableOptions& callable_options,
|
|
CallableHandle* out_handle);
|
|
|
|
/// \brief Invokes the subgraph named by `handle` with the given options and
|
|
/// input tensors.
|
|
///
|
|
/// The order of tensors in `feed_tensors` must match the order of names in
|
|
/// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
|
|
/// match the order of names in `CallableOptions::fetch()` when this subgraph
|
|
/// was created.
|
|
/// NOTE: This API is still experimental and may change.
|
|
Status RunCallable(CallableHandle handle,
|
|
const std::vector<Tensor>& feed_tensors,
|
|
std::vector<Tensor>* fetch_tensors,
|
|
RunMetadata* run_metadata);
|
|
|
|
/// \brief Invokes the subgraph named by `handle` with the given options and
|
|
/// input tensors.
|
|
///
|
|
/// The order of tensors in `feed_tensors` must match the order of names in
|
|
/// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
|
|
/// match the order of names in `CallableOptions::fetch()` when this subgraph
|
|
/// was created.
|
|
/// NOTE: This API is still experimental and may change.
|
|
Status RunCallable(CallableHandle handle,
|
|
const std::vector<Tensor>& feed_tensors,
|
|
std::vector<Tensor>* fetch_tensors,
|
|
RunMetadata* run_metadata,
|
|
const thread::ThreadPoolOptions& options);
|
|
|
|
/// \brief Releases resources associated with the given `handle` in this
|
|
/// session.
|
|
/// NOTE: This API is still experimental and may change.
|
|
Status ReleaseCallable(CallableHandle handle);
|
|
|
|
private:
|
|
class Impl;
|
|
std::unique_ptr<Impl> impl_;
|
|
Impl* impl() { return impl_.get(); }
|
|
const Impl* impl() const { return impl_.get(); }
|
|
};
|
|
|
|
/// @}
|
|
|
|
} // end namespace tensorflow
|
|
|
|
#endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
|