[tf.data] Add non-determinstic seed code path for RandomSeedGenerator to match TF 1.X behavior.
Fixes: #31706 PiperOrigin-RevId: 263878374
This commit is contained in:
parent
cf74e4b75e
commit
c73f7cfe54
@ -81,8 +81,16 @@ AnonymousRandomSeedGeneratorHandleOp::AnonymousRandomSeedGeneratorHandleOp(
|
|||||||
: AnonymousResourceOp<RandomSeedGenerator>(ctx) {}
|
: AnonymousResourceOp<RandomSeedGenerator>(ctx) {}
|
||||||
|
|
||||||
void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
|
void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed_));
|
int64 seed;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2_));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||||
|
int64 seed2;
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||||
|
if (seed == 0 && seed2 == 0) {
|
||||||
|
seed = random::New64();
|
||||||
|
seed2 = random::New64();
|
||||||
|
}
|
||||||
|
seed_ = seed;
|
||||||
|
seed2_ = seed2;
|
||||||
AnonymousResourceOp<RandomSeedGenerator>::Compute(ctx);
|
AnonymousResourceOp<RandomSeedGenerator>::Compute(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,5 +273,25 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllEqual(results, range(10))
|
self.assertAllEqual(results, range(10))
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(
|
||||||
|
combinations.combine(tf_api_version=[1, 2], mode="eager"),
|
||||||
|
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
|
||||||
|
def testReshuffleSeparateTransformations(self, reshuffle, seed):
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
|
||||||
|
first_epoch = []
|
||||||
|
for elem in dataset.shuffle(
|
||||||
|
10, seed=seed, reshuffle_each_iteration=reshuffle):
|
||||||
|
first_epoch.append(elem.numpy())
|
||||||
|
|
||||||
|
second_epoch = []
|
||||||
|
for elem in dataset.shuffle(
|
||||||
|
10, seed=seed, reshuffle_each_iteration=reshuffle):
|
||||||
|
second_epoch.append(elem.numpy())
|
||||||
|
|
||||||
|
self.assertEqual(first_epoch != second_epoch, seed is None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user