[tf.data] Make AnonymousSeedGeneratorHandleOp
thread-safe.
PiperOrigin-RevId: 308432160 Change-Id: I397bcd46dfe45833b9487f0f3ee7027d2ca09780
This commit is contained in:
parent
6d9cde7cc7
commit
7f37206771
@ -72,6 +72,7 @@ void AnonymousSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) {
|
||||
int64 seed2;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
|
||||
// Seeds will be consumed by `CreateResource`, which is called via `Compute`.
|
||||
mutex_lock l(mu_);
|
||||
seeds_ = absl::make_unique<RandomSeeds>(seed, seed2);
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, kReshuffle, &reshuffle_));
|
||||
AnonymousResourceOp<SeedGeneratorManager>::Compute(ctx);
|
||||
@ -82,7 +83,8 @@ std::string AnonymousSeedGeneratorHandleOp::name() { return kSeedGenerator; }
|
||||
Status AnonymousSeedGeneratorHandleOp::CreateResource(
|
||||
OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* lib, SeedGeneratorManager** manager) {
|
||||
FunctionLibraryRuntime* lib, SeedGeneratorManager** manager)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (reshuffle_) {
|
||||
*manager = new SeedGeneratorManager(new RandomSeedGenerator(*seeds_));
|
||||
} else {
|
||||
|
@ -142,7 +142,8 @@ class AnonymousSeedGeneratorHandleOp
|
||||
FunctionLibraryRuntime* lib,
|
||||
SeedGeneratorManager** manager) override;
|
||||
|
||||
std::unique_ptr<RandomSeeds> seeds_ = nullptr;
|
||||
mutex mu_;
|
||||
std::unique_ptr<RandomSeeds> seeds_ TF_GUARDED_BY(mu_);
|
||||
bool reshuffle_;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user