[tf.data] Add non-determinstic seed code path for RandomSeedGenerator to match TF 1.X behavior.
Fixes: #31706 PiperOrigin-RevId: 263878374
This commit is contained in:
parent
cf74e4b75e
commit
c73f7cfe54
@ -81,8 +81,16 @@ AnonymousRandomSeedGeneratorHandleOp::AnonymousRandomSeedGeneratorHandleOp(
|
||||
: AnonymousResourceOp<RandomSeedGenerator>(ctx) {}
|
||||
|
||||
void AnonymousRandomSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed_));
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2_));
|
||||
int64 seed;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
|
||||
int64 seed2;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||
if (seed == 0 && seed2 == 0) {
|
||||
seed = random::New64();
|
||||
seed2 = random::New64();
|
||||
}
|
||||
seed_ = seed;
|
||||
seed2_ = seed2;
|
||||
AnonymousResourceOp<RandomSeedGenerator>::Compute(ctx);
|
||||
}
|
||||
|
||||
|
@ -273,5 +273,25 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(results, range(10))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
combinations.combine(tf_api_version=[1, 2], mode="eager"),
|
||||
combinations.combine(reshuffle=[True, False], seed=[None, 42])))
|
||||
def testReshuffleSeparateTransformations(self, reshuffle, seed):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
|
||||
first_epoch = []
|
||||
for elem in dataset.shuffle(
|
||||
10, seed=seed, reshuffle_each_iteration=reshuffle):
|
||||
first_epoch.append(elem.numpy())
|
||||
|
||||
second_epoch = []
|
||||
for elem in dataset.shuffle(
|
||||
10, seed=seed, reshuffle_each_iteration=reshuffle):
|
||||
second_epoch.append(elem.numpy())
|
||||
|
||||
self.assertEqual(first_epoch != second_epoch, seed is None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user