Reverts stateless_random_normal to use old op because the new op can OOM on TPU due to excessive padding.

PiperOrigin-RevId: 339153626
Change-Id: Ib0a47d35f99c82a3b0b0e44e023d7a39b6710db2
This commit is contained in:
Peng Wang 2020-10-26 17:32:23 -07:00 committed by TensorFlower Gardener
parent 6b5e470f7f
commit 08e92a07fe

View File

@ -508,12 +508,10 @@ def stateless_random_normal(shape,
shape = tensor_util.shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
if compat.forward_compatible(2020, 10, 25):
key, counter, alg = _get_key_counter_alg(seed)
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg)
else:
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
# TODO(b/171746875): stateless_random_normal([1024, 32000], dtype='float32')
# OOM on TPU with StatelessRandomNormalV2 because of excessive padding.
# Investigate and switch to StatelessRandomNormalV2.
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
result = math_ops.add(rnd * stddev, mean, name=name)
tensor_util.maybe_set_static_shape(result, shape)
return result