[tf.data] Make AnonymousSeedGeneratorHandleOp thread-safe.

PiperOrigin-RevId: 308432160
Change-Id: I397bcd46dfe45833b9487f0f3ee7027d2ca09780
This commit is contained in:
Jiri Simsa 2020-04-25 12:32:33 -07:00 committed by TensorFlower Gardener
parent 6d9cde7cc7
commit 7f37206771
2 changed files with 5 additions and 2 deletions

View File

@ -72,6 +72,7 @@ void AnonymousSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
int64 seed2; int64 seed2;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2)); OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
// Seeds will be consumed by `CreateResource`, which is called via `Compute`. // Seeds will be consumed by `CreateResource`, which is called via `Compute`.
mutex_lock l(mu_);
seeds_ = absl::make_unique<RandomSeeds>(seed, seed2); seeds_ = absl::make_unique<RandomSeeds>(seed, seed2);
OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, kReshuffle, &reshuffle_)); OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, kReshuffle, &reshuffle_));
AnonymousResourceOp<SeedGeneratorManager>::Compute(ctx); AnonymousResourceOp<SeedGeneratorManager>::Compute(ctx);
@ -82,7 +83,8 @@ std::string AnonymousSeedGeneratorHandleOp::name() { return kSeedGenerator; }
Status AnonymousSeedGeneratorHandleOp::CreateResource( Status AnonymousSeedGeneratorHandleOp::CreateResource(
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def, OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* lib, SeedGeneratorManager** manager) { FunctionLibraryRuntime* lib, SeedGeneratorManager** manager)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (reshuffle_) { if (reshuffle_) {
*manager = new SeedGeneratorManager(new RandomSeedGenerator(*seeds_)); *manager = new SeedGeneratorManager(new RandomSeedGenerator(*seeds_));
} else { } else {

View File

@ -142,7 +142,8 @@ class AnonymousSeedGeneratorHandleOp
FunctionLibraryRuntime* lib, FunctionLibraryRuntime* lib,
SeedGeneratorManager** manager) override; SeedGeneratorManager** manager) override;
std::unique_ptr<RandomSeeds> seeds_ = nullptr; mutex mu_;
std::unique_ptr<RandomSeeds> seeds_ TF_GUARDED_BY(mu_);
bool reshuffle_; bool reshuffle_;
}; };