Fix regression caused by PR #1868

This commit is contained in:
Reuben Morais 2021-05-18 16:54:29 +02:00
parent 5ad6e6abbf
commit 5114362f6d
1 changed files with 8 additions and 10 deletions

View File

@ -255,16 +255,14 @@ def tf_pick_value_from_range(value_range, clock=None, double_precision=False):
tf.minimum(tf.constant(1.0, dtype=tf.float64), clock), tf.minimum(tf.constant(1.0, dtype=tf.float64), clock),
) )
value = value_range.start + clock * (value_range.end - value_range.start) value = value_range.start + clock * (value_range.end - value_range.start)
if value_range.r: # sample the value from a uniform distribution with "radius" <r>
# if the option <r> (<value>~<r>, randomization radius) is supplied, value = tf.random.stateless_uniform(
# sample the value from a uniform distribution with "radius" <r> [],
value = tf.random.stateless_uniform( minval=value - value_range.r,
[], maxval=value + value_range.r,
minval=value - value_range.r, seed=(clock * tf.int32.min, clock * tf.int32.max),
maxval=value + value_range.r, dtype=tf.float64,
seed=(clock * tf.int32.min, clock * tf.int32.max), )
dtype=tf.float64,
)
if isinstance(value_range.start, int): if isinstance(value_range.start, int):
return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32) return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32)
return tf.cast(value, tf.float64 if double_precision else tf.float32) return tf.cast(value, tf.float64 if double_precision else tf.float32)