Internal change
PiperOrigin-RevId: 332361104 Change-Id: I1f66d7fa0a7fa5e48656232278ae9e22f26f4747
This commit is contained in:
parent
4ec907ddd1
commit
b1c97a0bb2
tensorflow
compiler
core
BUILD
api_def/base_api
api_def_RngReadAndSkip.pbtxtapi_def_StatelessRandomGetKeyCounterAlg.pbtxtapi_def_StatelessRandomNormalV2.pbtxtapi_def_StatelessRandomUniformFullIntV2.pbtxtapi_def_StatelessRandomUniformIntV2.pbtxtapi_def_StatelessRandomUniformV2.pbtxtapi_def_StatelessTruncatedNormalV2.pbtxt
framework
kernels
BUILDmultinomial_op_gpu.cu.ccrandom_binomial_op.ccrandom_op.ccrandom_op.hrandom_op_cpu.hrandom_op_gpu.hrandom_ops_util.hstateful_random_ops.ccstateful_random_ops.hstateful_random_ops_cpu_gpu.hstateful_random_ops_gpu.cu.ccstateless_random_ops.ccstateless_random_ops_v2.ccstateless_random_ops_v2.h
ops
python
tools/api/golden
@ -1995,8 +1995,6 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"ResourceScatterNdUpdate",
|
||||
"ResourceScatterSub",
|
||||
"ResourceScatterUpdate",
|
||||
"RngReadAndSkip",
|
||||
"RngSkip",
|
||||
"Roll",
|
||||
"ScatterNd",
|
||||
"SelfAdjointEigV2",
|
||||
@ -2019,17 +2017,11 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"StatelessCase",
|
||||
"StatelessIf",
|
||||
"StatelessMultinomial",
|
||||
"StatelessRandomGetKeyCounterAlg",
|
||||
"StatelessRandomNormal",
|
||||
"StatelessRandomNormalV2",
|
||||
"StatelessRandomUniform",
|
||||
"StatelessRandomUniformV2",
|
||||
"StatelessRandomUniformInt",
|
||||
"StatelessRandomUniformIntV2",
|
||||
"StatelessRandomUniformFullInt",
|
||||
"StatelessRandomUniformFullIntV2",
|
||||
"StatelessTruncatedNormal",
|
||||
"StatelessTruncatedNormalV2",
|
||||
"StatelessWhile",
|
||||
"Svd",
|
||||
"SymbolicGradient",
|
||||
|
@ -25,9 +25,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
@ -158,10 +156,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testNewStateThreeFry(self):
|
||||
"""Tests that the new state is correct (for ThreeFry).
|
||||
"""
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
self.skipTest("The expected values in this test is inconsistent with "
|
||||
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
|
||||
"new states for the new version.")
|
||||
with ops.device(xla_device_name()):
|
||||
counter = 57
|
||||
key = 0x1234
|
||||
@ -177,10 +171,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testNewStatePhilox(self):
|
||||
"""Tests that the new state is correct (for Philox).
|
||||
"""
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
self.skipTest("The expected values in this test is inconsistent with "
|
||||
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
|
||||
"new states for the new version.")
|
||||
with ops.device(xla_device_name()):
|
||||
counter_low = 57
|
||||
counter_high = 283
|
||||
@ -214,39 +204,13 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
"""Tests that XLA and CPU kernels generate the same integers."""
|
||||
seed = 1234
|
||||
shape = [315, 49]
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu_gen = random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
with ops.device(xla_device_name()):
|
||||
xla_gen = random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
# Repeat multiple times to make sure that the state after
|
||||
# number-generation are the same between CPU and XLA.
|
||||
for _ in range(5):
|
||||
with ops.device("/device:CPU:0"):
|
||||
# Test both number-generation and skip
|
||||
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
cpu_gen.skip(100)
|
||||
with ops.device(xla_device_name()):
|
||||
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
xla_gen.skip(100)
|
||||
self.assertAllEqual(cpu, xla)
|
||||
self.assertAllEqual(cpu_gen.state, xla_gen.state)
|
||||
else:
|
||||
# The old version doesn't guarantee that CPU and XLA are in the same state
|
||||
# after number-generation, which is a bug.
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = (
|
||||
random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
|
||||
shape=shape, dtype=dtype))
|
||||
with ops.device(xla_device_name()):
|
||||
xla = (
|
||||
random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
|
||||
shape=shape, dtype=dtype))
|
||||
self.assertAllEqual(cpu, xla)
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
with ops.device(xla_device_name()):
|
||||
xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
self.assertAllEqual(cpu, xla)
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
@ -400,5 +364,4 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
config.set_soft_device_placement(False)
|
||||
test.main()
|
||||
|
@ -21,11 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests.random import util as \
|
||||
random_test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -43,26 +39,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
allowed_types.update({dtypes.int32, dtypes.int64})
|
||||
return self.all_tf_types & allowed_types
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testForcedCompile(self):
|
||||
"""Tests whole-function forced-compilation.
|
||||
|
||||
This test checks that stateless_random_* can be used in forced-compilation
|
||||
scenarios (e.g. TPU). The new version of stateless_random_* requires the
|
||||
intermediate tensor `alg` to be compile-time constant, so we need to check
|
||||
that this requirement is met. We use xla.compile instead of tf.function's
|
||||
experimental_compile because the latter doesn't throw an error even if the
|
||||
compile-time-constant constraint is not met.
|
||||
"""
|
||||
if config.list_logical_devices('TPU'):
|
||||
self.skipTest('To accommodate OSS, xla.compile support for TPU is not '
|
||||
'linked in.')
|
||||
@def_function.function
|
||||
def f(x):
|
||||
return xla.compile(
|
||||
lambda x: stateless.stateless_random_normal([], seed=x), [x])
|
||||
f([1, 2])
|
||||
|
||||
def testDeterminism(self):
|
||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||
with self.session(), self.test_scope():
|
||||
@ -162,7 +138,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
||||
|
||||
def _benchmarkUniform(self, name, dtype, use_xla_jit):
|
||||
|
||||
def builder_fn():
|
||||
def BuilderFn():
|
||||
shape = (10, 1000, 1000)
|
||||
seed_var = variables.Variable((312, 456),
|
||||
dtype=dtypes.int32,
|
||||
@ -171,7 +147,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
||||
shape, seed=seed_var, dtype=dtype)
|
||||
return '%s.shape%s' % (name, shape), [random_t]
|
||||
|
||||
xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu')
|
||||
xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu')
|
||||
|
||||
def benchmarkUniformF32(self):
|
||||
self._benchmarkUniform(
|
||||
@ -191,5 +167,4 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.set_soft_device_placement(False)
|
||||
test.main()
|
||||
|
@ -108,7 +108,6 @@ tf_kernel_library(
|
||||
"stack_ops.cc",
|
||||
"stateful_random_ops.cc",
|
||||
"stateless_random_ops.cc",
|
||||
"stateless_random_ops_v2.cc",
|
||||
"strided_slice_op.cc",
|
||||
"tensor_array_ops.cc",
|
||||
"tensor_list_ops.cc",
|
||||
@ -188,7 +187,6 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:stateful_random_ops_header",
|
||||
"//tensorflow/core/kernels:stateless_random_ops_v2_header",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -30,7 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
@ -181,7 +180,7 @@ Status CompileImpl(
|
||||
}
|
||||
xla::Literal alg_literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||
Algorithm alg = Algorithm(alg_literal.Get<int>({}));
|
||||
auto alg = alg_literal.Get<Algorithm>({});
|
||||
if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
|
||||
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
|
||||
}
|
||||
@ -408,80 +407,5 @@ REGISTER_XLA_OP(Name("StatefulUniformFullInt")
|
||||
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
||||
StatefulUniformFullIntOp);
|
||||
|
||||
xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
|
||||
xla::XlaOp delta) {
|
||||
// Multiplying 256 to be consistent with the CPU/GPU kernels
|
||||
delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return xla::PhiloxIncreaseCounter(counter, delta);
|
||||
} else {
|
||||
return counter + delta;
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp PadRight(xla::XlaOp a, int n) {
|
||||
return xla::Pad(a, xla::ScalarLike(a, 0),
|
||||
xla::MakeEdgePaddingConfig({{0, n}}));
|
||||
}
|
||||
|
||||
template <typename AlgEnumType = int64, bool read_old_value = false>
|
||||
class RngSkipOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const int state_input_idx = 0;
|
||||
const int alg_input_idx = 1;
|
||||
const int delta_input_idx = 2;
|
||||
xla::XlaOp var;
|
||||
TensorShape var_shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
|
||||
&var_shape, &var));
|
||||
xla::Literal alg_literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||
Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
|
||||
OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
|
||||
if (read_old_value) {
|
||||
auto counter_size = GetCounterSize(alg);
|
||||
xla::XlaOp output = var;
|
||||
if (RNG_MAX_COUNTER_SIZE > counter_size) {
|
||||
// Because the size of `var` depends on the algorithm while we want the
|
||||
// output to have a fixed size (to help shape inference), we fix the
|
||||
// output size to be the maximal state size among algorithms, and right-
|
||||
// pad it with zeros if var's size is smaller than that.
|
||||
output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
|
||||
}
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
xla::XlaOp counter;
|
||||
xla::XlaOp key;
|
||||
std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
|
||||
xla::XlaOp delta = ctx->Input(delta_input_idx);
|
||||
delta = BitcastConvertType(delta, xla::U64);
|
||||
auto new_counter = IncreaseCounter(alg, counter, delta);
|
||||
var = StateAndKeyToVariable(alg, new_counter, key);
|
||||
xla::PrimitiveType state_element_type;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||
var = BitcastConvertType(var, state_element_type);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
|
||||
RngSkipOp<>);
|
||||
|
||||
using RngReadAndSkipOp = RngSkipOp<int32, true>;
|
||||
|
||||
REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
|
||||
RngReadAndSkipOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -111,8 +111,6 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
||||
xla::XlaOp seeds,
|
||||
const xla::Shape& shape) {
|
||||
@ -142,6 +140,8 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||
|
@ -1,485 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return xla::RandomAlgorithm::RNG_PHILOX;
|
||||
}
|
||||
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
||||
}
|
||||
|
||||
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
|
||||
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
|
||||
return RNG_ALG_PHILOX;
|
||||
}
|
||||
return RNG_ALG_THREEFRY;
|
||||
}
|
||||
|
||||
xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
|
||||
Algorithm alg_ = RandomAlgorithmToAlgorithm(alg);
|
||||
return xla::Slice(state, {RNG_KEY_SIZE},
|
||||
{RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
|
||||
}
|
||||
|
||||
xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
|
||||
xla::XlaOp counter, const xla::Shape& shape) {
|
||||
key = BitcastConvertType(key, xla::U64);
|
||||
counter = BitcastConvertType(counter, xla::U64);
|
||||
xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
|
||||
xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
|
||||
auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
|
||||
new_counter = BitcastConvertType(new_counter, xla::S64);
|
||||
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
|
||||
/*state=*/new_counter};
|
||||
}
|
||||
|
||||
std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
|
||||
absl::string_view device_type_string, xla::XlaOp key) {
|
||||
// The Philox algorithm may cause performance regression on other devices.
|
||||
// Turn on the Philox algorithm for the CPU and GPU backends only.
|
||||
if (device_type_string == DEVICE_GPU_XLA_JIT ||
|
||||
device_type_string == DEVICE_CPU_XLA_JIT) {
|
||||
auto counter_key = xla::ScramblePhiloxKey(key);
|
||||
return std::make_tuple(counter_key.second, counter_key.first,
|
||||
RNG_ALG_PHILOX);
|
||||
} else {
|
||||
auto counter_shape =
|
||||
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
||||
auto counter = xla::Zeros(key.builder(), counter_shape);
|
||||
return std::make_tuple(key, counter, RNG_ALG_THREEFRY);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
|
||||
xla::XlaOp key, xla::XlaOp counter,
|
||||
const xla::Shape& shape, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::XlaBuilder* builder = key.builder();
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
using std::placeholders::_3;
|
||||
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
|
||||
switch (type) {
|
||||
case xla::F32:
|
||||
case xla::F64:
|
||||
return xla::UniformFloatingPointDistribution(key, counter, generator,
|
||||
minval, maxval, shape);
|
||||
case xla::S32:
|
||||
case xla::S64:
|
||||
case xla::U32:
|
||||
case xla::U64:
|
||||
return UniformIntDistribution(key, counter, generator, minval, maxval,
|
||||
shape);
|
||||
break;
|
||||
default:
|
||||
return {builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, S32, S64, U32 and U64 are not "
|
||||
"implemented by "
|
||||
"StatelessRngUniformV2; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter};
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
|
||||
xla::XlaOp key, xla::XlaOp counter,
|
||||
const xla::Shape& shape) {
|
||||
xla::XlaBuilder* builder = key.builder();
|
||||
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
xla::RngOutput output = BitGenerator(alg, key, counter, shape);
|
||||
switch (type) {
|
||||
case xla::U32:
|
||||
case xla::U64:
|
||||
return output;
|
||||
case xla::S32:
|
||||
case xla::S64:
|
||||
return xla::RngOutput{BitcastConvertType(output.value, type),
|
||||
output.state};
|
||||
default:
|
||||
return {
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||
"StatelessRngUniformFullInt; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
output.state};
|
||||
}
|
||||
}
|
||||
|
||||
Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx,
|
||||
xla::RandomAlgorithm* alg) {
|
||||
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||
if (alg_shape.dims() != 0) {
|
||||
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
alg_shape.DebugString());
|
||||
}
|
||||
xla::Literal alg_literal;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||
auto alg_ = Algorithm(alg_literal.Get<int>({}));
|
||||
*alg = AlgorithmToRandomAlgorithm(alg_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
|
||||
TensorShape const& counter_shape,
|
||||
xla::XlaOp counter) {
|
||||
auto input_counter_size = counter_shape.dim_size(0);
|
||||
auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg));
|
||||
if (input_counter_size > real_counter_size) {
|
||||
counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
|
||||
}
|
||||
return counter;
|
||||
}
|
||||
|
||||
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||
xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
|
||||
auto result = StatelessRngUniformV2(
|
||||
alg, key, counter, xla_shape,
|
||||
xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
|
||||
xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
|
||||
auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
|
||||
ctx->SetOutput(0, uniform);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||
StatelessRandomUniformOp);
|
||||
|
||||
class StatelessRandomUniformIntOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
const int minval_input_idx = 4;
|
||||
const int maxval_input_idx = 5;
|
||||
TensorShape minval_shape = ctx->InputShape(minval_input_idx);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
|
||||
errors::InvalidArgument("minval must be scalar, got shape ",
|
||||
minval_shape.DebugString()));
|
||||
TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
|
||||
errors::InvalidArgument("maxval must be scalar, got shape ",
|
||||
maxval_shape.DebugString()));
|
||||
|
||||
xla::XlaOp minval = ctx->Input(minval_input_idx);
|
||||
xla::XlaOp maxval = ctx->Input(maxval_input_idx);
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result =
|
||||
StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
|
||||
ctx->SetOutput(0, result.value);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
|
||||
StatelessRandomUniformIntOp);
|
||||
|
||||
class StatelessRandomUniformFullIntOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
|
||||
ctx->SetOutput(0, result.value);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
|
||||
StatelessRandomUniformFullIntOp);
|
||||
|
||||
class StatelessRandomNormalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||
|
||||
using std::placeholders::_1;
|
||||
using std::placeholders::_2;
|
||||
using std::placeholders::_3;
|
||||
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
|
||||
xla_shape);
|
||||
auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
|
||||
ctx->SetOutput(0, normal);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||
StatelessRandomNormalOp);
|
||||
|
||||
class StatelessTruncatedNormalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||
|
||||
const int key_input_idx = 1;
|
||||
const int counter_input_idx = 2;
|
||||
const int alg_input_idx = 3;
|
||||
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||
|
||||
xla::RandomAlgorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||
|
||||
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||
ctx->InputShape(key_input_idx),
|
||||
counter_shape));
|
||||
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
|
||||
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||
|
||||
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||
auto result = StatelessRngUniformV2(
|
||||
alg, key, counter, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(result.value);
|
||||
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||
ctx->SetOutput(0, truncated_normal);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
|
||||
.CompileTimeConstantInput("shape")
|
||||
.CompileTimeConstantInput("alg")
|
||||
.TypeConstraint("dtype",
|
||||
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||
StatelessTruncatedNormalOp);
|
||||
|
||||
class GetKeyCounterAlgOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx),
|
||||
device_type_string_(ctx->device_type().type_string()) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
TensorShape seed_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_shape.DebugString()));
|
||||
xla::XlaOp seed = ctx->Input(0);
|
||||
|
||||
xla::XlaBuilder* builder = seed.builder();
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key);
|
||||
key = std::get<0>(key_counter_alg);
|
||||
auto counter = std::get<1>(key_counter_alg);
|
||||
auto alg = std::get<2>(key_counter_alg);
|
||||
key = xla::Reshape(key, {RNG_KEY_SIZE});
|
||||
ctx->SetOutput(0, key);
|
||||
ctx->SetOutput(1, counter);
|
||||
ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
|
||||
}
|
||||
|
||||
private:
|
||||
string device_type_string_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -83,8 +83,6 @@ CreateResourceOpInfoMap() {
|
||||
add("ResourceScatterSub" , kReadWrite, kVariable);
|
||||
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
||||
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
||||
add("RngReadAndSkip" , kReadWrite, kVariable);
|
||||
add("RngSkip" , kReadWrite, kVariable);
|
||||
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
|
||||
add("StatefulTruncatedNormal" , kReadWrite, kVariable);
|
||||
add("StatefulUniform" , kReadWrite, kVariable);
|
||||
|
@ -487,10 +487,6 @@ std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
|
||||
return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
|
||||
}
|
||||
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
PrimitiveType type = shape.element_type();
|
||||
|
@ -89,9 +89,6 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
|
||||
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||
absl::Span<const xla::XlaOp> scalars);
|
||||
|
||||
// Increases Philox counter (an uint128) by a delta (an uint64).
|
||||
xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
|
||||
|
@ -488,7 +488,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/framework:register_types_traits.h",
|
||||
"//tensorflow/core/framework:resource_mgr.h",
|
||||
"//tensorflow/core/framework:resource_op_kernel.h",
|
||||
"//tensorflow/core/framework:rng_alg.h",
|
||||
"//tensorflow/core/framework:selective_registration.h",
|
||||
"//tensorflow/core/framework:session_state.h",
|
||||
"//tensorflow/core/framework:shape_inference.h",
|
||||
@ -653,7 +652,6 @@ tf_gen_op_libs(
|
||||
"spectral_ops",
|
||||
"state_ops",
|
||||
"stateless_random_ops",
|
||||
"stateless_random_ops_v2",
|
||||
"summary_ops",
|
||||
"training_ops",
|
||||
],
|
||||
@ -873,7 +871,6 @@ cc_library(
|
||||
":spectral_ops_op_lib",
|
||||
":state_ops_op_lib",
|
||||
":stateless_random_ops_op_lib",
|
||||
":stateless_random_ops_v2_op_lib",
|
||||
":string_ops_op_lib",
|
||||
":training_ops_op_lib",
|
||||
":user_ops_op_lib",
|
||||
|
@ -1,35 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "RngReadAndSkip"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "resource"
|
||||
description: <<END
|
||||
The handle of the resource variable that stores the state of the RNG.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "delta"
|
||||
description: <<END
|
||||
The amount of advancement.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "value"
|
||||
description: <<END
|
||||
The old value of the resource variable, before incrementing. Since state size is algorithm-dependent, this output will be right-padded with zeros to reach shape int64[3] (the current maximal state size among algorithms).
|
||||
END
|
||||
}
|
||||
summary: "Advance the counter of a counter-based RNG."
|
||||
description: <<END
|
||||
The state of the RNG after
|
||||
`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
|
||||
(or any other distribution). The actual increment added to the
|
||||
counter is an unspecified implementation choice.
|
||||
END
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomGetKeyCounterAlg"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "seed"
|
||||
description: <<END
|
||||
2 seeds (shape [2]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Counter for the counter-based RNG algorithm. Since counter size is algorithm-dependent, this output will be right-padded with zeros to reach shape uint64[2] (the current maximal counter size among algorithms).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
summary: "Picks the best algorithm based on device, and scrambles seed into key and counter."
|
||||
description: <<END
|
||||
This op picks the best counter-based RNG algorithm based on device, and scrambles a shape-[2] seed into a key and a counter, both needed by the counter-based algorithm. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
|
||||
END
|
||||
}
|
@ -1,46 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomNormalV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom values from a normal distribution."
|
||||
description: <<END
|
||||
The generated values will have mean 0 and standard deviation 1.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -1,46 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomUniformFullIntV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values are uniform integers covering the whole range of `dtype`.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomUniformIntV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "minval"
|
||||
description: <<END
|
||||
Minimum value (inclusive, scalar).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "maxval"
|
||||
description: <<END
|
||||
Maximum value (exclusive, scalar).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values follow a uniform distribution in the range `[minval, maxval)`.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter`, `alg`, `minval` and `maxval`.
|
||||
END
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "StatelessRandomUniformV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom random values from a uniform distribution."
|
||||
description: <<END
|
||||
The generated values follow a uniform distribution in the range `[0, 1)`. The
|
||||
lower bound 0 is included in the range, while the upper bound 1 is excluded.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -1,48 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "StatelessTruncatedNormalV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "shape"
|
||||
description: <<END
|
||||
The shape of the output tensor.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "key"
|
||||
description: <<END
|
||||
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "counter"
|
||||
description: <<END
|
||||
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "alg"
|
||||
description: <<END
|
||||
The RNG algorithm (shape int32[]).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "output"
|
||||
description: <<END
|
||||
Random values with specified shape.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the output.
|
||||
END
|
||||
}
|
||||
summary: "Outputs deterministic pseudorandom values from a truncated normal distribution."
|
||||
description: <<END
|
||||
The generated values follow a normal distribution with mean 0 and standard
|
||||
deviation 1, except that values whose magnitude is more than 2 standard
|
||||
deviations from the mean are dropped and re-picked.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
END
|
||||
}
|
@ -68,7 +68,6 @@ exports_files(
|
||||
"resource_mgr.h",
|
||||
"resource_op_kernel.h",
|
||||
"resource_var.h",
|
||||
"rng_alg.h",
|
||||
"run_handler.h",
|
||||
"run_handler_util.h",
|
||||
"session_state.h",
|
||||
@ -386,7 +385,6 @@ filegroup(
|
||||
"resource_mgr.h",
|
||||
"resource_op_kernel.h",
|
||||
"resource_var.h",
|
||||
"rng_alg.h",
|
||||
"run_handler.cc",
|
||||
"run_handler.h",
|
||||
"run_handler_util.cc",
|
||||
|
@ -1,34 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
enum Algorithm { RNG_ALG_PHILOX = 1, RNG_ALG_THREEFRY = 2 };
|
||||
|
||||
static constexpr int RNG_KEY_SIZE = 1;
|
||||
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
|
||||
inline int GetCounterSize(Algorithm alg) {
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
@ -4457,22 +4457,12 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stateless_random_ops_v2_header",
|
||||
hdrs = ["stateless_random_ops_v2.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stateful_random_ops",
|
||||
prefix = "stateful_random_ops",
|
||||
deps = [
|
||||
":bounds_check",
|
||||
":dense_update_functor",
|
||||
":fill_functor",
|
||||
":gather_functor",
|
||||
":mutex_ops",
|
||||
":random_op",
|
||||
@ -5347,7 +5337,7 @@ tf_kernel_library(
|
||||
prefix = "random_binomial_op",
|
||||
deps = [
|
||||
":cwise_op",
|
||||
":random_op",
|
||||
":random_ops",
|
||||
":resource_variable_ops",
|
||||
":stateful_random_ops",
|
||||
":stateless_random_ops",
|
||||
@ -6169,7 +6159,6 @@ filegroup(
|
||||
"ragged_tensor_to_tensor_op.cc",
|
||||
"random_op.cc",
|
||||
"random_op_cpu.h",
|
||||
"random_ops_util.h",
|
||||
"random_poisson_op.cc",
|
||||
"reduce_join_op.cc",
|
||||
"reduction_ops_all.cc",
|
||||
|
@ -66,9 +66,8 @@ struct MultinomialFunctor<GPUDevice, T, OutputType> {
|
||||
typename TTypes<OutputType>::Matrix output) {
|
||||
// Uniform, [0, 1).
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, float> Dist;
|
||||
functor::FillPhiloxRandom<GPUDevice, Dist>()(
|
||||
ctx, d, /*key=*/nullptr, /*counter=*/nullptr, gen, noises.data(),
|
||||
noises.size(), Dist());
|
||||
functor::FillPhiloxRandom<GPUDevice, Dist>()(ctx, d, gen, noises.data(),
|
||||
noises.size(), Dist());
|
||||
|
||||
#if defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::IndexList<int, int, int> bsc;
|
||||
|
@ -30,10 +30,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
@ -377,7 +375,7 @@ class RandomBinomialOp : public OpKernel {
|
||||
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
||||
errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
alg_tensor.shape().DebugString()));
|
||||
Algorithm alg = Algorithm(alg_tensor.flat<int64>()(0));
|
||||
Algorithm alg = alg_tensor.flat<Algorithm>()(0);
|
||||
|
||||
int64 samples_per_batch = 1;
|
||||
const int64 num_sample_dims =
|
||||
|
@ -74,7 +74,7 @@ class PhiloxRandomOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
||||
auto output_flat = output->flat<T>();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
|
||||
ctx, ctx->eigen_device<Device>(),
|
||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
||||
// it just here.
|
||||
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
||||
@ -123,7 +123,7 @@ class RandomUniformIntOp : public OpKernel {
|
||||
|
||||
auto output_flat = output->flat<IntType>();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
|
||||
ctx, ctx->eigen_device<Device>(),
|
||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
||||
// it just here.
|
||||
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
||||
|
@ -34,14 +34,10 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
// NOTE: Due to inlining done by the compiler, you may need to add
|
||||
// explicit instantiation of the functor in random_op.cc. See example
|
||||
// functor::FillPhiloxRandom<CPUDevice, random::UniformDistribution>.
|
||||
//
|
||||
// This functor can take the PhiloxRandom input from either device memory `key`
|
||||
// and `counter` or a stack value `gen`. If both `key` and `counter` are not
|
||||
// nullptr, they provide the input; otherwise `gen` provides the input.
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandom<CPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||
random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist);
|
||||
};
|
||||
@ -51,13 +47,14 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
// Declares the partially GPU-specialized functor struct.
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandom<GPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist);
|
||||
};
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/lib/hash/crc32c.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
@ -60,9 +59,8 @@ using random::SingleSampleAdapter;
|
||||
template <typename Device, class Distribution>
|
||||
struct FillPhiloxRandom {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
void operator()(OpKernelContext* ctx, const Device&, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen, T* data,
|
||||
int64 size, Distribution dist) {
|
||||
void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen,
|
||||
T* data, int64 size, Distribution dist) {
|
||||
OP_REQUIRES(
|
||||
ctx, false,
|
||||
errors::Internal(
|
||||
@ -156,24 +154,18 @@ struct FillPhiloxRandomTask<Distribution, true> {
|
||||
// It splits the work into several tasks and run them in parallel
|
||||
template <class Distribution>
|
||||
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
|
||||
OpKernelContext* ctx, const CPUDevice&, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist) {
|
||||
const int kGroupSize = Distribution::kResultElementCount;
|
||||
|
||||
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
|
||||
|
||||
const int kGroupCost =
|
||||
random::PhiloxRandom::kResultElementCount *
|
||||
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
|
||||
|
||||
if (key != nullptr && counter != nullptr) {
|
||||
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||
}
|
||||
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
|
||||
kGroupCost,
|
||||
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
@ -34,16 +33,14 @@ struct FillPhiloxRandomKernel;
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandomKernel<Distribution, false> {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
|
||||
random::PhiloxRandom gen, T* data, int64 size,
|
||||
PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size,
|
||||
Distribution dist);
|
||||
};
|
||||
|
||||
template <class Distribution>
|
||||
struct FillPhiloxRandomKernel<Distribution, true> {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
|
||||
random::PhiloxRandom base_gen, T* data,
|
||||
PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data,
|
||||
int64 size, Distribution dist);
|
||||
};
|
||||
|
||||
@ -139,16 +136,12 @@ class SampleCopier<int64, 2> {
|
||||
// distribution. Each output takes a fixed number of samples.
|
||||
template <class Distribution>
|
||||
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
||||
const uint64* key, const uint64* counter, random::PhiloxRandom gen, T* data,
|
||||
int64 size, Distribution dist) {
|
||||
random::PhiloxRandom gen, T* data, int64 size, Distribution dist) {
|
||||
const int kGroupSize = Distribution::kResultElementCount;
|
||||
|
||||
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32 total_thread_count = gridDim.x * blockDim.x;
|
||||
int32 offset = thread_id * kGroupSize;
|
||||
if (key != nullptr && counter != nullptr) {
|
||||
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||
}
|
||||
gen.Skip(thread_id);
|
||||
|
||||
const SampleCopier<T, kGroupSize> copier;
|
||||
@ -174,8 +167,8 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
||||
// distribution. Each output takes a variable number of samples.
|
||||
template <class Distribution>
|
||||
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||
const uint64* key, const uint64* counter, random::PhiloxRandom base_gen,
|
||||
T* data, int64 size, Distribution dist) {
|
||||
const random::PhiloxRandom& base_gen, T* data, int64 size,
|
||||
Distribution dist) {
|
||||
using random::PhiloxRandom;
|
||||
using random::SingleSampleAdapter;
|
||||
|
||||
@ -190,9 +183,6 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||
int64 group_index = thread_id;
|
||||
int64 offset = group_index * kGroupSize;
|
||||
|
||||
if (key != nullptr && counter != nullptr) {
|
||||
base_gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||
}
|
||||
while (offset < size) {
|
||||
// Since each output takes a variable number of samples, we need to
|
||||
// realign the generator to the beginning for the current output group
|
||||
@ -218,20 +208,18 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||
// A simple launch pad to call the correct function templates to fill the data
|
||||
template <class Distribution>
|
||||
__global__ void __launch_bounds__(1024)
|
||||
FillPhiloxRandomKernelLaunch(const uint64* key, const uint64* counter,
|
||||
random::PhiloxRandom base_gen,
|
||||
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
|
||||
typename Distribution::ResultElementType* data,
|
||||
int64 size, Distribution dist) {
|
||||
FillPhiloxRandomKernel<Distribution,
|
||||
Distribution::kVariableSamplesPerOutput>()
|
||||
.Run(key, counter, base_gen, data, size, dist);
|
||||
.Run(base_gen, data, size, dist);
|
||||
}
|
||||
|
||||
// Partial specialization for GPU
|
||||
template <class Distribution>
|
||||
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
||||
OpKernelContext*, const GPUDevice& d, const uint64* key,
|
||||
const uint64* counter, random::PhiloxRandom gen,
|
||||
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
|
||||
typename Distribution::ResultElementType* data, int64 size,
|
||||
Distribution dist) {
|
||||
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||
@ -240,8 +228,8 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
||||
block_size;
|
||||
|
||||
TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
||||
num_blocks, block_size, 0, d.stream(), key,
|
||||
counter, gen, data, size, dist));
|
||||
num_blocks, block_size, 0, d.stream(), gen, data,
|
||||
size, dist));
|
||||
}
|
||||
|
||||
} // namespace functor
|
||||
|
@ -1,72 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
||||
|
||||
#include "tensorflow/core/lib/random/philox_random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
// The following 2 functions use the contract "lower 32 bits for the first
|
||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||
PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1,
|
||||
uint32* output2) {
|
||||
*output1 = static_cast<uint32>(input);
|
||||
*output2 = static_cast<uint32>(input >> 32);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) {
|
||||
auto u64_1 = static_cast<uint64>(input1);
|
||||
auto u64_2 = static_cast<uint64>(input2);
|
||||
return u64_1 | (u64_2 << 32);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem(
|
||||
uint64 const* ptr) {
|
||||
PhiloxRandom::ResultType counter;
|
||||
Uint64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||
Uint64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||
return counter;
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void WriteCounterToMem(
|
||||
PhiloxRandom::ResultType const& counter, uint64* ptr) {
|
||||
ptr[0] = Uint32sToUint64(counter[0], counter[1]);
|
||||
ptr[1] = Uint32sToUint64(counter[2], counter[3]);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) {
|
||||
PhiloxRandom::Key key;
|
||||
Uint64ToUint32s(ptr[0], &key[0], &key[1]);
|
||||
return key;
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key,
|
||||
uint64* ptr) {
|
||||
*ptr = Uint32sToUint64(key[0], key[1]);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem(
|
||||
uint64 const* counter_ptr, uint64 const* key_ptr) {
|
||||
return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr));
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
@ -15,9 +15,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/random_op_cpu.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||
@ -25,8 +23,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Distribution>
|
||||
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
||||
@ -46,13 +42,10 @@ struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||
// No longer needs the lock.
|
||||
state_var_guard->Release();
|
||||
functor::FillPhiloxRandom<CPUDevice, Distribution>()(
|
||||
ctx, device, /*key=*/nullptr, /*counter=*/nullptr, philox, output_data,
|
||||
output_size, dist);
|
||||
ctx, device, philox, output_data, output_size, dist);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
Status CheckState(const Tensor& state) {
|
||||
if (state.dtype() != STATE_ELEMENT_DTYPE) {
|
||||
return errors::InvalidArgument("dtype of RNG state variable must be ",
|
||||
@ -71,12 +64,11 @@ Status CheckPhiloxState(const Tensor& state, int64 alg_tag_skip = 0) {
|
||||
"StateElementType must be int64");
|
||||
static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
|
||||
"PhiloxRandom::ResultElementType must be uint32");
|
||||
auto min_size = alg_tag_skip + PHILOX_MIN_STATE_SIZE;
|
||||
if (state.NumElements() < min_size) {
|
||||
if (state.NumElements() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
|
||||
return errors::InvalidArgument(
|
||||
"For the Philox algorithm, the size of state"
|
||||
" must be at least ",
|
||||
min_size, "; got ", state.NumElements());
|
||||
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ", state.NumElements());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -103,7 +95,7 @@ Status UpdateVariableAndFill(
|
||||
if (var_tensor_flat.size() < 1) {
|
||||
return errors::InvalidArgument("Size of tensor must be at least 1");
|
||||
}
|
||||
alg = Algorithm(var_tensor_flat(0));
|
||||
alg = var_tensor_flat(0);
|
||||
}
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip));
|
||||
@ -115,7 +107,7 @@ Status UpdateVariableAndFill(
|
||||
arg.alg_tag_skip = alg_tag_skip;
|
||||
arg.not_used = &state_var_guard;
|
||||
arg.state_tensor = var_tensor;
|
||||
functor::UpdateVariableAndFill_Philox<Device, Distribution>()(
|
||||
UpdateVariableAndFill_Philox<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), dist, &arg, output_data);
|
||||
return Status::OK();
|
||||
} else {
|
||||
@ -146,8 +138,7 @@ class StatefulRandomOp : public OpKernel {
|
||||
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true,
|
||||
RNG_ALG_PHILOX /*dummy*/);
|
||||
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -168,14 +159,6 @@ Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename AlgEnumType>
|
||||
Status GetAlg(OpKernelContext* ctx, int input_idx, Algorithm* alg) {
|
||||
AlgEnumType alg_id;
|
||||
TF_RETURN_IF_ERROR(GetScalar(ctx->input(input_idx), input_idx, &alg_id));
|
||||
*alg = Algorithm(alg_id);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename Device, class Distribution>
|
||||
class StatefulRandomOpV2 : public OpKernel {
|
||||
public:
|
||||
@ -183,7 +166,7 @@ class StatefulRandomOpV2 : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
|
||||
/*shape_input_idx=*/2,
|
||||
/*read_alg_from_state=*/false, alg);
|
||||
@ -197,7 +180,7 @@ class StatefulUniformIntOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
const Tensor& minval = ctx->input(3);
|
||||
const Tensor& maxval = ctx->input(4);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
||||
@ -234,7 +217,7 @@ class StatefulUniformFullIntOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
StatefulRandomCompute<Device>(
|
||||
ctx,
|
||||
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
|
||||
@ -243,66 +226,38 @@ class StatefulUniformFullIntOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <>
|
||||
struct RngSkip_Philox<CPUDevice> {
|
||||
void operator()(const CPUDevice& device, const StateElementType* in_data,
|
||||
uint64 delta, StateElementType* out_data) {
|
||||
void operator()(const CPUDevice& device, int64 delta, Tensor* state_tensor) {
|
||||
auto state_data = state_tensor->flat<StateElementType>().data();
|
||||
// Delegates to PhiloxRandom to do the actual increasing.
|
||||
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
|
||||
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
|
||||
auto philox = GetPhiloxRandomFromMem(state_data);
|
||||
UpdateMemWithPhiloxRandom(philox, delta, state_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
template <typename Device, typename AlgEnumType = int64,
|
||||
typename DeltaType = int64, bool read_old_value = false>
|
||||
template <typename Device>
|
||||
class RngSkipOp : public OpKernel {
|
||||
public:
|
||||
explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
auto state_input_idx = 0;
|
||||
auto alg_input_idx = 1;
|
||||
auto delta_input_idx = 2;
|
||||
Algorithm alg;
|
||||
OP_REQUIRES_OK(ctx, GetAlg<AlgEnumType>(ctx, alg_input_idx, &alg));
|
||||
DeltaType delta_;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetScalar(ctx->input(delta_input_idx), delta_input_idx, &delta_));
|
||||
uint64 delta = static_cast<uint64>(delta_);
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
||||
int64 delta;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(2), 2, &delta));
|
||||
Var* var = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
|
||||
ScopedUnlockUnrefVar state_var_guard(var);
|
||||
Tensor* var_tensor = var->tensor();
|
||||
OP_REQUIRES_OK(ctx, CheckState(*var_tensor));
|
||||
using T = StateElementType;
|
||||
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, T>(
|
||||
ctx, var_tensor, var->copy_on_read_mode.load()));
|
||||
if (read_old_value) {
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, {RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE},
|
||||
&output));
|
||||
auto output_flat = output->flat<T>();
|
||||
if (RNG_MAX_COUNTER_SIZE > GetCounterSize(alg)) {
|
||||
functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
|
||||
output_flat);
|
||||
}
|
||||
functor::DenseUpdate<Device, T, ASSIGN>()(
|
||||
ctx->eigen_device<Device>(), output_flat,
|
||||
const_cast<const Tensor*>(var_tensor)->flat<T>());
|
||||
}
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor));
|
||||
// var_tensor layout is counter+key, so var_tensor data is also counter
|
||||
// data.
|
||||
auto counter_data = var_tensor->flat<T>().data();
|
||||
functor::RngSkip_Philox<Device>()(ctx->eigen_device<Device>(),
|
||||
counter_data, delta, counter_data);
|
||||
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
|
||||
ctx, var_tensor, var->copy_on_read_mode.load()));
|
||||
RngSkip_Philox<Device>()(ctx->eigen_device<Device>(), delta, var_tensor);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
@ -438,20 +393,13 @@ TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU);
|
||||
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
|
||||
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
|
||||
|
||||
// TODO(wangpeng): Remove `HostMemory("delta")` for RngReadAndSkip
|
||||
#define REGISTER_RngSkip(DEVICE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("RngSkip") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("resource") \
|
||||
.HostMemory("algorithm") \
|
||||
.HostMemory("delta"), \
|
||||
RngSkipOp<DEVICE##Device>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("RngReadAndSkip") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("resource") \
|
||||
.HostMemory("alg") \
|
||||
.HostMemory("delta"), \
|
||||
RngSkipOp<DEVICE##Device, int32, uint64, true>);
|
||||
RngSkipOp<DEVICE##Device>);
|
||||
|
||||
REGISTER_RngSkip(CPU);
|
||||
|
||||
|
@ -22,12 +22,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
// 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained
|
||||
// in b/111604096 and cl/171681867), so we use signed int here. We choose int64
|
||||
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU, and
|
||||
// because of the "int32 problem".
|
||||
// in b/111604096 and cl/171681867), so I use signed int here. I choose int64
|
||||
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU.
|
||||
using StateElementType = int64;
|
||||
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
|
||||
|
||||
using Algorithm = StateElementType;
|
||||
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
|
||||
static constexpr Algorithm RNG_ALG_PHILOX = 1;
|
||||
static constexpr Algorithm RNG_ALG_THREEFRY = 2;
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
|
@ -17,51 +17,59 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
||||
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The following 5 functions are made templates to avoid duplicate symbols when
|
||||
// linking.
|
||||
|
||||
// The following 2 functions use the contract "lower 32 bits for the first
|
||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||
PHILOX_DEVICE_INLINE void Int64ToUint32s(int64 input, uint32* output1,
|
||||
uint32* output2) {
|
||||
auto u64 = static_cast<uint64>(input);
|
||||
*output1 = static_cast<uint32>(u64);
|
||||
*output2 = static_cast<uint32>(u64 >> 32);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE int64 Uint32sToInt64(uint32 input1, uint32 input2) {
|
||||
auto u64_1 = static_cast<uint64>(input1);
|
||||
auto u64_2 = static_cast<uint64>(input2);
|
||||
return static_cast<int64>(u64_1 | (u64_2 << 32));
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom
|
||||
GetPhiloxRandomFromMem(StateElementType const* ptr) {
|
||||
auto ptr_ = reinterpret_cast<uint64 const*>(ptr);
|
||||
return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2);
|
||||
PhiloxRandom::ResultType counter;
|
||||
PhiloxRandom::Key key;
|
||||
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||
Int64ToUint32s(ptr[2], &key[0], &key[1]);
|
||||
return PhiloxRandom(counter, key);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
|
||||
StateElementType* ptr) {
|
||||
auto ptr_ = reinterpret_cast<uint64*>(ptr);
|
||||
WriteCounterToMem(philox.counter(), ptr_);
|
||||
WriteKeyToMem(philox.key(), ptr_ + 2);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox,
|
||||
uint64 output_size) {
|
||||
auto new_philox = philox;
|
||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it
|
||||
// just here.
|
||||
auto delta = output_size * 256;
|
||||
new_philox.Skip(delta); // do the actual increasing
|
||||
return new_philox;
|
||||
PhiloxRandom::ResultType const& counter = philox.counter();
|
||||
PhiloxRandom::Key const& key = philox.key();
|
||||
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
|
||||
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
|
||||
ptr[2] = Uint32sToInt64(key[0], key[1]);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
|
||||
uint64 output_size,
|
||||
int64 output_size,
|
||||
StateElementType* ptr) {
|
||||
auto new_philox = SkipPhiloxRandom(philox, output_size);
|
||||
auto new_philox = philox;
|
||||
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
|
||||
// it just here.
|
||||
auto delta = output_size * 256;
|
||||
new_philox.Skip(delta); // do the actual increasing
|
||||
WritePhiloxRandomToMem(new_philox, ptr);
|
||||
}
|
||||
|
||||
PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom(
|
||||
PhiloxRandom::ResultType const& counter, uint64 output_size,
|
||||
StateElementType* ptr) {
|
||||
auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/);
|
||||
auto new_philox = SkipPhiloxRandom(philox, output_size);
|
||||
WriteCounterToMem(new_philox.counter(), reinterpret_cast<uint64*>(ptr));
|
||||
}
|
||||
|
||||
namespace functor {
|
||||
|
||||
// A per-device helper function that does the actual work for
|
||||
// `UpdateVariableAndFill`.
|
||||
// Reason to use functor: C++ doesn't allow function-template partial
|
||||
@ -72,8 +80,6 @@ struct UpdateVariableAndFill_Philox;
|
||||
template <typename Device>
|
||||
struct RngSkip_Philox;
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
|
||||
struct UpdateVariableAndFill_Philox_Arg {
|
||||
@ -87,8 +93,6 @@ struct UpdateVariableAndFill_Philox_Arg {
|
||||
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Declares the partially GPU-specialized functor structs.
|
||||
// must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
|
||||
template <typename Distribution>
|
||||
@ -100,12 +104,9 @@ struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
||||
|
||||
template <>
|
||||
struct RngSkip_Philox<GPUDevice> {
|
||||
void operator()(const GPUDevice& device, const StateElementType* in_data,
|
||||
uint64 delta, StateElementType* out_data);
|
||||
void operator()(const GPUDevice& device, int64 delta, Tensor* state_tensor);
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -31,8 +31,6 @@ __device__ int tensorflow_philox_thread_counter;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
using random::PhiloxRandom;
|
||||
|
||||
template <typename Distribution>
|
||||
@ -50,8 +48,7 @@ __global__ void FillKernel(
|
||||
__syncthreads();
|
||||
functor::FillPhiloxRandomKernel<Distribution,
|
||||
Distribution::kVariableSamplesPerOutput>()
|
||||
.Run(/*key=*/nullptr, /*counter=*/nullptr, *philox, output_data,
|
||||
output_size, dist);
|
||||
.Run(*philox, output_data, output_size, dist);
|
||||
// The last thread updates the state.
|
||||
auto total_thread_count = gridDim.x * blockDim.x;
|
||||
auto old_counter_value = atomicAdd(&tensorflow_philox_thread_counter, 1);
|
||||
@ -99,19 +96,16 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
||||
}
|
||||
|
||||
// Precondition: there is only 1 block and 1 thread.
|
||||
__global__ void SkipKernel(const StateElementType* __restrict__ in_data,
|
||||
uint64 delta,
|
||||
StateElementType* __restrict__ out_data) {
|
||||
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
|
||||
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
|
||||
__global__ void SkipKernel(int64 delta,
|
||||
StateElementType* __restrict__ state_data) {
|
||||
auto philox = GetPhiloxRandomFromMem(state_data);
|
||||
UpdateMemWithPhiloxRandom(philox, delta, state_data);
|
||||
}
|
||||
|
||||
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d,
|
||||
const StateElementType* in_data,
|
||||
uint64 delta,
|
||||
StateElementType* out_data) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), in_data, delta,
|
||||
out_data));
|
||||
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d, int64 delta,
|
||||
Tensor* state_tensor) {
|
||||
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), delta,
|
||||
state_tensor->flat<StateElementType>().data()));
|
||||
}
|
||||
|
||||
// Explicit instantiation of the GPU distributions functors.
|
||||
@ -160,7 +154,6 @@ template struct UpdateVariableAndFill_Philox<
|
||||
random::PhiloxRandom, uint64> >;
|
||||
// clang-format on
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -121,8 +121,8 @@ class StatelessRandomOp : public StatelessRandomOpBase {
|
||||
auto flat = output->flat<T>();
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||
/*counter=*/nullptr, random, flat.data(), flat.size(), Distribution());
|
||||
context, context->eigen_device<Device>(), random, flat.data(),
|
||||
flat.size(), Distribution());
|
||||
}
|
||||
};
|
||||
|
||||
@ -158,8 +158,8 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||
auto flat = output->flat<IntType>();
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
|
||||
context, context->eigen_device<Device>(), random, flat.data(),
|
||||
flat.size(), dist);
|
||||
}
|
||||
};
|
||||
|
||||
@ -178,8 +178,8 @@ class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
|
||||
auto flat = output->flat<IntType>();
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
|
||||
context, context->eigen_device<Device>(), random, flat.data(),
|
||||
flat.size(), dist);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,330 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||
#include "tensorflow/core/kernels/random_poisson_op.h"
|
||||
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||
#else
|
||||
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
|
||||
auto dtype = DataTypeToEnum<T>::v();
|
||||
if (tensor.dims() != 0) {
|
||||
return errors::InvalidArgument("input ", std::to_string(input_idx),
|
||||
" (0-based) must have shape [], not ",
|
||||
tensor.shape().DebugString());
|
||||
}
|
||||
if (tensor.dtype() != dtype) {
|
||||
return errors::InvalidArgument("dtype of input ", std::to_string(input_idx),
|
||||
" (0-based) must be ", DataTypeString(dtype),
|
||||
", not ", DataTypeString(tensor.dtype()));
|
||||
}
|
||||
*result = tensor.flat<T>()(0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class StatelessRandomOpBase : public OpKernel {
|
||||
public:
|
||||
explicit StatelessRandomOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// Sanitize input
|
||||
const Tensor& shape_t = ctx->input(0);
|
||||
const Tensor& key_t = ctx->input(1);
|
||||
const Tensor& counter_t = ctx->input(2);
|
||||
const int alg_input_idx = 3;
|
||||
const Tensor& alg_t = ctx->input(alg_input_idx);
|
||||
|
||||
int alg_id;
|
||||
OP_REQUIRES_OK(ctx, GetScalar(alg_t, alg_input_idx, &alg_id));
|
||||
Algorithm alg = Algorithm(alg_id);
|
||||
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CheckKeyCounterShape(alg, key_t.shape(), counter_t.shape()));
|
||||
|
||||
// Allocate output
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
|
||||
if (shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill in the random numbers
|
||||
Fill(ctx, alg, key_t, counter_t, output);
|
||||
}
|
||||
|
||||
// The part of Compute that depends on device, type, and distribution
|
||||
virtual void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) = 0;
|
||||
};
|
||||
|
||||
template <typename Device, typename Distribution>
|
||||
class StatelessRandomOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) override {
|
||||
typedef typename Distribution::ResultElementType T;
|
||||
auto flat = output->flat<T>();
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
auto key_data = key.flat<uint64>().data();
|
||||
auto counter_data = counter.flat<uint64>().data();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(),
|
||||
Distribution());
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename IntType>
|
||||
class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) override {
|
||||
const Tensor& minval = ctx->input(4);
|
||||
const Tensor& maxval = ctx->input(5);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
||||
errors::InvalidArgument("minval must be 0-D, got shape ",
|
||||
minval.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
|
||||
errors::InvalidArgument("maxval must be 0-D, got shape ",
|
||||
maxval.shape().DebugString()));
|
||||
|
||||
// Verify that minval < maxval. Note that we'll never reach this point for
|
||||
// empty output. Zero impossible things are fine.
|
||||
const auto lo = minval.scalar<IntType>()();
|
||||
const auto hi = maxval.scalar<IntType>()();
|
||||
OP_REQUIRES(
|
||||
ctx, lo < hi,
|
||||
errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
|
||||
|
||||
// Build distribution
|
||||
typedef random::UniformDistribution<random::PhiloxRandom, IntType>
|
||||
Distribution;
|
||||
Distribution dist(lo, hi);
|
||||
|
||||
auto flat = output->flat<IntType>();
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
auto key_data = key.flat<uint64>().data();
|
||||
auto counter_data = counter.flat<uint64>().data();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename IntType>
|
||||
class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
|
||||
public:
|
||||
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||
|
||||
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||
const Tensor& counter, Tensor* output) override {
|
||||
// Build distribution
|
||||
typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
|
||||
Distribution;
|
||||
Distribution dist;
|
||||
|
||||
auto flat = output->flat<IntType>();
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
// Reuse the compute kernels from the stateful random ops
|
||||
auto key_data = key.flat<uint64>().data();
|
||||
auto counter_data = counter.flat<uint64>().data();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
|
||||
} else {
|
||||
OP_REQUIRES(ctx, false,
|
||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class GetKeyCounterAlgOp : public OpKernel {
|
||||
public:
|
||||
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& seed_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_t.shape().DebugString()));
|
||||
// Allocate outputs
|
||||
Tensor* key_output;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
|
||||
Tensor* counter_output;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
|
||||
&counter_output));
|
||||
Tensor* alg_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &alg_output));
|
||||
|
||||
random::PhiloxRandom::Key key;
|
||||
random::PhiloxRandom::ResultType counter;
|
||||
OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
|
||||
WriteKeyToMem(key, key_output->flat<uint64>().data());
|
||||
WriteCounterToMem(counter, counter_output->flat<uint64>().data());
|
||||
alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER(DEVICE, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomUniformV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
|
||||
random::PhiloxRandom, TYPE> >); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomNormalV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
|
||||
random::PhiloxRandom, TYPE> >); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessTruncatedNormalV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomOp< \
|
||||
DEVICE##Device, \
|
||||
random::TruncatedNormalDistribution< \
|
||||
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
|
||||
|
||||
#define REGISTER_FULL_INT(DEVICE, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StatelessRandomUniformFullIntV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
|
||||
|
||||
#define REGISTER_INT(DEVICE, TYPE) \
|
||||
REGISTER_FULL_INT(DEVICE, TYPE); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformIntV2") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("shape") \
|
||||
.HostMemory("alg") \
|
||||
.HostMemory("minval") \
|
||||
.HostMemory("maxval") \
|
||||
.TypeConstraint<TYPE>("dtype"), \
|
||||
StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
|
||||
|
||||
#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
|
||||
#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
|
||||
#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
|
||||
#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
|
||||
#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
|
||||
#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
|
||||
|
||||
TF_CALL_half(REGISTER_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_double(REGISTER_CPU);
|
||||
TF_CALL_int32(REGISTER_INT_CPU);
|
||||
TF_CALL_int64(REGISTER_INT_CPU);
|
||||
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
|
||||
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
|
||||
|
||||
#define REGISTER_GET_KCA(DEVICE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
|
||||
.Device(DEVICE_##DEVICE) \
|
||||
.HostMemory("seed") \
|
||||
.HostMemory("key") \
|
||||
.HostMemory("counter") \
|
||||
.HostMemory("alg"), \
|
||||
GetKeyCounterAlgOp)
|
||||
|
||||
REGISTER_GET_KCA(CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
TF_CALL_int32(REGISTER_INT_GPU);
|
||||
TF_CALL_int64(REGISTER_INT_GPU);
|
||||
TF_CALL_uint32(REGISTER_FULL_INT_GPU);
|
||||
TF_CALL_uint64(REGISTER_FULL_INT_GPU);
|
||||
|
||||
REGISTER_GET_KCA(GPU);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#undef REGISTER
|
||||
#undef REGISTER_INT
|
||||
#undef REGISTER_CPU
|
||||
#undef REGISTER_GPU
|
||||
#undef REGISTER_INT_CPU
|
||||
#undef REGISTER_INT_GPU
|
||||
#undef REGISTER_FULL_INT_CPU
|
||||
#undef REGISTER_FULL_INT_GPU
|
||||
|
||||
#undef REGISTER_GET_KCA
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
@ -1,46 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
||||
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
inline Status CheckKeyCounterShape(Algorithm const& alg,
|
||||
TensorShape const& key_shape,
|
||||
TensorShape const& counter_shape) {
|
||||
if (!(key_shape.dims() == 1 && key_shape.dim_size(0) == RNG_KEY_SIZE)) {
|
||||
return errors::InvalidArgument(
|
||||
"key must have shape [", RNG_KEY_SIZE, "], not ",
|
||||
key_shape.DebugString(),
|
||||
". (Note that batched keys are not supported yet.)");
|
||||
}
|
||||
auto counter_size = GetCounterSize(alg);
|
||||
if (!(counter_shape.dims() == 1 &&
|
||||
counter_shape.dim_size(0) >= counter_size)) {
|
||||
return errors::InvalidArgument(
|
||||
"counter must be a vector with length at least ", counter_size,
|
||||
"; got shape: ", counter_shape.DebugString(),
|
||||
". (Note that batched counters are not supported yet.)");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -91,19 +90,6 @@ REGISTER_OP("RngSkip")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("RngReadAndSkip")
|
||||
.Input("resource: resource")
|
||||
.Input("alg: int32")
|
||||
.Input("delta: uint64")
|
||||
.Output("value: int64")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
c->set_output(0, c->MakeShape({RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("NonDeterministicInts")
|
||||
.Input("shape: shape_dtype")
|
||||
.SetIsStateful()
|
||||
|
@ -1,119 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
static Status StatelessShapeV2(InferenceContext* c) {
|
||||
// Check key and counter shapes
|
||||
ShapeHandle key;
|
||||
ShapeHandle counter;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &key));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &counter));
|
||||
shape_inference::ShapeHandle unused_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), RNG_KEY_SIZE, &unused));
|
||||
|
||||
// Set output shape
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define REGISTER_STATELESS_OP(name) \
|
||||
REGISTER_OP(name) \
|
||||
.Input("shape: Tshape") \
|
||||
.Input("key: uint64") \
|
||||
.Input("counter: uint64") \
|
||||
.Input("alg: int32") \
|
||||
.Output("output: dtype") \
|
||||
.Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32") \
|
||||
.SetShapeFn(StatelessShapeV2)
|
||||
|
||||
REGISTER_STATELESS_OP("StatelessRandomUniformV2");
|
||||
REGISTER_STATELESS_OP("StatelessRandomNormalV2");
|
||||
REGISTER_STATELESS_OP("StatelessTruncatedNormalV2");
|
||||
|
||||
#undef REGISTER_STATELESS_OP
|
||||
|
||||
REGISTER_OP("StatelessRandomUniformIntV2")
|
||||
.Input("shape: Tshape")
|
||||
.Input("key: uint64")
|
||||
.Input("counter: uint64")
|
||||
.Input("alg: int32")
|
||||
.Input("minval: dtype")
|
||||
.Input("maxval: dtype")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: {int32, int64, uint32, uint64}")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
Status s = c->WithRank(c->input(4), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"minval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(4)));
|
||||
}
|
||||
s = c->WithRank(c->input(5), 0, &unused);
|
||||
if (!s.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"maxval must be a scalar; got a tensor of shape ",
|
||||
c->DebugString(c->input(5)));
|
||||
}
|
||||
return StatelessShapeV2(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("StatelessRandomUniformFullIntV2")
|
||||
.Input("shape: Tshape")
|
||||
.Input("key: uint64")
|
||||
.Input("counter: uint64")
|
||||
.Input("alg: int32")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(StatelessShapeV2);
|
||||
|
||||
REGISTER_OP("StatelessRandomGetKeyCounterAlg")
|
||||
.Input("seed: Tseed")
|
||||
.Output("key: uint64")
|
||||
.Output("counter: uint64")
|
||||
.Output("alg: int32")
|
||||
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||
.SetIsStateful() // because outputs depend on device
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Check seed shape
|
||||
ShapeHandle seed;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
|
||||
|
||||
// Set output shapes
|
||||
c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
|
||||
c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
|
||||
c->set_output(2, c->MakeShape({}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
@ -3183,10 +3183,6 @@ tf_gen_op_wrapper_private_py(
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "stateless_random_ops_v2_gen",
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "list_ops_gen",
|
||||
)
|
||||
@ -4444,7 +4440,6 @@ py_library(
|
||||
":framework_ops",
|
||||
":math_ops",
|
||||
":stateful_random_ops_gen",
|
||||
":stateless_random_ops_v2_gen",
|
||||
":variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
@ -4479,7 +4474,6 @@ py_library(
|
||||
":math_ops",
|
||||
":random_ops",
|
||||
":stateless_random_ops_gen",
|
||||
":stateless_random_ops_v2_gen",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 357> a = {{
|
||||
static std::array<OpIndexInfo, 352> a = {{
|
||||
{"Acosh"},
|
||||
{"AllToAll", 1, {0}},
|
||||
{"ApproximateEqual"},
|
||||
@ -332,16 +332,11 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
{"StatelessRandomBinomial"},
|
||||
{"StatelessRandomGammaV2", 1, {1}},
|
||||
{"StatelessRandomNormal"},
|
||||
{"StatelessRandomNormalV2"},
|
||||
{"StatelessRandomPoisson"},
|
||||
{"StatelessRandomUniform"},
|
||||
{"StatelessRandomUniformFullInt"},
|
||||
{"StatelessRandomUniformFullIntV2"},
|
||||
{"StatelessRandomUniformInt"},
|
||||
{"StatelessRandomUniformIntV2"},
|
||||
{"StatelessRandomUniformV2"},
|
||||
{"StatelessTruncatedNormal"},
|
||||
{"StatelessTruncatedNormalV2"},
|
||||
{"StopGradient"},
|
||||
{"StridedSliceGrad", 2, {0, 4}},
|
||||
{"StringSplit"},
|
||||
@ -420,7 +415,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 473> a = {{
|
||||
static std::array<OpIndexInfo, 468> a = {{
|
||||
{"Abs"},
|
||||
{"AccumulateNV2"},
|
||||
{"Acos"},
|
||||
@ -798,16 +793,11 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
{"StatelessMultinomial"},
|
||||
{"StatelessRandomBinomial"},
|
||||
{"StatelessRandomNormal"},
|
||||
{"StatelessRandomNormalV2"},
|
||||
{"StatelessRandomPoisson"},
|
||||
{"StatelessRandomUniform"},
|
||||
{"StatelessRandomUniformFullInt"},
|
||||
{"StatelessRandomUniformFullIntV2"},
|
||||
{"StatelessRandomUniformInt"},
|
||||
{"StatelessRandomUniformIntV2"},
|
||||
{"StatelessRandomUniformV2"},
|
||||
{"StatelessTruncatedNormal"},
|
||||
{"StatelessTruncatedNormalV2"},
|
||||
{"StopGradient"},
|
||||
{"StridedSlice"},
|
||||
{"StridedSliceGrad"},
|
||||
|
@ -30,7 +30,6 @@ from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import image_ops_impl as image_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -1238,14 +1237,11 @@ class RandomHeightTest(keras_parameterized.TestCase):
|
||||
mock_factor = 0
|
||||
with test.mock.patch.object(
|
||||
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
||||
with test.mock.patch.object(
|
||||
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
|
||||
return_value=mock_factor):
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 5, 8, 3))
|
||||
layer = image_preprocessing.RandomHeight(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[1], 3)
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 5, 8, 3))
|
||||
layer = image_preprocessing.RandomHeight(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[1], 3)
|
||||
|
||||
def test_random_height_longer_numeric(self):
|
||||
for dtype in (np.int64, np.float32):
|
||||
@ -1332,14 +1328,11 @@ class RandomWidthTest(keras_parameterized.TestCase):
|
||||
mock_factor = 0
|
||||
with test.mock.patch.object(
|
||||
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
||||
with test.mock.patch.object(
|
||||
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
|
||||
return_value=mock_factor):
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 8, 5, 3))
|
||||
layer = image_preprocessing.RandomWidth(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[2], 3)
|
||||
with testing_utils.use_gpu():
|
||||
img = np.random.random((12, 8, 5, 3))
|
||||
layer = image_preprocessing.RandomWidth(.4)
|
||||
img_out = layer(img, training=True)
|
||||
self.assertEqual(img_out.shape[2], 3)
|
||||
|
||||
def test_random_width_longer_numeric(self):
|
||||
for dtype in (np.int64, np.float32):
|
||||
|
@ -22,13 +22,10 @@ import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -38,31 +35,6 @@ from tensorflow.python.ops import stateless_random_ops as stateless
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# Note that in theory each test will reset the eager context and may choose to
|
||||
# hide some devices, so we shouldn't cache this transient info. Tests in this
|
||||
# file don't make those config changes, so caching is fine. It provides a good
|
||||
# speed-up.
|
||||
_cached_device = None
|
||||
|
||||
|
||||
def get_device():
|
||||
global _cached_device
|
||||
if _cached_device is not None:
|
||||
return _cached_device
|
||||
# Precedence from high to low
|
||||
for device_type in ('XLA_GPU', 'GPU', 'XLA_CPU', 'CPU'):
|
||||
devices = config.list_logical_devices(device_type)
|
||||
if devices:
|
||||
_cached_device = devices[0]
|
||||
return _cached_device
|
||||
raise ValueError('Cannot find any suitable device. Available devices: %s' %
|
||||
config.list_logical_devices())
|
||||
|
||||
|
||||
BEFORE_EXPIRE = (2020, 10, 24)
|
||||
AFTER_EXPIRE = (2020, 10, 26)
|
||||
|
||||
|
||||
def invert_philox(key, value):
|
||||
"""Invert the Philox bijection."""
|
||||
key = np.array(key, dtype=np.uint32)
|
||||
@ -87,71 +59,47 @@ SEED_TYPES = [dtypes.int32, dtypes.int64]
|
||||
def float_cases(shape_dtypes=(None,)):
|
||||
cases = (
|
||||
# Uniform distribution, with and without range
|
||||
('uniform', stateless.stateless_random_uniform, random_ops.random_uniform,
|
||||
{}),
|
||||
('uniform2', stateless.stateless_random_uniform,
|
||||
random_ops.random_uniform, dict(minval=2.2, maxval=7.1)),
|
||||
(stateless.stateless_random_uniform, random_ops.random_uniform, {}),
|
||||
(stateless.stateless_random_uniform, random_ops.random_uniform,
|
||||
dict(minval=2.2, maxval=7.1)),
|
||||
# Normal distribution, with and without mean+stddev
|
||||
('normal', stateless.stateless_random_normal, random_ops.random_normal,
|
||||
{}),
|
||||
('normal2', stateless.stateless_random_normal, random_ops.random_normal,
|
||||
(stateless.stateless_random_normal, random_ops.random_normal, {}),
|
||||
(stateless.stateless_random_normal, random_ops.random_normal,
|
||||
dict(mean=2, stddev=3)),
|
||||
# Truncated normal distribution, with and without mean+stddev
|
||||
('trnorm', stateless.stateless_truncated_normal,
|
||||
random_ops.truncated_normal, {}),
|
||||
('trnorm2', stateless.stateless_truncated_normal,
|
||||
random_ops.truncated_normal, dict(mean=3, stddev=4)),
|
||||
(stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
|
||||
(stateless.stateless_truncated_normal, random_ops.truncated_normal,
|
||||
dict(mean=3, stddev=4)),
|
||||
)
|
||||
# Explicitly passing in params because capturing cell variable from loop is
|
||||
# problematic in Python
|
||||
def wrap(op, dtype, shape, shape_dtype, kwds, seed):
|
||||
device_type = get_device().device_type
|
||||
# Some dtypes are not supported on some devices
|
||||
if (dtype == dtypes.float16 and device_type in ('XLA_GPU', 'XLA_CPU') or
|
||||
dtype == dtypes.bfloat16 and device_type == 'GPU'):
|
||||
dtype = dtypes.float32
|
||||
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
||||
if shape_dtype is not None else shape)
|
||||
return op(seed=seed, shape=shape_, dtype=dtype, **kwds)
|
||||
|
||||
def _name(a):
|
||||
if hasattr(a, 'name'):
|
||||
return a.name
|
||||
else:
|
||||
return a
|
||||
|
||||
for dtype in dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64:
|
||||
for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
|
||||
for shape_dtype in shape_dtypes:
|
||||
for shape in (), (3,), (2, 5):
|
||||
for name, stateless_op, stateful_op, kwds in cases:
|
||||
yield (('%s_%s_%s_%s' %
|
||||
(name, _name(dtype), shape, _name(shape_dtype))).replace(
|
||||
' ', ''),
|
||||
functools.partial(wrap, stateless_op, dtype, shape,
|
||||
for stateless_op, stateful_op, kwds in cases:
|
||||
yield (functools.partial(wrap, stateless_op, dtype, shape,
|
||||
shape_dtype, kwds),
|
||||
functools.partial(wrap, stateful_op, dtype, shape, shape_dtype,
|
||||
kwds))
|
||||
functools.partial(wrap, stateful_op, dtype, shape,
|
||||
shape_dtype, kwds))
|
||||
|
||||
|
||||
def int_cases(shape_dtypes=(None,), minval_maxval=None):
|
||||
|
||||
def wrap(op, minval, maxval, shape, shape_dtype, dtype, seed):
|
||||
def int_cases(shape_dtypes=(None,)):
|
||||
def wrap(op, shape, shape_dtype, dtype, seed):
|
||||
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
||||
if shape_dtype is not None else shape)
|
||||
return op(
|
||||
seed=seed, shape=shape_, minval=minval, maxval=maxval, dtype=dtype)
|
||||
|
||||
if minval_maxval is None:
|
||||
minval_maxval = ((2, 11111),)
|
||||
for minval, maxval in minval_maxval:
|
||||
for shape_dtype in shape_dtypes:
|
||||
for shape in (), (3,), (2, 5):
|
||||
for dtype in dtypes.int32, dtypes.int64:
|
||||
yield ('uniform_%s_%s' % (minval, maxval),
|
||||
functools.partial(wrap, stateless.stateless_random_uniform,
|
||||
minval, maxval, shape, shape_dtype, dtype),
|
||||
functools.partial(wrap, random_ops.random_uniform, minval,
|
||||
maxval, shape, shape_dtype, dtype))
|
||||
return op(seed=seed, shape=shape_, minval=2, maxval=11111,
|
||||
dtype=dtype)
|
||||
for shape_dtype in shape_dtypes:
|
||||
for shape in (), (3,), (2, 5):
|
||||
for dtype in dtypes.int32, dtypes.int64:
|
||||
yield (functools.partial(wrap, stateless.stateless_random_uniform,
|
||||
shape, shape_dtype, dtype),
|
||||
functools.partial(wrap, random_ops.random_uniform,
|
||||
shape, shape_dtype, dtype))
|
||||
|
||||
|
||||
def multinomial_cases():
|
||||
@ -164,8 +112,7 @@ def multinomial_cases():
|
||||
for output_dtype in dtypes.int32, dtypes.int64:
|
||||
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
|
||||
[0.25, 0.75]]):
|
||||
yield ('multinomial',
|
||||
functools.partial(wrap, stateless.stateless_multinomial, logits,
|
||||
yield (functools.partial(wrap, stateless.stateless_multinomial, logits,
|
||||
logits_dtype, output_dtype),
|
||||
functools.partial(wrap, random_ops.multinomial, logits,
|
||||
logits_dtype, output_dtype))
|
||||
@ -177,11 +124,10 @@ def gamma_cases():
|
||||
alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
|
||||
for dtype in np.float16, np.float32, np.float64:
|
||||
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
|
||||
yield ('gamma',
|
||||
functools.partial(wrap, stateless.stateless_random_gamma, alpha,
|
||||
yield (functools.partial(wrap, stateless.stateless_random_gamma, alpha,
|
||||
dtype, (10,) + tuple(np.shape(alpha))),
|
||||
functools.partial(wrap, random_ops.random_gamma, alpha, dtype,
|
||||
(10,)))
|
||||
functools.partial(wrap, random_ops.random_gamma, alpha,
|
||||
dtype, (10,)))
|
||||
|
||||
|
||||
def poisson_cases():
|
||||
@ -192,8 +138,7 @@ def poisson_cases():
|
||||
for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||
for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||
for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
|
||||
yield ('poisson',
|
||||
functools.partial(wrap, stateless.stateless_random_poisson, lam,
|
||||
yield (functools.partial(wrap, stateless.stateless_random_poisson, lam,
|
||||
lam_dtype, out_dtype,
|
||||
(10,) + tuple(np.shape(lam))),
|
||||
functools.partial(wrap, random_ops.random_poisson, lam,
|
||||
@ -208,28 +153,22 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
key = 0x3ec8f720, 0x02461e29
|
||||
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
|
||||
preseed = preseed[::2] | preseed[1::2] << 32
|
||||
with ops.device(get_device().name):
|
||||
_, stateless_op, stateful_op = case
|
||||
random_seed.set_random_seed(seed[0])
|
||||
random_seed.set_random_seed(seed[0])
|
||||
with test_util.use_gpu():
|
||||
stateless_op, stateful_op = case
|
||||
if context.executing_eagerly():
|
||||
# Call set_random_seed in order to clear kernel cache, to prevent
|
||||
# kernel reusing for the stateful op
|
||||
random_seed.set_random_seed(seed[0])
|
||||
stateful = stateful_op(seed=seed[1])
|
||||
pure = stateless_op(seed=preseed)
|
||||
self.assertAllEqual(stateful, pure)
|
||||
|
||||
def _test_old_and_new_stateless_match(self, case, seed):
|
||||
"""Tests that the new stateless ops match the old stateless ones."""
|
||||
with ops.device(get_device().name):
|
||||
_, stateless_op, _ = case
|
||||
with compat.forward_compatibility_horizon(*BEFORE_EXPIRE):
|
||||
old = stateless_op(seed=seed)
|
||||
with compat.forward_compatibility_horizon(*AFTER_EXPIRE):
|
||||
new = stateless_op(seed=seed)
|
||||
self.assertAllClose(old, new)
|
||||
|
||||
def _test_determinism(self, case, seed_type):
|
||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
|
||||
with self.test_session(use_gpu=True), ops.device(get_device().name):
|
||||
_, stateless_op, _ = case
|
||||
with self.test_session(use_gpu=True), test_util.use_gpu():
|
||||
stateless_op, _ = case
|
||||
if context.executing_eagerly():
|
||||
values = [
|
||||
(seed, stateless_op(seed=constant_op.constant(seed, seed_type)))
|
||||
@ -244,172 +183,88 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
]
|
||||
for s0, v0 in values:
|
||||
for s1, v1 in values:
|
||||
if dtypes.as_dtype(v0.dtype) != dtypes.bfloat16:
|
||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||
elif s0 == s1:
|
||||
# Skip the s0 != s1 case because v0 and v1 can be either equal or
|
||||
# unequal in that case due to bfloat16's low precision
|
||||
self.assertAllEqual(v0, v1)
|
||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(float_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchFloat(self, case, seed):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
|
||||
'seeds needed by this test.')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(int_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchInt(self, case, seed):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
|
||||
'seeds needed by this test.')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(multinomial_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchMultinomial(self, case, seed):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(gamma_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchGamma(self, case, seed):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(poisson_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testMatchPoisson(self, case, seed):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(float_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testOldAndNewStatelessMatchFloat(self, case, seed):
|
||||
self._test_old_and_new_stateless_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||
for seed_id, seed in enumerate(SEEDS)
|
||||
for case_id, case in enumerate(
|
||||
int_cases(minval_maxval=((2, 11111), (None, None)))))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testOldAndNewStatelessMatchInt(self, case, seed):
|
||||
self._test_old_and_new_stateless_match(case, seed)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
for case_id, case in enumerate(
|
||||
float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
for case_id, case in enumerate(float_cases(
|
||||
shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismFloat(self, case, seed_type):
|
||||
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
|
||||
'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest(
|
||||
'Skip on XLA because XLA kernels do not support int64 seeds.')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
for case_id, case in enumerate(
|
||||
int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
for case_id, case in enumerate(int_cases(
|
||||
shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismInt(self, case, seed_type):
|
||||
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
|
||||
'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest(
|
||||
'Skip on XLA because XLA kernels do not support int64 seeds.')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
for case_id, case in enumerate(multinomial_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismMultinomial(self, case, seed_type):
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
for case_id, case in enumerate(gamma_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismGamma(self, case, seed_type):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for seed_type in SEED_TYPES
|
||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
||||
for case_id, case in enumerate(poisson_cases()))
|
||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||
def testDeterminismPoisson(self, case, seed_type):
|
||||
if get_device().device_type == 'GPU':
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking GPU kernel')
|
||||
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernels.
|
||||
self.skipTest('Lacking XLA kernel')
|
||||
self._test_determinism(case, seed_type)
|
||||
|
||||
def assertDTypeEqual(self, a, b):
|
||||
@ -472,6 +327,4 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.set_soft_device_placement(False)
|
||||
context.context().enable_xla_devices()
|
||||
test.main()
|
||||
|
@ -23,7 +23,6 @@ import enum # pylint: disable=g-bad-import-order
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
@ -33,14 +32,12 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# A seed for random ops (stateful and stateless) will always be 1024
|
||||
# bits, all of which will be sent to the C++ code. The actual C++
|
||||
# implementation of some algorithms may only use a lower part of the bits.
|
||||
@ -139,15 +136,6 @@ def _make_1d_state(state_size, seed):
|
||||
return seed
|
||||
|
||||
|
||||
def _get_counter_size(alg):
|
||||
if alg == RNG_ALG_PHILOX:
|
||||
return 2
|
||||
elif alg == RNG_ALG_THREEFRY:
|
||||
return 1
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm id: %s" % alg)
|
||||
|
||||
|
||||
def _get_state_size(alg):
|
||||
if alg == RNG_ALG_PHILOX:
|
||||
return PHILOX_STATE_SIZE
|
||||
@ -572,10 +560,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
return self._alg
|
||||
|
||||
def _standard_normal(self, shape, dtype):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
|
||||
return gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||
|
||||
@ -602,8 +586,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm id: %s" % alg)
|
||||
|
||||
# TODO(wangpeng): Add "Returns" section to docstring once new version kicks in
|
||||
# pylint: disable=g-doc-return-or-yield
|
||||
def skip(self, delta):
|
||||
"""Advance the counter of a counter-based RNG.
|
||||
|
||||
@ -613,24 +595,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
(or any other distribution). The actual increment added to the
|
||||
counter is an unspecified implementation detail.
|
||||
"""
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
return gen_stateful_random_ops.rng_read_and_skip(
|
||||
self.state.handle,
|
||||
alg=math_ops.cast(self.algorithm, dtypes.int32),
|
||||
delta=math_ops.cast(delta, dtypes.uint64))
|
||||
gen_stateful_random_ops.rng_skip(
|
||||
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
|
||||
math_ops.cast(delta, dtypes.int64))
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
def _prepare_key_counter(self, shape):
|
||||
delta = math_ops.reduce_prod(shape)
|
||||
counter_key = self.skip(delta)
|
||||
counter_size = _get_counter_size(self.algorithm)
|
||||
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
|
||||
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
|
||||
dtypes.uint64)
|
||||
return key, counter
|
||||
gen_stateful_random_ops.rng_skip(self.state.handle, self.algorithm, delta)
|
||||
|
||||
# The following functions return a tensor and as a side effect update
|
||||
# self._state_var.
|
||||
@ -659,14 +624,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
return math_ops.add(rnd * stddev, mean, name=name)
|
||||
|
||||
def _truncated_normal(self, shape, dtype):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
dtype=dtype,
|
||||
alg=self.algorithm)
|
||||
return gen_stateful_random_ops.stateful_truncated_normal(
|
||||
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||
|
||||
@ -705,27 +662,10 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
return math_ops.add(mul, mean_tensor, name=name)
|
||||
|
||||
def _uniform(self, shape, dtype):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
dtype=dtype,
|
||||
alg=self.algorithm)
|
||||
return gen_stateful_random_ops.stateful_uniform(
|
||||
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
|
||||
|
||||
def _uniform_full_int(self, shape, dtype, name=None):
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
dtype=dtype,
|
||||
alg=self.algorithm,
|
||||
name=name)
|
||||
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
dtype=dtype, name=name)
|
||||
@ -789,16 +729,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||
if dtype.is_integer:
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter = self._prepare_key_counter(shape)
|
||||
return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
|
||||
shape=shape,
|
||||
key=key,
|
||||
counter=counter,
|
||||
minval=minval,
|
||||
maxval=maxval,
|
||||
alg=self.algorithm,
|
||||
name=name)
|
||||
return gen_stateful_random_ops.stateful_uniform_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
minval=minval, maxval=maxval, name=name)
|
||||
|
@ -274,7 +274,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
with self.cached_session() as sess:
|
||||
gen1 = random.Generator.from_seed(seed)
|
||||
gen2 = random.Generator.from_non_deterministic_state()
|
||||
sess.run((gen1.state.initializer, gen2.state.initializer))
|
||||
sess.run((gen1._state_var.initializer, gen2._state_var.initializer))
|
||||
r1 = gen1.normal(shape, dtype=dtypes.float32)
|
||||
r2 = gen2.normal(shape, dtype=dtypes.float32)
|
||||
def f():
|
||||
@ -372,7 +372,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
|
||||
delta = 432
|
||||
gen.skip(delta)
|
||||
new_counter = gen.state[0]
|
||||
new_counter = gen._state_var[0]
|
||||
self.assertAllEqual(counter + delta * 256, new_counter)
|
||||
|
||||
def _sameAsOldRandomOps(self, device, floats):
|
||||
@ -394,7 +394,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device(device):
|
||||
return new(dtype, gen)
|
||||
|
||||
for _ in range(5):
|
||||
for _ in range(100):
|
||||
self.assertAllEqual(run_old(), run_new())
|
||||
|
||||
shape = constant_op.constant([4, 7])
|
||||
@ -582,11 +582,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_v2_only
|
||||
def testGetGlobalGeneratorWithXla(self):
|
||||
"""Demonstrates using the global generator with XLA."""
|
||||
# This test was passing before because soft placement silently picked the
|
||||
# CPU kernel.
|
||||
# TODO(wangpeng): Remove this skip
|
||||
self.skipTest("NonDeterministicInts lacks XLA kernel.")
|
||||
|
||||
if not config.list_physical_devices("XLA_CPU"):
|
||||
self.skipTest("No XLA_CPU device available.")
|
||||
|
||||
@ -680,16 +675,17 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testCreateOutsideMirroredStrat(self):
|
||||
@test_util.run_cuda_only
|
||||
def testMirroredStratSeq(self):
|
||||
"""Tests RNG/MirrorStrategy interaction #1.
|
||||
|
||||
If an RNG is created outside a DS scope, all replicas will access the
|
||||
If an RNG is created outside strategy.scope(), all replicas will access the
|
||||
same RNG object, and accesses are serialized.
|
||||
"""
|
||||
shape = [3, 4]
|
||||
dtype = dtypes.int32
|
||||
gen = random.Generator.from_seed(1234)
|
||||
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
||||
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
|
||||
with strat.scope():
|
||||
def f():
|
||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
@ -766,5 +762,4 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.set_soft_device_placement(False)
|
||||
test.main()
|
||||
|
@ -20,19 +20,16 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
ops.NotDifferentiable("StatelessMultinomial")
|
||||
ops.NotDifferentiable("StatelessRandomBinomial")
|
||||
ops.NotDifferentiable("StatelessRandomNormal")
|
||||
@ -43,13 +40,6 @@ ops.NotDifferentiable("StatelessRandomUniformFullInt")
|
||||
ops.NotDifferentiable("StatelessTruncatedNormal")
|
||||
|
||||
|
||||
ops.NotDifferentiable("StatelessRandomNormalV2")
|
||||
ops.NotDifferentiable("StatelessRandomUniformV2")
|
||||
ops.NotDifferentiable("StatelessRandomUniformIntV2")
|
||||
ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
|
||||
ops.NotDifferentiable("StatelessTruncatedNormalV2")
|
||||
|
||||
|
||||
@tf_export("random.experimental.stateless_split")
|
||||
@dispatch.add_dispatch_support
|
||||
def split(seed, num=2):
|
||||
@ -123,10 +113,6 @@ def fold_in(seed, data):
|
||||
return array_ops.stack([seed1, data])
|
||||
|
||||
|
||||
_get_key_counter_alg = (gen_stateless_random_ops_v2
|
||||
.stateless_random_get_key_counter_alg)
|
||||
|
||||
|
||||
@tf_export("random.stateless_uniform")
|
||||
@dispatch.add_dispatch_support
|
||||
def stateless_random_uniform(shape,
|
||||
@ -206,35 +192,17 @@ def stateless_random_uniform(shape,
|
||||
[shape, seed, minval, maxval]) as name:
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
if dtype.is_integer and minval is None:
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
result = (gen_stateless_random_ops_v2
|
||||
.stateless_random_uniform_full_int_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg,
|
||||
name=name))
|
||||
else:
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
|
||||
shape, seed=seed, dtype=dtype, name=name)
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
|
||||
shape, seed=seed, dtype=dtype, name=name)
|
||||
else:
|
||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||
if dtype.is_integer:
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
|
||||
shape, key=key, counter=counter, minval=minval, maxval=maxval,
|
||||
alg=alg, name=name)
|
||||
else:
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_int(
|
||||
shape, seed=seed, minval=minval, maxval=maxval, name=name)
|
||||
result = gen_stateless_random_ops.stateless_random_uniform_int(
|
||||
shape, seed=seed, minval=minval, maxval=maxval, name=name)
|
||||
else:
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_random_uniform(
|
||||
shape, seed=seed, dtype=dtype)
|
||||
rnd = gen_stateless_random_ops.stateless_random_uniform(
|
||||
shape, seed=seed, dtype=dtype)
|
||||
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
@ -508,12 +476,7 @@ def stateless_random_normal(shape,
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
@ -558,13 +521,8 @@ def stateless_truncated_normal(shape,
|
||||
shape = tensor_util.shape_tensor(shape)
|
||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
key, counter, alg = _get_key_counter_alg(seed)
|
||||
rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
|
||||
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||
else:
|
||||
rnd = gen_stateless_random_ops.stateless_truncated_normal(
|
||||
shape, seed, dtype)
|
||||
rnd = gen_stateless_random_ops.stateless_truncated_normal(
|
||||
shape, seed, dtype)
|
||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||
tensor_util.maybe_set_static_shape(result, shape)
|
||||
return result
|
||||
|
@ -3792,10 +3792,6 @@ tf_module {
|
||||
name: "Rint"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngReadAndSkip"
|
||||
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngSkip"
|
||||
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4548,18 +4544,10 @@ tf_module {
|
||||
name: "StatelessRandomGammaV2"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomGetKeyCounterAlg"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomPoisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4572,22 +4560,10 @@ tf_module {
|
||||
name: "StatelessRandomUniformFullInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformFullIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessSampleDistortedBoundingBox"
|
||||
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
||||
@ -4596,10 +4572,6 @@ tf_module {
|
||||
name: "StatelessTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessTruncatedNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessWhile"
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||
|
@ -3792,10 +3792,6 @@ tf_module {
|
||||
name: "Rint"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngReadAndSkip"
|
||||
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RngSkip"
|
||||
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4548,18 +4544,10 @@ tf_module {
|
||||
name: "StatelessRandomGammaV2"
|
||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomGetKeyCounterAlg"
|
||||
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomPoisson"
|
||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -4572,22 +4560,10 @@ tf_module {
|
||||
name: "StatelessRandomUniformFullInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformFullIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformInt"
|
||||
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformIntV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessRandomUniformV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessSampleDistortedBoundingBox"
|
||||
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
||||
@ -4596,10 +4572,6 @@ tf_module {
|
||||
name: "StatelessTruncatedNormal"
|
||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessTruncatedNormalV2"
|
||||
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "StatelessWhile"
|
||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user