Reverts stateless_random_normal to use old op because the new op can OOM on TPU due to excessive padding.
PiperOrigin-RevId: 339153626 Change-Id: Ib0a47d35f99c82a3b0b0e44e023d7a39b6710db2
This commit is contained in:
parent
6b5e470f7f
commit
08e92a07fe
@ -508,12 +508,10 @@ def stateless_random_normal(shape,
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
# TODO(b/171746875): stateless_random_normal([1024, 32000], dtype='float32')
|
||||
# OOM on TPU with StatelessRandomNormalV2 because of excessive padding.
|
||||
# Investigate and switch to StatelessRandomNormalV2.
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
|
Loading…
Reference in New Issue
Block a user