[tf.data] Create core/data directory and add standalone dataset

core/data will be a place to put tf.data c++ code that isn't
part of a kernel implementation (that goes in core/kernels/data).
For example, the tf.data service implementation will go in core/data.

PiperOrigin-RevId: 300594272
Change-Id: I9d693ca6edf6d8683aeedfff2d30554bdb3111b5
This commit is contained in:
Andrew Audibert 2020-03-12 11:58:41 -07:00 committed by TensorFlower Gardener
parent 79d479b9b2
commit dbf532e329
4 changed files with 613 additions and 0 deletions

View File

@ -0,0 +1,35 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all")
package(
default_visibility = [
"//tensorflow:internal",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
cc_library(
name = "standalone",
srcs = ["standalone.cc"],
hdrs = ["standalone.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:session_options",
"@com_google_absl//absl/memory",
],
)
tf_cc_test(
name = "standalone_test",
srcs = ["standalone_test.cc"],
deps = [
":standalone",
"//tensorflow/core:all_kernels",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + tf_protos_all(),
)

View File

@ -0,0 +1,149 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/standalone.h"
#include <memory>
#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
namespace standalone {
Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) {
return iterator_->GetNext(ctx_.get(), outputs, end_of_input);
}
Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx)
: iterator_(iterator), ctx_(ctx) {}
Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
std::unique_ptr<Dataset>* result) {
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
// Instantiate enough of the TF runtime to run `graph` on a single CPU device.
auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", params.session_options, "/job:localhost/replica:0/task:0"));
Device* device = device_mgr->ListDevices()[0];
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
// the lifetime of `graph`.
auto flib_def =
absl::make_unique<FunctionLibraryDefinition>(graph.flib_def());
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{},
/*thread_pool=*/nullptr, /*parent=*/nullptr,
/*custom_kernel_creator=*/nullptr,
/*session_metadata=*/nullptr,
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
});
string fetch_node = "";
for (auto node : graph_def.node()) {
if (node.op() == "_Retval") {
fetch_node = node.input(0);
}
}
if (fetch_node.empty()) {
return errors::NotFound("Failed to find a _Retval op in the given dataset");
}
// Run graph up to `output_node` and extract the `DatasetBase` stored in the
// DT_VARIANT output tensor.
data::DatasetBase* dataset;
{
std::vector<Tensor> outputs;
GraphRunner graph_runner(device);
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"),
{}, {fetch_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
// NOTE(mrry): The dataset is currently owned by `outputs[0]`, so acquire an
// additional reference.
dataset->Ref();
}
std::unique_ptr<thread::ThreadPool> pool(
NewThreadPoolFromSessionOptions(params.session_options));
*result =
WrapUnique(new Dataset(dataset, device_mgr.release(), pflr.release(),
flib_def.release(), pool.release()));
return Status::OK();
} // static
Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
// Create an `IteratorContext`, which bundles together the necessary runtime
// support to create and get elements from an iterator.
std::unique_ptr<IteratorContext> ctx;
{
// NOTE(mrry): In the current API, an `IteratorContext` is always initially
// created from an `OpKernelContext*`, so we need to create a fake
// `OpKernelContext` with the appropriate subset of parameters.
OpKernelContext::Params op_params;
op_params.function_library = pflr_->GetFLR("/device:CPU:0");
op_params.device = device_mgr_->ListDevices()[0];
op_params.runner = &runner_;
OpKernelContext op_ctx(&op_params, 0);
IteratorContext::Params params(&op_ctx);
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
params.cancellation_manager = &cancellation_manager_;
ctx = absl::make_unique<IteratorContext>(std::move(params));
}
// Create the iterator from the dataset.
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), /*parent=*/nullptr,
"iterator", &iterator));
*result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
return Status::OK();
}
Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
ProcessFunctionLibraryRuntime* pflr,
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool)
: dataset_(dataset),
device_mgr_(device_mgr),
flib_def_(flib_def),
pflr_(pflr),
pool_(pool) {
runner_ = [this](std::function<void()> c) { pool_->Schedule(std::move(c)); };
function_handle_cache_ =
absl::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
}
Dataset::~Dataset() { dataset_->Unref(); }
} // namespace standalone
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,122 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_STANDALONE_H_
#define TENSORFLOW_CORE_DATA_STANDALONE_H_
#include <memory>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace data {
namespace standalone {
// The purpose of the API in this file is to facilitate standalone execution of
// a tf.data input pipeline graph.
//
// The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which
// encapsulate TensorFlow runtime.
//
// The `Dataset` abstraction represents an input pipeline as a collection
// of data sources and a logical plan of transformations that operate over the
// data.
//
// The `Iterator` abstraction represents an execution of an input pipeline that
// can be used to enumerate its elements.
//
// Example usage:
//
// // Create a `Dataset` by running the `graph_def` graph.
// tensorflow::data:standalone::Dataset::Params params;
// std::unique_ptr<tensorflow::data::standalone::Dataset> dataset;
// Status s = tensorflow::data::standalone::Dataset::FromGraph(
// params, graph_def, &dataset);
// if (!s.ok()) { /* error handling */ }
//
// std::unique_ptr<tensorflow::data::standalone::Iterator> iterator;
// s = dataset->MakeIterator(&iterator);
// if (!s.ok()) { /* error handling */ }
//
// bool end_of_input = false;
// while (!end_of_input) {
// std::vector<tensorflow::Tensor> outputs;
// s = iterator->GetNext(&outputs, &end_of_input);
// if (!s.ok()) { /* error handling */ }
// if (!end_of_input) { /* output handling */ }
// }
class Dataset;
// Represents an execution of an input pipeline that can be used to enumerate
// its elements.
class Iterator {
public:
// Returns the next element of the input pipeline (if there is one) and an
// indication of whether the end of the input pipeline has been reached.
Status GetNext(std::vector<Tensor>* outputs, bool* end_of_input);
private:
friend class Dataset;
Iterator(IteratorBase* iterator, IteratorContext* ctx);
std::unique_ptr<IteratorBase> iterator_;
std::unique_ptr<IteratorContext> ctx_;
};
// Represents an input pipeline as a collection of data sources and a logical
// plan of transformations that operate over the data.
class Dataset {
public:
// Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration).
struct Params {
SessionOptions session_options;
};
// Creates a new `Dataset` instance by running the given dataset graph.
static Status FromGraph(Params params, const GraphDef& graph_def,
std::unique_ptr<Dataset>* result);
~Dataset();
// Creates an iterator for this dataset.
Status MakeIterator(std::unique_ptr<Iterator>* result);
private:
Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
ProcessFunctionLibraryRuntime* pflr,
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool);
DatasetBase* dataset_; // owned
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
std::unique_ptr<thread::ThreadPool> pool_;
std::unique_ptr<FunctionHandleCache> function_handle_cache_;
std::function<void(std::function<void()>)> runner_;
ResourceMgr resource_mgr_;
CancellationManager cancellation_manager_;
};
} // namespace standalone
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_STANDALONE_H_

View File

@ -0,0 +1,307 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/standalone.h"
#include <memory>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace standalone {
namespace {
constexpr const char* const kRangeGraphProto = R"proto(
node {
name: "Const/_0"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 0
}
}
}
}
node {
name: "Const/_1"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 10
}
}
}
}
node {
name: "Const/_2"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 1
}
}
}
}
node {
name: "RangeDataset/_3"
op: "RangeDataset"
input: "Const/_0"
input: "Const/_1"
input: "Const/_2"
attr {
key: "output_shapes"
value { list { shape {} } }
}
attr {
key: "output_types"
value { list { type: DT_INT64 } }
}
}
node {
name: "dataset"
op: "_Retval"
input: "RangeDataset/_3"
attr {
key: "T"
value { type: DT_VARIANT }
}
attr {
key: "index"
value { i: 0 }
}
}
library {}
versions { producer: 96 }
)proto";
// range(10).map(lambda x: x*x)
constexpr const char* const kMapGraphProto = R"proto(
node {
name: "Const/_0"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 0
}
}
}
}
node {
name: "Const/_1"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 10
}
}
}
}
node {
name: "Const/_2"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 1
}
}
}
}
node {
name: "RangeDataset/_3"
op: "RangeDataset"
input: "Const/_0"
input: "Const/_1"
input: "Const/_2"
attr {
key: "output_shapes"
value { list { shape {} } }
}
attr {
key: "output_types"
value { list { type: DT_INT64 } }
}
}
node {
name: "MapDataset/_4"
op: "MapDataset"
input: "RangeDataset/_3"
attr {
key: "Targuments"
value { list {} }
}
attr {
key: "f"
value { func { name: "__inference_Dataset_map_<lambda>_67" } }
}
attr {
key: "output_shapes"
value { list { shape {} } }
}
attr {
key: "output_types"
value { list { type: DT_INT64 } }
}
attr {
key: "preserve_cardinality"
value { b: false }
}
attr {
key: "use_inter_op_parallelism"
value { b: true }
}
}
node {
name: "dataset"
op: "_Retval"
input: "MapDataset/_4"
attr {
key: "T"
value { type: DT_VARIANT }
}
attr {
key: "index"
value { i: 0 }
}
}
library {
function {
signature {
name: "__inference_Dataset_map_<lambda>_67"
input_arg { name: "args_0" type: DT_INT64 }
output_arg { name: "identity" type: DT_INT64 }
}
node_def {
name: "mul"
op: "Mul"
input: "args_0"
input: "args_0"
attr {
key: "T"
value { type: DT_INT64 }
}
}
node_def {
name: "Identity"
op: "Identity"
input: "mul:z:0"
attr {
key: "T"
value { type: DT_INT64 }
}
}
ret { key: "identity" value: "Identity:output:0" }
arg_attr {
key: 0
value {
attr {
key: "_user_specified_name"
value { s: "args_0" }
}
}
}
}
}
versions { producer: 96 min_consumer: 12 }
)proto";
TEST(Scalar, Standalone) {
struct TestCase {
string graph_string;
std::vector<int64> expected_outputs;
};
auto test_cases = {
TestCase{kRangeGraphProto, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}},
TestCase{kMapGraphProto, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81}},
};
for (auto test_case : test_cases) {
GraphDef graph_def;
protobuf::TextFormat::ParseFromString(test_case.graph_string, &graph_def);
std::unique_ptr<Dataset> dataset;
auto s = Dataset::FromGraph({}, graph_def, &dataset);
TF_EXPECT_OK(s);
std::unique_ptr<Iterator> iterator;
s = dataset->MakeIterator(&iterator);
TF_EXPECT_OK(s);
bool end_of_input = false;
for (int num_outputs = 0; !end_of_input; ++num_outputs) {
std::vector<tensorflow::Tensor> outputs;
s = iterator->GetNext(&outputs, &end_of_input);
TF_EXPECT_OK(s);
if (!end_of_input) {
EXPECT_EQ(outputs[0].scalar<int64>()(),
test_case.expected_outputs[num_outputs]);
} else {
EXPECT_EQ(test_case.expected_outputs.size(), num_outputs);
}
}
}
}
} // namespace
} // namespace standalone
} // namespace data
} // namespace tensorflow