Splits StatelessRandomGetKeyCounter into two ops because XLA propagates a compile-time-constant requirement (originating from V2 stateless RNG ops) on output alg
to input seed
, which prevents seed
from depending on a variable. With the split ops, alg
no longer depends on seed
.
The resulting ops are also more conceptually correct, as the output actually depends on the input, and no needless dependencies are introduced. PiperOrigin-RevId: 342688742 Change-Id: I94e6e99dbfd5e614367710e56883be08080a49a2
This commit is contained in:
parent
eebe967339
commit
7c77ee880e
@ -1990,6 +1990,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"StatelessCase",
|
"StatelessCase",
|
||||||
"StatelessIf",
|
"StatelessIf",
|
||||||
"StatelessMultinomial",
|
"StatelessMultinomial",
|
||||||
|
"StatelessRandomGetAlg",
|
||||||
|
"StatelessRandomGetKeyCounter",
|
||||||
"StatelessRandomGetKeyCounterAlg",
|
"StatelessRandomGetKeyCounterAlg",
|
||||||
"StatelessRandomNormal",
|
"StatelessRandomNormal",
|
||||||
"StatelessRandomNormalV2",
|
"StatelessRandomNormalV2",
|
||||||
|
@ -12302,6 +12302,41 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
|
|||||||
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
|
||||||
|
let summary = [{
|
||||||
|
Picks the best counter-based RNG algorithm based on device.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This op picks the best counter-based RNG algorithm based on device.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Int32Tensor:$alg
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomGetKeyCounterOp : TF_Op<"StatelessRandomGetKeyCounter", []> {
|
||||||
|
let summary = [{
|
||||||
|
Scrambles seed into key and counter, using the best algorithm based on device.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This op scrambles a shape-[2] seed into a key and a counter, both needed by counter-based RNG algorithms. The scrambing uses the best algorithm based on device. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_I32OrI64Tensor:$seed
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Uint64Tensor:$key,
|
||||||
|
TF_Uint64Tensor:$counter
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_StatelessRandomGetKeyCounterAlgOp : TF_Op<"StatelessRandomGetKeyCounterAlg", []> {
|
def TF_StatelessRandomGetKeyCounterAlgOp : TF_Op<"StatelessRandomGetKeyCounterAlg", []> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Picks the best algorithm based on device, and scrambles seed into key and counter.
|
Picks the best algorithm based on device, and scrambles seed into key and counter.
|
||||||
|
@ -229,6 +229,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::SqrtGradOp>(),
|
TypeID::get<TF::SqrtGradOp>(),
|
||||||
TypeID::get<TF::SquareOp>(),
|
TypeID::get<TF::SquareOp>(),
|
||||||
TypeID::get<TF::StatelessMultinomialOp>(),
|
TypeID::get<TF::StatelessMultinomialOp>(),
|
||||||
|
TypeID::get<TF::StatelessRandomGetAlgOp>(),
|
||||||
|
TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
|
||||||
TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
|
TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
|
||||||
TypeID::get<TF::StatelessRandomNormalOp>(),
|
TypeID::get<TF::StatelessRandomNormalOp>(),
|
||||||
TypeID::get<TF::StatelessRandomNormalV2Op>(),
|
TypeID::get<TF::StatelessRandomNormalV2Op>(),
|
||||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.compiler.xla import xla
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -30,6 +29,7 @@ from tensorflow.python.kernel_tests.random import util as \
|
|||||||
random_test_util
|
random_test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -51,18 +51,34 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
This test checks that stateless_random_* can be used in forced-compilation
|
This test checks that stateless_random_* can be used in forced-compilation
|
||||||
scenarios (e.g. TPU). The new version of stateless_random_* requires the
|
scenarios (e.g. TPU). The new version of stateless_random_* requires the
|
||||||
intermediate tensor `alg` to be compile-time constant, so we need to check
|
intermediate tensor `alg` to be compile-time constant, so we need to check
|
||||||
that this requirement is met. We use xla.compile instead of tf.function's
|
that this requirement won't prevent `seed` from depending on variables.
|
||||||
jit_compile because the latter doesn't throw an error even if the
|
|
||||||
compile-time-constant constraint is not met.
|
|
||||||
"""
|
"""
|
||||||
if config.list_logical_devices('TPU'):
|
if config.list_logical_devices('TPU'):
|
||||||
self.skipTest('To accommodate OSS, xla.compile support for TPU is not '
|
self.skipTest('To accommodate OSS, experimental_compile support for TPU '
|
||||||
'linked in.')
|
'is not linked in.')
|
||||||
@def_function.function
|
# GPU doesn't support int32 variables, so we use int64.
|
||||||
def f(x):
|
v = variables.Variable([1, 2], dtype=dtypes.int64)
|
||||||
return xla.compile(
|
|
||||||
lambda x: stateless.stateless_random_normal([], seed=x), [x])
|
@def_function.function(experimental_compile=True)
|
||||||
f([1, 2])
|
def f():
|
||||||
|
key, counter = (
|
||||||
|
gen_stateless_random_ops_v2.stateless_random_get_key_counter(
|
||||||
|
seed=math_ops.cast(v.read_value(), dtypes.int32)))
|
||||||
|
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
|
||||||
|
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||||
|
shape=[], key=key, counter=counter, alg=alg)
|
||||||
|
|
||||||
|
f()
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testGetKeyCounterAlg(self):
|
||||||
|
seed = [1, 2]
|
||||||
|
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
|
||||||
|
seed)
|
||||||
|
self.assertAllEqual(key.shape, [1])
|
||||||
|
self.assertAllEqual(counter.shape, [2])
|
||||||
|
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
|
||||||
|
self.assertAllEqual(alg.shape, [])
|
||||||
|
|
||||||
def testLargeNormal(self):
|
def testLargeNormal(self):
|
||||||
"""Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
|
"""Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
|
||||||
@ -74,7 +90,15 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
|
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
|
||||||
alg=alg)
|
alg=alg)
|
||||||
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
|
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
|
||||||
self.assertAllEqual([1024, 32000], y.shape)
|
self.assertAllEqual([1024, 32000], y.shape)
|
||||||
|
key, counter = (gen_stateless_random_ops_v2.
|
||||||
|
stateless_random_get_key_counter(seed_t))
|
||||||
|
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
|
||||||
|
x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||||
|
shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
|
||||||
|
alg=alg)
|
||||||
|
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
|
||||||
|
self.assertAllEqual([1024, 32000], y.shape)
|
||||||
|
|
||||||
def testDeterminism(self):
|
def testDeterminism(self):
|
||||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||||
|
@ -79,20 +79,30 @@ xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
|
|||||||
/*state=*/new_counter};
|
/*state=*/new_counter};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
|
std::tuple<xla::XlaOp, xla::XlaOp> GetKeyCounter(
|
||||||
absl::string_view device_type_string, xla::XlaOp key) {
|
absl::string_view device_type_string, xla::XlaOp key) {
|
||||||
// The Philox algorithm may cause performance regression on other devices.
|
// The Philox algorithm may cause performance regression on other devices.
|
||||||
// Turn on the Philox algorithm for the CPU and GPU backends only.
|
// Turn on the Philox algorithm for the CPU and GPU backends only.
|
||||||
if (device_type_string == DEVICE_GPU_XLA_JIT ||
|
if (device_type_string == DEVICE_GPU_XLA_JIT ||
|
||||||
device_type_string == DEVICE_CPU_XLA_JIT) {
|
device_type_string == DEVICE_CPU_XLA_JIT) {
|
||||||
auto counter_key = xla::ScramblePhiloxKey(key);
|
auto counter_key = xla::ScramblePhiloxKey(key);
|
||||||
return std::make_tuple(counter_key.second, counter_key.first,
|
return std::make_tuple(counter_key.second, counter_key.first);
|
||||||
RNG_ALG_PHILOX);
|
|
||||||
} else {
|
} else {
|
||||||
auto counter_shape =
|
auto counter_shape =
|
||||||
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
||||||
auto counter = xla::Zeros(key.builder(), counter_shape);
|
auto counter = xla::Zeros(key.builder(), counter_shape);
|
||||||
return std::make_tuple(key, counter, RNG_ALG_XLA_DEFAULT);
|
return std::make_tuple(key, counter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Algorithm GetAlg(absl::string_view device_type_string) {
|
||||||
|
// The Philox algorithm may cause performance regression on other devices.
|
||||||
|
// Turn on the Philox algorithm for the CPU and GPU backends only.
|
||||||
|
if (device_type_string == DEVICE_GPU_XLA_JIT ||
|
||||||
|
device_type_string == DEVICE_CPU_XLA_JIT) {
|
||||||
|
return RNG_ALG_PHILOX;
|
||||||
|
} else {
|
||||||
|
return RNG_ALG_XLA_DEFAULT;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -471,10 +481,10 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
|
|||||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||||
ConstantR0WithType(builder, xla::U64, 32));
|
ConstantR0WithType(builder, xla::U64, 32));
|
||||||
auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key);
|
auto key_counter = GetKeyCounter(device_type_string_, key);
|
||||||
key = std::get<0>(key_counter_alg);
|
key = std::get<0>(key_counter);
|
||||||
auto counter = std::get<1>(key_counter_alg);
|
auto counter = std::get<1>(key_counter);
|
||||||
auto alg = std::get<2>(key_counter_alg);
|
auto alg = GetAlg(device_type_string_);
|
||||||
key = xla::Reshape(key, {RNG_KEY_SIZE});
|
key = xla::Reshape(key, {RNG_KEY_SIZE});
|
||||||
ctx->SetOutput(0, key);
|
ctx->SetOutput(0, key);
|
||||||
ctx->SetOutput(1, counter);
|
ctx->SetOutput(1, counter);
|
||||||
@ -489,5 +499,60 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
|
|||||||
|
|
||||||
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
|
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
|
||||||
|
|
||||||
|
class GetKeyCounterOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit GetKeyCounterOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx),
|
||||||
|
device_type_string_(ctx->device_type().type_string()) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
TensorShape seed_shape = ctx->InputShape(0);
|
||||||
|
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
|
||||||
|
errors::InvalidArgument("seed must have shape [2], not ",
|
||||||
|
seed_shape.DebugString()));
|
||||||
|
xla::XlaOp seed = ctx->Input(0);
|
||||||
|
|
||||||
|
xla::XlaBuilder* builder = seed.builder();
|
||||||
|
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||||
|
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||||
|
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||||
|
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||||
|
ConstantR0WithType(builder, xla::U64, 32));
|
||||||
|
auto key_counter = GetKeyCounter(device_type_string_, key);
|
||||||
|
key = std::get<0>(key_counter);
|
||||||
|
auto counter = std::get<1>(key_counter);
|
||||||
|
key = xla::Reshape(key, {RNG_KEY_SIZE});
|
||||||
|
ctx->SetOutput(0, key);
|
||||||
|
ctx->SetOutput(1, counter);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
string device_type_string_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
|
||||||
|
|
||||||
|
class GetAlgOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit GetAlgOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx),
|
||||||
|
device_type_string_(ctx->device_type().type_string()) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
auto alg = GetAlg(device_type_string_);
|
||||||
|
auto builder = ctx->builder();
|
||||||
|
ctx->SetOutput(0, ConstantR0(builder, static_cast<int>(alg)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
string device_type_string_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(GetAlgOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessRandomGetAlg"), GetAlgOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessRandomGetAlg"
|
||||||
|
visibility: HIDDEN
|
||||||
|
out_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm (shape int32[]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Picks the best counter-based RNG algorithm based on device."
|
||||||
|
description: <<END
|
||||||
|
This op picks the best counter-based RNG algorithm based on device.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessRandomGetKeyCounter"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "seed"
|
||||||
|
description: <<END
|
||||||
|
2 seeds (shape [2]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "counter"
|
||||||
|
description: <<END
|
||||||
|
Counter for the counter-based RNG algorithm. Since counter size is algorithm-dependent, this output will be right-padded with zeros to reach shape uint64[2] (the current maximal counter size among algorithms).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Scrambles seed into key and counter, using the best algorithm based on device."
|
||||||
|
description: <<END
|
||||||
|
This op scrambles a shape-[2] seed into a key and a counter, both needed by counter-based RNG algorithms. The scrambing uses the best algorithm based on device. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
|
||||||
|
END
|
||||||
|
}
|
@ -225,6 +225,43 @@ class GetKeyCounterAlgOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class GetKeyCounterOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit GetKeyCounterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& seed_t = ctx->input(0);
|
||||||
|
OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
|
||||||
|
errors::InvalidArgument("seed must have shape [2], not ",
|
||||||
|
seed_t.shape().DebugString()));
|
||||||
|
// Allocate outputs
|
||||||
|
Tensor* key_output;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
|
||||||
|
Tensor* counter_output;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
|
||||||
|
&counter_output));
|
||||||
|
|
||||||
|
random::PhiloxRandom::Key key;
|
||||||
|
random::PhiloxRandom::ResultType counter;
|
||||||
|
OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
|
||||||
|
WriteKeyToMem(key, key_output->flat<uint64>().data());
|
||||||
|
WriteCounterToMem(counter, counter_output->flat<uint64>().data());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class GetAlgOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit GetAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
Tensor* alg_output;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &alg_output));
|
||||||
|
alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#define REGISTER(DEVICE, TYPE) \
|
#define REGISTER(DEVICE, TYPE) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("StatelessRandomUniformV2") \
|
Name("StatelessRandomUniformV2") \
|
||||||
@ -289,14 +326,23 @@ TF_CALL_int64(REGISTER_INT_CPU);
|
|||||||
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
|
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
|
||||||
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
|
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
|
||||||
|
|
||||||
#define REGISTER_GET_KCA(DEVICE) \
|
#define REGISTER_GET_KCA(DEVICE) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
|
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
|
||||||
.Device(DEVICE_##DEVICE) \
|
.Device(DEVICE_##DEVICE) \
|
||||||
.HostMemory("seed") \
|
.HostMemory("seed") \
|
||||||
.HostMemory("key") \
|
.HostMemory("key") \
|
||||||
.HostMemory("counter") \
|
.HostMemory("counter") \
|
||||||
.HostMemory("alg"), \
|
.HostMemory("alg"), \
|
||||||
GetKeyCounterAlgOp)
|
GetKeyCounterAlgOp) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounter") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("seed") \
|
||||||
|
.HostMemory("key") \
|
||||||
|
.HostMemory("counter"), \
|
||||||
|
GetKeyCounterOp) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("StatelessRandomGetAlg").Device(DEVICE_##DEVICE).HostMemory("alg"), \
|
||||||
|
GetAlgOp)
|
||||||
|
|
||||||
REGISTER_GET_KCA(CPU);
|
REGISTER_GET_KCA(CPU);
|
||||||
|
|
||||||
|
@ -115,4 +115,29 @@ REGISTER_OP("StatelessRandomGetKeyCounterAlg")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("StatelessRandomGetKeyCounter")
|
||||||
|
.Input("seed: Tseed")
|
||||||
|
.Output("key: uint64")
|
||||||
|
.Output("counter: uint64")
|
||||||
|
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
// Check seed shape
|
||||||
|
ShapeHandle seed;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
|
||||||
|
DimensionHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
|
||||||
|
|
||||||
|
// Set output shapes
|
||||||
|
c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
|
||||||
|
c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("StatelessRandomGetAlg")
|
||||||
|
.Output("alg: int32")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
c->set_output(0, c->MakeShape({}));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||||
@ -412,6 +413,16 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.skipTest('Lacking XLA kernel')
|
self.skipTest('Lacking XLA kernel')
|
||||||
self._test_determinism(case, seed_type)
|
self._test_determinism(case, seed_type)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testGetKeyCounterAlg(self):
|
||||||
|
seed = [1, 2]
|
||||||
|
key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
|
||||||
|
seed)
|
||||||
|
self.assertAllEqual(key.shape, [1])
|
||||||
|
self.assertAllEqual(counter.shape, [2])
|
||||||
|
alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
|
||||||
|
self.assertAllEqual(alg.shape, [])
|
||||||
|
|
||||||
def assertDTypeEqual(self, a, b):
|
def assertDTypeEqual(self, a, b):
|
||||||
self.assertEqual(dtypes.as_dtype(a), dtypes.as_dtype(b))
|
self.assertEqual(dtypes.as_dtype(a), dtypes.as_dtype(b))
|
||||||
|
|
||||||
|
@ -4568,6 +4568,14 @@ tf_module {
|
|||||||
name: "StatelessRandomGammaV2"
|
name: "StatelessRandomGammaV2"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomGetAlg"
|
||||||
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomGetKeyCounter"
|
||||||
|
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomGetKeyCounterAlg"
|
name: "StatelessRandomGetKeyCounterAlg"
|
||||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -4568,6 +4568,14 @@ tf_module {
|
|||||||
name: "StatelessRandomGammaV2"
|
name: "StatelessRandomGammaV2"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomGetAlg"
|
||||||
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomGetKeyCounter"
|
||||||
|
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomGetKeyCounterAlg"
|
name: "StatelessRandomGetKeyCounterAlg"
|
||||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user