Replaces uses of op StatelessRandomGetKeyCounterAlg with StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that seed is no longer required by XLA to be a compile-time constant.

Adds SetIsStateful to StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that they won't be constant-folded away.

Re-introduces StatelessRandomNormalV2 in tf.random.stateless_normal, since the OOM problem has been fixed.

PiperOrigin-RevId: 351442194
Change-Id: I5bc20df735e88e3d1cbb7fc7e69b070386af27d6
This commit is contained in:
Peng Wang 2021-01-12 13:53:34 -08:00 committed by TensorFlower Gardener
parent 88e57640cd
commit 5a7e5687c6
2 changed files with 17 additions and 6 deletions

View File

@ -120,6 +120,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter")
.Output("key: uint64") .Output("key: uint64")
.Output("counter: uint64") .Output("counter: uint64")
.Attr("Tseed: {int32, int64} = DT_INT64") .Attr("Tseed: {int32, int64} = DT_INT64")
.SetIsStateful() // because outputs depend on device
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
// Check seed shape // Check seed shape
ShapeHandle seed; ShapeHandle seed;
@ -135,6 +136,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter")
REGISTER_OP("StatelessRandomGetAlg") REGISTER_OP("StatelessRandomGetAlg")
.Output("alg: int32") .Output("alg: int32")
.SetIsStateful() // because outputs depend on device
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->MakeShape({})); c->set_output(0, c->MakeShape({}));
return Status::OK(); return Status::OK();

View File

@ -123,8 +123,15 @@ def fold_in(seed, data):
return array_ops.stack([seed1, data]) return array_ops.stack([seed1, data])
_get_key_counter_alg = (gen_stateless_random_ops_v2 def _get_key_counter_alg(seed):
.stateless_random_get_key_counter_alg) if compat.forward_compatible(2021, 2, 2):
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
seed)
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
return key, counter, alg
else:
return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg(
seed)
@tf_export("random.stateless_uniform") @tf_export("random.stateless_uniform")
@ -508,10 +515,12 @@ def stateless_random_normal(shape,
shape = tensor_util.shape_tensor(shape) shape = tensor_util.shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
# TODO(b/171746875): stateless_random_normal([1024, 32000], dtype='float32') if compat.forward_compatible(2021, 2, 2):
# OOM on TPU with StatelessRandomNormalV2 because of excessive padding. key, counter, alg = _get_key_counter_alg(seed)
# Investigate and switch to StatelessRandomNormalV2. rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype) shape, key=key, counter=counter, dtype=dtype, alg=alg)
else:
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
result = math_ops.add(rnd * stddev, mean, name=name) result = math_ops.add(rnd * stddev, mean, name=name)
tensor_util.maybe_set_static_shape(result, shape) tensor_util.maybe_set_static_shape(result, shape)
return result return result