Does a deep copy of the tensors output from GraphRunner::Run(...)
The current GraphRunner::Run(...) outputs tensors produced from running the Executor on the graph, but these tensors are actually owned by the allocator from the device created for the Run(...), which could be deleted along with the device. The deep copy allows the ownership to be transferred to the global static cpu_allocator(). Before, the allocator was always a global cpu_allocator(), but with a recent change there is an option to tie allocations to a memory limited allocator per-session. Change: 152756520
This commit is contained in:
parent
72c023d396
commit
d34eec7ec3
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
@ -175,8 +176,13 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(output_key, &parsed));
|
||||
bool is_dead;
|
||||
Tensor output_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
rendez->Recv(parsed, Rendezvous::Args(), &(*outputs)[i], &is_dead));
|
||||
rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead));
|
||||
// Does a deep copy so that ownership of the tensor isn't tied to the
|
||||
// allocator of the cpu device we created above. The allocator could be
|
||||
// deleted along with the device.
|
||||
(*outputs)[i] = tensor::DeepCopy(output_tensor);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -53,6 +53,33 @@ TEST(GraphRunnerTest, SingleConst) {
|
||||
ExpectEqual(42.0f, outputs[0].scalar<float>()());
|
||||
}
|
||||
|
||||
// If not using DeepCopy, and the allocator is deleted with the cpu-device,
|
||||
// this test will seg-fault.
|
||||
TEST(GraphRunnerTest, DeepCopy) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto p1 = ops::Placeholder(root.WithOpName("p1"), DT_FLOAT);
|
||||
auto p2 = ops::Placeholder(root.WithOpName("p2"), DT_FLOAT);
|
||||
auto add = ops::Add(root.WithOpName("add"), p1, p2);
|
||||
|
||||
Tensor p1_data(DT_FLOAT, TensorShape({}));
|
||||
Tensor p2_data(DT_FLOAT, TensorShape({}));
|
||||
p1_data.scalar<float>()() = 1.0f;
|
||||
p2_data.scalar<float>()() = 2.0f;
|
||||
std::vector<std::pair<string, Tensor>> inputs = {{"p1:0", p1_data},
|
||||
{"p2:0", p2_data}};
|
||||
|
||||
// Create and destroy the GraphRunner, and ensure that the outputs are
|
||||
// consumable beyond the lifetime of GraphRunner.
|
||||
std::vector<Tensor> outputs;
|
||||
{
|
||||
GraphRunner graph_runner(Env::Default());
|
||||
Status s =
|
||||
graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs);
|
||||
TF_ASSERT_OK(s);
|
||||
}
|
||||
ExpectEqual(3.0f, outputs[0].scalar<float>()());
|
||||
}
|
||||
|
||||
TEST(GraphRunnerTest, MultiFetchConst) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
auto c = ops::Const(root, 42.0f);
|
||||
|
Loading…
x
Reference in New Issue
Block a user