[tf.data] Add non-determinstic seed code path for RandomSeedGenerator to match TF 1.X behavior.

Fixes: #31706
PiperOrigin-RevId: 263878374
This commit is contained in:
Jiri Simsa 2019-08-16 17:16:09 -07:00 committed by TensorFlower Gardener
parent cf74e4b75e
commit c73f7cfe54
2 changed files with 30 additions and 2 deletions

View File

@ -81,8 +81,16 @@ AnonymousRandomSeedGeneratorHandleOp::AnonymousRandomSeedGeneratorHandleOp(
: AnonymousResourceOp<RandomSeedGenerator>(ctx) {}
void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed_));
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2_));
int64 seed;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
int64 seed2;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
if (seed == 0 && seed2 == 0) {
seed = random::New64();
seed2 = random::New64();
}
seed_ = seed;
seed2_ = seed2;
AnonymousResourceOp<RandomSeedGenerator>::Compute(ctx);
}

View File

@ -273,5 +273,25 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertAllEqual(results, range(10))
@combinations.generate(
combinations.times(
combinations.combine(tf_api_version=[1, 2], mode="eager"),
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
def testReshuffleSeparateTransformations(self, reshuffle, seed):
dataset = dataset_ops.Dataset.range(10)
first_epoch = []
for elem in dataset.shuffle(
10, seed=seed, reshuffle_each_iteration=reshuffle):
first_epoch.append(elem.numpy())
second_epoch = []
for elem in dataset.shuffle(
10, seed=seed, reshuffle_each_iteration=reshuffle):
second_epoch.append(elem.numpy())
self.assertEqual(first_epoch != second_epoch, seed is None)
if __name__ == "__main__":
test.main()