Replaces uses of op StatelessRandomGetKeyCounterAlg with StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that seed is no longer required by XLA to be a compile-time constant.

Adds SetIsStateful to StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that they won't be constant-folded away.

Re-introduces StatelessRandomNormalV2 in tf.random.stateless_normal, since the OOM problem has been fixed.

PiperOrigin-RevId: 351442194
Change-Id: I5bc20df735e88e3d1cbb7fc7e69b070386af27d6
This commit is contained in:
Peng Wang 2021-01-12 13:53:34 -08:00 committed by TensorFlower Gardener
parent 88e57640cd
commit 5a7e5687c6
2 changed files with 17 additions and 6 deletions

View File

@ -120,6 +120,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter")
.Output("key: uint64")
.Output("counter: uint64")
.Attr("Tseed: {int32, int64} = DT_INT64")
.SetIsStateful() // because outputs depend on device
.SetShapeFn([](InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
@ -135,6 +136,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter")
REGISTER_OP("StatelessRandomGetAlg")
.Output("alg: int32")
.SetIsStateful() // because outputs depend on device
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->MakeShape({}));
return Status::OK();

View File

@ -123,8 +123,15 @@ def fold_in(seed, data):
return array_ops.stack([seed1, data])
_get_key_counter_alg = (gen_stateless_random_ops_v2
.stateless_random_get_key_counter_alg)
def _get_key_counter_alg(seed):
if compat.forward_compatible(2021, 2, 2):
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
seed)
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
return key, counter, alg
else:
return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg(
seed)
@tf_export("random.stateless_uniform")
@ -508,10 +515,12 @@ def stateless_random_normal(shape,
shape = tensor_util.shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
# TODO(b/171746875): stateless_random_normal([1024, 32000], dtype='float32')
# OOM on TPU with StatelessRandomNormalV2 because of excessive padding.
# Investigate and switch to StatelessRandomNormalV2.
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
if compat.forward_compatible(2021, 2, 2):
key, counter, alg = _get_key_counter_alg(seed)
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg)
else:
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
result = math_ops.add(rnd * stddev, mean, name=name)
tensor_util.maybe_set_static_shape(result, shape)
return result