STT/training/coqui_stt_training/util/augmentations.py

766 lines
27 KiB
Python

import math
import os
import random
import re
from multiprocessing import Process, Queue
import numpy as np
import resampy
from .audio import (
AUDIO_TYPE_NP,
AUDIO_TYPE_OPUS,
AUDIO_TYPE_PCM,
gain_db_to_ratio,
max_dbfs,
normalize_audio,
)
from .helpers import (
MEGABYTE,
LimitingPool,
float_range,
int_range,
pick_value_from_range,
tf_pick_value_from_range,
)
from .sample_collections import samples_from_source, unpack_maybe
BUFFER_SIZE = 1 * MEGABYTE
SPEC_PARSER = re.compile(r"^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$")
class Augmentation:
def __init__(self, p=1.0):
self.probability = float(p)
class SampleAugmentation(Augmentation):
def start(self, buffering=BUFFER_SIZE):
pass
def apply(self, sample, clock=0.0):
raise NotImplementedError
def stop(self):
pass
class GraphAugmentation(Augmentation):
def __init__(self, p=1.0, domain="spectrogram"):
super(GraphAugmentation, self).__init__(p)
if domain not in ["signal", "spectrogram", "features"]:
raise ValueError("Unsupported augmentation domain: {}".format(domain))
self.domain = domain
def apply(self, tensor, transcript=None, clock=0.0):
raise NotImplementedError
def apply_with_probability(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
rv = tf.random.stateless_uniform(
[], seed=(clock * tf.int32.min, clock * tf.int32.max)
)
return tf.cond(
tf.less(rv, self.probability),
lambda: self.apply(tensor, transcript=transcript, clock=clock),
lambda: tensor,
)
def maybe_apply(self, domain, tensor, transcript=None, clock=0.0):
if domain == self.domain:
return self.apply_with_probability(
tensor, transcript=transcript, clock=clock
)
return tensor
def units_per_ms(self):
from .config import Config # pylint: disable=import-outside-toplevel
return (
Config.audio_sample_rate / 1000.0
if self.domain == "signal"
else 1.0 / Config.feature_win_step
)
def parse_augmentation(augmentation_spec):
"""
Parses an augmentation specification.
Parameters
----------
augmentation_spec : str
Augmentation specification like "reverb[delay=20.0,decay=1.0]".
Returns
-------
Instance of an augmentation class from util.augmentations.*.
"""
match = SPEC_PARSER.match(augmentation_spec)
if not match:
raise ValueError("Augmentation specification has wrong format")
cls_name = "".join(
map(lambda p: p[0].upper() + p[1:], match.group("cls").split("_"))
)
augmentation_cls = globals()[cls_name] if cls_name in globals() else None
if (
augmentation_cls is None
or not issubclass(augmentation_cls, Augmentation)
or augmentation_cls == Augmentation
):
raise ValueError("Unknown augmentation: {}".format(cls_name))
parameters = match.group("params")
parameters = [] if parameters is None else parameters.split(",")
args = []
kwargs = {}
for parameter in parameters:
pair = tuple(list(map(str.strip, (parameter.split("=")))))
if len(pair) == 1:
args.append(pair)
elif len(pair) == 2:
kwargs[pair[0]] = pair[1]
else:
raise ValueError("Unable to parse augmentation value assignment")
return augmentation_cls(*args, **kwargs)
def parse_augmentations(augmentation_specs):
"""
Parses an augmentation specification.
Parameters
----------
augmentation_specs : list of str
List of augmentation specifications like ["reverb[delay=20.0,decay=1.0]", "volume"].
Returns
-------
List of augmentation class instances from util.augmentations.*.
"""
return list(map(parse_augmentation, augmentation_specs or []))
def apply_graph_augmentations(
domain, tensor, augmentations, transcript=None, clock=0.0
):
"""
Augments training sample tensor of a certain domain with matching augmentations of passed list.
Parameters
----------
domain : str
Domain of the tensor to apply augmentations to. One of "signal", "spectrogram" or "features"
tensor : Tensor of type float32
Tensor to apply augmentations to.
augmentations : list of augmentation class instances from util.augmentations.*.
List of augmentations of which only the spectrogram ones will get applied to the samples.
transcript : SparseTensor
clock : Tensor of type float32
Time indicator for augmentation value-ranges. Running from 0.0 (start of training) to 1.0 (end of training).
Returns
-------
Tensor of type float32
The augmented spectrogram
"""
if augmentations:
for augmentation in augmentations:
if isinstance(augmentation, GraphAugmentation):
tensor = augmentation.maybe_apply(
domain, tensor, transcript=transcript, clock=clock
)
return tensor
class AugmentationContext:
def __init__(self, target_audio_type, augmentations):
self.target_audio_type = target_audio_type
self.augmentations = augmentations
AUGMENTATION_CONTEXT = None
def _init_augmentation_worker(preparation_context):
global AUGMENTATION_CONTEXT # pylint: disable=global-statement
AUGMENTATION_CONTEXT = preparation_context
def _load_and_augment_sample(timed_sample, context=None):
sample, clock = timed_sample
realized_sample = unpack_maybe(sample)
return _augment_sample((realized_sample, clock), context)
def _augment_sample(timed_sample, context=None):
context = AUGMENTATION_CONTEXT if context is None else context
sample, clock = timed_sample
for augmentation in context.augmentations:
if random.random() < augmentation.probability:
augmentation.apply(sample, clock)
sample.change_audio_type(new_audio_type=context.target_audio_type)
return sample
def apply_sample_augmentations(
samples,
augmentations,
audio_type=AUDIO_TYPE_NP,
buffering=BUFFER_SIZE,
process_ahead=None,
clock=0.0,
final_clock=None,
):
"""
Prepares samples for being used during training.
This includes parallel and buffered application of augmentations and a conversion to a specified audio-type.
Parameters
----------
samples : Sample enumeration
Typically produced by util.sample_collections.samples_from_sources.
augmentations : list of augmentation class instances from util.augmentations.*.
List of augmentations of which only the signal ones will get applied to the samples.
audio_type : str
Target audio-type to convert samples to. See util.audio.Sample.__init__ .
buffering : int
Read-buffer size to use while reading files.
process_ahead : int
Number of samples to pre-process ahead of time.
clock : float
Start or fixed clock value between 0.0 and 1.0 for the first or all samples. Has to be <= than final_clock.
final_clock : float
Final clock value between 0.0 and 1.0 for the last sample. Has to be >= than clock.
Requires samples.__len__ attribute.
Returns
-------
iterable of util.sample_collections.LabeledSample or util.audio.Sample
"""
def timed_samples():
if final_clock is None:
for sample in samples:
yield sample, clock
else:
for sample_index, sample in enumerate(samples):
sample_clock = clock + (final_clock - clock) * (
sample_index / len(samples)
)
yield sample, sample_clock
assert 0.0 <= clock <= 1.0
if final_clock is not None:
assert 0.0 <= final_clock <= 1.0
assert clock <= final_clock
augmentations = (
[aug for aug in augmentations if isinstance(aug, SampleAugmentation)]
if augmentations
else []
)
try:
for augmentation in augmentations:
augmentation.start(buffering=buffering)
context = AugmentationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
yield _load_and_augment_sample(timed_sample, context=context)
else:
with LimitingPool(
process_ahead=process_ahead,
initializer=_init_augmentation_worker,
initargs=(context,),
) as pool:
yield from pool.imap(_load_and_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
augmentation.stop()
def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
"""
As the central distribution point for overlay samples this function is supposed to run in one process only.
This ensures that samples are not used twice if not required.
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
These are then doing decompression, potential conversion and overlaying in parallel.
"""
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
while True:
for sample in samples:
queue.put(sample)
class Overlay(SampleAugmentation):
"""See "Overlay augmentation" in training documentation"""
def __init__(self, source, p=1.0, snr=3.0, layers=1):
super(Overlay, self).__init__(p)
self.source = source
self.snr = float_range(snr)
self.layers = int_range(layers)
self.current_sample = None
self.queue = None
self.enqueue_process = None
def __repr__(self):
return f"Overlay(source={self.source!r}, p={self.probability!r}, snr={self.snr!r}, layers={self.layers!r})"
def start(self, buffering=BUFFER_SIZE):
self.queue = Queue(
max(1, math.floor(self.probability * self.layers[1] * os.cpu_count()))
)
self.enqueue_process = Process(
target=_enqueue_overlay_samples,
args=(self.source, self.queue),
kwargs={"buffering": buffering},
)
self.enqueue_process.start()
def apply(self, sample, clock=0.0):
sample = unpack_maybe(sample)
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
n_layers = pick_value_from_range(self.layers, clock=clock)
audio = sample.audio
overlay_data = np.zeros_like(audio)
for _ in range(n_layers):
overlay_offset = 0
while overlay_offset < len(audio):
if self.current_sample is None:
next_overlay_sample = self.queue.get()
next_overlay_sample = unpack_maybe(next_overlay_sample)
next_overlay_sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
self.current_sample = next_overlay_sample.audio
n_required = len(audio) - overlay_offset
n_current = len(self.current_sample)
if n_required >= n_current: # take it completely
overlay_data[
overlay_offset : overlay_offset + n_current
] += self.current_sample
overlay_offset += n_current
self.current_sample = None
else: # take required slice from head and keep tail for next layer or sample
overlay_data[
overlay_offset : overlay_offset + n_required
] += self.current_sample[0:n_required]
overlay_offset += n_required
self.current_sample = self.current_sample[n_required:]
snr_db = pick_value_from_range(self.snr, clock=clock)
orig_dbfs = max_dbfs(audio)
overlay_gain = orig_dbfs - max_dbfs(overlay_data) - snr_db
audio += overlay_data * gain_db_to_ratio(overlay_gain)
sample.audio = normalize_audio(audio, dbfs=orig_dbfs)
def stop(self):
if self.enqueue_process is not None:
self.enqueue_process.terminate()
self.enqueue_process = None
self.current_sample = None
self.queue = None
class Codec(SampleAugmentation):
"""See "Codec augmentation" in training documentation"""
def __init__(self, p=1.0, bitrate=3200):
super(Codec, self).__init__(p)
self.bitrate = int_range(bitrate)
def __repr__(self):
return f"Codec(p={self.probability!r}, bitrate={self.bitrate!r})"
def apply(self, sample, clock=0.0):
bitrate = pick_value_from_range(self.bitrate, clock=clock)
sample.change_audio_type(
new_audio_type=AUDIO_TYPE_PCM
) # decoding to ensure it has to get encoded again
sample.change_audio_type(
new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate
) # will get decoded again downstream
class Reverb(SampleAugmentation):
"""See "Reverb augmentation" in training documentation"""
def __init__(self, p=1.0, delay=20.0, decay=10.0):
super(Reverb, self).__init__(p)
self.delay = float_range(delay)
self.decay = float_range(decay)
def __repr__(self):
return f"Reverb(p={self.probability!r}, delay={self.delay!r}, decay={self.decay!r})"
def apply(self, sample, clock=0.0):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
audio = np.array(sample.audio, dtype=np.float64)
orig_dbfs = max_dbfs(audio)
delay = pick_value_from_range(self.delay, clock=clock)
decay = pick_value_from_range(self.decay, clock=clock)
decay = gain_db_to_ratio(-decay)
result = np.copy(audio)
primes = [17, 19, 23, 29, 31]
for delay_prime in primes: # primes to minimize comb filter interference
layer = np.copy(audio)
n_delay = math.floor(
delay * (delay_prime / primes[0]) * sample.audio_format.rate / 1000.0
)
n_delay = max(
16, n_delay
) # 16 samples minimum to avoid performance trap and risk of division by zero
for w_index in range(0, math.floor(len(audio) / n_delay)):
w1 = w_index * n_delay
w2 = (w_index + 1) * n_delay
width = min(len(audio) - w2, n_delay) # last window could be smaller
layer[w2 : w2 + width] += decay * layer[w1 : w1 + width]
result += layer
audio = normalize_audio(result, dbfs=orig_dbfs)
sample.audio = np.array(audio, dtype=np.float32)
class Resample(SampleAugmentation):
"""See "Resample augmentation" in training documentation"""
def __init__(self, p=1.0, rate=8000):
super(Resample, self).__init__(p)
self.rate = int_range(rate)
def __repr__(self):
return f"Resample(p={self.probability!r}, rate={self.rate!r})"
def apply(self, sample, clock=0.0):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
rate = pick_value_from_range(self.rate, clock=clock)
orig_len = len(sample.audio)
resampled = resampy.resample(
sample.audio, sample.audio_format.rate, rate, axis=0, filter="kaiser_fast"
)
sample.audio = resampy.resample(
resampled, rate, sample.audio_format.rate, axis=0, filter="kaiser_fast"
)[:orig_len]
class NormalizeSampleRate(SampleAugmentation):
def __init__(self, rate):
super().__init__(p=1.0)
self.rate = rate
def __repr__(self):
return f"NormalizeSampleRate(rate={self.rate!r})"
def apply(self, sample, clock=0.0):
if sample.audio_format.rate == self.rate:
return
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
sample.audio = resampy.resample(
sample.audio,
sample.audio_format.rate,
self.rate,
axis=0,
filter="kaiser_fast",
)
sample.audio_format = sample.audio_format._replace(rate=self.rate)
class Volume(SampleAugmentation):
"""See "Volume augmentation" in training documentation"""
def __init__(self, p=1.0, dbfs=3.0103):
super(Volume, self).__init__(p)
self.target_dbfs = float_range(dbfs)
def __repr__(self):
return f"Volume(p={self.probability!r}, dbfs={self.target_dbfs!r})"
def apply(self, sample, clock=0.0):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_NP)
target_dbfs = pick_value_from_range(self.target_dbfs, clock=clock)
sample.audio = normalize_audio(sample.audio, dbfs=target_dbfs)
class Pitch(GraphAugmentation):
"""See "Pitch augmentation" in training documentation"""
def __init__(self, p=1.0, pitch=(1.075, 1.075, 0.125)):
super(Pitch, self).__init__(p, domain="spectrogram")
self.pitch = float_range(pitch)
def __repr__(self):
return f"Pitch(p={self.probability!r}, pitch={self.pitch!r})"
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor)
pitch = tf_pick_value_from_range(self.pitch, clock=clock)
new_freq_size = tf.cast(
tf.cast(original_shape[2], tf.float32) * pitch, tf.int32
)
spectrogram_aug = tf.image.resize_bilinear(
tf.expand_dims(tensor, -1), [original_shape[1], new_freq_size]
)
spectrogram_aug = tf.image.crop_to_bounding_box(
spectrogram_aug,
offset_height=0,
offset_width=0,
target_height=original_shape[1],
target_width=tf.math.minimum(original_shape[2], new_freq_size),
)
spectrogram_aug = tf.cond(
pitch < 1,
lambda: tf.image.pad_to_bounding_box(
spectrogram_aug,
offset_height=0,
offset_width=0,
target_height=tf.shape(spectrogram_aug)[1],
target_width=original_shape[2],
),
lambda: spectrogram_aug,
)
return spectrogram_aug[:, :, :, 0]
class Tempo(GraphAugmentation):
"""See "Tempo augmentation" in training documentation"""
def __init__(self, p=1.0, factor=1.1, max_time=-1):
super(Tempo, self).__init__(p, domain="spectrogram")
self.factor = float_range(factor)
self.max_time = float(max_time)
def __repr__(self):
return f"Tempo(p={self.probability!r}, factor={self.factor!r}, max_time={self.max_time!r})"
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
factor = tf_pick_value_from_range(self.factor, clock=clock)
original_shape = tf.shape(tensor)
new_time_size = tf.cast(
tf.cast(original_shape[1], tf.float32) / factor, tf.int32
)
if transcript is not None:
new_time_size = tf.math.maximum(new_time_size, tf.shape(transcript)[1])
if self.max_time > 0:
new_time_size = tf.math.minimum(
new_time_size, tf.cast(self.max_time * self.units_per_ms(), tf.int32)
)
spectrogram_aug = tf.image.resize_bilinear(
tf.expand_dims(tensor, -1), [new_time_size, original_shape[2]]
)
return spectrogram_aug[:, :, :, 0]
class Warp(GraphAugmentation):
"""See "Warp augmentation" in training documentation"""
def __init__(self, p=1.0, num_t=1, num_f=1, warp_t=0.1, warp_f=0.0):
super(Warp, self).__init__(p, domain="spectrogram")
self.num_t = int_range(num_t)
self.num_f = int_range(num_f)
self.warp_t = float_range(warp_t)
self.warp_f = float_range(warp_f)
def __repr__(self):
return f"Warp(p={self.probability!r}, num_t={self.num_t!r}, num_f={self.num_f!r}, warp_t={self.warp_t!r}, warp_f={self.warp_f!r})"
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
original_shape = tf.shape(tensor)
size_t, size_f = original_shape[1], original_shape[2]
seed = (clock * tf.int32.min, clock * tf.int32.max)
num_t = tf_pick_value_from_range(self.num_t, clock=clock)
num_f = tf_pick_value_from_range(self.num_f, clock=clock)
def get_flows(n, size, warp):
warp = tf_pick_value_from_range(warp, clock=clock)
warp = (
warp
* tf.cast(size, dtype=tf.float32)
/ tf.cast(2 * (n + 1), dtype=tf.float32)
)
f = tf.random.stateless_normal(
[num_t, num_f], seed, mean=0.0, stddev=warp, dtype=tf.float32
)
return tf.pad(
f, tf.constant([[1, 1], [1, 1]]), "CONSTANT"
) # zero flow at all edges
flows = tf.stack(
[
get_flows(num_t, size_t, self.warp_t),
get_flows(num_f, size_f, self.warp_f),
],
axis=2,
)
flows = tf.image.resize_bicubic(tf.expand_dims(flows, 0), [size_t, size_f])
spectrogram_aug = tf.contrib.image.dense_image_warp(
tf.expand_dims(tensor, -1), flows
)
return tf.reshape(spectrogram_aug, shape=(1, -1, size_f))
class FrequencyMask(GraphAugmentation):
"""See "Frequency mask augmentation" in training documentation"""
def __init__(self, p=1.0, n=3, size=2):
super(FrequencyMask, self).__init__(p, domain="spectrogram")
self.n = int_range(n) # pylint: disable=invalid-name
self.size = int_range(size)
def __repr__(self):
return (
f"FrequencyMask(p={self.probability!r}, n={self.n!r}, size={self.size!r})"
)
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
time_max = tf.shape(tensor)[1]
freq_max = tf.shape(tensor)[2]
n = tf_pick_value_from_range(self.n, clock=clock)
def body(i, spectrogram_aug):
size = tf_pick_value_from_range(self.size, clock=clock)
size = tf.math.maximum(1, tf.math.minimum(freq_max - 1, size))
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
f0 = tf.random.stateless_uniform(
(),
(-seed, seed),
minval=0,
maxval=freq_max - size,
dtype=tf.dtypes.int32,
)
freq_mask = tf.concat(
[
tf.ones([1, time_max, f0]),
tf.zeros([1, time_max, size]),
tf.ones([1, time_max, freq_max - f0 - size]),
],
axis=2,
)
return i + 1, spectrogram_aug * freq_mask
return tf.while_loop(lambda i, spectrogram_aug: i < n, body, (0, tensor))[1]
class TimeMask(GraphAugmentation):
"""See "Time mask augmentation" in training documentation"""
def __init__(self, p=1.0, domain="spectrogram", n=3, size=10.0):
super(TimeMask, self).__init__(p, domain=domain)
self.n = int_range(n) # pylint: disable=invalid-name
self.size = float_range(size)
def __repr__(self):
return f"TimeMask(p={self.probability!r}, domain={self.domain!r}, n={self.n!r}, size={self.size!r})"
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
time_max = tf.shape(tensor)[0 if self.domain == "signal" else 1]
n = tf_pick_value_from_range(self.n, clock=clock)
def body(i, augmented):
size = tf.cast(
tf_pick_value_from_range(self.size, clock=clock) * self.units_per_ms(),
dtype=tf.int32,
)
size = tf.math.maximum(1, tf.math.minimum(time_max - 1, size))
seed = tf.cast(clock * tf.int32.max, tf.int32) - i
t0 = tf.random.stateless_uniform(
(),
(-seed, seed),
minval=0,
maxval=time_max - size,
dtype=tf.dtypes.int32,
)
rest = time_max - t0 - size
if self.domain == "spectrogram":
fm = tf.shape(tensor)[2]
time_mask = tf.concat(
[
tf.ones([1, t0, fm]),
tf.zeros([1, size, fm]),
tf.ones([1, rest, fm]),
],
axis=1,
)
elif self.domain == "signal":
time_mask = tf.concat(
[tf.ones([t0, 1]), tf.zeros([size, 1]), tf.ones([rest, 1])], axis=0
)
else:
time_mask = tf.concat(
[tf.ones([1, t0]), tf.zeros([1, size]), tf.ones([1, rest])], axis=1
)
return i + 1, augmented * time_mask
return tf.while_loop(lambda i, augmented: i < n, body, (0, tensor))[1]
class Dropout(GraphAugmentation):
"""See "Dropout augmentation" in training documentation"""
def __init__(self, p=1.0, domain="spectrogram", rate=0.05):
super(Dropout, self).__init__(p, domain=domain)
self.rate = float_range(rate)
def __repr__(self):
return f"Dropout(p={self.probability!r}, domain={self.domain!r}, rate={self.rate!r})"
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
rate = tf_pick_value_from_range(self.rate, clock=clock)
rate = tf.math.maximum(0.0, rate)
factors = tf.random.stateless_uniform(
tf.shape(tensor),
(clock * tf.int32.min, clock * tf.int32.max),
minval=0.0,
maxval=1.0,
dtype=tf.float32,
)
return tensor * tf.math.sign(tf.math.floor(factors + rate))
class Add(GraphAugmentation):
"""See "Add augmentation" in training documentation"""
def __init__(self, p=1.0, domain="features", stddev=5):
super(Add, self).__init__(p, domain=domain)
self.stddev = float_range(stddev)
def __repr__(self):
return f"Add(p={self.probability!r}, domain={self.domain!r}, stddev={self.stddev!r})"
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max)
return tensor + tf.random.stateless_normal(
tf.shape(tensor), seed, mean=0.0, stddev=stddev
)
class Multiply(GraphAugmentation):
"""See "Multiply augmentation" in training documentation"""
def __init__(self, p=1.0, domain="features", stddev=5):
super(Multiply, self).__init__(p, domain=domain)
self.stddev = float_range(stddev)
def __repr__(self):
return f"Multiply(p={self.probability!r}, domain={self.domain!r}, stddev={self.stddev!r})"
def apply(self, tensor, transcript=None, clock=0.0):
import tensorflow as tf # pylint: disable=import-outside-toplevel
stddev = tf_pick_value_from_range(self.stddev, clock=clock)
seed = (clock * tf.int32.min, clock * tf.int32.max)
return tensor * tf.random.stateless_normal(
tf.shape(tensor), seed, mean=1.0, stddev=stddev
)