Avoid baking random seeds into RandomDataset and SamplingDataset.
Instead of resolving nondeterministic seeds during dataset construction, we will now resolve them during iterator construction. This way, each iterator will produce data using independent seeds. I tried to move ShuffleDataset to the new paradigm, but it is not simple because we need to support `reshuffle_each_iteration`, which requires shuffle datasets to remember their nondeterministically chosen seeds across iterations. PiperOrigin-RevId: 298928444 Change-Id: I72e1cfa0dd9b2584fef624ddd5562aaee8ceb7d4
This commit is contained in:
parent
f6793de3fe
commit
97f595f5f7
@ -419,6 +419,13 @@ Status HashGraph(const GraphDef& graph_def, uint64* hash) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds) {
|
||||
if (seeds.first == 0 && seeds.second == 0) {
|
||||
return {random::New64(), random::New64()};
|
||||
}
|
||||
return seeds;
|
||||
}
|
||||
|
||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
std::function<void()> register_fn,
|
||||
std::function<void()>* deregister_fn) {
|
||||
|
@ -188,6 +188,13 @@ class DeterminismPolicy {
|
||||
Type determinism_;
|
||||
};
|
||||
|
||||
// Resolves non-deterministic seeds if necessary, returning either the original
|
||||
// seeds or the resolved seeds.
|
||||
//
|
||||
// By TensorFlow convention, if both seeds are 0, they should be replaced with
|
||||
// non-deterministically chosen seeds.
|
||||
std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds);
|
||||
|
||||
// Helper class for reading data from a vector of VariantTensorData objects.
|
||||
class VariantTensorDataReader : public IteratorStateReader {
|
||||
public:
|
||||
|
@ -320,6 +320,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
],
|
||||
)
|
||||
|
||||
@ -364,6 +365,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
],
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
@ -37,7 +38,7 @@ namespace experimental {
|
||||
class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
|
||||
: DatasetBase(DatasetContext(ctx)), seed_(seed), seed2_(seed2) {}
|
||||
: DatasetBase(DatasetContext(ctx)), seeds_(seed, seed2) {}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
@ -57,8 +58,8 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_,
|
||||
")::Dataset");
|
||||
return strings::StrCat("RandomDatasetOp(", seeds_.first, ", ",
|
||||
seeds_.second, ")::Dataset");
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return kInfiniteCardinality; }
|
||||
@ -69,8 +70,8 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
Node** output) const override {
|
||||
Node* seed = nullptr;
|
||||
Node* seed2 = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.first, &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.second, &seed2));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -80,7 +81,8 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
parent_generator_(dataset()->seed_, dataset()->seed2_),
|
||||
seeds_(MaybeOverrideSeeds(dataset()->seeds_)),
|
||||
parent_generator_(seeds_.first, seeds_.second),
|
||||
generator_(&parent_generator_) {}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
@ -112,8 +114,7 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
|
||||
&num_random_samples_));
|
||||
parent_generator_ =
|
||||
random::PhiloxRandom(dataset()->seed_, dataset()->seed2_);
|
||||
parent_generator_ = random::PhiloxRandom(seeds_.first, seeds_.second);
|
||||
generator_ =
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
|
||||
generator_.Skip(num_random_samples_);
|
||||
@ -127,6 +128,7 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
auto out = generator_();
|
||||
return out;
|
||||
}
|
||||
const std::pair<int64, int64> seeds_;
|
||||
mutex mu_;
|
||||
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
|
||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_
|
||||
@ -134,8 +136,7 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
|
||||
};
|
||||
|
||||
const int64 seed_;
|
||||
const int64 seed2_;
|
||||
const std::pair<int64, int64> seeds_;
|
||||
}; // RandomDatasetOp::Dataset
|
||||
|
||||
RandomDatasetOp::RandomDatasetOp(OpKernelConstruction* ctx)
|
||||
@ -148,13 +149,6 @@ void RandomDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
|
||||
int64 seed2;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
|
||||
|
||||
// By TensorFlow convention, passing 0 for both seeds indicates
|
||||
// that the shuffling should be seeded non-deterministically.
|
||||
if (seed == 0 && seed2 == 0) {
|
||||
seed = random::New64();
|
||||
seed2 = random::New64();
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, seed, seed2);
|
||||
}
|
||||
namespace {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/name_utils.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
@ -43,8 +44,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
const DatasetBase* input)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
rate_(rate),
|
||||
seed_(seed),
|
||||
seed2_(seed2),
|
||||
seeds_(seed, seed2),
|
||||
input_(input) {
|
||||
input_->Ref();
|
||||
}
|
||||
@ -55,7 +55,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||
seed_, seed2_));
|
||||
seeds_.first, seeds_.second));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
@ -84,8 +84,8 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
Node* seed = nullptr;
|
||||
Node* seed2 = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(rate_, &rate));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.first, &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.second, &seed2));
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node, rate, seed, seed2}, output));
|
||||
return Status::OK();
|
||||
@ -96,9 +96,8 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Iterator(const Params& params, int64 seed, int64 seed2)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
seed_(seed),
|
||||
seed2_(seed2),
|
||||
parent_generator_(seed, seed2),
|
||||
seeds_(MaybeOverrideSeeds({seed, seed2})),
|
||||
parent_generator_(seeds_.first, seeds_.second),
|
||||
generator_(&parent_generator_) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
@ -140,7 +139,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
protected:
|
||||
void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// Reset the generators based on the current iterator seeds.
|
||||
parent_generator_ = random::PhiloxRandom(seed_, seed2_);
|
||||
parent_generator_ = random::PhiloxRandom(seeds_.first, seeds_.second);
|
||||
generator_ =
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
|
||||
generator_.Skip(num_random_samples_);
|
||||
@ -151,8 +150,10 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
// Save state needed to restore the random number generators.
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
this->full_name("num_random_samples"), num_random_samples_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("seed"), seed_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("seed2"), seed2_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(this->full_name("seed"), seeds_.first));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(this->full_name("seed2"), seeds_.second));
|
||||
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
@ -169,8 +170,11 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
// Restore the random number generators.
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||
this->full_name("num_random_samples"), &num_random_samples_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed_));
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed2"), &seed2_));
|
||||
int64 seed;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed));
|
||||
int64 seed2;
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed2"), &seed2));
|
||||
seeds_ = {seed, seed2};
|
||||
ResetRngs();
|
||||
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
@ -182,8 +186,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
int64 seed_ GUARDED_BY(mu_);
|
||||
int64 seed2_ GUARDED_BY(mu_);
|
||||
std::pair<int64, int64> seeds_ GUARDED_BY(mu_);
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
@ -206,7 +209,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
};
|
||||
|
||||
const float rate_;
|
||||
const int64 seed_, seed2_;
|
||||
const std::pair<int64, int64> seeds_;
|
||||
const DatasetBase* const input_;
|
||||
}; // SamplingDatasetOp::Dataset
|
||||
|
||||
@ -223,10 +226,6 @@ void SamplingDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||
|
||||
if (seed == 0 && seed2 == 0) {
|
||||
seed = random::New64();
|
||||
seed2 = random::New64();
|
||||
}
|
||||
*output = new Dataset(ctx, rate, seed, seed2, input);
|
||||
}
|
||||
|
||||
|
@ -540,6 +540,20 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "random_dataset_test",
|
||||
size = "small",
|
||||
srcs = ["random_dataset_test.py"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "reader_dataset_ops_test_base",
|
||||
srcs = [
|
||||
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.experimental.RandomDataset()`."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import random_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RandomDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(global_seed=[None, 10], local_seed=[None, 20])))
|
||||
def testDeterminism(self, global_seed, local_seed):
|
||||
expect_determinism = (global_seed is not None) or (local_seed is not None)
|
||||
|
||||
random_seed.set_random_seed(global_seed)
|
||||
ds = random_ops.RandomDataset(seed=local_seed).take(10)
|
||||
|
||||
output_1 = self.getDatasetOutput(ds)
|
||||
ds = self.graphRoundTrip(ds)
|
||||
output_2 = self.getDatasetOutput(ds)
|
||||
|
||||
if expect_determinism:
|
||||
self.assertEqual(output_1, output_2)
|
||||
else:
|
||||
# Technically not guaranteed since the two randomly-chosen int64 seeds
|
||||
# could match, but that is sufficiently unlikely (1/2^128 with perfect
|
||||
# random number generation).
|
||||
self.assertNotEqual(output_1, output_2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user