From 5a7e5687c649fa0ada7888eb92be5bde2f246772 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Tue, 12 Jan 2021 13:53:34 -0800 Subject: [PATCH] Replaces uses of op StatelessRandomGetKeyCounterAlg with StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that `seed` is no longer required by XLA to be a compile-time constant. Adds SetIsStateful to StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that they won't be constant-folded away. Re-introduces StatelessRandomNormalV2 in tf.random.stateless_normal, since the OOM problem has been fixed. PiperOrigin-RevId: 351442194 Change-Id: I5bc20df735e88e3d1cbb7fc7e69b070386af27d6 --- .../core/ops/stateless_random_ops_v2.cc | 2 ++ tensorflow/python/ops/stateless_random_ops.py | 21 +++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/ops/stateless_random_ops_v2.cc b/tensorflow/core/ops/stateless_random_ops_v2.cc index b905680ff5b..64751541510 100644 --- a/tensorflow/core/ops/stateless_random_ops_v2.cc +++ b/tensorflow/core/ops/stateless_random_ops_v2.cc @@ -120,6 +120,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter") .Output("key: uint64") .Output("counter: uint64") .Attr("Tseed: {int32, int64} = DT_INT64") + .SetIsStateful() // because outputs depend on device .SetShapeFn([](InferenceContext* c) { // Check seed shape ShapeHandle seed; @@ -135,6 +136,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter") REGISTER_OP("StatelessRandomGetAlg") .Output("alg: int32") + .SetIsStateful() // because outputs depend on device .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->MakeShape({})); return Status::OK(); diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py index ed0b66443cf..f26d4b679a7 100644 --- a/tensorflow/python/ops/stateless_random_ops.py +++ b/tensorflow/python/ops/stateless_random_ops.py @@ -123,8 +123,15 @@ def fold_in(seed, data): return array_ops.stack([seed1, data]) -_get_key_counter_alg = (gen_stateless_random_ops_v2 - .stateless_random_get_key_counter_alg) +def _get_key_counter_alg(seed): + if compat.forward_compatible(2021, 2, 2): + key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( + seed) + alg = gen_stateless_random_ops_v2.stateless_random_get_alg() + return key, counter, alg + else: + return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg( + seed) @tf_export("random.stateless_uniform") @@ -508,10 +515,12 @@ def stateless_random_normal(shape, shape = tensor_util.shape_tensor(shape) mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") - # TODO(b/171746875): stateless_random_normal([1024, 32000], dtype='float32') - # OOM on TPU with StatelessRandomNormalV2 because of excessive padding. - # Investigate and switch to StatelessRandomNormalV2. - rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype) + if compat.forward_compatible(2021, 2, 2): + key, counter, alg = _get_key_counter_alg(seed) + rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2( + shape, key=key, counter=counter, dtype=dtype, alg=alg) + else: + rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype) result = math_ops.add(rnd * stddev, mean, name=name) tensor_util.maybe_set_static_shape(result, shape) return result