Splits StatelessRandomGetKeyCounter into two ops because XLA propagates a compile-time-constant requirement (originating from V2 stateless RNG ops) on output alg to input seed, which prevents seed from depending on a variable. With the split ops, alg no longer depends on seed.

The resulting ops are also more conceptually correct, as the output actually depends on the input, and no needless dependencies are introduced.

PiperOrigin-RevId: 342688742
Change-Id: I94e6e99dbfd5e614367710e56883be08080a49a2
This commit is contained in:
Peng Wang 2020-11-16 12:02:56 -08:00 committed by TensorFlower Gardener
parent eebe967339
commit 7c77ee880e
12 changed files with 294 additions and 28 deletions

View File

@ -1990,6 +1990,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"StatelessCase",
"StatelessIf",
"StatelessMultinomial",
"StatelessRandomGetAlg",
"StatelessRandomGetKeyCounter",
"StatelessRandomGetKeyCounterAlg",
"StatelessRandomNormal",
"StatelessRandomNormalV2",

View File

@ -12302,6 +12302,41 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
}
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
let summary = [{
Picks the best counter-based RNG algorithm based on device.
}];
let description = [{
This op picks the best counter-based RNG algorithm based on device.
}];
let results = (outs
TF_Int32Tensor:$alg
);
}
def TF_StatelessRandomGetKeyCounterOp : TF_Op<"StatelessRandomGetKeyCounter", []> {
let summary = [{
Scrambles seed into key and counter, using the best algorithm based on device.
}];
let description = [{
This op scrambles a shape-[2] seed into a key and a counter, both needed by counter-based RNG algorithms. The scrambing uses the best algorithm based on device. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
}];
let arguments = (ins
TF_I32OrI64Tensor:$seed
);
let results = (outs
TF_Uint64Tensor:$key,
TF_Uint64Tensor:$counter
);
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<0>;
}
def TF_StatelessRandomGetKeyCounterAlgOp : TF_Op<"StatelessRandomGetKeyCounterAlg", []> {
let summary = [{
Picks the best algorithm based on device, and scrambles seed into key and counter.

View File

@ -229,6 +229,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::SqrtGradOp>(),
TypeID::get<TF::SquareOp>(),
TypeID::get<TF::StatelessMultinomialOp>(),
TypeID::get<TF::StatelessRandomGetAlgOp>(),
TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
TypeID::get<TF::StatelessRandomNormalOp>(),
TypeID::get<TF::StatelessRandomNormalV2Op>(),

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
@ -30,6 +29,7 @@ from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops as stateless
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -51,18 +51,34 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
This test checks that stateless_random_* can be used in forced-compilation
scenarios (e.g. TPU). The new version of stateless_random_* requires the
intermediate tensor `alg` to be compile-time constant, so we need to check
that this requirement is met. We use xla.compile instead of tf.function's
jit_compile because the latter doesn't throw an error even if the
compile-time-constant constraint is not met.
that this requirement won't prevent `seed` from depending on variables.
"""
if config.list_logical_devices('TPU'):
self.skipTest('To accommodate OSS, xla.compile support for TPU is not '
'linked in.')
@def_function.function
def f(x):
return xla.compile(
lambda x: stateless.stateless_random_normal([], seed=x), [x])
f([1, 2])
self.skipTest('To accommodate OSS, experimental_compile support for TPU '
'is not linked in.')
# GPU doesn't support int32 variables, so we use int64.
v = variables.Variable([1, 2], dtype=dtypes.int64)
@def_function.function(experimental_compile=True)
def f():
key, counter = (
gen_stateless_random_ops_v2.stateless_random_get_key_counter(
seed=math_ops.cast(v.read_value(), dtypes.int32)))
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape=[], key=key, counter=counter, alg=alg)
f()
@test_util.run_v2_only
def testGetKeyCounterAlg(self):
seed = [1, 2]
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
seed)
self.assertAllEqual(key.shape, [1])
self.assertAllEqual(counter.shape, [2])
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
self.assertAllEqual(alg.shape, [])
def testLargeNormal(self):
"""Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
@ -74,7 +90,15 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
alg=alg)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
self.assertAllEqual([1024, 32000], y.shape)
self.assertAllEqual([1024, 32000], y.shape)
key, counter = (gen_stateless_random_ops_v2.
stateless_random_get_key_counter(seed_t))
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
alg=alg)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
self.assertAllEqual([1024, 32000], y.shape)
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)

View File

@ -79,20 +79,30 @@ xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
/*state=*/new_counter};
}
std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
std::tuple<xla::XlaOp, xla::XlaOp> GetKeyCounter(
absl::string_view device_type_string, xla::XlaOp key) {
// The Philox algorithm may cause performance regression on other devices.
// Turn on the Philox algorithm for the CPU and GPU backends only.
if (device_type_string == DEVICE_GPU_XLA_JIT ||
device_type_string == DEVICE_CPU_XLA_JIT) {
auto counter_key = xla::ScramblePhiloxKey(key);
return std::make_tuple(counter_key.second, counter_key.first,
RNG_ALG_PHILOX);
return std::make_tuple(counter_key.second, counter_key.first);
} else {
auto counter_shape =
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
auto counter = xla::Zeros(key.builder(), counter_shape);
return std::make_tuple(key, counter, RNG_ALG_XLA_DEFAULT);
return std::make_tuple(key, counter);
}
}
Algorithm GetAlg(absl::string_view device_type_string) {
// The Philox algorithm may cause performance regression on other devices.
// Turn on the Philox algorithm for the CPU and GPU backends only.
if (device_type_string == DEVICE_GPU_XLA_JIT ||
device_type_string == DEVICE_CPU_XLA_JIT) {
return RNG_ALG_PHILOX;
} else {
return RNG_ALG_XLA_DEFAULT;
}
}
@ -471,10 +481,10 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key);
key = std::get<0>(key_counter_alg);
auto counter = std::get<1>(key_counter_alg);
auto alg = std::get<2>(key_counter_alg);
auto key_counter = GetKeyCounter(device_type_string_, key);
key = std::get<0>(key_counter);
auto counter = std::get<1>(key_counter);
auto alg = GetAlg(device_type_string_);
key = xla::Reshape(key, {RNG_KEY_SIZE});
ctx->SetOutput(0, key);
ctx->SetOutput(1, counter);
@ -489,5 +499,60 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
class GetKeyCounterOp : public XlaOpKernel {
public:
explicit GetKeyCounterOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape seed_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
xla::XlaOp seed = ctx->Input(0);
xla::XlaBuilder* builder = seed.builder();
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
auto key_counter = GetKeyCounter(device_type_string_, key);
key = std::get<0>(key_counter);
auto counter = std::get<1>(key_counter);
key = xla::Reshape(key, {RNG_KEY_SIZE});
ctx->SetOutput(0, key);
ctx->SetOutput(1, counter);
}
private:
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
};
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
class GetAlgOp : public XlaOpKernel {
public:
explicit GetAlgOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {}
void Compile(XlaOpKernelContext* ctx) override {
auto alg = GetAlg(device_type_string_);
auto builder = ctx->builder();
ctx->SetOutput(0, ConstantR0(builder, static_cast<int>(alg)));
}
private:
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(GetAlgOp);
};
REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,14 @@
op {
graph_op_name: "StatelessRandomGetAlg"
visibility: HIDDEN
out_arg {
name: "alg"
description: <<END
The RNG algorithm (shape int32[]).
END
}
summary: "Picks the best counter-based RNG algorithm based on device."
description: <<END
This op picks the best counter-based RNG algorithm based on device.
END
}

View File

@ -0,0 +1,26 @@
op {
graph_op_name: "StatelessRandomGetKeyCounter"
visibility: HIDDEN
in_arg {
name: "seed"
description: <<END
2 seeds (shape [2]).
END
}
out_arg {
name: "key"
description: <<END
Key for the counter-based RNG algorithm (shape uint64[1]).
END
}
out_arg {
name: "counter"
description: <<END
Counter for the counter-based RNG algorithm. Since counter size is algorithm-dependent, this output will be right-padded with zeros to reach shape uint64[2] (the current maximal counter size among algorithms).
END
}
summary: "Scrambles seed into key and counter, using the best algorithm based on device."
description: <<END
This op scrambles a shape-[2] seed into a key and a counter, both needed by counter-based RNG algorithms. The scrambing uses the best algorithm based on device. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
END
}

View File

@ -225,6 +225,43 @@ class GetKeyCounterAlgOp : public OpKernel {
}
};
class GetKeyCounterOp : public OpKernel {
public:
explicit GetKeyCounterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& seed_t = ctx->input(0);
OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_t.shape().DebugString()));
// Allocate outputs
Tensor* key_output;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
Tensor* counter_output;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
&counter_output));
random::PhiloxRandom::Key key;
random::PhiloxRandom::ResultType counter;
OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
WriteKeyToMem(key, key_output->flat<uint64>().data());
WriteCounterToMem(counter, counter_output->flat<uint64>().data());
}
};
class GetAlgOp : public OpKernel {
public:
explicit GetAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
Tensor* alg_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &alg_output));
alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
}
};
#define REGISTER(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniformV2") \
@ -289,14 +326,23 @@ TF_CALL_int64(REGISTER_INT_CPU);
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
#define REGISTER_GET_KCA(DEVICE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
.Device(DEVICE_##DEVICE) \
.HostMemory("seed") \
.HostMemory("key") \
.HostMemory("counter") \
.HostMemory("alg"), \
GetKeyCounterAlgOp)
#define REGISTER_GET_KCA(DEVICE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
.Device(DEVICE_##DEVICE) \
.HostMemory("seed") \
.HostMemory("key") \
.HostMemory("counter") \
.HostMemory("alg"), \
GetKeyCounterAlgOp) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounter") \
.Device(DEVICE_##DEVICE) \
.HostMemory("seed") \
.HostMemory("key") \
.HostMemory("counter"), \
GetKeyCounterOp) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomGetAlg").Device(DEVICE_##DEVICE).HostMemory("alg"), \
GetAlgOp)
REGISTER_GET_KCA(CPU);

View File

@ -115,4 +115,29 @@ REGISTER_OP("StatelessRandomGetKeyCounterAlg")
return Status::OK();
});
REGISTER_OP("StatelessRandomGetKeyCounter")
.Input("seed: Tseed")
.Output("key: uint64")
.Output("counter: uint64")
.Attr("Tseed: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shapes
c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
return Status::OK();
});
REGISTER_OP("StatelessRandomGetAlg")
.Output("alg: int32")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->MakeShape({}));
return Status::OK();
});
} // namespace tensorflow

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops as stateless
@ -412,6 +413,16 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
self.skipTest('Lacking XLA kernel')
self._test_determinism(case, seed_type)
@test_util.run_v2_only
def testGetKeyCounterAlg(self):
seed = [1, 2]
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
seed)
self.assertAllEqual(key.shape, [1])
self.assertAllEqual(counter.shape, [2])
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
self.assertAllEqual(alg.shape, [])
def assertDTypeEqual(self, a, b):
self.assertEqual(dtypes.as_dtype(a), dtypes.as_dtype(b))

View File

@ -4568,6 +4568,14 @@ tf_module {
name: "StatelessRandomGammaV2"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetAlg"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetKeyCounter"
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetKeyCounterAlg"
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -4568,6 +4568,14 @@ tf_module {
name: "StatelessRandomGammaV2"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetAlg"
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetKeyCounter"
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetKeyCounterAlg"
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "