[tf.data] Add dataset splitting mechanism.

This CL introduces the concept of a SplitProvider. A SplitProvider produces a sequence of "split" tensors which are interpreted by source datasets to produce dataset elements.

When we initialize an iterator, a SplitProvider can be passed through the IteratorContext to indicate that the iterator should only iterate through the splits provided by the SplitProvider.

This CL adds an optional DatasetBase::MakeSplitIterator method which creates a SplitIterator to create splits for the dataset. For non-source datasets, the proper implementation is generally just to call MakeSplitIterator on their input. To support this reasonable default, we add a `DatasetBase::InputDatasets` method, which produces the input datasets for a dataset. If a dataset implements InputDatasets and has a single input dataset, MakeSplitIterator will call delegate to the input by default.

This CL only implements splitting for range_dataset_op; other splitting implementation will come in later CLs. This CL also implements a `ShardingSplitProvider`, to better test the range_dataset_op splitting implementation. `ShardingSplitProvider` will be useful in its own right for implementing an alternative to AutoShard which leverages splitting.

PiperOrigin-RevId: 332056019
Change-Id: I73b9b03cb91ae689c57a72fa6ba0acd092cf4cbe
This commit is contained in:
Andrew Audibert 2020-09-16 11:58:12 -07:00 committed by TensorFlower Gardener
parent 221535f56d
commit 5703a4eed0
12 changed files with 683 additions and 36 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/resource.h"
@ -427,6 +428,31 @@ Status DatasetBase::MakeIterator(
return s;
}
Status DatasetBase::MakeSplitProvider(
std::unique_ptr<SplitProvider>* split_provider) const {
std::vector<const DatasetBase*> inputs;
Status s = InputDatasets(&inputs);
if (errors::IsUnimplemented(s)) {
return errors::Unimplemented(
"Cannot create a split provider for dataset of type ", type_string(),
", because the dataset implements neither `InputDatasets` nor "
"`MakeSplitProvider`.");
}
if (inputs.size() != 1) {
return errors::Unimplemented(
"Cannot create a split provider for dataset of type ", type_string(),
", because the dataset is not unary (having arity ", inputs.size(),
"), and no custom implementation of `MakeSplitProvider` is defined.");
}
return inputs[0]->MakeSplitProvider(split_provider);
}
Status DatasetBase::InputDatasets(
std::vector<const DatasetBase*>* inputs) const {
return errors::Unimplemented("InputDatasets not implemented for ",
type_string());
}
Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
Status status = dataset->AsGraphDefInternal(ctx, this, output);

View File

@ -295,6 +295,24 @@ class Runner {
static Runner* get();
};
// A class which provides a sequence of splits. Iterators created with a split
// provider will iterate over only the splits provided by the split provider.
class SplitProvider {
public:
virtual ~SplitProvider() {}
// Stores the next split in `*split`, setting `*end_of_splits` to indicate
// whether there were any splits left.
virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
// Resets the split provider to its beginning.
virtual Status Reset() = 0;
// Saves the state of this split provider.
virtual Status Save(std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) = 0;
// Saves the state of this split provider.
virtual Status Restore(std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) = 0;
};
// A cut-down version of `OpKernelContext` for running computations in
// iterators. Note that we cannot simply use `OpKernelContext` here because we
// might run computation in an iterator whose lifetime is not nested within the
@ -319,6 +337,7 @@ class IteratorContext {
model(ctx->model()),
runner(*(ctx->runner())),
runner_threadpool_size(ctx->runner_threadpool_size()),
split_provider(ctx->split_provider()),
stats_aggregator(ctx->stats_aggregator()),
thread_factory(ctx->thread_factory()),
thread_pool(ctx->thread_pool()) {}
@ -386,6 +405,9 @@ class IteratorContext {
// Number of threads used for executing user-defined functions.
int32 runner_threadpool_size = 0;
// An optional split provider indicating which splits to process.
std::shared_ptr<SplitProvider> split_provider = nullptr;
// The `StatsAggregator` object to record statistics about the iterator.
std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
@ -432,6 +454,10 @@ class IteratorContext {
int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
std::shared_ptr<SplitProvider> split_provider() {
return params_.split_provider;
}
std::shared_ptr<StatsAggregator> stats_aggregator() {
return params_.stats_aggregator;
}
@ -802,6 +828,12 @@ class DatasetBase : public core::RefCounted {
return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
}
// Returns a split provider which partitions the dataset's data into splits
// and provides them in a sequence. The split provider is stored in
// `*split_provider`.
virtual Status MakeSplitProvider(
std::unique_ptr<SplitProvider>* split_provider) const;
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
// dataset.
@ -824,6 +856,14 @@ class DatasetBase : public core::RefCounted {
// A human-readable debug string for this dataset.
virtual string DebugString() const = 0;
// Stores the dataset's input datasets in `*inputs`. The pointers stored in
// `*inputs` are borrowed. The only valid non-ok return status is
// UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
// subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
// default implementation of `MakeSplitProvider` when there is a single input
// dataset.
virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
// Indicates whether the dataset depends on any external state which would
// prevent it from being serializable. If so, the method returns
// `errors::FailedPrecondition` with a message that identifies the external

View File

@ -26,6 +26,7 @@ cc_library(
":map_dataset_op",
":name_utils",
":range_dataset_op",
":split_providers",
":take_dataset_op",
":tensor_slice_dataset_op",
"//tensorflow/core:core_cpu",
@ -737,6 +738,31 @@ tf_cc_test(
],
)
cc_library(
name = "split_providers",
srcs = ["split_providers.cc"],
hdrs = ["split_providers.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
],
)
tf_cc_test(
name = "split_providers_test",
size = "small",
srcs = ["split_providers_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":split_providers",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/framework:tensor_testutil",
],
)
tf_kernel_library(
name = "take_dataset_op",
srcs = ["take_dataset_op.cc"],
@ -813,6 +839,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/memory",
],
)

View File

@ -62,6 +62,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/map_dataset_op.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "tensorflow/core/kernels/data/split_providers.h"
#include "tensorflow/core/kernels/data/take_dataset_op.h"
#include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -606,6 +607,47 @@ Status DatasetOpsTestBase::CheckIteratorGetNext(
return Status::OK();
}
Status DatasetOpsTestBase::CheckSplitProviderFullIteration(
const DatasetParams& params, const std::vector<Tensor>& expected_outputs) {
std::unique_ptr<TestDataset> dataset;
TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
std::unique_ptr<SplitProvider> split_provider;
TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
std::unique_ptr<TestIterator> iterator;
TF_RETURN_IF_ERROR(
MakeIterator(params, *dataset, std::move(split_provider), &iterator));
TF_RETURN_IF_ERROR(CheckIteratorGetNext(iterator.get(), expected_outputs,
/*compare_order=*/true));
return Status::OK();
}
Status DatasetOpsTestBase::CheckSplitProviderShardedIteration(
const DatasetParams& params, int64 num_shards, int64 shard_index,
const std::vector<Tensor>& expected_outputs) {
std::unique_ptr<TestDataset> dataset;
TF_RETURN_IF_ERROR(MakeDataset(params, &dataset));
std::unique_ptr<SplitProvider> split_provider;
TF_RETURN_IF_ERROR(dataset->dataset()->MakeSplitProvider(&split_provider));
split_provider = absl::make_unique<ShardingSplitProvider>(
num_shards, shard_index, std::move(split_provider));
std::unique_ptr<IteratorContext> iterator_ctx;
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset->op_kernel_context(), &iterator_ctx));
IteratorContext::Params iterator_params(iterator_ctx.get());
iterator_params.split_provider = std::move(split_provider);
iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
int mid_breakpoint = expected_outputs.size() / 2;
int near_end_breakpoint = expected_outputs.size() - 1;
int end_breakpoint = expected_outputs.size();
TF_RETURN_IF_ERROR(CheckIteratorSaveAndRestore(
dataset->dataset(), iterator_ctx.get(), params.iterator_prefix(),
expected_outputs,
/*breakpoints=*/
{0, mid_breakpoint, near_end_breakpoint, end_breakpoint},
/*compare_order=*/true));
return Status::OK();
}
Status DatasetOpsTestBase::CheckDatasetNodeName(
const string& expected_dataset_node_name) {
EXPECT_EQ(dataset_->node_name(), expected_dataset_node_name);
@ -658,11 +700,13 @@ Status DatasetOpsTestBase::CheckIteratorPrefix(
}
Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
const string& iterator_prefix, const std::vector<Tensor>& expected_outputs,
DatasetBase* dataset, IteratorContext* iterator_ctx,
const std::string& iterator_prefix,
const std::vector<Tensor>& expected_outputs,
const std::vector<int>& breakpoints, bool compare_order) {
std::unique_ptr<IteratorBase> iterator;
TF_RETURN_IF_ERROR(dataset_->MakeIterator(
iterator_ctx_.get(), /*parent=*/nullptr, iterator_prefix, &iterator));
TF_RETURN_IF_ERROR(dataset->MakeIterator(iterator_ctx, /*parent=*/nullptr,
iterator_prefix, &iterator));
std::unique_ptr<SerializationContext> serialization_ctx;
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false;
@ -674,33 +718,31 @@ Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
std::vector<const VariantTensorData*> data;
writer.GetData(&data);
VariantTensorDataReader reader(data);
TF_EXPECT_OK(RestoreIterator(iterator_ctx_.get(), &reader, iterator_prefix,
*dataset_, &iterator));
TF_EXPECT_OK(RestoreIterator(iterator_ctx, &reader, iterator_prefix,
*dataset, &iterator));
while (cur_iteration <= breakpoint) {
std::vector<Tensor> next;
TF_RETURN_IF_ERROR(
iterator->GetNext(iterator_ctx_.get(), &next, &end_of_sequence));
iterator->GetNext(iterator_ctx, &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
cur_iteration++;
}
if (dataset_->Cardinality() == kUnknownCardinality) {
continue;
}
if (dataset_->Cardinality() == kInfiniteCardinality ||
breakpoint < dataset_->Cardinality()) {
EXPECT_FALSE(end_of_sequence);
} else {
EXPECT_TRUE(end_of_sequence);
}
}
TF_EXPECT_OK(ExpectEqual(out_tensors, expected_outputs,
/*compare_order=*/compare_order));
return Status::OK();
}
Status DatasetOpsTestBase::CheckIteratorSaveAndRestore(
const std::string& iterator_prefix,
const std::vector<Tensor>& expected_outputs,
const std::vector<int>& breakpoints, bool compare_order) {
return CheckIteratorSaveAndRestore(dataset_, iterator_ctx_.get(),
iterator_prefix, expected_outputs,
breakpoints, compare_order);
}
Status DatasetOpsTestBase::Initialize(const DatasetParams& dataset_params) {
if (initialized_) {
return errors::Internal(
@ -793,10 +835,14 @@ Status DatasetOpsTestBase::MakeDataset(
Status DatasetOpsTestBase::MakeIterator(
const DatasetParams& dataset_params, const TestDataset& dataset,
std::unique_ptr<SplitProvider> split_provider,
std::unique_ptr<TestIterator>* iterator) {
std::unique_ptr<IteratorContext> iterator_ctx;
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset.op_kernel_context(), &iterator_ctx));
IteratorContext::Params iterator_params(iterator_ctx.get());
iterator_params.split_provider = std::move(split_provider);
iterator_ctx = absl::make_unique<IteratorContext>(iterator_params);
std::unique_ptr<IteratorBase> iterator_base;
TF_RETURN_IF_ERROR(dataset.dataset()->MakeIterator(
iterator_ctx.get(), /*parent=*/nullptr, dataset_params.iterator_prefix(),
@ -806,6 +852,13 @@ Status DatasetOpsTestBase::MakeIterator(
return Status::OK();
}
Status DatasetOpsTestBase::MakeIterator(
const DatasetParams& dataset_params, const TestDataset& dataset,
std::unique_ptr<TestIterator>* iterator) {
return MakeIterator(dataset_params, dataset, /*split_provider=*/nullptr,
iterator);
}
Status DatasetOpsTestBase::RunDatasetOp(const DatasetParams& dataset_params,
std::vector<Tensor>* outputs) {
TF_RETURN_IF_ERROR(RunDatasetOp(dataset_params, &dataset_kernel_, &params_,

View File

@ -517,6 +517,12 @@ class DatasetOpsTestBase : public ::testing::Test {
Status MakeDataset(const DatasetParams& dataset_params,
std::unique_ptr<TestDataset>* dataset);
// Creates an iterator for the given dataset, using the specified split
// provider.
Status MakeIterator(const DatasetParams& dataset_params,
const TestDataset& dataset,
std::unique_ptr<SplitProvider> split_provider,
std::unique_ptr<TestIterator>* iterator);
// Creates an iterator for the given dataset.
Status MakeIterator(const DatasetParams& dataset_params,
const TestDataset& dataset,
@ -556,6 +562,18 @@ class DatasetOpsTestBase : public ::testing::Test {
const std::vector<Tensor>& expected_outputs,
bool compare_order);
// Checks that iterating through the dataset using a split provider produces
// the expected outputs.
Status CheckSplitProviderFullIteration(
const DatasetParams& params, const std::vector<Tensor>& expected_outputs);
// Checks that iterating through the dataset using a sharded split provider
// with the given `num_shards` and `shard_index` produces the expected
// outputs.
Status CheckSplitProviderShardedIteration(
const DatasetParams& params, int64 num_shards, int64 shard_index,
const std::vector<Tensor>& expected_outputs);
// Checks `DatasetBase::node_name()`.
Status CheckDatasetNodeName(const string& expected_dataset_node_name);
@ -583,9 +601,14 @@ class DatasetOpsTestBase : public ::testing::Test {
// Checks `IteratorBase::prefix()`.
Status CheckIteratorPrefix(const string& expected_iterator_prefix);
// Checks `IteratorBase::GetNext()`.
Status CheckIteratorSaveAndRestore(
const string& iterator_prefix,
DatasetBase* dataset, IteratorContext* iterator_ctx,
const std::string& iterator_prefix,
const std::vector<Tensor>& expected_outputs,
const std::vector<int>& breakpoints, bool compare_order);
Status CheckIteratorSaveAndRestore(
const std::string& iterator_prefix,
const std::vector<Tensor>& expected_outputs,
const std::vector<int>& breakpoints, bool compare_order);
@ -660,6 +683,7 @@ class DatasetOpsTestBase : public ::testing::Test {
OpKernelContext* const op_context,
std::unique_ptr<IteratorContext>* iterator_context);
// Creates a new iterator context for iterating the dataset.
// Creates a new serialization context for serializing the dataset and
// iterator.
Status CreateSerializationContext(

View File

@ -99,6 +99,12 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
Status InputDatasets(
std::vector<const DatasetBase*>* inputs) const override {
inputs->push_back(input_);
return Status::OK();
}
Status CheckExternalState() const override {
return input_->CheckExternalState();
}

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
@ -32,7 +33,97 @@ namespace data {
/* static */ constexpr const char* const RangeDatasetOp::kOutputTypes;
/* static */ constexpr const char* const RangeDatasetOp::kOutputShapes;
namespace {
constexpr char kNext[] = "next";
constexpr char kHasSplitProvider[] = "has_split_provider";
constexpr char kSlash[] = "/";
constexpr char kSplitProvider[] = "split_provider";
// Class which produces the elements of `range(start, stop, step)`. Threadsafe.
class RangeCounter {
public:
RangeCounter(int64 start, int64 stop, int64 step)
: start_(start), stop_(stop), step_(step), next_(start) {}
// Returns the next value for the counter. Sets `*end_of_counter` to indicate
// whether the end of the counter was reached.
int64 GetNext(bool* end_of_counter) {
mutex_lock l(mu_);
if ((step_ > 0 && next_ >= stop_) || (step_ < 0 && next_ <= stop_)) {
*end_of_counter = true;
return -1;
}
*end_of_counter = false;
int result = next_;
next_ += step_;
return result;
}
int64 Peek() const {
mutex_lock l(mu_);
return next_;
}
void Reset() {
mutex_lock l(mu_);
next_ = start_;
}
void SetNext(int64 value) {
mutex_lock l(mu_);
next_ = value;
}
private:
const int64 start_;
const int64 stop_;
const int64 step_;
mutable mutex mu_;
int64 next_ TF_GUARDED_BY(mu_);
};
} // namespace
// Split provider where splits are individual outputs from RangeDataset.
// For example, the "splits" of range(0, 10, 2) will be {0, 2, 4, 6, 8}.
// The split tensors are scalars of type DT_INT64.
class RangeDatasetOp::RangeSplitProvider : public SplitProvider {
public:
RangeSplitProvider(int64 start, int64 stop, int64 step)
: counter_(start, stop, step) {}
Status GetNext(Tensor* split, bool* end_of_splits) override {
int64 next = counter_.GetNext(end_of_splits);
if (*end_of_splits) {
return Status::OK();
}
*split = Tensor(DT_INT64, TensorShape{});
split->scalar<int64>()() = next;
return Status::OK();
}
Status Reset() override {
counter_.Reset();
return Status::OK();
}
Status Save(std::function<std::string(std::string)> key_name_fn,
IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(
writer->WriteScalar(key_name_fn(kNext), counter_.Peek()));
return Status::OK();
}
Status Restore(std::function<std::string(std::string)> key_name_fn,
IteratorStateReader* reader) override {
int64 next;
TF_RETURN_IF_ERROR(reader->ReadScalar(key_name_fn(kNext), &next));
counter_.SetNext(next);
return Status::OK();
}
private:
RangeCounter counter_;
};
class RangeDatasetOp::Dataset : public DatasetBase {
public:
@ -74,6 +165,18 @@ class RangeDatasetOp::Dataset : public DatasetBase {
}
}
Status MakeSplitProvider(
std::unique_ptr<SplitProvider>* split_provider) const override {
*split_provider =
absl::make_unique<RangeSplitProvider>(start_, stop_, step_);
return Status::OK();
}
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
inputs->clear();
return Status::OK();
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
@ -93,24 +196,40 @@ class RangeDatasetOp::Dataset : public DatasetBase {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params) {
next_ = params.dataset->start_;
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
split_provider_ = ctx->split_provider();
if (!split_provider_) {
counter_ = absl::make_unique<RangeCounter>(
dataset()->start_, dataset()->stop_, dataset()->step_);
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
(dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
*end_of_sequence = true;
return Status::OK();
int64 value;
if (split_provider_ != nullptr) {
Tensor split;
TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
if (*end_of_sequence) {
return Status::OK();
}
value = split.scalar<int64>()();
} else {
value = counter_->GetNext(end_of_sequence);
if (*end_of_sequence) {
return Status::OK();
}
}
out_tensors->reserve(1);
switch (dataset()->output_dtypes()[0]) {
#define HANDLE_TYPE(type) \
case DataTypeToEnum<type>::value: { \
out_tensors->emplace_back(static_cast<type>(next_)); \
out_tensors->emplace_back(static_cast<type>(value)); \
break; \
}
TF_CALL_NUMBER_TYPES(HANDLE_TYPE);
@ -120,9 +239,6 @@ class RangeDatasetOp::Dataset : public DatasetBase {
"Unsupported data type: ",
DataTypeString(dataset()->output_dtypes()[0]));
}
*end_of_sequence = false;
next_ += dataset()->step_;
return Status::OK();
}
@ -134,21 +250,44 @@ class RangeDatasetOp::Dataset : public DatasetBase {
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNext), next_));
if (split_provider_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kHasSplitProvider), true));
TF_RETURN_IF_ERROR(split_provider_->Save(
[this](const std::string& key) {
return SplitProviderKeyNameFn(key);
},
writer));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kNext), counter_->Peek()));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next_));
if (reader->Contains(full_name(kHasSplitProvider))) {
TF_RETURN_IF_ERROR(split_provider_->Restore(
[this](const std::string& key) {
return SplitProviderKeyNameFn(key);
},
reader));
} else {
int64 next;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNext), &next));
counter_->SetNext(next);
}
return Status::OK();
}
std::string SplitProviderKeyNameFn(const std::string& key) {
return full_name(absl::StrCat(kSplitProvider, kSlash, key));
}
private:
mutex mu_;
int64 next_ TF_GUARDED_BY(mu_);
std::unique_ptr<RangeCounter> counter_;
std::shared_ptr<SplitProvider> split_provider_;
};
const int64 start_;

View File

@ -36,6 +36,7 @@ class RangeDatasetOp : public DatasetOpKernel {
private:
class Dataset;
class RangeSplitProvider;
DataTypeVector output_types_;
};

View File

@ -141,6 +141,39 @@ TEST_F(RangeDatasetOpTest, ZeroStep) {
tensorflow::error::INVALID_ARGUMENT);
}
TEST_F(RangeDatasetOpTest, SplitProviderPositiveStep) {
auto params = RangeDatasetParams(/*start=*/0, /*stop=*/10, /*step=*/3,
/*output_dtypes=*/{DT_INT64});
TF_ASSERT_OK(InitializeRuntime(params));
TF_EXPECT_OK(CheckSplitProviderFullIteration(
params, CreateTensors<int64>(TensorShape({}), {{0}, {3}, {6}, {9}})));
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
params, /*num_shards=*/2, /*shard_index=*/1,
CreateTensors<int64>(TensorShape({}), {{3}, {9}})));
}
TEST_F(RangeDatasetOpTest, SplitProviderNegativeStep) {
auto params = RangeDatasetParams(/*start=*/10, /*stop=*/0, /*step=*/-3,
/*output_dtypes=*/{DT_INT64});
TF_ASSERT_OK(InitializeRuntime(params));
TF_EXPECT_OK(CheckSplitProviderFullIteration(
params, CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})));
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
params, /*num_shards=*/2, /*shard_index=*/0,
CreateTensors<int64>(TensorShape({}), {{10}, {4}})));
}
TEST_F(RangeDatasetOpTest, SplitProviderEmpty) {
auto params = RangeDatasetParams(/*start=*/0, /*stop=*/0, /*step=*/1,
/*output_dtypes=*/{DT_INT64});
TF_ASSERT_OK(InitializeRuntime(params));
TF_EXPECT_OK(CheckSplitProviderFullIteration(
params, CreateTensors<int64>(TensorShape({}), {})));
TF_EXPECT_OK(CheckSplitProviderShardedIteration(
params, /*num_shards=*/3, /*shard_index=*/2,
CreateTensors<int64>(TensorShape({}), {})));
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,83 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/split_providers.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNumToSkip[] = "num_to_skip";
constexpr char kSplitProvider[] = "split_provider";
constexpr char kSlash[] = "/";
} // namespace
ShardingSplitProvider::ShardingSplitProvider(
int64 num_shards, int64 shard_index,
std::shared_ptr<SplitProvider> split_provider)
: num_shards_(num_shards),
shard_index_(shard_index),
split_provider_(split_provider),
num_to_skip_(shard_index_) {}
Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
mutex_lock l(mu_);
while (num_to_skip_ > 0) {
TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
if (*end_of_splits) {
return Status::OK();
}
num_to_skip_--;
}
num_to_skip_ = num_shards_ - 1;
TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
return Status::OK();
}
Status ShardingSplitProvider::Reset() {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(split_provider_->Reset());
num_to_skip_ = shard_index_;
return Status::OK();
}
Status ShardingSplitProvider::Save(
std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(split_provider_->Save(
[&](const std::string& key) {
return full_name(absl::StrCat(kSplitProvider, kSlash, key));
},
writer));
return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
}
Status ShardingSplitProvider::Restore(
std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(split_provider_->Restore(
[&](const std::string& key) {
return full_name(absl::StrCat(kSplitProvider, kSlash, key));
},
reader));
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
// A SplitProvider which wraps another split provider, but drops all splits
// where `index != shard_index % num_shards`
class ShardingSplitProvider : public SplitProvider {
public:
ShardingSplitProvider(int64 num_shards, int64 shard_index,
std::shared_ptr<SplitProvider> split_provider);
Status GetNext(Tensor* split, bool* end_of_splits) override;
Status Reset() override;
Status Save(std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) override;
Status Restore(std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) override;
private:
const int64 num_shards_;
const int64 shard_index_;
mutex mu_;
std::shared_ptr<SplitProvider> split_provider_ TF_GUARDED_BY(mu_);
int64 num_to_skip_ TF_GUARDED_BY(mu_);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SPLIT_PROVIDERS_H_

View File

@ -0,0 +1,167 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/split_providers.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
constexpr char kPipe[] = "|";
constexpr char kColon[] = ":";
constexpr char kSplits[] = "splits";
constexpr char kSplitsSize[] = "splits_size";
std::string full_name(const std::string& name) {
return strings::StrCat(kFullNameRandomHex, kPipe, "test", kColon, name);
}
Status SaveAndRestore(SplitProvider* split_provider) {
VariantTensorDataWriter writer;
TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer));
std::vector<const VariantTensorData*> variants;
writer.GetData(&variants);
VariantTensorDataReader reader(variants);
TF_RETURN_IF_ERROR(split_provider->Restore(full_name, &reader));
return Status::OK();
}
// A split provider that provides pre-defined splits.
class TestSplitProvider : public SplitProvider {
public:
explicit TestSplitProvider(std::vector<Tensor> splits) : splits_(splits) {}
Status GetNext(Tensor* split, bool* end_of_splits) override {
*end_of_splits = i_ >= splits_.size();
if (*end_of_splits) {
return Status::OK();
}
*split = splits_[i_++];
return Status::OK();
}
Status Reset() override {
i_ = 0;
return Status::OK();
}
Status Save(std::function<std::string(std::string)> full_name,
IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(kSplitsSize), splits_.size()));
for (int i = 0; i < splits_.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(absl::StrCat(kSplits, "[", i, "]")), splits_[i]));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i_"), i_));
return Status::OK();
}
Status Restore(std::function<std::string(std::string)> full_name,
IteratorStateReader* reader) override {
int64 splits_size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name(kSplitsSize), &splits_size));
splits_.clear();
for (int i = 0; i < splits_size; ++i) {
splits_.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(absl::StrCat(kSplits, "[", i, "]")), &splits_.back()));
}
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i_"), &i_));
return Status::OK();
}
private:
std::vector<Tensor> splits_;
int64 i_ = 0;
};
Status CheckOutput(ShardingSplitProvider& split_provider,
std::vector<Tensor> expected) {
int64 next = 0;
bool end_of_splits = false;
while (!end_of_splits) {
Tensor split;
TF_RETURN_IF_ERROR(split_provider.GetNext(&split, &end_of_splits));
if (!end_of_splits) {
test::ExpectEqual(split, expected[next++]);
}
}
EXPECT_EQ(next, expected.size());
return Status::OK();
}
TEST(ShardingSplitProvider, TwoWayShardZero) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
ShardingSplitProvider split_provider(2, 0, base);
TF_EXPECT_OK(CheckOutput(split_provider,
CreateTensors<int64>(TensorShape({}), {{0}, {2}})));
}
TEST(ShardingSplitProvider, TwoWayShardOne) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}}));
ShardingSplitProvider split_provider(2, 1, base);
TF_EXPECT_OK(CheckOutput(split_provider,
CreateTensors<int64>(TensorShape({}), {{1}, {3}})));
}
TEST(ShardingSplitProvider, ThreeWayShardOne) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
ShardingSplitProvider split_provider(3, 1, base);
TF_EXPECT_OK(CheckOutput(split_provider,
CreateTensors<int64>(TensorShape({}), {{1}, {4}})));
}
TEST(ShardingSplitProvider, Empty) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}}));
ShardingSplitProvider split_provider(2, 1, base);
TF_EXPECT_OK(
CheckOutput(split_provider, CreateTensors<int64>(TensorShape({}), {})));
}
TEST(ShardingSplitProvider, SaveAndRestore) {
auto base = std::make_shared<TestSplitProvider>(
CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}}));
std::vector<Tensor> expected =
CreateTensors<int64>(TensorShape({}), {{1}, {4}});
ShardingSplitProvider split_provider(3, 1, base);
for (int i = 0; i < expected.size(); ++i) {
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = true;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_FALSE(end_of_splits);
test::ExpectEqual(split, expected[i]);
}
TF_ASSERT_OK(SaveAndRestore(&split_provider));
Tensor split;
bool end_of_splits = false;
TF_ASSERT_OK(split_provider.GetNext(&split, &end_of_splits));
EXPECT_TRUE(end_of_splits);
}
} // namespace
} // namespace data
} // namespace tensorflow