[tf.data] Add dataset splitting mechanism.
This CL introduces the concept of a SplitProvider. A SplitProvider produces a sequence of "split" tensors which are interpreted by source datasets to produce dataset elements. When we initialize an iterator, a SplitProvider can be passed through the IteratorContext to indicate that the iterator should only iterate through the splits provided by the SplitProvider. This CL adds an optional DatasetBase::MakeSplitIterator method which creates a SplitIterator to create splits for the dataset. For non-source datasets, the proper implementation is generally just to call MakeSplitIterator on their input. To support this reasonable default, we add a `DatasetBase::InputDatasets` method, which produces the input datasets for a dataset. If a dataset implements InputDatasets and has a single input dataset, MakeSplitIterator will call delegate to the input by default. This CL only implements splitting for range_dataset_op; other splitting implementation will come in later CLs. This CL also implements a `ShardingSplitProvider`, to better test the range_dataset_op splitting implementation. `ShardingSplitProvider` will be useful in its own right for implementing an alternative to AutoShard which leverages splitting. PiperOrigin-RevId: 332056019 Change-Id: I73b9b03cb91ae689c57a72fa6ba0acd092cf4cbe
This commit is contained in:
parent
221535f56d
commit
5703a4eed0
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/resource.h"
|
||||
@ -427,6 +428,31 @@ Status DatasetBase::MakeIterator(
|
||||
return s;
|
||||
}
|
||||
|
||||
Status DatasetBase::MakeSplitProvider(
|
||||
std::unique_ptr<SplitProvider>* split_provider) const {
|
||||
std::vector<const DatasetBase*> inputs;
|
||||
Status s = InputDatasets(&inputs);
|
||||
if (errors::IsUnimplemented(s)) {
|
||||
return errors::Unimplemented(
|
||||
"Cannot create a split provider for dataset of type ", type_string(),
|
||||
", because the dataset implements neither `InputDatasets` nor "
|
||||
"`MakeSplitProvider`.");
|
||||
}
|
||||
if (inputs.size() != 1) {
|
||||
return errors::Unimplemented(
|
||||
"Cannot create a split provider for dataset of type ", type_string(),
|
||||
", because the dataset is not unary (having arity ", inputs.size(),
|
||||
"), and no custom implementation of `MakeSplitProvider` is defined.");
|
||||
}
|
||||
return inputs[0]->MakeSplitProvider(split_provider);
|
||||
}
|
||||
|
||||
Status DatasetBase::InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const {
|
||||
return errors::Unimplemented("InputDatasets not implemented for ",
|
||||
type_string());
|
||||
}
|
||||
|
||||
Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
|
||||
SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
|
||||
Status status = dataset->AsGraphDefInternal(ctx, this, output);
|
||||
|
@ -295,6 +295,24 @@ class Runner {
|
||||
static Runner* get();
|
||||
};
|
||||
|
||||
// A class which provides a sequence of splits. Iterators created with a split
|
||||
// provider will iterate over only the splits provided by the split provider.
|
||||
class SplitProvider {
|
||||
public:
|
||||
virtual ~SplitProvider() {}
|
||||
// Stores the next split in `*split`, setting `*end_of_splits` to indicate
|
||||
// whether there were any splits left.
|
||||
virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
|
||||
// Resets the split provider to its beginning.
|
||||
virtual Status Reset() = 0;
|
||||
// Saves the state of this split provider.
|
||||
virtual Status Save(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) = 0;
|
||||
// Saves the state of this split provider.
|
||||
virtual Status Restore(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) = 0;
|
||||
};
|
||||
|
||||
// A cut-down version of `OpKernelContext` for running computations in
|
||||
// iterators. Note that we cannot simply use `OpKernelContext` here because we
|
||||
// might run computation in an iterator whose lifetime is not nested within the
|
||||
@ -319,6 +337,7 @@ class IteratorContext {
|
||||
model(ctx->model()),
|
||||
runner(*(ctx->runner())),
|
||||
runner_threadpool_size(ctx->runner_threadpool_size()),
|
||||
split_provider(ctx->split_provider()),
|
||||
stats_aggregator(ctx->stats_aggregator()),
|
||||
thread_factory(ctx->thread_factory()),
|
||||
thread_pool(ctx->thread_pool()) {}
|
||||
@ -386,6 +405,9 @@ class IteratorContext {
|
||||
// Number of threads used for executing user-defined functions.
|
||||
int32 runner_threadpool_size = 0;
|
||||
|
||||
// An optional split provider indicating which splits to process.
|
||||
std::shared_ptr<SplitProvider> split_provider = nullptr;
|
||||
|
||||
// The `StatsAggregator` object to record statistics about the iterator.
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
|
||||
|
||||
@ -432,6 +454,10 @@ class IteratorContext {
|
||||
|
||||
int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
|
||||
|
||||
std::shared_ptr<SplitProvider> split_provider() {
|
||||
return params_.split_provider;
|
||||
}
|
||||
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator() {
|
||||
return params_.stats_aggregator;
|
||||
}
|
||||
@ -802,6 +828,12 @@ class DatasetBase : public core::RefCounted {
|
||||
return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
|
||||
}
|
||||
|
||||
// Returns a split provider which partitions the dataset's data into splits
|
||||
// and provides them in a sequence. The split provider is stored in
|
||||
// `*split_provider`.
|
||||
virtual Status MakeSplitProvider(
|
||||
std::unique_ptr<SplitProvider>* split_provider) const;
|
||||
|
||||
// Returns a vector of DataType values, representing the respective
|
||||
// element types of each tuple component in the outputs of this
|
||||
// dataset.
|
||||
@ -824,6 +856,14 @@ class DatasetBase : public core::RefCounted {
|
||||
// A human-readable debug string for this dataset.
|
||||
virtual string DebugString() const = 0;
|
||||
|
||||
// Stores the dataset's input datasets in `*inputs`. The pointers stored in
|
||||
// `*inputs` are borrowed. The only valid non-ok return status is
|
||||
// UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
|
||||
// subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
|
||||
// default implementation of `MakeSplitProvider` when there is a single input
|
||||
// dataset.
|
||||
virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
|
||||
|
||||
// Indicates whether the dataset depends on any external state which would
|
||||
// prevent it from being serializable. If so, the method returns
|
||||
// `errors::FailedPrecondition` with a message that identifies the external
|
||||
|
@ -26,6 +26,7 @@ cc_library(
|
||||
":map_dataset_op",
|
||||
":name_utils",
|
||||
":range_dataset_op",
|
||||
":split_providers",
|
||||
":take_dataset_op",
|
||||
":tensor_slice_dataset_op",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -737,6 +738,31 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split_providers",
|
||||
srcs = ["split_providers.cc"],
|
||||
hdrs = ["split_providers.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "split_providers_test",
|
||||
size = "small",
|
||||
srcs = ["split_providers_test.cc"],
|
||||
deps = [
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":split_providers",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/framework:tensor_testutil",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "take_dataset_op",
|
||||
srcs = ["take_dataset_op.cc"],
|
||||
@ -813,6 +839,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -62,6 +62,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/data/map_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/split_providers.h"
|
||||
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -606,6 +607,47 @@ Status DatasetOpsTestBase::CheckIteratorGetNext(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
|
||||
const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
|
||||
std::unique_ptr<TestDataset> dataset;
|
||||
TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
|
||||
std::unique_ptr<SplitProvider> split_provider;
|
||||
TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
|
||||
std::unique_ptr<TestIterator> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
MakeIterator(params, *dataset, std::move(split_provider), &iterator));
|
||||
TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
|
||||
/*compare_order=*/true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
|
||||
const DatasetParams& params, int64 num_shards, int64 shard_index,
|
||||
const std::vector<Tensor>& expected_outputs) {
|
||||
std::unique_ptr<TestDataset> dataset;
|
||||
TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
|
||||
std::unique_ptr<SplitProvider> split_provider;
|
||||
TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
|
||||
split_provider = absl::make_unique<ShardingSplitProvider>(
|
||||
num_shards, shard_index, std::move(split_provider));
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
|
||||
IteratorContext::Params iterator_params(iterator_ctx.get());
|
||||
iterator_params.split_provider = std::move(split_provider);
|
||||
iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
|
||||
int mid_breakpoint = expected_outputs.size() / 2;
|
||||
int near_end_breakpoint = expected_outputs.size() - 1;
|
||||
int end_breakpoint = expected_outputs.size();
|
||||
TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
|
||||
dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
|
||||
expected_outputs,
|
||||
/*breakpoints=*/
|
||||
{0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
|
||||
/*compare_order=*/true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CheckDatasetNodeName(
|
||||
const string& expected_dataset_node_name) {
|
||||
EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
|
||||
@ -658,11 +700,13 @@ Status DatasetOpsTestBase::CheckIteratorPrefix(
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
|
||||
const string& iterator_prefix, const std::vector<Tensor>& expected_outputs,
|
||||
DatasetBase* dataset, IteratorContext* iterator_ctx,
|
||||
const std::string& iterator_prefix,
|
||||
const std::vector<Tensor>& expected_outputs,
|
||||
const std::vector<int>& breakpoints, bool compare_order) {
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
|
||||
iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator));
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
|
||||
iterator_prefix, &iterator));
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
|
||||
bool end_of_sequence = false;
|
||||
@ -674,33 +718,31 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
|
||||
std::vector<const VariantTensorData*> data;
|
||||
writer.GetData(&data);
|
||||
VariantTensorDataReader reader(data);
|
||||
TF_EXPECT_OK(RestoreIterator(iterator_ctx_.get(), &reader, iterator_prefix,
|
||||
*dataset_, &iterator));
|
||||
TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
|
||||
*dataset, &iterator));
|
||||
|
||||
while (cur_iteration <= breakpoint) {
|
||||
std::vector<Tensor> next;
|
||||
TF_RETURN_IF_ERROR(
|
||||
iterator->GetNext(iterator_ctx_.get(), &next, &end_of_sequence));
|
||||
iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
|
||||
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
|
||||
cur_iteration++;
|
||||
}
|
||||
|
||||
if (dataset_->Cardinality() == kUnknownCardinality) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (dataset_->Cardinality() == kInfiniteCardinality ||
|
||||
breakpoint < dataset_->Cardinality()) {
|
||||
EXPECT_FALSE(end_of_sequence);
|
||||
} else {
|
||||
EXPECT_TRUE(end_of_sequence);
|
||||
}
|
||||
}
|
||||
TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
|
||||
/*compare_order=*/compare_order));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
|
||||
const std::string& iterator_prefix,
|
||||
const std::vector<Tensor>& expected_outputs,
|
||||
const std::vector<int>& breakpoints, bool compare_order) {
|
||||
return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
|
||||
iterator_prefix, expected_outputs,
|
||||
breakpoints, compare_order);
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
|
||||
if (initialized_) {
|
||||
return errors::Internal(
|
||||
@ -793,10 +835,14 @@ Status DatasetOpsTestBase::MakeDataset(
|
||||
|
||||
Status DatasetOpsTestBase::MakeIterator(
|
||||
const DatasetParams& dataset_params, const TestDataset& dataset,
|
||||
std::unique_ptr<SplitProvider> split_provider,
|
||||
std::unique_ptr<TestIterator>* iterator) {
|
||||
std::unique_ptr<IteratorContext> iterator_ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
|
||||
IteratorContext::Params iterator_params(iterator_ctx.get());
|
||||
iterator_params.split_provider = std::move(split_provider);
|
||||
iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
|
||||
std::unique_ptr<IteratorBase> iterator_base;
|
||||
TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
|
||||
iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
|
||||
@ -806,6 +852,13 @@ Status DatasetOpsTestBase::MakeIterator(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::MakeIterator(
|
||||
const DatasetParams& dataset_params, const TestDataset& dataset,
|
||||
std::unique_ptr<TestIterator>* iterator) {
|
||||
return MakeIterator(dataset_params, dataset, /*split_provider=*/nullptr,
|
||||
iterator);
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
|
||||
std::vector<Tensor>* outputs) {
|
||||
TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, ¶ms_,
|
||||
|
@ -517,6 +517,12 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
Status MakeDataset(const DatasetParams& dataset_params,
|
||||
std::unique_ptr<TestDataset>* dataset);
|
||||
|
||||
// Creates an iterator for the given dataset, using the specified split
|
||||
// provider.
|
||||
Status MakeIterator(const DatasetParams& dataset_params,
|
||||
const TestDataset& dataset,
|
||||
std::unique_ptr<SplitProvider> split_provider,
|
||||
std::unique_ptr<TestIterator>* iterator);
|
||||
// Creates an iterator for the given dataset.
|
||||
Status MakeIterator(const DatasetParams& dataset_params,
|
||||
const TestDataset& dataset,
|
||||
@ -556,6 +562,18 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
const std::vector<Tensor>& expected_outputs,
|
||||
bool compare_order);
|
||||
|
||||
// Checks that iterating through the dataset using a split provider produces
|
||||
// the expected outputs.
|
||||
Status CheckSplitProviderFullIteration(
|
||||
const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
|
||||
|
||||
// Checks that iterating through the dataset using a sharded split provider
|
||||
// with the given `num_shards` and `shard_index` produces the expected
|
||||
// outputs.
|
||||
Status CheckSplitProviderShardedIteration(
|
||||
const DatasetParams& params, int64 num_shards, int64 shard_index,
|
||||
const std::vector<Tensor>& expected_outputs);
|
||||
|
||||
// Checks `DatasetBase::node_name()`.
|
||||
Status CheckDatasetNodeName(const string& expected_dataset_node_name);
|
||||
|
||||
@ -583,9 +601,14 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
// Checks `IteratorBase::prefix()`.
|
||||
Status CheckIteratorPrefix(const string& expected_iterator_prefix);
|
||||
|
||||
// Checks `IteratorBase::GetNext()`.
|
||||
Status CheckIteratorSaveAndRestore(
|
||||
const string& iterator_prefix,
|
||||
DatasetBase* dataset, IteratorContext* iterator_ctx,
|
||||
const std::string& iterator_prefix,
|
||||
const std::vector<Tensor>& expected_outputs,
|
||||
const std::vector<int>& breakpoints, bool compare_order);
|
||||
|
||||
Status CheckIteratorSaveAndRestore(
|
||||
const std::string& iterator_prefix,
|
||||
const std::vector<Tensor>& expected_outputs,
|
||||
const std::vector<int>& breakpoints, bool compare_order);
|
||||
|
||||
@ -660,6 +683,7 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
OpKernelContext* const op_context,
|
||||
std::unique_ptr<IteratorContext>* iterator_context);
|
||||
|
||||
// Creates a new iterator context for iterating the dataset.
|
||||
// Creates a new serialization context for serializing the dataset and
|
||||
// iterator.
|
||||
Status CreateSerializationContext(
|
||||
|
@ -99,6 +99,12 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -32,7 +33,97 @@ namespace data {
|
||||
/* static */ constexpr const char* const RangeDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const RangeDatasetOp::kOutputShapes;
|
||||
|
||||
namespace {
|
||||
constexpr char kNext[] = "next";
|
||||
constexpr char kHasSplitProvider[] = "has_split_provider";
|
||||
constexpr char kSlash[] = "/";
|
||||
constexpr char kSplitProvider[] = "split_provider";
|
||||
|
||||
// Class which produces the elements of `range(start, stop, step)`. Threadsafe.
|
||||
class RangeCounter {
|
||||
public:
|
||||
RangeCounter(int64 start, int64 stop, int64 step)
|
||||
: start_(start), stop_(stop), step_(step), next_(start) {}
|
||||
|
||||
// Returns the next value for the counter. Sets `*end_of_counter` to indicate
|
||||
// whether the end of the counter was reached.
|
||||
int64 GetNext(bool* end_of_counter) {
|
||||
mutex_lock l(mu_);
|
||||
if ((step_ > 0 && next_ >= stop_) || (step_ < 0 && next_ <= stop_)) {
|
||||
*end_of_counter = true;
|
||||
return -1;
|
||||
}
|
||||
*end_of_counter = false;
|
||||
int result = next_;
|
||||
next_ += step_;
|
||||
return result;
|
||||
}
|
||||
|
||||
int64 Peek() const {
|
||||
mutex_lock l(mu_);
|
||||
return next_;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
mutex_lock l(mu_);
|
||||
next_ = start_;
|
||||
}
|
||||
|
||||
void SetNext(int64 value) {
|
||||
mutex_lock l(mu_);
|
||||
next_ = value;
|
||||
}
|
||||
|
||||
private:
|
||||
const int64 start_;
|
||||
const int64 stop_;
|
||||
const int64 step_;
|
||||
mutable mutex mu_;
|
||||
int64 next_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Split provider where splits are individual outputs from RangeDataset.
|
||||
// For example, the "splits" of range(0, 10, 2) will be {0, 2, 4, 6, 8}.
|
||||
// The split tensors are scalars of type DT_INT64.
|
||||
class RangeDatasetOp::RangeSplitProvider : public SplitProvider {
|
||||
public:
|
||||
RangeSplitProvider(int64 start, int64 stop, int64 step)
|
||||
: counter_(start, stop, step) {}
|
||||
|
||||
Status GetNext(Tensor* split, bool* end_of_splits) override {
|
||||
int64 next = counter_.GetNext(end_of_splits);
|
||||
if (*end_of_splits) {
|
||||
return Status::OK();
|
||||
}
|
||||
*split = Tensor(DT_INT64, TensorShape{});
|
||||
split->scalar<int64>()() = next;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Reset() override {
|
||||
counter_.Reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Save(std::function<std::string(std::string)> key_name_fn,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(key_name_fn(kNext), counter_.Peek()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Restore(std::function<std::string(std::string)> key_name_fn,
|
||||
IteratorStateReader* reader) override {
|
||||
int64 next;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(key_name_fn(kNext), &next));
|
||||
counter_.SetNext(next);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
RangeCounter counter_;
|
||||
};
|
||||
|
||||
class RangeDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
@ -74,6 +165,18 @@ class RangeDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status MakeSplitProvider(
|
||||
std::unique_ptr<SplitProvider>* split_provider) const override {
|
||||
*split_provider =
|
||||
absl::make_unique<RangeSplitProvider>(start_, stop_, step_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->clear();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
@ -93,24 +196,40 @@ class RangeDatasetOp::Dataset : public DatasetBase {
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params) {
|
||||
next_ = params.dataset->start_;
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
split_provider_ = ctx->split_provider();
|
||||
if (!split_provider_) {
|
||||
counter_ = absl::make_unique<RangeCounter>(
|
||||
dataset()->start_, dataset()->stop_, dataset()->step_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
|
||||
(dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
int64 value;
|
||||
if (split_provider_ != nullptr) {
|
||||
Tensor split;
|
||||
TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
return Status::OK();
|
||||
}
|
||||
value = split.scalar<int64>()();
|
||||
} else {
|
||||
value = counter_->GetNext(end_of_sequence);
|
||||
if (*end_of_sequence) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
out_tensors->reserve(1);
|
||||
switch (dataset()->output_dtypes()[0]) {
|
||||
#define HANDLE_TYPE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
out_tensors->emplace_back(static_cast<type>(next_)); \
|
||||
out_tensors->emplace_back(static_cast<type>(value)); \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_NUMBER_TYPES(HANDLE_TYPE);
|
||||
@ -120,9 +239,6 @@ class RangeDatasetOp::Dataset : public DatasetBase {
|
||||
"Unsupported data type: ",
|
||||
DataTypeString(dataset()->output_dtypes()[0]));
|
||||
}
|
||||
*end_of_sequence = false;
|
||||
next_ += dataset()->step_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -134,21 +250,44 @@ class RangeDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNext), next_));
|
||||
if (split_provider_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kHasSplitProvider), true));
|
||||
TF_RETURN_IF_ERROR(split_provider_->Save(
|
||||
[this](const std::string& key) {
|
||||
return SplitProviderKeyNameFn(key);
|
||||
},
|
||||
writer));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kNext), counter_->Peek()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next_));
|
||||
if (reader->Contains(full_name(kHasSplitProvider))) {
|
||||
TF_RETURN_IF_ERROR(split_provider_->Restore(
|
||||
[this](const std::string& key) {
|
||||
return SplitProviderKeyNameFn(key);
|
||||
},
|
||||
reader));
|
||||
} else {
|
||||
int64 next;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next));
|
||||
counter_->SetNext(next);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string SplitProviderKeyNameFn(const std::string& key) {
|
||||
return full_name(absl::StrCat(kSplitProvider, kSlash, key));
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int64 next_ TF_GUARDED_BY(mu_);
|
||||
std::unique_ptr<RangeCounter> counter_;
|
||||
std::shared_ptr<SplitProvider> split_provider_;
|
||||
};
|
||||
|
||||
const int64 start_;
|
||||
|
@ -36,6 +36,7 @@ class RangeDatasetOp : public DatasetOpKernel {
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
class RangeSplitProvider;
|
||||
DataTypeVector output_types_;
|
||||
};
|
||||
|
||||
|
@ -141,6 +141,39 @@ TEST_F(RangeDatasetOpTest, ZeroStep) {
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST_F(RangeDatasetOpTest, SplitProviderPositiveStep) {
|
||||
auto params = RangeDatasetParams(/*start=*/0, /*stop=*/10, /*step=*/3,
|
||||
/*output_dtypes=*/{DT_INT64});
|
||||
TF_ASSERT_OK(InitializeRuntime(params));
|
||||
TF_EXPECT_OK(CheckSplitProviderFullIteration(
|
||||
params, CreateTensors<int64>(TensorShape({}), {{0}, {3}, {6}, {9}})));
|
||||
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
|
||||
params, /*num_shards=*/2, /*shard_index=*/1,
|
||||
CreateTensors<int64>(TensorShape({}), {{3}, {9}})));
|
||||
}
|
||||
|
||||
TEST_F(RangeDatasetOpTest, SplitProviderNegativeStep) {
|
||||
auto params = RangeDatasetParams(/*start=*/10, /*stop=*/0, /*step=*/-3,
|
||||
/*output_dtypes=*/{DT_INT64});
|
||||
TF_ASSERT_OK(InitializeRuntime(params));
|
||||
TF_EXPECT_OK(CheckSplitProviderFullIteration(
|
||||
params, CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})));
|
||||
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
|
||||
params, /*num_shards=*/2, /*shard_index=*/0,
|
||||
CreateTensors<int64>(TensorShape({}), {{10}, {4}})));
|
||||
}
|
||||
|
||||
TEST_F(RangeDatasetOpTest, SplitProviderEmpty) {
|
||||
auto params = RangeDatasetParams(/*start=*/0, /*stop=*/0, /*step=*/1,
|
||||
/*output_dtypes=*/{DT_INT64});
|
||||
TF_ASSERT_OK(InitializeRuntime(params));
|
||||
TF_EXPECT_OK(CheckSplitProviderFullIteration(
|
||||
params, CreateTensors<int64>(TensorShape({}), {})));
|
||||
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
|
||||
params, /*num_shards=*/3, /*shard_index=*/2,
|
||||
CreateTensors<int64>(TensorShape({}), {})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
83
tensorflow/core/kernels/data/split_providers.cc
Normal file
83
tensorflow/core/kernels/data/split_providers.cc
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/split_providers.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
constexpr char kNumToSkip[] = "num_to_skip";
|
||||
constexpr char kSplitProvider[] = "split_provider";
|
||||
constexpr char kSlash[] = "/";
|
||||
} // namespace
|
||||
|
||||
ShardingSplitProvider::ShardingSplitProvider(
|
||||
int64 num_shards, int64 shard_index,
|
||||
std::shared_ptr<SplitProvider> split_provider)
|
||||
: num_shards_(num_shards),
|
||||
shard_index_(shard_index),
|
||||
split_provider_(split_provider),
|
||||
num_to_skip_(shard_index_) {}
|
||||
|
||||
Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
|
||||
mutex_lock l(mu_);
|
||||
while (num_to_skip_ > 0) {
|
||||
TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
|
||||
if (*end_of_splits) {
|
||||
return Status::OK();
|
||||
}
|
||||
num_to_skip_--;
|
||||
}
|
||||
num_to_skip_ = num_shards_ - 1;
|
||||
TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardingSplitProvider::Reset() {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(split_provider_->Reset());
|
||||
num_to_skip_ = shard_index_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardingSplitProvider::Save(
|
||||
std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(split_provider_->Save(
|
||||
[&](const std::string& key) {
|
||||
return full_name(absl::StrCat(kSplitProvider, kSlash, key));
|
||||
},
|
||||
writer));
|
||||
return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
|
||||
}
|
||||
|
||||
Status ShardingSplitProvider::Restore(
|
||||
std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(split_provider_->Restore(
|
||||
[&](const std::string& key) {
|
||||
return full_name(absl::StrCat(kSplitProvider, kSlash, key));
|
||||
},
|
||||
reader));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
48
tensorflow/core/kernels/data/split_providers.h
Normal file
48
tensorflow/core/kernels/data/split_providers.h
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// A SplitProvider which wraps another split provider, but drops all splits
|
||||
// where `index != shard_index % num_shards`
|
||||
class ShardingSplitProvider : public SplitProvider {
|
||||
public:
|
||||
ShardingSplitProvider(int64 num_shards, int64 shard_index,
|
||||
std::shared_ptr<SplitProvider> split_provider);
|
||||
|
||||
Status GetNext(Tensor* split, bool* end_of_splits) override;
|
||||
Status Reset() override;
|
||||
Status Save(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) override;
|
||||
Status Restore(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) override;
|
||||
|
||||
private:
|
||||
const int64 num_shards_;
|
||||
const int64 shard_index_;
|
||||
mutex mu_;
|
||||
std::shared_ptr<SplitProvider> split_provider_ TF_GUARDED_BY(mu_);
|
||||
int64 num_to_skip_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
|
167
tensorflow/core/kernels/data/split_providers_test.cc
Normal file
167
tensorflow/core/kernels/data/split_providers_test.cc
Normal file
@ -0,0 +1,167 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/split_providers.h"
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
|
||||
constexpr char kPipe[] = "|";
|
||||
constexpr char kColon[] = ":";
|
||||
constexpr char kSplits[] = "splits";
|
||||
constexpr char kSplitsSize[] = "splits_size";
|
||||
|
||||
std::string full_name(const std::string& name) {
|
||||
return strings::StrCat(kFullNameRandomHex, kPipe, "test", kColon, name);
|
||||
}
|
||||
|
||||
Status SaveAndRestore(SplitProvider* split_provider) {
|
||||
VariantTensorDataWriter writer;
|
||||
TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer));
|
||||
std::vector<const VariantTensorData*> variants;
|
||||
writer.GetData(&variants);
|
||||
VariantTensorDataReader reader(variants);
|
||||
TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A split provider that provides pre-defined splits.
|
||||
class TestSplitProvider : public SplitProvider {
|
||||
public:
|
||||
explicit TestSplitProvider(std::vector<Tensor> splits) : splits_(splits) {}
|
||||
|
||||
Status GetNext(Tensor* split, bool* end_of_splits) override {
|
||||
*end_of_splits = i_ >= splits_.size();
|
||||
if (*end_of_splits) {
|
||||
return Status::OK();
|
||||
}
|
||||
*split = splits_[i_++];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Reset() override {
|
||||
i_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Save(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kSplitsSize), splits_.size()));
|
||||
for (int i = 0; i < splits_.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(absl::StrCat(kSplits, "[", i, "]")), splits_[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i_"), i_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Restore(std::function<std::string(std::string)> full_name,
|
||||
IteratorStateReader* reader) override {
|
||||
int64 splits_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name(kSplitsSize), &splits_size));
|
||||
splits_.clear();
|
||||
for (int i = 0; i < splits_size; ++i) {
|
||||
splits_.emplace_back();
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(absl::StrCat(kSplits, "[", i, "]")), &splits_.back()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i_"), &i_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Tensor> splits_;
|
||||
int64 i_ = 0;
|
||||
};
|
||||
|
||||
Status CheckOutput(ShardingSplitProvider& split_provider,
|
||||
std::vector<Tensor> expected) {
|
||||
int64 next = 0;
|
||||
bool end_of_splits = false;
|
||||
while (!end_of_splits) {
|
||||
Tensor split;
|
||||
TF_RETURN_IF_ERROR(split_provider.GetNext(&split, &end_of_splits));
|
||||
if (!end_of_splits) {
|
||||
test::ExpectEqual(split, expected[next++]);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(next, expected.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, TwoWayShardZero) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
|
||||
ShardingSplitProvider split_provider(2, 0, base);
|
||||
TF_EXPECT_OK(CheckOutput(split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {2}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, TwoWayShardOne) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
|
||||
ShardingSplitProvider split_provider(2, 1, base);
|
||||
TF_EXPECT_OK(CheckOutput(split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {3}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, ThreeWayShardOne) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
|
||||
ShardingSplitProvider split_provider(3, 1, base);
|
||||
TF_EXPECT_OK(CheckOutput(split_provider,
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {4}})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, Empty) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}}));
|
||||
ShardingSplitProvider split_provider(2, 1, base);
|
||||
TF_EXPECT_OK(
|
||||
CheckOutput(split_provider, CreateTensors<int64>(TensorShape({}), {})));
|
||||
}
|
||||
|
||||
TEST(ShardingSplitProvider, SaveAndRestore) {
|
||||
auto base = std::make_shared<TestSplitProvider>(
|
||||
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
|
||||
std::vector<Tensor> expected =
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {4}});
|
||||
ShardingSplitProvider split_provider(3, 1, base);
|
||||
for (int i = 0; i < expected.size(); ++i) {
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = true;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_FALSE(end_of_splits);
|
||||
test::ExpectEqual(split, expected[i]);
|
||||
}
|
||||
TF_ASSERT_OK(SaveAndRestore(&split_provider));
|
||||
Tensor split;
|
||||
bool end_of_splits = false;
|
||||
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
|
||||
EXPECT_TRUE(end_of_splits);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user