[tf.data] Create core/data directory and add standalone dataset
core/data will be a place to put tf.data c++ code that isn't part of a kernel implementation (that goes in core/kernels/data). For example, the tf.data service implementation will go in core/data. PiperOrigin-RevId: 300594272 Change-Id: I9d693ca6edf6d8683aeedfff2d30554bdb3111b5
This commit is contained in:
parent
79d479b9b2
commit
dbf532e329
35
tensorflow/core/data/BUILD
Normal file
35
tensorflow/core/data/BUILD
Normal file
@ -0,0 +1,35 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "standalone",
|
||||
srcs = ["standalone.cc"],
|
||||
hdrs = ["standalone.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:session_options",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "standalone_test",
|
||||
srcs = ["standalone_test.cc"],
|
||||
deps = [
|
||||
":standalone",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
] + tf_protos_all(),
|
||||
)
|
149
tensorflow/core/data/standalone.cc
Normal file
149
tensorflow/core/data/standalone.cc
Normal file
@ -0,0 +1,149 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_runner.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace standalone {
|
||||
|
||||
Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) {
|
||||
return iterator_->GetNext(ctx_.get(), outputs, end_of_input);
|
||||
}
|
||||
|
||||
Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx)
|
||||
: iterator_(iterator), ctx_(ctx) {}
|
||||
|
||||
Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
|
||||
std::unique_ptr<Dataset>* result) {
|
||||
Graph graph(OpRegistry::Global());
|
||||
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
|
||||
|
||||
// Instantiate enough of the TF runtime to run `graph` on a single CPU device.
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", params.session_options, "/job:localhost/replica:0/task:0"));
|
||||
Device* device = device_mgr->ListDevices()[0];
|
||||
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
|
||||
// the lifetime of `graph`.
|
||||
auto flib_def =
|
||||
absl::make_unique<FunctionLibraryDefinition>(graph.flib_def());
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{},
|
||||
/*thread_pool=*/nullptr, /*parent=*/nullptr,
|
||||
/*custom_kernel_creator=*/nullptr,
|
||||
/*session_metadata=*/nullptr,
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
string fetch_node = "";
|
||||
for (auto node : graph_def.node()) {
|
||||
if (node.op() == "_Retval") {
|
||||
fetch_node = node.input(0);
|
||||
}
|
||||
}
|
||||
if (fetch_node.empty()) {
|
||||
return errors::NotFound("Failed to find a _Retval op in the given dataset");
|
||||
}
|
||||
|
||||
// Run graph up to `output_node` and extract the `DatasetBase` stored in the
|
||||
// DT_VARIANT output tensor.
|
||||
data::DatasetBase* dataset;
|
||||
{
|
||||
std::vector<Tensor> outputs;
|
||||
GraphRunner graph_runner(device);
|
||||
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"),
|
||||
{}, {fetch_node}, &outputs));
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
|
||||
// NOTE(mrry): The dataset is currently owned by `outputs[0]`, so acquire an
|
||||
// additional reference.
|
||||
dataset->Ref();
|
||||
}
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> pool(
|
||||
NewThreadPoolFromSessionOptions(params.session_options));
|
||||
*result =
|
||||
WrapUnique(new Dataset(dataset, device_mgr.release(), pflr.release(),
|
||||
flib_def.release(), pool.release()));
|
||||
return Status::OK();
|
||||
} // static
|
||||
|
||||
Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
|
||||
// Create an `IteratorContext`, which bundles together the necessary runtime
|
||||
// support to create and get elements from an iterator.
|
||||
std::unique_ptr<IteratorContext> ctx;
|
||||
{
|
||||
// NOTE(mrry): In the current API, an `IteratorContext` is always initially
|
||||
// created from an `OpKernelContext*`, so we need to create a fake
|
||||
// `OpKernelContext` with the appropriate subset of parameters.
|
||||
OpKernelContext::Params op_params;
|
||||
op_params.function_library = pflr_->GetFLR("/device:CPU:0");
|
||||
op_params.device = device_mgr_->ListDevices()[0];
|
||||
op_params.runner = &runner_;
|
||||
OpKernelContext op_ctx(&op_params, 0);
|
||||
IteratorContext::Params params(&op_ctx);
|
||||
params.function_handle_cache = function_handle_cache_.get();
|
||||
params.resource_mgr = &resource_mgr_;
|
||||
params.cancellation_manager = &cancellation_manager_;
|
||||
|
||||
ctx = absl::make_unique<IteratorContext>(std::move(params));
|
||||
}
|
||||
|
||||
// Create the iterator from the dataset.
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(ctx.get(), /*parent=*/nullptr,
|
||||
"iterator", &iterator));
|
||||
|
||||
*result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
|
||||
ProcessFunctionLibraryRuntime* pflr,
|
||||
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool)
|
||||
: dataset_(dataset),
|
||||
device_mgr_(device_mgr),
|
||||
flib_def_(flib_def),
|
||||
pflr_(pflr),
|
||||
pool_(pool) {
|
||||
runner_ = [this](std::function<void()> c) { pool_->Schedule(std::move(c)); };
|
||||
function_handle_cache_ =
|
||||
absl::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
|
||||
}
|
||||
|
||||
Dataset::~Dataset() { dataset_->Unref(); }
|
||||
|
||||
} // namespace standalone
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
122
tensorflow/core/data/standalone.h
Normal file
122
tensorflow/core/data/standalone.h
Normal file
@ -0,0 +1,122 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DATA_STANDALONE_H_
|
||||
#define TENSORFLOW_CORE_DATA_STANDALONE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace standalone {
|
||||
|
||||
// The purpose of the API in this file is to facilitate standalone execution of
|
||||
// a tf.data input pipeline graph.
|
||||
//
|
||||
// The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which
|
||||
// encapsulate TensorFlow runtime.
|
||||
//
|
||||
// The `Dataset` abstraction represents an input pipeline as a collection
|
||||
// of data sources and a logical plan of transformations that operate over the
|
||||
// data.
|
||||
//
|
||||
// The `Iterator` abstraction represents an execution of an input pipeline that
|
||||
// can be used to enumerate its elements.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// // Create a `Dataset` by running the `graph_def` graph.
|
||||
// tensorflow::data:standalone::Dataset::Params params;
|
||||
// std::unique_ptr<tensorflow::data::standalone::Dataset> dataset;
|
||||
// Status s = tensorflow::data::standalone::Dataset::FromGraph(
|
||||
// params, graph_def, &dataset);
|
||||
// if (!s.ok()) { /* error handling */ }
|
||||
//
|
||||
// std::unique_ptr<tensorflow::data::standalone::Iterator> iterator;
|
||||
// s = dataset->MakeIterator(&iterator);
|
||||
// if (!s.ok()) { /* error handling */ }
|
||||
//
|
||||
// bool end_of_input = false;
|
||||
// while (!end_of_input) {
|
||||
// std::vector<tensorflow::Tensor> outputs;
|
||||
// s = iterator->GetNext(&outputs, &end_of_input);
|
||||
// if (!s.ok()) { /* error handling */ }
|
||||
// if (!end_of_input) { /* output handling */ }
|
||||
// }
|
||||
|
||||
class Dataset;
|
||||
|
||||
// Represents an execution of an input pipeline that can be used to enumerate
|
||||
// its elements.
|
||||
class Iterator {
|
||||
public:
|
||||
// Returns the next element of the input pipeline (if there is one) and an
|
||||
// indication of whether the end of the input pipeline has been reached.
|
||||
Status GetNext(std::vector<Tensor>* outputs, bool* end_of_input);
|
||||
|
||||
private:
|
||||
friend class Dataset;
|
||||
|
||||
Iterator(IteratorBase* iterator, IteratorContext* ctx);
|
||||
|
||||
std::unique_ptr<IteratorBase> iterator_;
|
||||
std::unique_ptr<IteratorContext> ctx_;
|
||||
};
|
||||
|
||||
// Represents an input pipeline as a collection of data sources and a logical
|
||||
// plan of transformations that operate over the data.
|
||||
class Dataset {
|
||||
public:
|
||||
// Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration).
|
||||
struct Params {
|
||||
SessionOptions session_options;
|
||||
};
|
||||
|
||||
// Creates a new `Dataset` instance by running the given dataset graph.
|
||||
static Status FromGraph(Params params, const GraphDef& graph_def,
|
||||
std::unique_ptr<Dataset>* result);
|
||||
|
||||
~Dataset();
|
||||
|
||||
// Creates an iterator for this dataset.
|
||||
Status MakeIterator(std::unique_ptr<Iterator>* result);
|
||||
|
||||
private:
|
||||
Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
|
||||
ProcessFunctionLibraryRuntime* pflr,
|
||||
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool);
|
||||
|
||||
DatasetBase* dataset_; // owned
|
||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||
std::unique_ptr<thread::ThreadPool> pool_;
|
||||
std::unique_ptr<FunctionHandleCache> function_handle_cache_;
|
||||
std::function<void(std::function<void()>)> runner_;
|
||||
ResourceMgr resource_mgr_;
|
||||
CancellationManager cancellation_manager_;
|
||||
};
|
||||
|
||||
} // namespace standalone
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_STANDALONE_H_
|
307
tensorflow/core/data/standalone_test.cc
Normal file
307
tensorflow/core/data/standalone_test.cc
Normal file
@ -0,0 +1,307 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace standalone {
|
||||
namespace {
|
||||
|
||||
constexpr const char* const kRangeGraphProto = R"proto(
|
||||
node {
|
||||
name: "Const/_0"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT64
|
||||
tensor_shape {}
|
||||
int64_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const/_1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT64
|
||||
tensor_shape {}
|
||||
int64_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const/_2"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT64
|
||||
tensor_shape {}
|
||||
int64_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "RangeDataset/_3"
|
||||
op: "RangeDataset"
|
||||
input: "Const/_0"
|
||||
input: "Const/_1"
|
||||
input: "Const/_2"
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value { list { shape {} } }
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value { list { type: DT_INT64 } }
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "dataset"
|
||||
op: "_Retval"
|
||||
input: "RangeDataset/_3"
|
||||
attr {
|
||||
key: "T"
|
||||
value { type: DT_VARIANT }
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value { i: 0 }
|
||||
}
|
||||
}
|
||||
library {}
|
||||
versions { producer: 96 }
|
||||
)proto";
|
||||
|
||||
// range(10).map(lambda x: x*x)
|
||||
constexpr const char* const kMapGraphProto = R"proto(
|
||||
node {
|
||||
name: "Const/_0"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT64
|
||||
tensor_shape {}
|
||||
int64_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const/_1"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT64
|
||||
tensor_shape {}
|
||||
int64_val: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Const/_2"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT64
|
||||
tensor_shape {}
|
||||
int64_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "RangeDataset/_3"
|
||||
op: "RangeDataset"
|
||||
input: "Const/_0"
|
||||
input: "Const/_1"
|
||||
input: "Const/_2"
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value { list { shape {} } }
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value { list { type: DT_INT64 } }
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "MapDataset/_4"
|
||||
op: "MapDataset"
|
||||
input: "RangeDataset/_3"
|
||||
attr {
|
||||
key: "Targuments"
|
||||
value { list {} }
|
||||
}
|
||||
attr {
|
||||
key: "f"
|
||||
value { func { name: "__inference_Dataset_map_<lambda>_67" } }
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value { list { shape {} } }
|
||||
}
|
||||
attr {
|
||||
key: "output_types"
|
||||
value { list { type: DT_INT64 } }
|
||||
}
|
||||
attr {
|
||||
key: "preserve_cardinality"
|
||||
value { b: false }
|
||||
}
|
||||
attr {
|
||||
key: "use_inter_op_parallelism"
|
||||
value { b: true }
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "dataset"
|
||||
op: "_Retval"
|
||||
input: "MapDataset/_4"
|
||||
attr {
|
||||
key: "T"
|
||||
value { type: DT_VARIANT }
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value { i: 0 }
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "__inference_Dataset_map_<lambda>_67"
|
||||
input_arg { name: "args_0" type: DT_INT64 }
|
||||
output_arg { name: "identity" type: DT_INT64 }
|
||||
}
|
||||
node_def {
|
||||
name: "mul"
|
||||
op: "Mul"
|
||||
input: "args_0"
|
||||
input: "args_0"
|
||||
attr {
|
||||
key: "T"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
input: "mul:z:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value { type: DT_INT64 }
|
||||
}
|
||||
}
|
||||
ret { key: "identity" value: "Identity:output:0" }
|
||||
arg_attr {
|
||||
key: 0
|
||||
value {
|
||||
attr {
|
||||
key: "_user_specified_name"
|
||||
value { s: "args_0" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions { producer: 96 min_consumer: 12 }
|
||||
)proto";
|
||||
|
||||
TEST(Scalar, Standalone) {
|
||||
struct TestCase {
|
||||
string graph_string;
|
||||
std::vector<int64> expected_outputs;
|
||||
};
|
||||
auto test_cases = {
|
||||
TestCase{kRangeGraphProto, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}},
|
||||
TestCase{kMapGraphProto, {0, 1, 4, 9, 16, 25, 36, 49, 64, 81}},
|
||||
};
|
||||
for (auto test_case : test_cases) {
|
||||
GraphDef graph_def;
|
||||
protobuf::TextFormat::ParseFromString(test_case.graph_string, &graph_def);
|
||||
std::unique_ptr<Dataset> dataset;
|
||||
auto s = Dataset::FromGraph({}, graph_def, &dataset);
|
||||
TF_EXPECT_OK(s);
|
||||
std::unique_ptr<Iterator> iterator;
|
||||
s = dataset->MakeIterator(&iterator);
|
||||
TF_EXPECT_OK(s);
|
||||
bool end_of_input = false;
|
||||
for (int num_outputs = 0; !end_of_input; ++num_outputs) {
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
s = iterator->GetNext(&outputs, &end_of_input);
|
||||
TF_EXPECT_OK(s);
|
||||
if (!end_of_input) {
|
||||
EXPECT_EQ(outputs[0].scalar<int64>()(),
|
||||
test_case.expected_outputs[num_outputs]);
|
||||
} else {
|
||||
EXPECT_EQ(test_case.expected_outputs.size(), num_outputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace standalone
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user