Adds a new set of stateless RNG ops, and rebases existing stateless-RNG Python API and tf.random.Generator onto them.
The new ops have three differences from existing (old) stateless RNG ops: * They take in `key` and `counter` instead of `seed` (thus no seed scrambling). * They take in an `alg` argument to control which RNG algorithm to use, unlike the old ones which pick algorithm based on device. * They don't have `HostMemory` constraints on `key` and `counter` (the old ones have such constraints on `seed`). Two new ops `StatelessRandomGetKeyCounterAlg` and `RngReadAndSkip` are also added to bridge the gaps between the new stateless ops and the Python API for stateless RNGs and tf.random.Generator, so that the Python API's behavior doesn't change. Also adds set_soft_device_placement(False) to tests to control which kernels are tested. PiperOrigin-RevId: 332346574 Change-Id: Ibe0e41cccce82e50b5581ea6298218efb163157a
This commit is contained in:
parent
2fc3df3e6f
commit
e922e10a0f
@ -1995,6 +1995,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"ResourceScatterNdUpdate",
|
||||
"ResourceScatterSub",
|
||||
"ResourceScatterUpdate",
|
||||
"RngReadAndSkip",
|
||||
"RngSkip",
|
||||
"Roll",
|
||||
"ScatterNd",
|
||||
"SelfAdjointEigV2",
|
||||
@ -2017,11 +2019,17 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"StatelessCase",
|
||||
"StatelessIf",
|
||||
"StatelessMultinomial",
|
||||
"StatelessRandomGetKeyCounterAlg",
|
||||
"StatelessRandomNormal",
|
||||
"StatelessRandomNormalV2",
|
||||
"StatelessRandomUniform",
|
||||
"StatelessRandomUniformV2",
|
||||
"StatelessRandomUniformInt",
|
||||
"StatelessRandomUniformIntV2",
|
||||
"StatelessRandomUniformFullInt",
|
||||
"StatelessRandomUniformFullIntV2",
|
||||
"StatelessTruncatedNormal",
|
||||
"StatelessTruncatedNormalV2",
|
||||
"StatelessWhile",
|
||||
"Svd",
|
||||
"SymbolicGradient",
|
||||
|
@ -25,7 +25,9 @@ import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
@ -156,6 +158,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testNewStateThreeFry(self):
|
||||
"""Tests that the new state is correct (for ThreeFry).
|
||||
"""
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
self.skipTest("The expected values in this test is inconsistent with "
|
||||
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
|
||||
"new states for the new version.")
|
||||
with ops.device(xla_device_name()):
|
||||
counter = 57
|
||||
key = 0x1234
|
||||
@ -171,6 +177,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testNewStatePhilox(self):
|
||||
"""Tests that the new state is correct (for Philox).
|
||||
"""
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
self.skipTest("The expected values in this test is inconsistent with "
|
||||
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
|
||||
"new states for the new version.")
|
||||
with ops.device(xla_device_name()):
|
||||
counter_low = 57
|
||||
counter_high = 283
|
||||
@ -204,13 +214,39 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
"""Tests that XLA and CPU kernels generate the same integers."""
|
||||
seed = 1234
|
||||
shape = [315, 49]
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
with ops.device(xla_device_name()):
|
||||
xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
self.assertAllEqual(cpu, xla)
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu_gen = random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
with ops.device(xla_device_name()):
|
||||
xla_gen = random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
# Repeat multiple times to make sure that the state after
|
||||
# number-generation are the same between CPU and XLA.
|
||||
for _ in range(5):
|
||||
with ops.device("/device:CPU:0"):
|
||||
# Test both number-generation and skip
|
||||
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
cpu_gen.skip(100)
|
||||
with ops.device(xla_device_name()):
|
||||
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
xla_gen.skip(100)
|
||||
self.assertAllEqual(cpu, xla)
|
||||
self.assertAllEqual(cpu_gen.state, xla_gen.state)
|
||||
else:
|
||||
# The old version doesn't guarantee that CPU and XLA are in the same state
|
||||
# after number-generation, which is a bug.
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = (
|
||||
random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
|
||||
shape=shape, dtype=dtype))
|
||||
with ops.device(xla_device_name()):
|
||||
xla = (
|
||||
random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
|
||||
shape=shape, dtype=dtype))
|
||||
self.assertAllEqual(cpu, xla)
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
@ -364,4 +400,5 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
config.set_soft_device_placement(False)
|
||||
test.main()
|
||||
|
@ -21,7 +21,11 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests.random import util as \
|
||||
random_test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -39,6 +43,26 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
allowed_types.update({dtypes.int32, dtypes.int64})
|
||||
return self.all_tf_types & allowed_types
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testForcedCompile(self):
|
||||
"""Tests whole-function forced-compilation.
|
||||
|
||||
This test checks that stateless_random_* can be used in forced-compilation
|
||||
scenarios (e.g. TPU). The new version of stateless_random_* requires the
|
||||
intermediate tensor `alg` to be compile-time constant, so we need to check
|
||||
that this requirement is met. We use xla.compile instead of tf.function's
|
||||
experimental_compile because the latter doesn't throw an error even if the
|
||||
compile-time-constant constraint is not met.
|
||||
"""
|
||||
if config.list_logical_devices('TPU'):
|
||||
self.skipTest('To accommodate OSS, xla.compile support for TPU is not '
|
||||
'linked in.')
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return xla.compile(
|
||||
lambda x: stateless.stateless_random_normal([], seed=x), [x])
|
||||
f([1, 2])
|
||||
|
||||
def testDeterminism(self):
|
||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||
with self.session(), self.test_scope():
|
||||
@ -138,7 +162,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
||||
|
||||
def _benchmarkUniform(self, name, dtype, use_xla_jit):
|
||||
|
||||
def BuilderFn():
|
||||
def builder_fn():
|
||||
shape = (10, 1000, 1000)
|
||||
seed_var = variables.Variable((312, 456),
|
||||
dtype=dtypes.int32,
|
||||
@ -147,7 +171,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
||||
shape, seed=seed_var, dtype=dtype)
|
||||
return '%s.shape%s' % (name, shape), [random_t]
|
||||
|
||||
xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu')
|
||||
xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu')
|
||||
|
||||
def benchmarkUniformF32(self):
|
||||
self._benchmarkUniform(
|
||||
@ -167,4 +191,5 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.set_soft_device_placement(False)
|
||||
test.main()
|
||||
|
@ -108,6 +108,7 @@ tf_kernel_library(
|
||||
"stack_ops.cc",
|
||||
"stateful_random_ops.cc",
|
||||
"stateless_random_ops.cc",
|
||||
"stateless_random_ops_v2.cc",
|
||||
"strided_slice_op.cc",
|
||||
"tensor_array_ops.cc",
|
||||
"tensor_list_ops.cc",
|
||||
@ -187,6 +188,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:stateful_random_ops_header",
|
||||
"//tensorflow/core/kernels:stateless_random_ops_v2_header",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
@ -180,7 +181,7 @@ Status CompileImpl(
|
||||
}
|
||||
xla::Literal alg_literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||
auto alg = alg_literal.Get<Algorithm>({});
|
||||
Algorithm alg = Algorithm(alg_literal.Get<int>({}));
|
||||
if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
|
||||
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
|
||||
}
|
||||
@ -407,5 +408,80 @@ REGISTER_XLA_OP(Name("StatefulUniformFullInt")
|
||||
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
||||
StatefulUniformFullIntOp);
|
||||
|
||||
xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
|
||||
xla::XlaOp delta) {
|
||||
// Multiplying 256 to be consistent with the CPU/GPU kernels
|
||||
delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return xla::PhiloxIncreaseCounter(counter, delta);
|
||||
} else {
|
||||
return counter + delta;
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp PadRight(xla::XlaOp a, int n) {
|
||||
return xla::Pad(a, xla::ScalarLike(a, 0),
|
||||
xla::MakeEdgePaddingConfig({{0, n}}));
|
||||
}
|
||||
|
||||
template <typename AlgEnumType = int64, bool read_old_value = false>
|
||||
class RngSkipOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const int state_input_idx = 0;
|
||||
const int alg_input_idx = 1;
|
||||
const int delta_input_idx = 2;
|
||||
xla::XlaOp var;
|
||||
TensorShape var_shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
|
||||
&var_shape, &var));
|
||||
xla::Literal alg_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||
Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
|
||||
OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
|
||||
if (read_old_value) {
|
||||
auto counter_size = GetCounterSize(alg);
|
||||
xla::XlaOp output = var;
|
||||
if (RNG_MAX_COUNTER_SIZE > counter_size) {
|
||||
// Because the size of `var` depends on the algorithm while we want the
|
||||
// output to have a fixed size (to help shape inference), we fix the
|
||||
// output size to be the maximal state size among algorithms, and right-
|
||||
// pad it with zeros if var's size is smaller than that.
|
||||
output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
|
||||
}
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
xla::XlaOp counter;
|
||||
xla::XlaOp key;
|
||||
std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
|
||||
xla::XlaOp delta = ctx->Input(delta_input_idx);
|
||||
delta = BitcastConvertType(delta, xla::U64);
|
||||
auto new_counter = IncreaseCounter(alg, counter, delta);
|
||||
var = StateAndKeyToVariable(alg, new_counter, key);
|
||||
xla::PrimitiveType state_element_type;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||
var = BitcastConvertType(var, state_element_type);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
|
||||
RngSkipOp<>);
|
||||
|
||||
using RngReadAndSkipOp = RngSkipOp<int32, true>;
|
||||
|
||||
REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
|
||||
RngReadAndSkipOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -111,6 +111,8 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
||||
xla::XlaOp seeds,
|
||||
const xla::Shape& shape) {
|
||||
@ -140,8 +142,6 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||
|
485
tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc
Normal file
485
tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc
Normal file
@ -0,0 +1,485 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return xla::RandomAlgorithm::RNG_PHILOX;
|
||||
}
|
||||
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
||||
}
|
||||
|
||||
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
|
||||
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
|
||||
return RNG_ALG_PHILOX;
|
||||
}
|
||||
return RNG_ALG_THREEFRY;
|
||||
}
|
||||
|
||||
xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
|
||||
Algorithm alg_ = RandomAlgorithmToAlgorithm(alg);
|
||||
return xla::Slice(state, {RNG_KEY_SIZE},
|
||||
{RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
|
||||
}
|
||||
|
||||
xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
|
||||
xla::XlaOp counter, const xla::Shape& shape) {
|
||||
key = BitcastConvertType(key, xla::U64);
|
||||
counter = BitcastConvertType(counter, xla::U64);
|
||||
xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
|
||||
xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
|
||||
auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
|
||||
new_counter = BitcastConvertType(new_counter, xla::S64);
|
||||
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
|
||||
/*state=*/new_counter};
|
||||
}
|
||||
|
||||
std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
|
||||
absl::string_view device_type_string, xla::XlaOp key) {
|
||||
// The Philox algorithm may cause performance regression on other devices.
|
||||
// Turn on the Philox algorithm for the CPU and GPU backends only.
|
||||
if (device_type_string == DEVICE_GPU_XLA_JIT ||
|
||||
device_type_string == DEVICE_CPU_XLA_JIT) {
|
||||
auto counter_key = xla::ScramblePhiloxKey(key);
|
||||
return std::make_tuple(counter_key.second, counter_key.first,
|
||||
RNG_ALG_PHILOX);
|
||||
} else {
|
||||
auto counter_shape =
|
||||
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
||||
auto counter = xla::Zeros(key.builder(), counter_shape);
|
||||
return std::make_tuple(key, counter, RNG_ALG_THREEFRY);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
|
||||
xla::XlaOp key, xla::XlaOp counter,
|
||||
const xla::Shape& shape, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::XlaBuilder* builder = key.builder();
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
using std::placeholders::_3;
|
||||
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
|
||||
switch (type) {
|
||||
case xla::F32:
|
||||
case xla::F64:
|
||||
return xla::UniformFloatingPointDistribution(key, counter, generator,
|
||||
minval, maxval, shape);
|
||||
case xla::S32:
|
||||
case xla::S64:
|
||||
case xla::U32:
|
||||
case xla::U64:
|
||||
return UniformIntDistribution(key, counter, generator, minval, maxval,
|
||||
shape);
|
||||
break;
|
||||
default:
|
||||
return {builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, S32, S64, U32 and U64 are not "
|
||||
"implemented by "
|
||||
"StatelessRngUniformV2; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter};
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
|
||||
xla::XlaOp key, xla::XlaOp counter,
|
||||
const xla::Shape& shape) {
|
||||
xla::XlaBuilder* builder = key.builder();
|
||||
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
xla::RngOutput output = BitGenerator(alg, key, counter, shape);
|
||||
switch (type) {
|
||||
case xla::U32:
|
||||
case xla::U64:
|
||||
return output;
|
||||
case xla::S32:
|
||||
case xla::S64:
|
||||
return xla::RngOutput{BitcastConvertType(output.value, type),
|
||||
output.state};
|
||||
default:
|
||||
return {
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||
"StatelessRngUniformFullInt; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
output.state};
|
||||
}
|
||||
}
|
||||
|
||||
Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx,
|
||||
xla::RandomAlgorithm* alg) {
|
||||
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||
if (alg_shape.dims() != 0) {
|
||||
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
alg_shape.DebugString());
|
||||
}
|
||||
xla::Literal alg_literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||
auto alg_ = Algorithm(alg_literal.Get<int>({}));
|
||||
*alg = AlgorithmToRandomAlgorithm(alg_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
|
||||
TensorShape const& counter_shape,
|
||||
xla::XlaOp counter) {
|
||||
auto input_counter_size = counter_shape.dim_size(0);
|
||||
auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg));
|
||||
if (input_counter_size > real_counter_size) {
|
||||
counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
|
||||
}
|
||||
return counter;
|
||||
}
|
||||
|
||||
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||
xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
|
||||
auto result = StatelessRngUniformV2(
|
||||
alg, key, counter, xla_shape,
|
||||
xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
|
||||
xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
|
||||
auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
|
||||
ctx->SetOutput(0, uniform);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||
StatelessRandomUniformOp);
|
||||
|
||||
class StatelessRandomUniformIntOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
const int minval_input_idx = 4;
|
||||
const int maxval_input_idx = 5;
|
||||
TensorShape minval_shape = ctx->InputShape(minval_input_idx);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
|
||||
errors::InvalidArgument("minval must be scalar, got shape ",
|
||||
minval_shape.DebugString()));
|
||||
TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
|
||||
errors::InvalidArgument("maxval must be scalar, got shape ",
|
||||
maxval_shape.DebugString()));
|
||||
|
||||
xla::XlaOp minval = ctx->Input(minval_input_idx);
|
||||
xla::XlaOp maxval = ctx->Input(maxval_input_idx);
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result =
|
||||
StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
|
||||
ctx->SetOutput(0, result.value);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
|
||||
StatelessRandomUniformIntOp);
|
||||
|
||||
class StatelessRandomUniformFullIntOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
|
||||
ctx->SetOutput(0, result.value);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
|
||||
StatelessRandomUniformFullIntOp);
|
||||
|
||||
class StatelessRandomNormalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
using std::placeholders::_3;
|
||||
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
|
||||
xla_shape);
|
||||
auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
|
||||
ctx->SetOutput(0, normal);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||
StatelessRandomNormalOp);
|
||||
|
||||
class StatelessTruncatedNormalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
|
||||
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result = StatelessRngUniformV2(
|
||||
alg, key, counter, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(result.value);
|
||||
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||
ctx->SetOutput(0, truncated_normal);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||
StatelessTruncatedNormalOp);
|
||||
|
||||
class GetKeyCounterAlgOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx),
|
||||
device_type_string_(ctx->device_type().type_string()) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape seed_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_shape.DebugString()));
|
||||
xla::XlaOp seed = ctx->Input(0);
|
||||
|
||||
xla::XlaBuilder* builder = seed.builder();
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key);
|
||||
key = std::get<0>(key_counter_alg);
|
||||
auto counter = std::get<1>(key_counter_alg);
|
||||
auto alg = std::get<2>(key_counter_alg);
|
||||
key = xla::Reshape(key, {RNG_KEY_SIZE});
|
||||
ctx->SetOutput(0, key);
|
||||
ctx->SetOutput(1, counter);
|
||||
ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
|
||||
}
|
||||
|
||||
private:
|
||||
string device_type_string_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -83,6 +83,8 @@ CreateResourceOpInfoMap() {
|
||||
add("ResourceScatterSub" , kReadWrite, kVariable);
|
||||
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
||||
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
||||
add("RngReadAndSkip" , kReadWrite, kVariable);
|
||||
add("RngSkip" , kReadWrite, kVariable);
|
||||
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
|
||||
add("StatefulTruncatedNormal" , kReadWrite, kVariable);
|
||||
add("StatefulUniform" , kReadWrite, kVariable);
|
||||
|
@ -487,6 +487,10 @@ std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
|
||||
return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
|
||||
}
|
||||
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
PrimitiveType type = shape.element_type();
|
||||
|
@ -89,6 +89,9 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
|
||||
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||
absl::Span<const xla::XlaOp> scalars);
|
||||
|
||||
// Increases Philox counter (an uint128) by a delta (an uint64).
|
||||
xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
|
||||
|
@ -488,6 +488,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/framework:register_types_traits.h",
|
||||
"//tensorflow/core/framework:resource_mgr.h",
|
||||
"//tensorflow/core/framework:resource_op_kernel.h",
|
||||
"//tensorflow/core/framework:rng_alg.h",
|
||||
"//tensorflow/core/framework:selective_registration.h",
|
||||
"//tensorflow/core/framework:session_state.h",
|
||||
"//tensorflow/core/framework:shape_inference.h",
|
||||
@ -652,6 +653,7 @@ tf_gen_op_libs(
|
||||
"spectral_ops",
|
||||
"state_ops",
|
||||
"stateless_random_ops",
|
||||
"stateless_random_ops_v2",
|
||||
"summary_ops",
|
||||
"training_ops",
|
||||
],
|
||||
@ -871,6 +873,7 @@ cc_library(
|
||||
":spectral_ops_op_lib",
|
||||
":state_ops_op_lib",
|
||||
":stateless_random_ops_op_lib",
|
||||
":stateless_random_ops_v2_op_lib",
|
||||
":string_ops_op_lib",
|
||||
":training_ops_op_lib",
|
||||
":user_ops_op_lib",
|
||||
|
@ -0,0 +1,35 @@
|
||||
op {
|
||||
graph_op_name: "RngReadAndSkip"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "resource"
|
||||
description: <<END
|
||||
The handle of the resource variable that stores the state of the RNG.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "delta"
|
||||
description: <<END
|
||||
The amount of advancement.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
The old value of the resource variable, before incrementing. Since state size is algorithm-dependent, this output will be right-padded with zeros to reach shape int64[3] (the current maximal state size among algorithms).
|
||||
END
|
||||
}
|
||||
summary: "Advance the counter of a counter-based RNG."
|
||||
description: <<END
|
||||
The state of the RNG after
|
||||
`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
|
||||
(or any other distribution). The actual increment added to the
|
||||
counter is an unspecified implementation choice.
|
||||
END
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomGetKeyCounterAlg"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "seed"
|
||||
description: <<END
|
||||
2 seeds (shape [2]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Counter for the counter-based RNG algorithm. Since counter size is algorithm-dependent, this output will be right-padded with zeros to reach shape uint64[2] (the current maximal counter size among algorithms).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
summary: "Picks the best algorithm based on device, and scrambles seed into key and counter."
|
||||
description: <<END
|
||||
This op picks the best counter-based RNG algorithm based on device, and scrambles a shape-[2] seed into a key and a counter, both needed by the counter-based algorithm. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
|
||||
END
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomNormalV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom values from a normal distribution."
|
||||
description: <<END
|
||||
The generated values will have mean 0 and standard deviation 1.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomUniformFullIntV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values are uniform integers covering the whole range of `dtype`.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomUniformIntV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "minval"
|
||||
description: <<END
|
||||
Minimum value (inclusive, scalar).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "maxval"
|
||||
description: <<END
|
||||
Maximum value (exclusive, scalar).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values follow a uniform distribution in the range `[minval, maxval)`.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter`, `alg`, `minval` and `maxval`.
|
||||
END
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomUniformV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random values from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values follow a uniform distribution in the range `[0, 1)`. The
|
||||
lower bound 0 is included in the range, while the upper bound 1 is excluded.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
op {
|
||||
graph_op_name: "StatelessTruncatedNormalV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom values from a truncated normal distribution."
|
||||
description: <<END
|
||||
The generated values follow a normal distribution with mean 0 and standard
|
||||
deviation 1, except that values whose magnitude is more than 2 standard
|
||||
deviations from the mean are dropped and re-picked.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -68,6 +68,7 @@ exports_files(
|
||||
"resource_mgr.h",
|
||||
"resource_op_kernel.h",
|
||||
"resource_var.h",
|
||||
"rng_alg.h",
|
||||
"run_handler.h",
|
||||
"run_handler_util.h",
|
||||
"session_state.h",
|
||||
@ -385,6 +386,7 @@ filegroup(
|
||||
"resource_mgr.h",
|
||||
"resource_op_kernel.h",
|
||||
"resource_var.h",
|
||||
"rng_alg.h",
|
||||
"run_handler.cc",
|
||||
"run_handler.h",
|
||||
"run_handler_util.cc",
|
||||
|
34
tensorflow/core/framework/rng_alg.h
Normal file
34
tensorflow/core/framework/rng_alg.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
enum Algorithm { RNG_ALG_PHILOX = 1, RNG_ALG_THREEFRY = 2 };
|
||||
|
||||
static constexpr int RNG_KEY_SIZE = 1;
|
||||
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
|
||||
inline int GetCounterSize(Algorithm alg) {
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
@ -4457,12 +4457,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stateless_random_ops_v2_header",
|
||||
hdrs = ["stateless_random_ops_v2.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stateful_random_ops",
|
||||
prefix = "stateful_random_ops",
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":dense_update_functor",
|
||||
":fill_functor",
|
||||
":gather_functor",
|
||||
":mutex_ops",
|
||||
":random_op",
|
||||
@ -5337,7 +5347,7 @@ tf_kernel_library(
|
||||
prefix = "random_binomial_op",
|
||||
deps = [
|
||||
":cwise_op",
|
||||
":random_ops",
|
||||
":random_op",
|
||||
":resource_variable_ops",
|
||||
":stateful_random_ops",
|
||||
":stateless_random_ops",
|
||||
@ -6159,6 +6169,7 @@ filegroup(
|
||||
"ragged_tensor_to_tensor_op.cc",
|
||||
"random_op.cc",
|
||||
"random_op_cpu.h",
|
||||
"random_ops_util.h",
|
||||
"random_poisson_op.cc",
|
||||
"reduce_join_op.cc",
|
||||
"reduction_ops_all.cc",
|
||||
|
@ -66,8 +66,9 @@ struct MultinomialFunctor<GPUDevice, T, OutputType> {
|
||||
typename TTypes<OutputType>::Matrix output) {
|
||||
// Uniform, [0, 1).
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, float> Dist;
|
||||
functor::FillPhiloxRandom<GPUDevice, Dist>()(ctx, d, gen, noises.data(),
|
||||
noises.size(), Dist());
|
||||
functor::FillPhiloxRandom<GPUDevice, Dist>()(
|
||||
ctx, d, /*key=*/nullptr, /*counter=*/nullptr, gen, noises.data(),
|
||||
noises.size(), Dist());
|
||||
|
||||
#if defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::IndexList<int, int, int> bsc;
|
||||
|
@ -30,8 +30,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
@ -375,7 +377,7 @@ class RandomBinomialOp : public OpKernel {
|
||||
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
||||
errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
alg_tensor.shape().DebugString()));
|
||||
Algorithm alg = alg_tensor.flat<Algorithm>()(0);
|
||||
Algorithm alg = Algorithm(alg_tensor.flat<int64>()(0));
|
||||
|
||||
int64 samples_per_batch = 1;
|
||||
const int64 num_sample_dims =
|
||||
|
@ -74,7 +74,7 @@ class PhiloxRandomOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
||||
auto output_flat = output->flat<T>();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(),
|
||||
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
|
||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
||||
// it just here.
|
||||
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
||||
@ -123,7 +123,7 @@ class RandomUniformIntOp : public OpKernel {
|
||||
|
||||
auto output_flat = output->flat<IntType>();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(),
|
||||
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
|
||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
||||
// it just here.
|
||||
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
||||
|
@ -34,10 +34,14 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
// NOTE: Due to inlining done by the compiler, you may need to add
|
||||
// explicit instantiation of the functor in random_op.cc. See example
|
||||
// functor::FillPhiloxRandom<CPUDevice, random::UniformDistribution>.
|
||||
//
|
||||
// This functor can take the PhiloxRandom input from either device memory `key`
|
||||
// and `counter` or a stack value `gen`. If both `key` and `counter` are not
|
||||
// nullptr, they provide the input; otherwise `gen` provides the input.
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandom<CPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||
random::PhiloxRandom gen,
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist);
|
||||
};
|
||||
@ -47,14 +51,13 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
// Declares the partially GPU-specialized functor struct.
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandom<GPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
random::PhiloxRandom gen,
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist);
|
||||
};
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/lib/hash/crc32c.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
@ -59,8 +60,9 @@ using random::SingleSampleAdapter;
|
||||
template <typename Device, class Distribution>
|
||||
struct FillPhiloxRandom {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen,
|
||||
T* data, int64 size, Distribution dist) {
|
||||
void operator()(OpKernelContext* ctx, const Device&, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen, T* data,
|
||||
int64 size, Distribution dist) {
|
||||
OP_REQUIRES(
|
||||
ctx, false,
|
||||
errors::Internal(
|
||||
@ -154,18 +156,24 @@ struct FillPhiloxRandomTask<Distribution, true> {
|
||||
// It splits the work into several tasks and run them in parallel
|
||||
template <class Distribution>
|
||||
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
|
||||
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
|
||||
OpKernelContext* ctx, const CPUDevice&, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist) {
|
||||
const int kGroupSize = Distribution::kResultElementCount;
|
||||
|
||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
|
||||
|
||||
const int kGroupCost =
|
||||
random::PhiloxRandom::kResultElementCount *
|
||||
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
|
||||
|
||||
if (key != nullptr && counter != nullptr) {
|
||||
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||
}
|
||||
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
|
||||
kGroupCost,
|
||||
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
@ -33,14 +34,16 @@ struct FillPhiloxRandomKernel;
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandomKernel<Distribution, false> {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size,
|
||||
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
|
||||
random::PhiloxRandom gen, T* data, int64 size,
|
||||
Distribution dist);
|
||||
};
|
||||
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandomKernel<Distribution, true> {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data,
|
||||
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
|
||||
random::PhiloxRandom base_gen, T* data,
|
||||
int64 size, Distribution dist);
|
||||
};
|
||||
|
||||
@ -136,12 +139,16 @@ class SampleCopier<int64, 2> {
|
||||
// distribution. Each output takes a fixed number of samples.
|
||||
template <class Distribution>
|
||||
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
||||
random::PhiloxRandom gen, T* data, int64 size, Distribution dist) {
|
||||
const uint64* key, const uint64* counter, random::PhiloxRandom gen, T* data,
|
||||
int64 size, Distribution dist) {
|
||||
const int kGroupSize = Distribution::kResultElementCount;
|
||||
|
||||
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32 total_thread_count = gridDim.x * blockDim.x;
|
||||
int32 offset = thread_id * kGroupSize;
|
||||
if (key != nullptr && counter != nullptr) {
|
||||
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||
}
|
||||
gen.Skip(thread_id);
|
||||
|
||||
const SampleCopier<T, kGroupSize> copier;
|
||||
@ -167,8 +174,8 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
||||
// distribution. Each output takes a variable number of samples.
|
||||
template <class Distribution>
|
||||
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||
const random::PhiloxRandom& base_gen, T* data, int64 size,
|
||||
Distribution dist) {
|
||||
const uint64* key, const uint64* counter, random::PhiloxRandom base_gen,
|
||||
T* data, int64 size, Distribution dist) {
|
||||
using random::PhiloxRandom;
|
||||
using random::SingleSampleAdapter;
|
||||
|
||||
@ -183,6 +190,9 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||
int64 group_index = thread_id;
|
||||
int64 offset = group_index * kGroupSize;
|
||||
|
||||
if (key != nullptr && counter != nullptr) {
|
||||
base_gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||
}
|
||||
while (offset < size) {
|
||||
// Since each output takes a variable number of samples, we need to
|
||||
// realign the generator to the beginning for the current output group
|
||||
@ -208,18 +218,20 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||
// A simple launch pad to call the correct function templates to fill the data
|
||||
template <class Distribution>
|
||||
__global__ void __launch_bounds__(1024)
|
||||
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
|
||||
FillPhiloxRandomKernelLaunch(const uint64* key, const uint64* counter,
|
||||
random::PhiloxRandom base_gen,
|
||||
typename Distribution::ResultElementType* data,
|
||||
int64 size, Distribution dist) {
|
||||
FillPhiloxRandomKernel<Distribution,
|
||||
Distribution::kVariableSamplesPerOutput>()
|
||||
.Run(base_gen, data, size, dist);
|
||||
.Run(key, counter, base_gen, data, size, dist);
|
||||
}
|
||||
|
||||
// Partial specialization for GPU
|
||||
template <class Distribution>
|
||||
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
||||
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
|
||||
OpKernelContext*, const GPUDevice& d, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist) {
|
||||
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||
@ -228,8 +240,8 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
||||
block_size;
|
||||
|
||||
TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
||||
num_blocks, block_size, 0, d.stream(), gen, data,
|
||||
size, dist));
|
||||
num_blocks, block_size, 0, d.stream(), key,
|
||||
counter, gen, data, size, dist));
|
||||
}
|
||||
|
||||
} // namespace functor
|
||||
|
72
tensorflow/core/kernels/random_ops_util.h
Normal file
72
tensorflow/core/kernels/random_ops_util.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
// The following 2 functions use the contract "lower 32 bits for the first
|
||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||
PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1,
|
||||
uint32* output2) {
|
||||
*output1 = static_cast<uint32>(input);
|
||||
*output2 = static_cast<uint32>(input >> 32);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) {
|
||||
auto u64_1 = static_cast<uint64>(input1);
|
||||
auto u64_2 = static_cast<uint64>(input2);
|
||||
return u64_1 | (u64_2 << 32);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem(
|
||||
uint64 const* ptr) {
|
||||
PhiloxRandom::ResultType counter;
|
||||
Uint64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||
Uint64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||
return counter;
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void WriteCounterToMem(
|
||||
PhiloxRandom::ResultType const& counter, uint64* ptr) {
|
||||
ptr[0] = Uint32sToUint64(counter[0], counter[1]);
|
||||
ptr[1] = Uint32sToUint64(counter[2], counter[3]);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) {
|
||||
PhiloxRandom::Key key;
|
||||
Uint64ToUint32s(ptr[0], &key[0], &key[1]);
|
||||
return key;
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key,
|
||||
uint64* ptr) {
|
||||
*ptr = Uint32sToUint64(key[0], key[1]);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem(
|
||||
uint64 const* counter_ptr, uint64 const* key_ptr) {
|
||||
return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr));
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/random_op_cpu.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
@ -23,6 +25,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Distribution>
|
||||
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
||||
@ -42,10 +46,13 @@ struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||
// No longer needs the lock.
|
||||
state_var_guard->Release();
|
||||
functor::FillPhiloxRandom<CPUDevice, Distribution>()(
|
||||
ctx, device, philox, output_data, output_size, dist);
|
||||
ctx, device, /*key=*/nullptr, /*counter=*/nullptr, philox, output_data,
|
||||
output_size, dist);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
Status CheckState(const Tensor& state) {
|
||||
if (state.dtype() != STATE_ELEMENT_DTYPE) {
|
||||
return errors::InvalidArgument("dtype of RNG state variable must be ",
|
||||
@ -64,11 +71,12 @@ Status CheckPhiloxState(const Tensor& state, int64 alg_tag_skip = 0) {
|
||||
"StateElementType must be int64");
|
||||
static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
|
||||
"PhiloxRandom::ResultElementType must be uint32");
|
||||
if (state.NumElements() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
|
||||
auto min_size = alg_tag_skip + PHILOX_MIN_STATE_SIZE;
|
||||
if (state.NumElements() < min_size) {
|
||||
return errors::InvalidArgument(
|
||||
"For the Philox algorithm, the size of state"
|
||||
" must be at least ",
|
||||
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ", state.NumElements());
|
||||
min_size, "; got ", state.NumElements());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -95,7 +103,7 @@ Status UpdateVariableAndFill(
|
||||
if (var_tensor_flat.size() < 1) {
|
||||
return errors::InvalidArgument("Size of tensor must be at least 1");
|
||||
}
|
||||
alg = var_tensor_flat(0);
|
||||
alg = Algorithm(var_tensor_flat(0));
|
||||
}
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip));
|
||||
@ -107,7 +115,7 @@ Status UpdateVariableAndFill(
|
||||
arg.alg_tag_skip = alg_tag_skip;
|
||||
arg.not_used = &state_var_guard;
|
||||
arg.state_tensor = var_tensor;
|
||||
UpdateVariableAndFill_Philox<Device, Distribution>()(
|
||||
functor::UpdateVariableAndFill_Philox<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), dist, &arg, output_data);
|
||||
return Status::OK();
|
||||
} else {
|
||||
@ -138,7 +146,8 @@ class StatefulRandomOp : public OpKernel {
|
||||
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true, 0);
|
||||
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true,
|
||||
RNG_ALG_PHILOX /*dummy*/);
|
||||
}
|
||||
};
|
||||
|
||||
@ -159,6 +168,14 @@ Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename AlgEnumType>
|
||||
Status GetAlg(OpKernelContext* ctx, int input_idx, Algorithm* alg) {
|
||||
AlgEnumType alg_id;
|
||||
TF_RETURN_IF_ERROR(GetScalar(ctx->input(input_idx), input_idx, &alg_id));
|
||||
*alg = Algorithm(alg_id);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Device, class Distribution>
|
||||
class StatefulRandomOpV2 : public OpKernel {
|
||||
public:
|
||||
@ -166,7 +183,7 @@ class StatefulRandomOpV2 : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||
StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
|
||||
/*shape_input_idx=*/2,
|
||||
/*read_alg_from_state=*/false, alg);
|
||||
@ -180,7 +197,7 @@ class StatefulUniformIntOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||
const Tensor& minval = ctx->input(3);
|
||||
const Tensor& maxval = ctx->input(4);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
||||
@ -217,7 +234,7 @@ class StatefulUniformFullIntOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||
StatefulRandomCompute<Device>(
|
||||
ctx,
|
||||
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
|
||||
@ -226,38 +243,66 @@ class StatefulUniformFullIntOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <>
|
||||
struct RngSkip_Philox<CPUDevice> {
|
||||
void operator()(const CPUDevice& device, int64 delta, Tensor* state_tensor) {
|
||||
auto state_data = state_tensor->flat<StateElementType>().data();
|
||||
void operator()(const CPUDevice& device, const StateElementType* in_data,
|
||||
uint64 delta, StateElementType* out_data) {
|
||||
// Delegates to PhiloxRandom to do the actual increasing.
|
||||
auto philox = GetPhiloxRandomFromMem(state_data);
|
||||
UpdateMemWithPhiloxRandom(philox, delta, state_data);
|
||||
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
|
||||
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
} // end namespace functor
|
||||
|
||||
template <typename Device, typename AlgEnumType = int64,
|
||||
typename DeltaType = int64, bool read_old_value = false>
|
||||
class RngSkipOp : public OpKernel {
|
||||
public:
|
||||
explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto state_input_idx = 0;
|
||||
auto alg_input_idx = 1;
|
||||
auto delta_input_idx = 2;
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
int64 delta;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(2), 2, &delta));
|
||||
OP_REQUIRES_OK(ctx, GetAlg<AlgEnumType>(ctx, alg_input_idx, &alg));
|
||||
DeltaType delta_;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetScalar(ctx->input(delta_input_idx), delta_input_idx, &delta_));
|
||||
uint64 delta = static_cast<uint64>(delta_);
|
||||
Var* var = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
|
||||
ScopedUnlockUnrefVar state_var_guard(var);
|
||||
Tensor* var_tensor = var->tensor();
|
||||
OP_REQUIRES_OK(ctx, CheckState(*var_tensor));
|
||||
using T = StateElementType;
|
||||
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, T>(
|
||||
ctx, var_tensor, var->copy_on_read_mode.load()));
|
||||
if (read_old_value) {
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, {RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE},
|
||||
&output));
|
||||
auto output_flat = output->flat<T>();
|
||||
if (RNG_MAX_COUNTER_SIZE > GetCounterSize(alg)) {
|
||||
functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
|
||||
output_flat);
|
||||
}
|
||||
functor::DenseUpdate<Device, T, ASSIGN>()(
|
||||
ctx->eigen_device<Device>(), output_flat,
|
||||
const_cast<const Tensor*>(var_tensor)->flat<T>());
|
||||
}
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor));
|
||||
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
|
||||
ctx, var_tensor, var->copy_on_read_mode.load()));
|
||||
RngSkip_Philox<Device>()(ctx->eigen_device<Device>(), delta, var_tensor);
|
||||
// var_tensor layout is counter+key, so var_tensor data is also counter
|
||||
// data.
|
||||
auto counter_data = var_tensor->flat<T>().data();
|
||||
functor::RngSkip_Philox<Device>()(ctx->eigen_device<Device>(),
|
||||
counter_data, delta, counter_data);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
@ -393,13 +438,20 @@ TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU);
|
||||
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
|
||||
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
|
||||
|
||||
// TODO(wangpeng): Remove `HostMemory("delta")` for RngReadAndSkip
|
||||
#define REGISTER_RngSkip(DEVICE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RngSkip") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("resource") \
|
||||
.HostMemory("algorithm") \
|
||||
.HostMemory("delta"), \
|
||||
RngSkipOp<DEVICE##Device>);
|
||||
RngSkipOp<DEVICE##Device>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("RngReadAndSkip") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("resource") \
|
||||
.HostMemory("alg") \
|
||||
.HostMemory("delta"), \
|
||||
RngSkipOp<DEVICE##Device, int32, uint64, true>);
|
||||
|
||||
REGISTER_RngSkip(CPU);
|
||||
|
||||
|
@ -22,15 +22,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained
|
||||
// in b/111604096 and cl/171681867), so I use signed int here. I choose int64
|
||||
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU.
|
||||
// in b/111604096 and cl/171681867), so we use signed int here. We choose int64
|
||||
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU, and
|
||||
// because of the "int32 problem".
|
||||
using StateElementType = int64;
|
||||
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
|
||||
|
||||
using Algorithm = StateElementType;
|
||||
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
|
||||
static constexpr Algorithm RNG_ALG_PHILOX = 1;
|
||||
static constexpr Algorithm RNG_ALG_THREEFRY = 2;
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
|
@ -17,59 +17,51 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The following 5 functions are made templates to avoid duplicate symbols when
|
||||
// linking.
|
||||
|
||||
// The following 2 functions use the contract "lower 32 bits for the first
|
||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||
PHILOX_DEVICE_INLINE void Int64ToUint32s(int64 input, uint32* output1,
|
||||
uint32* output2) {
|
||||
auto u64 = static_cast<uint64>(input);
|
||||
*output1 = static_cast<uint32>(u64);
|
||||
*output2 = static_cast<uint32>(u64 >> 32);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE int64 Uint32sToInt64(uint32 input1, uint32 input2) {
|
||||
auto u64_1 = static_cast<uint64>(input1);
|
||||
auto u64_2 = static_cast<uint64>(input2);
|
||||
return static_cast<int64>(u64_1 | (u64_2 << 32));
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom
|
||||
GetPhiloxRandomFromMem(StateElementType const* ptr) {
|
||||
PhiloxRandom::ResultType counter;
|
||||
PhiloxRandom::Key key;
|
||||
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||
Int64ToUint32s(ptr[2], &key[0], &key[1]);
|
||||
return PhiloxRandom(counter, key);
|
||||
auto ptr_ = reinterpret_cast<uint64 const*>(ptr);
|
||||
return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
|
||||
StateElementType* ptr) {
|
||||
PhiloxRandom::ResultType const& counter = philox.counter();
|
||||
PhiloxRandom::Key const& key = philox.key();
|
||||
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
|
||||
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
|
||||
ptr[2] = Uint32sToInt64(key[0], key[1]);
|
||||
auto ptr_ = reinterpret_cast<uint64*>(ptr);
|
||||
WriteCounterToMem(philox.counter(), ptr_);
|
||||
WriteKeyToMem(philox.key(), ptr_ + 2);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox,
|
||||
uint64 output_size) {
|
||||
auto new_philox = philox;
|
||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it
|
||||
// just here.
|
||||
auto delta = output_size * 256;
|
||||
new_philox.Skip(delta); // do the actual increasing
|
||||
return new_philox;
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
|
||||
int64 output_size,
|
||||
uint64 output_size,
|
||||
StateElementType* ptr) {
|
||||
auto new_philox = philox;
|
||||
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
|
||||
// it just here.
|
||||
auto delta = output_size * 256;
|
||||
new_philox.Skip(delta); // do the actual increasing
|
||||
auto new_philox = SkipPhiloxRandom(philox, output_size);
|
||||
WritePhiloxRandomToMem(new_philox, ptr);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom(
|
||||
PhiloxRandom::ResultType const& counter, uint64 output_size,
|
||||
StateElementType* ptr) {
|
||||
auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/);
|
||||
auto new_philox = SkipPhiloxRandom(philox, output_size);
|
||||
WriteCounterToMem(new_philox.counter(), reinterpret_cast<uint64*>(ptr));
|
||||
}
|
||||
|
||||
namespace functor {
|
||||
|
||||
// A per-device helper function that does the actual work for
|
||||
// `UpdateVariableAndFill`.
|
||||
// Reason to use functor: C++ doesn't allow function-template partial
|
||||
@ -80,6 +72,8 @@ struct UpdateVariableAndFill_Philox;
|
||||
template <typename Device>
|
||||
struct RngSkip_Philox;
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
struct UpdateVariableAndFill_Philox_Arg {
|
||||
@ -93,6 +87,8 @@ struct UpdateVariableAndFill_Philox_Arg {
|
||||
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Declares the partially GPU-specialized functor structs.
|
||||
// must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
|
||||
template <typename Distribution>
|
||||
@ -104,9 +100,12 @@ struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
||||
|
||||
template <>
|
||||
struct RngSkip_Philox<GPUDevice> {
|
||||
void operator()(const GPUDevice& device, int64 delta, Tensor* state_tensor);
|
||||
void operator()(const GPUDevice& device, const StateElementType* in_data,
|
||||
uint64 delta, StateElementType* out_data);
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -31,6 +31,8 @@ __device__ int tensorflow_philox_thread_counter;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
template <typename Distribution>
|
||||
@ -48,7 +50,8 @@ __global__ void FillKernel(
|
||||
__syncthreads();
|
||||
functor::FillPhiloxRandomKernel<Distribution,
|
||||
Distribution::kVariableSamplesPerOutput>()
|
||||
.Run(*philox, output_data, output_size, dist);
|
||||
.Run(/*key=*/nullptr, /*counter=*/nullptr, *philox, output_data,
|
||||
output_size, dist);
|
||||
// The last thread updates the state.
|
||||
auto total_thread_count = gridDim.x * blockDim.x;
|
||||
auto old_counter_value = atomicAdd(&tensorflow_philox_thread_counter, 1);
|
||||
@ -96,16 +99,19 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
||||
}
|
||||
|
||||
// Precondition: there is only 1 block and 1 thread.
|
||||
__global__ void SkipKernel(int64 delta,
|
||||
StateElementType* __restrict__ state_data) {
|
||||
auto philox = GetPhiloxRandomFromMem(state_data);
|
||||
UpdateMemWithPhiloxRandom(philox, delta, state_data);
|
||||
__global__ void SkipKernel(const StateElementType* __restrict__ in_data,
|
||||
uint64 delta,
|
||||
StateElementType* __restrict__ out_data) {
|
||||
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
|
||||
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
|
||||
}
|
||||
|
||||
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d, int64 delta,
|
||||
Tensor* state_tensor) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), delta,
|
||||
state_tensor->flat<StateElementType>().data()));
|
||||
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d,
|
||||
const StateElementType* in_data,
|
||||
uint64 delta,
|
||||
StateElementType* out_data) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), in_data, delta,
|
||||
out_data));
|
||||
}
|
||||
|
||||
// Explicit instantiation of the GPU distributions functors.
|
||||
@ -154,6 +160,7 @@ template struct UpdateVariableAndFill_Philox<
|
||||
random::PhiloxRandom, uint64> >;
|
||||
// clang-format on
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -121,8 +121,8 @@ class StatelessRandomOp : public StatelessRandomOpBase {
|
||||
auto flat = output->flat<T>();
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
context, context->eigen_device<Device>(), random, flat.data(),
|
||||
flat.size(), Distribution());
|
||||
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||
/*counter=*/nullptr, random, flat.data(), flat.size(), Distribution());
|
||||
}
|
||||
};
|
||||
|
||||
@ -158,8 +158,8 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||
auto flat = output->flat<IntType>();
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
context, context->eigen_device<Device>(), random, flat.data(),
|
||||
flat.size(), dist);
|
||||
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
|
||||
}
|
||||
};
|
||||
|
||||
@ -178,8 +178,8 @@ class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
|
||||
auto flat = output->flat<IntType>();
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
context, context->eigen_device<Device>(), random, flat.data(),
|
||||
flat.size(), dist);
|
||||
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
|
||||
}
|
||||
};
|
||||
|
||||
|
330
tensorflow/core/kernels/stateless_random_ops_v2.cc
Normal file
330
tensorflow/core/kernels/stateless_random_ops_v2.cc
Normal file
@ -0,0 +1,330 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/kernels/random_poisson_op.h"
|
||||
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
|
||||
auto dtype = DataTypeToEnum<T>::v();
|
||||
if (tensor.dims() != 0) {
|
||||
return errors::InvalidArgument("input ", std::to_string(input_idx),
|
||||
" (0-based) must have shape [], not ",
|
||||
tensor.shape().DebugString());
|
||||
}
|
||||
if (tensor.dtype() != dtype) {
|
||||
return errors::InvalidArgument("dtype of input ", std::to_string(input_idx),
|
||||
" (0-based) must be ", DataTypeString(dtype),
|
||||
", not ", DataTypeString(tensor.dtype()));
|
||||
}
|
||||
*result = tensor.flat<T>()(0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class StatelessRandomOpBase : public OpKernel {
|
||||
public:
|
||||
explicit StatelessRandomOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// Sanitize input
|
||||
const Tensor& shape_t = ctx->input(0);
|
||||
const Tensor& key_t = ctx->input(1);
|
||||
const Tensor& counter_t = ctx->input(2);
|
||||
const int alg_input_idx = 3;
|
||||
const Tensor& alg_t = ctx->input(alg_input_idx);
|
||||
|
||||
int alg_id;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(alg_t, alg_input_idx, &alg_id));
|
||||
Algorithm alg = Algorithm(alg_id);
|
||||
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CheckKeyCounterShape(alg, key_t.shape(), counter_t.shape()));
|
||||
|
||||
// Allocate output
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
|
||||
if (shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill in the random numbers
|
||||
Fill(ctx, alg, key_t, counter_t, output);
|
||||
}
|
||||
|
||||
// The part of Compute that depends on device, type, and distribution
|
||||
virtual void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) = 0;
|
||||
};
|
||||
|
||||
template <typename Device, typename Distribution>
|
||||
class StatelessRandomOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) override {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
auto flat = output->flat<T>();
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
auto key_data = key.flat<uint64>().data();
|
||||
auto counter_data = counter.flat<uint64>().data();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(),
|
||||
Distribution());
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename IntType>
|
||||
class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) override {
|
||||
const Tensor& minval = ctx->input(4);
|
||||
const Tensor& maxval = ctx->input(5);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
||||
errors::InvalidArgument("minval must be 0-D, got shape ",
|
||||
minval.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
|
||||
errors::InvalidArgument("maxval must be 0-D, got shape ",
|
||||
maxval.shape().DebugString()));
|
||||
|
||||
// Verify that minval < maxval. Note that we'll never reach this point for
|
||||
// empty output. Zero impossible things are fine.
|
||||
const auto lo = minval.scalar<IntType>()();
|
||||
const auto hi = maxval.scalar<IntType>()();
|
||||
OP_REQUIRES(
|
||||
ctx, lo < hi,
|
||||
errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
|
||||
|
||||
// Build distribution
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, IntType>
|
||||
Distribution;
|
||||
Distribution dist(lo, hi);
|
||||
|
||||
auto flat = output->flat<IntType>();
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
auto key_data = key.flat<uint64>().data();
|
||||
auto counter_data = counter.flat<uint64>().data();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename IntType>
|
||||
class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) override {
|
||||
// Build distribution
|
||||
typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
|
||||
Distribution;
|
||||
Distribution dist;
|
||||
|
||||
auto flat = output->flat<IntType>();
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
auto key_data = key.flat<uint64>().data();
|
||||
auto counter_data = counter.flat<uint64>().data();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class GetKeyCounterAlgOp : public OpKernel {
|
||||
public:
|
||||
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& seed_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_t.shape().DebugString()));
|
||||
// Allocate outputs
|
||||
Tensor* key_output;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
|
||||
Tensor* counter_output;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
|
||||
&counter_output));
|
||||
Tensor* alg_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &alg_output));
|
||||
|
||||
random::PhiloxRandom::Key key;
|
||||
random::PhiloxRandom::ResultType counter;
|
||||
OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
|
||||
WriteKeyToMem(key, key_output->flat<uint64>().data());
|
||||
WriteCounterToMem(counter, counter_output->flat<uint64>().data());
|
||||
alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER(DEVICE, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomUniformV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
|
||||
random::PhiloxRandom, TYPE> >); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomNormalV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
|
||||
random::PhiloxRandom, TYPE> >); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessTruncatedNormalV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp< \
|
||||
DEVICE##Device, \
|
||||
random::TruncatedNormalDistribution< \
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
|
||||
|
||||
#define REGISTER_FULL_INT(DEVICE, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomUniformFullIntV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
|
||||
|
||||
#define REGISTER_INT(DEVICE, TYPE) \
|
||||
REGISTER_FULL_INT(DEVICE, TYPE); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformIntV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.HostMemory("minval") \
|
||||
.HostMemory("maxval") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
|
||||
|
||||
#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
|
||||
#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
|
||||
#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
|
||||
#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
|
||||
#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
|
||||
#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
|
||||
|
||||
TF_CALL_half(REGISTER_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_int32(REGISTER_INT_CPU);
|
||||
TF_CALL_int64(REGISTER_INT_CPU);
|
||||
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
|
||||
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
|
||||
|
||||
#define REGISTER_GET_KCA(DEVICE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("seed") \
|
||||
.HostMemory("key") \
|
||||
.HostMemory("counter") \
|
||||
.HostMemory("alg"), \
|
||||
GetKeyCounterAlgOp)
|
||||
|
||||
REGISTER_GET_KCA(CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
TF_CALL_int32(REGISTER_INT_GPU);
|
||||
TF_CALL_int64(REGISTER_INT_GPU);
|
||||
TF_CALL_uint32(REGISTER_FULL_INT_GPU);
|
||||
TF_CALL_uint64(REGISTER_FULL_INT_GPU);
|
||||
|
||||
REGISTER_GET_KCA(GPU);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
#undef REGISTER_INT
|
||||
#undef REGISTER_CPU
|
||||
#undef REGISTER_GPU
|
||||
#undef REGISTER_INT_CPU
|
||||
#undef REGISTER_INT_GPU
|
||||
#undef REGISTER_FULL_INT_CPU
|
||||
#undef REGISTER_FULL_INT_GPU
|
||||
|
||||
#undef REGISTER_GET_KCA
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
46
tensorflow/core/kernels/stateless_random_ops_v2.h
Normal file
46
tensorflow/core/kernels/stateless_random_ops_v2.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
||||
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
inline Status CheckKeyCounterShape(Algorithm const& alg,
|
||||
TensorShape const& key_shape,
|
||||
TensorShape const& counter_shape) {
|
||||
if (!(key_shape.dims() == 1 && key_shape.dim_size(0) == RNG_KEY_SIZE)) {
|
||||
return errors::InvalidArgument(
|
||||
"key must have shape [", RNG_KEY_SIZE, "], not ",
|
||||
key_shape.DebugString(),
|
||||
". (Note that batched keys are not supported yet.)");
|
||||
}
|
||||
auto counter_size = GetCounterSize(alg);
|
||||
if (!(counter_shape.dims() == 1 &&
|
||||
counter_shape.dim_size(0) >= counter_size)) {
|
||||
return errors::InvalidArgument(
|
||||
"counter must be a vector with length at least ", counter_size,
|
||||
"; got shape: ", counter_shape.DebugString(),
|
||||
". (Note that batched counters are not supported yet.)");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -90,6 +91,19 @@ REGISTER_OP("RngSkip")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("RngReadAndSkip")
|
||||
.Input("resource: resource")
|
||||
.Input("alg: int32")
|
||||
.Input("delta: uint64")
|
||||
.Output("value: int64")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
c->set_output(0, c->MakeShape({RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("NonDeterministicInts")
|
||||
.Input("shape: shape_dtype")
|
||||
.SetIsStateful()
|
||||
|
119
tensorflow/core/ops/stateless_random_ops_v2.cc
Normal file
119
tensorflow/core/ops/stateless_random_ops_v2.cc
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
static Status StatelessShapeV2(InferenceContext* c) {
|
||||
// Check key and counter shapes
|
||||
ShapeHandle key;
|
||||
ShapeHandle counter;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &key));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &counter));
|
||||
shape_inference::ShapeHandle unused_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), RNG_KEY_SIZE, &unused));
|
||||
|
||||
// Set output shape
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_STATELESS_OP(name) \
|
||||
REGISTER_OP(name) \
|
||||
.Input("shape: Tshape") \
|
||||
.Input("key: uint64") \
|
||||
.Input("counter: uint64") \
|
||||
.Input("alg: int32") \
|
||||
.Output("output: dtype") \
|
||||
.Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32") \
|
||||
.SetShapeFn(StatelessShapeV2)
|
||||
|
||||
REGISTER_STATELESS_OP("StatelessRandomUniformV2");
|
||||
REGISTER_STATELESS_OP("StatelessRandomNormalV2");
|
||||
REGISTER_STATELESS_OP("StatelessTruncatedNormalV2");
|
||||
|
||||
#undef REGISTER_STATELESS_OP
|
||||
|
||||
REGISTER_OP("StatelessRandomUniformIntV2")
|
||||
.Input("shape: Tshape")
|
||||
.Input("key: uint64")
|
||||
.Input("counter: uint64")
|
||||
.Input("alg: int32")
|
||||
.Input("minval: dtype")
|
||||
.Input("maxval: dtype")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: {int32, int64, uint32, uint64}")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
Status s = c->WithRank(c->input(4), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"minval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(4)));
|
||||
}
|
||||
s = c->WithRank(c->input(5), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"maxval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(5)));
|
||||
}
|
||||
return StatelessShapeV2(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("StatelessRandomUniformFullIntV2")
|
||||
.Input("shape: Tshape")
|
||||
.Input("key: uint64")
|
||||
.Input("counter: uint64")
|
||||
.Input("alg: int32")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(StatelessShapeV2);
|
||||
|
||||
REGISTER_OP("StatelessRandomGetKeyCounterAlg")
|
||||
.Input("seed: Tseed")
|
||||
.Output("key: uint64")
|
||||
.Output("counter: uint64")
|
||||
.Output("alg: int32")
|
||||
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||
.SetIsStateful() // because outputs depend on device
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Check seed shape
|
||||
ShapeHandle seed;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
|
||||
|
||||
// Set output shapes
|
||||
c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
|
||||
c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
|
||||
c->set_output(2, c->MakeShape({}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
@ -3183,6 +3183,10 @@ tf_gen_op_wrapper_private_py(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "stateless_random_ops_v2_gen",
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "list_ops_gen",
|
||||
)
|
||||
@ -4440,6 +4444,7 @@ py_library(
|
||||
":framework_ops",
|
||||
":math_ops",
|
||||
":stateful_random_ops_gen",
|
||||
":stateless_random_ops_v2_gen",
|
||||
":variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
@ -4474,6 +4479,7 @@ py_library(
|
||||
":math_ops",
|
||||
":random_ops",
|
||||
":stateless_random_ops_gen",
|
||||
":stateless_random_ops_v2_gen",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 352> a = {{
|
||||
static std::array<OpIndexInfo, 357> a = {{
|
||||
{"Acosh"},
|
||||
{"AllToAll", 1, {0}},
|
||||
{"ApproximateEqual"},
|
||||
@ -332,11 +332,16 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
{"StatelessRandomBinomial"},
|
||||
{"StatelessRandomGammaV2", 1, {1}},
|
||||
{"StatelessRandomNormal"},
|
||||
{"StatelessRandomNormalV2"},
|
||||
{"StatelessRandomPoisson"},
|
||||
{"StatelessRandomUniform"},
|
||||
{"StatelessRandomUniformFullInt"},
|
||||
{"StatelessRandomUniformFullIntV2"},
|
||||
{"StatelessRandomUniformInt"},
|
||||
{"StatelessRandomUniformIntV2"},
|
||||
{"StatelessRandomUniformV2"},
|
||||
{"StatelessTruncatedNormal"},
|
||||
{"StatelessTruncatedNormalV2"},
|
||||
{"StopGradient"},
|
||||
{"StridedSliceGrad", 2, {0, 4}},
|
||||
{"StringSplit"},
|
||||
@ -415,7 +420,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 468> a = {{
|
||||
static std::array<OpIndexInfo, 473> a = {{
|
||||
{"Abs"},
|
||||
{"AccumulateNV2"},
|
||||
{"Acos"},
|
||||
@ -793,11 +798,16 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
{"StatelessMultinomial"},
|
||||
{"StatelessRandomBinomial"},
|
||||
{"StatelessRandomNormal"},
|
||||
{"StatelessRandomNormalV2"},
|
||||
{"StatelessRandomPoisson"},
|
||||
{"StatelessRandomUniform"},
|
||||
{"StatelessRandomUniformFullInt"},
|
||||
{"StatelessRandomUniformFullIntV2"},
|
||||
{"StatelessRandomUniformInt"},
|
||||
{"StatelessRandomUniformIntV2"},
|
||||
{"StatelessRandomUniformV2"},
|
||||
{"StatelessTruncatedNormal"},
|
||||
{"StatelessTruncatedNormalV2"},
|
||||
{"StopGradient"},
|
||||
{"StridedSlice"},
|
||||
{"StridedSliceGrad"},
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import image_ops_impl as image_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -1237,11 +1238,14 @@ class RandomHeightTest(keras_parameterized.TestCase):
|
||||
mock_factor = 0
|
||||
with test.mock.patch.object(
|
||||
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 5, 8, 3))
|
||||
layer = image_preprocessing.RandomHeight(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[1], 3)
|
||||
with test.mock.patch.object(
|
||||
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
|
||||
return_value=mock_factor):
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 5, 8, 3))
|
||||
layer = image_preprocessing.RandomHeight(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[1], 3)
|
||||
|
||||
def test_random_height_longer_numeric(self):
|
||||
for dtype in (np.int64, np.float32):
|
||||
@ -1328,11 +1332,14 @@ class RandomWidthTest(keras_parameterized.TestCase):
|
||||
mock_factor = 0
|
||||
with test.mock.patch.object(
|
||||
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 8, 5, 3))
|
||||
layer = image_preprocessing.RandomWidth(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[2], 3)
|
||||
with test.mock.patch.object(
|
||||
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
|
||||
return_value=mock_factor):
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 8, 5, 3))
|
||||
layer = image_preprocessing.RandomWidth(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[2], 3)
|
||||
|
||||
def test_random_width_longer_numeric(self):
|
||||
for dtype in (np.int64, np.float32):
|
||||
|
@ -22,10 +22,13 @@ import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -35,6 +38,31 @@ from tensorflow.python.ops import stateless_random_ops as stateless
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# Note that in theory each test will reset the eager context and may choose to
|
||||
# hide some devices, so we shouldn't cache this transient info. Tests in this
|
||||
# file don't make those config changes, so caching is fine. It provides a good
|
||||
# speed-up.
|
||||
_cached_device = None
|
||||
|
||||
|
||||
def get_device():
|
||||
global _cached_device
|
||||
if _cached_device is not None:
|
||||
return _cached_device
|
||||
# Precedence from high to low
|
||||
for device_type in ('XLA_GPU', 'GPU', 'XLA_CPU', 'CPU'):
|
||||
devices = config.list_logical_devices(device_type)
|
||||
if devices:
|
||||
_cached_device = devices[0]
|
||||
return _cached_device
|
||||
raise ValueError('Cannot find any suitable device. Available devices: %s' %
|
||||
config.list_logical_devices())
|
||||
|
||||
|
||||
BEFORE_EXPIRE = (2020, 10, 24)
|
||||
AFTER_EXPIRE = (2020, 10, 26)
|
||||
|
||||
|
||||
def invert_philox(key, value):
|
||||
"""Invert the Philox bijection."""
|
||||
key = np.array(key, dtype=np.uint32)
|
||||
@ -59,47 +87,71 @@ SEED_TYPES = [dtypes.int32, dtypes.int64]
|
||||
def float_cases(shape_dtypes=(None,)):
|
||||
cases = (
|
||||
# Uniform distribution, with and without range
|
||||
(stateless.stateless_random_uniform, random_ops.random_uniform, {}),
|
||||
(stateless.stateless_random_uniform, random_ops.random_uniform,
|
||||
dict(minval=2.2, maxval=7.1)),
|
||||
('uniform', stateless.stateless_random_uniform, random_ops.random_uniform,
|
||||
{}),
|
||||
('uniform2', stateless.stateless_random_uniform,
|
||||
random_ops.random_uniform, dict(minval=2.2, maxval=7.1)),
|
||||
# Normal distribution, with and without mean+stddev
|
||||
(stateless.stateless_random_normal, random_ops.random_normal, {}),
|
||||
(stateless.stateless_random_normal, random_ops.random_normal,
|
||||
('normal', stateless.stateless_random_normal, random_ops.random_normal,
|
||||
{}),
|
||||
('normal2', stateless.stateless_random_normal, random_ops.random_normal,
|
||||
dict(mean=2, stddev=3)),
|
||||
# Truncated normal distribution, with and without mean+stddev
|
||||
(stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
|
||||
(stateless.stateless_truncated_normal, random_ops.truncated_normal,
|
||||
dict(mean=3, stddev=4)),
|
||||
('trnorm', stateless.stateless_truncated_normal,
|
||||
random_ops.truncated_normal, {}),
|
||||
('trnorm2', stateless.stateless_truncated_normal,
|
||||
random_ops.truncated_normal, dict(mean=3, stddev=4)),
|
||||
)
|
||||
# Explicitly passing in params because capturing cell variable from loop is
|
||||
# problematic in Python
|
||||
def wrap(op, dtype, shape, shape_dtype, kwds, seed):
|
||||
device_type = get_device().device_type
|
||||
# Some dtypes are not supported on some devices
|
||||
if (dtype == dtypes.float16 and device_type in ('XLA_GPU', 'XLA_CPU') or
|
||||
dtype == dtypes.bfloat16 and device_type == 'GPU'):
|
||||
dtype = dtypes.float32
|
||||
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
||||
if shape_dtype is not None else shape)
|
||||
return op(seed=seed, shape=shape_, dtype=dtype, **kwds)
|
||||
for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
|
||||
def _name(a):
|
||||
if hasattr(a, 'name'):
|
||||
return a.name
|
||||
else:
|
||||
return a
|
||||
|
||||
for dtype in dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64:
|
||||
for shape_dtype in shape_dtypes:
|
||||
for shape in (), (3,), (2, 5):
|
||||
for stateless_op, stateful_op, kwds in cases:
|
||||
yield (functools.partial(wrap, stateless_op, dtype, shape,
|
||||
for name, stateless_op, stateful_op, kwds in cases:
|
||||
yield (('%s_%s_%s_%s' %
|
||||
(name, _name(dtype), shape, _name(shape_dtype))).replace(
|
||||
' ', ''),
|
||||
functools.partial(wrap, stateless_op, dtype, shape,
|
||||
shape_dtype, kwds),
|
||||
functools.partial(wrap, stateful_op, dtype, shape,
|
||||
shape_dtype, kwds))
|
||||
functools.partial(wrap, stateful_op, dtype, shape, shape_dtype,
|
||||
kwds))
|
||||
|
||||
|
||||
def int_cases(shape_dtypes=(None,)):
|
||||
def wrap(op, shape, shape_dtype, dtype, seed):
|
||||
def int_cases(shape_dtypes=(None,), minval_maxval=None):
|
||||
|
||||
def wrap(op, minval, maxval, shape, shape_dtype, dtype, seed):
|
||||
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
||||
if shape_dtype is not None else shape)
|
||||
return op(seed=seed, shape=shape_, minval=2, maxval=11111,
|
||||
dtype=dtype)
|
||||
for shape_dtype in shape_dtypes:
|
||||
for shape in (), (3,), (2, 5):
|
||||
for dtype in dtypes.int32, dtypes.int64:
|
||||
yield (functools.partial(wrap, stateless.stateless_random_uniform,
|
||||
shape, shape_dtype, dtype),
|
||||
functools.partial(wrap, random_ops.random_uniform,
|
||||
shape, shape_dtype, dtype))
|
||||
return op(
|
||||
seed=seed, shape=shape_, minval=minval, maxval=maxval, dtype=dtype)
|
||||
|
||||
if minval_maxval is None:
|
||||
minval_maxval = ((2, 11111),)
|
||||
for minval, maxval in minval_maxval:
|
||||
for shape_dtype in shape_dtypes:
|
||||
for shape in (), (3,), (2, 5):
|
||||
for dtype in dtypes.int32, dtypes.int64:
|
||||
yield ('uniform_%s_%s' % (minval, maxval),
|
||||
functools.partial(wrap, stateless.stateless_random_uniform,
|
||||
minval, maxval, shape, shape_dtype, dtype),
|
||||
functools.partial(wrap, random_ops.random_uniform, minval,
|
||||
maxval, shape, shape_dtype, dtype))
|
||||
|
||||
|
||||
def multinomial_cases():
|
||||
@ -112,7 +164,8 @@ def multinomial_cases():
|
||||
for output_dtype in dtypes.int32, dtypes.int64:
|
||||
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
|
||||
[0.25, 0.75]]):
|
||||
yield (functools.partial(wrap, stateless.stateless_multinomial, logits,
|
||||
yield ('multinomial',
|
||||
functools.partial(wrap, stateless.stateless_multinomial, logits,
|
||||
logits_dtype, output_dtype),
|
||||
functools.partial(wrap, random_ops.multinomial, logits,
|
||||
logits_dtype, output_dtype))
|
||||
@ -124,10 +177,11 @@ def gamma_cases():
|
||||
alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
|
||||
for dtype in np.float16, np.float32, np.float64:
|
||||
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
|
||||
yield (functools.partial(wrap, stateless.stateless_random_gamma, alpha,
|
||||
yield ('gamma',
|
||||
functools.partial(wrap, stateless.stateless_random_gamma, alpha,
|
||||
dtype, (10,) + tuple(np.shape(alpha))),
|
||||
functools.partial(wrap, random_ops.random_gamma, alpha,
|
||||
dtype, (10,)))
|
||||
functools.partial(wrap, random_ops.random_gamma, alpha, dtype,
|
||||
(10,)))
|
||||
|
||||
|
||||
def poisson_cases():
|
||||
@ -138,7 +192,8 @@ def poisson_cases():
|
||||
for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||
for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||
for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
|
||||
yield (functools.partial(wrap, stateless.stateless_random_poisson, lam,
|
||||
yield ('poisson',
|
||||
functools.partial(wrap, stateless.stateless_random_poisson, lam,
|
||||
lam_dtype, out_dtype,
|
||||
(10,) + tuple(np.shape(lam))),
|
||||
functools.partial(wrap, random_ops.random_poisson, lam,
|
||||
@ -153,22 +208,28 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
key = 0x3ec8f720, 0x02461e29
|
||||
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
|
||||
preseed = preseed[::2] | preseed[1::2] << 32
|
||||
random_seed.set_random_seed(seed[0])
|
||||
with test_util.use_gpu():
|
||||
stateless_op, stateful_op = case
|
||||
if context.executing_eagerly():
|
||||
# Call set_random_seed in order to clear kernel cache, to prevent
|
||||
# kernel reusing for the stateful op
|
||||
random_seed.set_random_seed(seed[0])
|
||||
with ops.device(get_device().name):
|
||||
_, stateless_op, stateful_op = case
|
||||
random_seed.set_random_seed(seed[0])
|
||||
stateful = stateful_op(seed=seed[1])
|
||||
pure = stateless_op(seed=preseed)
|
||||
self.assertAllEqual(stateful, pure)
|
||||
|
||||
def _test_old_and_new_stateless_match(self, case, seed):
|
||||
"""Tests that the new stateless ops match the old stateless ones."""
|
||||
with ops.device(get_device().name):
|
||||
_, stateless_op, _ = case
|
||||
with compat.forward_compatibility_horizon(*BEFORE_EXPIRE):
|
||||
old = stateless_op(seed=seed)
|
||||
with compat.forward_compatibility_horizon(*AFTER_EXPIRE):
|
||||
new = stateless_op(seed=seed)
|
||||
self.assertAllClose(old, new)
|
||||
|
||||
def _test_determinism(self, case, seed_type):
|
||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
|
||||
with self.test_session(use_gpu=True), test_util.use_gpu():
|
||||
stateless_op, _ = case
|
||||
with self.test_session(use_gpu=True), ops.device(get_device().name):
|
||||
_, stateless_op, _ = case
|
||||
if context.executing_eagerly():
|
||||
values = [
|
||||
(seed, stateless_op(seed=constant_op.constant(seed, seed_type)))
|
||||
@ -183,88 +244,172 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
]
|
||||
for s0, v0 in values:
|
||||
for s1, v1 in values:
|
||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||
if dtypes.as_dtype(v0.dtype) != dtypes.bfloat16:
|
||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||
elif s0 == s1:
|
||||
# Skip the s0 != s1 case because v0 and v1 can be either equal or
|
||||
# unequal in that case due to bfloat16's low precision
|
||||
self.assertAllEqual(v0, v1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(float_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchFloat(self, case, seed):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
|
||||
'seeds needed by this test.')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(int_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchInt(self, case, seed):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
|
||||
'seeds needed by this test.')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(multinomial_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchMultinomial(self, case, seed):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(gamma_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchGamma(self, case, seed):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(poisson_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchPoisson(self, case, seed):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
for case_id, case in enumerate(float_cases(
|
||||
shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(float_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testOldAndNewStatelessMatchFloat(self, case, seed):
|
||||
self._test_old_and_new_stateless_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(
|
||||
int_cases(minval_maxval=((2, 11111), (None, None)))))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testOldAndNewStatelessMatchInt(self, case, seed):
|
||||
self._test_old_and_new_stateless_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
for case_id, case in enumerate(
|
||||
float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismFloat(self, case, seed_type):
|
||||
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
|
||||
'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest(
|
||||
'Skip on XLA because XLA kernels do not support int64 seeds.')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
for case_id, case in enumerate(int_cases(
|
||||
shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
for case_id, case in enumerate(
|
||||
int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismInt(self, case, seed_type):
|
||||
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
|
||||
'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest(
|
||||
'Skip on XLA because XLA kernels do not support int64 seeds.')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
for case_id, case in enumerate(multinomial_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismMultinomial(self, case, seed_type):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
for case_id, case in enumerate(gamma_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismGamma(self, case, seed_type):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
for case_id, case in enumerate(poisson_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismPoisson(self, case, seed_type):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
def assertDTypeEqual(self, a, b):
|
||||
@ -327,4 +472,6 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.set_soft_device_placement(False)
|
||||
context.context().enable_xla_devices()
|
||||
test.main()
|
||||
|
@ -23,6 +23,7 @@ import enum # pylint: disable=g-bad-import-order
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
@ -32,12 +33,14 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# A seed for random ops (stateful and stateless) will always be 1024
|
||||
# bits, all of which will be sent to the C++ code. The actual C++
|
||||
# implementation of some algorithms may only use a lower part of the bits.
|
||||
@ -136,6 +139,15 @@ def _make_1d_state(state_size, seed):
|
||||
return seed
|
||||
|
||||
|
||||
def _get_counter_size(alg):
|
||||
if alg == RNG_ALG_PHILOX:
|
||||
return 2
|
||||
elif alg == RNG_ALG_THREEFRY:
|
||||
return 1
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm id: %s" % alg)
|
||||
|
||||
|
||||
def _get_state_size(alg):
|
||||
if alg == RNG_ALG_PHILOX:
|
||||
return PHILOX_STATE_SIZE
|
||||
@ -560,6 +572,10 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
return self._alg
|
||||
|
||||
def _standard_normal(self, shape, dtype):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
|
||||
return gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||
|
||||
@ -586,6 +602,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm id: %s" % alg)
|
||||
|
||||
# TODO(wangpeng): Add "Returns" section to docstring once new version kicks in
|
||||
# pylint: disable=g-doc-return-or-yield
|
||||
def skip(self, delta):
|
||||
"""Advance the counter of a counter-based RNG.
|
||||
|
||||
@ -595,7 +613,24 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
(or any other distribution). The actual increment added to the
|
||||
counter is an unspecified implementation detail.
|
||||
"""
|
||||
gen_stateful_random_ops.rng_skip(self.state.handle, self.algorithm, delta)
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
return gen_stateful_random_ops.rng_read_and_skip(
|
||||
self.state.handle,
|
||||
alg=math_ops.cast(self.algorithm, dtypes.int32),
|
||||
delta=math_ops.cast(delta, dtypes.uint64))
|
||||
gen_stateful_random_ops.rng_skip(
|
||||
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
|
||||
math_ops.cast(delta, dtypes.int64))
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
def _prepare_key_counter(self, shape):
|
||||
delta = math_ops.reduce_prod(shape)
|
||||
counter_key = self.skip(delta)
|
||||
counter_size = _get_counter_size(self.algorithm)
|
||||
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
|
||||
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
|
||||
dtypes.uint64)
|
||||
return key, counter
|
||||
|
||||
# The following functions return a tensor and as a side effect update
|
||||
# self._state_var.
|
||||
@ -624,6 +659,14 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
return math_ops.add(rnd * stddev, mean, name=name)
|
||||
|
||||
def _truncated_normal(self, shape, dtype):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
dtype=dtype,
|
||||
alg=self.algorithm)
|
||||
return gen_stateful_random_ops.stateful_truncated_normal(
|
||||
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||
|
||||
@ -662,10 +705,27 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
return math_ops.add(mul, mean_tensor, name=name)
|
||||
|
||||
def _uniform(self, shape, dtype):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
dtype=dtype,
|
||||
alg=self.algorithm)
|
||||
return gen_stateful_random_ops.stateful_uniform(
|
||||
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
|
||||
|
||||
def _uniform_full_int(self, shape, dtype, name=None):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
dtype=dtype,
|
||||
alg=self.algorithm,
|
||||
name=name)
|
||||
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
dtype=dtype, name=name)
|
||||
@ -729,6 +789,16 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||
if dtype.is_integer:
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
minval=minval,
|
||||
maxval=maxval,
|
||||
alg=self.algorithm,
|
||||
name=name)
|
||||
return gen_stateful_random_ops.stateful_uniform_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
minval=minval, maxval=maxval, name=name)
|
||||
|
@ -274,7 +274,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
gen1 = random.Generator.from_seed(seed)
|
||||
gen2 = random.Generator.from_non_deterministic_state()
|
||||
sess.run((gen1._state_var.initializer, gen2._state_var.initializer))
|
||||
sess.run((gen1.state.initializer, gen2.state.initializer))
|
||||
r1 = gen1.normal(shape, dtype=dtypes.float32)
|
||||
r2 = gen2.normal(shape, dtype=dtypes.float32)
|
||||
def f():
|
||||
@ -372,7 +372,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
|
||||
delta = 432
|
||||
gen.skip(delta)
|
||||
new_counter = gen._state_var[0]
|
||||
new_counter = gen.state[0]
|
||||
self.assertAllEqual(counter + delta * 256, new_counter)
|
||||
|
||||
def _sameAsOldRandomOps(self, device, floats):
|
||||
@ -394,7 +394,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device(device):
|
||||
return new(dtype, gen)
|
||||
|
||||
for _ in range(100):
|
||||
for _ in range(5):
|
||||
self.assertAllEqual(run_old(), run_new())
|
||||
|
||||
shape = constant_op.constant([4, 7])
|
||||
@ -582,6 +582,11 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_v2_only
|
||||
def testGetGlobalGeneratorWithXla(self):
|
||||
"""Demonstrates using the global generator with XLA."""
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernel.
|
||||
# TODO(wangpeng): Remove this skip
|
||||
self.skipTest("NonDeterministicInts lacks XLA kernel.")
|
||||
|
||||
if not config.list_physical_devices("XLA_CPU"):
|
||||
self.skipTest("No XLA_CPU device available.")
|
||||
|
||||
@ -675,17 +680,16 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
|
||||
|
||||
@test_util.run_v2_only
|
||||
@test_util.run_cuda_only
|
||||
def testMirroredStratSeq(self):
|
||||
def testCreateOutsideMirroredStrat(self):
|
||||
"""Tests RNG/MirrorStrategy interaction #1.
|
||||
|
||||
If an RNG is created outside strategy.scope(), all replicas will access the
|
||||
If an RNG is created outside a DS scope, all replicas will access the
|
||||
same RNG object, and accesses are serialized.
|
||||
"""
|
||||
shape = [3, 4]
|
||||
dtype = dtypes.int32
|
||||
gen = random.Generator.from_seed(1234)
|
||||
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
|
||||
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
||||
with strat.scope():
|
||||
def f():
|
||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
@ -762,4 +766,5 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.set_soft_device_placement(False)
|
||||
test.main()
|
||||
|
@ -20,16 +20,19 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
ops.NotDifferentiable("StatelessMultinomial")
|
||||
ops.NotDifferentiable("StatelessRandomBinomial")
|
||||
ops.NotDifferentiable("StatelessRandomNormal")
|
||||
@ -40,6 +43,13 @@ ops.NotDifferentiable("StatelessRandomUniformFullInt")
|
||||
ops.NotDifferentiable("StatelessTruncatedNormal")
|
||||
|
||||
|
||||
ops.NotDifferentiable("StatelessRandomNormalV2")
|
||||
ops.NotDifferentiable("StatelessRandomUniformV2")
|
||||
ops.NotDifferentiable("StatelessRandomUniformIntV2")
|
||||
ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
|
||||
ops.NotDifferentiable("StatelessTruncatedNormalV2")
|
||||
|
||||
|
||||
@tf_export("random.experimental.stateless_split")
|
||||
@dispatch.add_dispatch_support
|
||||
def split(seed, num=2):
|
||||
@ -113,6 +123,10 @@ def fold_in(seed, data):
|
||||
return array_ops.stack([seed1, data])
|
||||
|
||||
|
||||
_get_key_counter_alg = (gen_stateless_random_ops_v2
|
||||
.stateless_random_get_key_counter_alg)
|
||||
|
||||
|
||||
@tf_export("random.stateless_uniform")
|
||||
@dispatch.add_dispatch_support
|
||||
def stateless_random_uniform(shape,
|
||||
@ -192,17 +206,35 @@ def stateless_random_uniform(shape,
|
||||
[shape, seed, minval, maxval]) as name:
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
if dtype.is_integer and minval is None:
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
|
||||
shape, seed=seed, dtype=dtype, name=name)
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
result = (gen_stateless_random_ops_v2
|
||||
.stateless_random_uniform_full_int_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg,
|
||||
name=name))
|
||||
else:
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
|
||||
shape, seed=seed, dtype=dtype, name=name)
|
||||
else:
|
||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||
if dtype.is_integer:
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_int(
|
||||
shape, seed=seed, minval=minval, maxval=maxval, name=name)
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
|
||||
shape, key=key, counter=counter, minval=minval, maxval=maxval,
|
||||
alg=alg, name=name)
|
||||
else:
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_int(
|
||||
shape, seed=seed, minval=minval, maxval=maxval, name=name)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_random_uniform(
|
||||
shape, seed=seed, dtype=dtype)
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_random_uniform(
|
||||
shape, seed=seed, dtype=dtype)
|
||||
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
@ -476,7 +508,12 @@ def stateless_random_normal(shape,
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
@ -521,8 +558,13 @@ def stateless_truncated_normal(shape,
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
rnd = gen_stateless_random_ops.stateless_truncated_normal(
|
||||
shape, seed, dtype)
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_truncated_normal(
|
||||
shape, seed, dtype)
|
||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
|
@ -3792,6 +3792,10 @@ tf_module {
|
||||
name: "Rint"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngReadAndSkip"
|
||||
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngSkip"
|
||||
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4544,10 +4548,18 @@ tf_module {
|
||||
name: "StatelessRandomGammaV2"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomGetKeyCounterAlg"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomPoisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4560,10 +4572,22 @@ tf_module {
|
||||
name: "StatelessRandomUniformFullInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformFullIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessSampleDistortedBoundingBox"
|
||||
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
||||
@ -4572,6 +4596,10 @@ tf_module {
|
||||
name: "StatelessTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessTruncatedNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessWhile"
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||
|
@ -3792,6 +3792,10 @@ tf_module {
|
||||
name: "Rint"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngReadAndSkip"
|
||||
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngSkip"
|
||||
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4544,10 +4548,18 @@ tf_module {
|
||||
name: "StatelessRandomGammaV2"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomGetKeyCounterAlg"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomPoisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4560,10 +4572,22 @@ tf_module {
|
||||
name: "StatelessRandomUniformFullInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformFullIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessSampleDistortedBoundingBox"
|
||||
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
||||
@ -4572,6 +4596,10 @@ tf_module {
|
||||
name: "StatelessTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessTruncatedNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessWhile"
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user