[tf.data] Make AnonymousSeedGeneratorHandleOp
thread-safe.
PiperOrigin-RevId: 308432160 Change-Id: I397bcd46dfe45833b9487f0f3ee7027d2ca09780
This commit is contained in:
parent
6d9cde7cc7
commit
7f37206771
@ -72,6 +72,7 @@ void AnonymousSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
|
|||||||
int64 seed2;
|
int64 seed2;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||||
// Seeds will be consumed by `CreateResource`, which is called via `Compute`.
|
// Seeds will be consumed by `CreateResource`, which is called via `Compute`.
|
||||||
|
mutex_lock l(mu_);
|
||||||
seeds_ = absl::make_unique<RandomSeeds>(seed, seed2);
|
seeds_ = absl::make_unique<RandomSeeds>(seed, seed2);
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, kReshuffle, &reshuffle_));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, kReshuffle, &reshuffle_));
|
||||||
AnonymousResourceOp<SeedGeneratorManager>::Compute(ctx);
|
AnonymousResourceOp<SeedGeneratorManager>::Compute(ctx);
|
||||||
@ -82,7 +83,8 @@ std::string AnonymousSeedGeneratorHandleOp::name() { return kSeedGenerator; }
|
|||||||
Status AnonymousSeedGeneratorHandleOp::CreateResource(
|
Status AnonymousSeedGeneratorHandleOp::CreateResource(
|
||||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||||
FunctionLibraryRuntime* lib, SeedGeneratorManager** manager) {
|
FunctionLibraryRuntime* lib, SeedGeneratorManager** manager)
|
||||||
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (reshuffle_) {
|
if (reshuffle_) {
|
||||||
*manager = new SeedGeneratorManager(new RandomSeedGenerator(*seeds_));
|
*manager = new SeedGeneratorManager(new RandomSeedGenerator(*seeds_));
|
||||||
} else {
|
} else {
|
||||||
|
@ -142,7 +142,8 @@ class AnonymousSeedGeneratorHandleOp
|
|||||||
FunctionLibraryRuntime* lib,
|
FunctionLibraryRuntime* lib,
|
||||||
SeedGeneratorManager** manager) override;
|
SeedGeneratorManager** manager) override;
|
||||||
|
|
||||||
std::unique_ptr<RandomSeeds> seeds_ = nullptr;
|
mutex mu_;
|
||||||
|
std::unique_ptr<RandomSeeds> seeds_ TF_GUARDED_BY(mu_);
|
||||||
bool reshuffle_;
|
bool reshuffle_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user