Replaces uses of op StatelessRandomGetKeyCounterAlg with StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that seed
is no longer required by XLA to be a compile-time constant.
Adds SetIsStateful to StatelessRandomGetKeyCounter and StatelessRandomGetAlg, so that they won't be constant-folded away. Re-introduces StatelessRandomNormalV2 in tf.random.stateless_normal, since the OOM problem has been fixed. PiperOrigin-RevId: 351442194 Change-Id: I5bc20df735e88e3d1cbb7fc7e69b070386af27d6
This commit is contained in:
parent
88e57640cd
commit
5a7e5687c6
@ -120,6 +120,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter")
|
||||
.Output("key: uint64")
|
||||
.Output("counter: uint64")
|
||||
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||
.SetIsStateful() // because outputs depend on device
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Check seed shape
|
||||
ShapeHandle seed;
|
||||
@ -135,6 +136,7 @@ REGISTER_OP("StatelessRandomGetKeyCounter")
|
||||
|
||||
REGISTER_OP("StatelessRandomGetAlg")
|
||||
.Output("alg: int32")
|
||||
.SetIsStateful() // because outputs depend on device
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->MakeShape({}));
|
||||
return Status::OK();
|
||||
|
@ -123,8 +123,15 @@ def fold_in(seed, data):
|
||||
return array_ops.stack([seed1, data])
|
||||
|
||||
|
||||
_get_key_counter_alg = (gen_stateless_random_ops_v2
|
||||
.stateless_random_get_key_counter_alg)
|
||||
def _get_key_counter_alg(seed):
|
||||
if compat.forward_compatible(2021, 2, 2):
|
||||
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
|
||||
seed)
|
||||
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
|
||||
return key, counter, alg
|
||||
else:
|
||||
return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg(
|
||||
seed)
|
||||
|
||||
|
||||
@tf_export("random.stateless_uniform")
|
||||
@ -508,10 +515,12 @@ def stateless_random_normal(shape,
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
# TODO(b/171746875): stateless_random_normal([1024, 32000], dtype='float32')
|
||||
# OOM on TPU with StatelessRandomNormalV2 because of excessive padding.
|
||||
# Investigate and switch to StatelessRandomNormalV2.
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
if compat.forward_compatible(2021, 2, 2):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
|
Loading…
x
Reference in New Issue
Block a user