Internal change

PiperOrigin-RevId: 332361104
Change-Id: I1f66d7fa0a7fa5e48656232278ae9e22f26f4747
This commit is contained in:
A. Unique TensorFlower 2020-09-17 18:44:10 -07:00 committed by TensorFlower Gardener
parent 4ec907ddd1
commit b1c97a0bb2
46 changed files with 203 additions and 2212 deletions

View File

@ -1995,8 +1995,6 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"ResourceScatterNdUpdate",
"ResourceScatterSub",
"ResourceScatterUpdate",
"RngReadAndSkip",
"RngSkip",
"Roll",
"ScatterNd",
"SelfAdjointEigV2",
@ -2019,17 +2017,11 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"StatelessCase",
"StatelessIf",
"StatelessMultinomial",
"StatelessRandomGetKeyCounterAlg",
"StatelessRandomNormal",
"StatelessRandomNormalV2",
"StatelessRandomUniform",
"StatelessRandomUniformV2",
"StatelessRandomUniformInt",
"StatelessRandomUniformIntV2",
"StatelessRandomUniformFullInt",
"StatelessRandomUniformFullIntV2",
"StatelessTruncatedNormal",
"StatelessTruncatedNormalV2",
"StatelessWhile",
"Svd",
"SymbolicGradient",

View File

@ -25,9 +25,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import device_lib
from tensorflow.python.compat import compat
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
@ -158,10 +156,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def testNewStateThreeFry(self):
"""Tests that the new state is correct (for ThreeFry).
"""
if compat.forward_compatible(2020, 10, 25):
self.skipTest("The expected values in this test is inconsistent with "
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
"new states for the new version.")
with ops.device(xla_device_name()):
counter = 57
key = 0x1234
@ -177,10 +171,6 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def testNewStatePhilox(self):
"""Tests that the new state is correct (for Philox).
"""
if compat.forward_compatible(2020, 10, 25):
self.skipTest("The expected values in this test is inconsistent with "
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
"new states for the new version.")
with ops.device(xla_device_name()):
counter_low = 57
counter_high = 283
@ -214,39 +204,13 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
"""Tests that XLA and CPU kernels generate the same integers."""
seed = 1234
shape = [315, 49]
if compat.forward_compatible(2020, 10, 25):
with ops.device("/device:CPU:0"):
cpu_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
with ops.device(xla_device_name()):
xla_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
# Repeat multiple times to make sure that the state after
# number-generation are the same between CPU and XLA.
for _ in range(5):
with ops.device("/device:CPU:0"):
# Test both number-generation and skip
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
cpu_gen.skip(100)
with ops.device(xla_device_name()):
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
xla_gen.skip(100)
self.assertAllEqual(cpu, xla)
self.assertAllEqual(cpu_gen.state, xla_gen.state)
else:
# The old version doesn't guarantee that CPU and XLA are in the same state
# after number-generation, which is a bug.
with ops.device("/device:CPU:0"):
cpu = (
random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
shape=shape, dtype=dtype))
with ops.device(xla_device_name()):
xla = (
random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
shape=shape, dtype=dtype))
self.assertAllEqual(cpu, xla)
with ops.device("/device:CPU:0"):
cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
.uniform_full_int(shape=shape, dtype=dtype))
with ops.device(xla_device_name()):
xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
.uniform_full_int(shape=shape, dtype=dtype))
self.assertAllEqual(cpu, xla)
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
@ -400,5 +364,4 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
if __name__ == "__main__":
ops.enable_eager_execution()
config.set_soft_device_placement(False)
test.main()

View File

@ -21,11 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import array_ops
@ -43,26 +39,6 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
allowed_types.update({dtypes.int32, dtypes.int64})
return self.all_tf_types & allowed_types
@test_util.run_v2_only
def testForcedCompile(self):
"""Tests whole-function forced-compilation.
This test checks that stateless_random_* can be used in forced-compilation
scenarios (e.g. TPU). The new version of stateless_random_* requires the
intermediate tensor `alg` to be compile-time constant, so we need to check
that this requirement is met. We use xla.compile instead of tf.function's
experimental_compile because the latter doesn't throw an error even if the
compile-time-constant constraint is not met.
"""
if config.list_logical_devices('TPU'):
self.skipTest('To accommodate OSS, xla.compile support for TPU is not '
'linked in.')
@def_function.function
def f(x):
return xla.compile(
lambda x: stateless.stateless_random_normal([], seed=x), [x])
f([1, 2])
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
with self.session(), self.test_scope():
@ -162,7 +138,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
def _benchmarkUniform(self, name, dtype, use_xla_jit):
def builder_fn():
def BuilderFn():
shape = (10, 1000, 1000)
seed_var = variables.Variable((312, 456),
dtype=dtypes.int32,
@ -171,7 +147,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
shape, seed=seed_var, dtype=dtype)
return '%s.shape%s' % (name, shape), [random_t]
xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu')
xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu')
def benchmarkUniformF32(self):
self._benchmarkUniform(
@ -191,5 +167,4 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
if __name__ == '__main__':
config.set_soft_device_placement(False)
test.main()

View File

@ -108,7 +108,6 @@ tf_kernel_library(
"stack_ops.cc",
"stateful_random_ops.cc",
"stateless_random_ops.cc",
"stateless_random_ops_v2.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
"tensor_list_ops.cc",
@ -188,7 +187,6 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:stateful_random_ops_header",
"//tensorflow/core/kernels:stateless_random_ops_v2_header",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/rng_alg.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/math/math_util.h"
@ -181,7 +180,7 @@ Status CompileImpl(
}
xla::Literal alg_literal;
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
Algorithm alg = Algorithm(alg_literal.Get<int>({}));
auto alg = alg_literal.Get<Algorithm>({});
if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
}
@ -408,80 +407,5 @@ REGISTER_XLA_OP(Name("StatefulUniformFullInt")
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
StatefulUniformFullIntOp);
xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
xla::XlaOp delta) {
// Multiplying 256 to be consistent with the CPU/GPU kernels
delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
if (alg == RNG_ALG_PHILOX) {
return xla::PhiloxIncreaseCounter(counter, delta);
} else {
return counter + delta;
}
}
xla::XlaOp PadRight(xla::XlaOp a, int n) {
return xla::Pad(a, xla::ScalarLike(a, 0),
xla::MakeEdgePaddingConfig({{0, n}}));
}
template <typename AlgEnumType = int64, bool read_old_value = false>
class RngSkipOp : public XlaOpKernel {
public:
explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const int state_input_idx = 0;
const int alg_input_idx = 1;
const int delta_input_idx = 2;
xla::XlaOp var;
TensorShape var_shape;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
&var_shape, &var));
xla::Literal alg_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
errors::InvalidArgument("Unsupported algorithm id: ", alg));
OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
if (read_old_value) {
auto counter_size = GetCounterSize(alg);
xla::XlaOp output = var;
if (RNG_MAX_COUNTER_SIZE > counter_size) {
// Because the size of `var` depends on the algorithm while we want the
// output to have a fixed size (to help shape inference), we fix the
// output size to be the maximal state size among algorithms, and right-
// pad it with zeros if var's size is smaller than that.
output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
}
ctx->SetOutput(0, output);
}
xla::XlaOp counter;
xla::XlaOp key;
std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
xla::XlaOp delta = ctx->Input(delta_input_idx);
delta = BitcastConvertType(delta, xla::U64);
auto new_counter = IncreaseCounter(alg, counter, delta);
var = StateAndKeyToVariable(alg, new_counter, key);
xla::PrimitiveType state_element_type;
OP_REQUIRES_OK(
ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
var = BitcastConvertType(var, state_element_type);
OP_REQUIRES_OK(
ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
};
REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
RngSkipOp<>);
using RngReadAndSkipOp = RngSkipOp<int32, true>;
REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
RngReadAndSkipOp);
} // namespace
} // namespace tensorflow

View File

@ -111,8 +111,6 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
}
}
namespace {
xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
xla::XlaOp seeds,
const xla::Shape& shape) {
@ -142,6 +140,8 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
}
}
namespace {
class StatelessRandomUniformOp : public XlaOpKernel {
public:
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)

View File

@ -1,485 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
#include <cmath>
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/rng_alg.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace tensorflow {
namespace {
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
if (alg == RNG_ALG_PHILOX) {
return xla::RandomAlgorithm::RNG_PHILOX;
}
return xla::RandomAlgorithm::RNG_THREE_FRY;
}
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
return RNG_ALG_PHILOX;
}
return RNG_ALG_THREEFRY;
}
xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
Algorithm alg_ = RandomAlgorithmToAlgorithm(alg);
return xla::Slice(state, {RNG_KEY_SIZE},
{RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
}
xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
xla::XlaOp counter, const xla::Shape& shape) {
key = BitcastConvertType(key, xla::U64);
counter = BitcastConvertType(counter, xla::U64);
xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
new_counter = BitcastConvertType(new_counter, xla::S64);
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
/*state=*/new_counter};
}
std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
absl::string_view device_type_string, xla::XlaOp key) {
// The Philox algorithm may cause performance regression on other devices.
// Turn on the Philox algorithm for the CPU and GPU backends only.
if (device_type_string == DEVICE_GPU_XLA_JIT ||
device_type_string == DEVICE_CPU_XLA_JIT) {
auto counter_key = xla::ScramblePhiloxKey(key);
return std::make_tuple(counter_key.second, counter_key.first,
RNG_ALG_PHILOX);
} else {
auto counter_shape =
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
auto counter = xla::Zeros(key.builder(), counter_shape);
return std::make_tuple(key, counter, RNG_ALG_THREEFRY);
}
}
} // namespace
xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
xla::XlaOp key, xla::XlaOp counter,
const xla::Shape& shape, xla::XlaOp minval,
xla::XlaOp maxval) {
xla::XlaBuilder* builder = key.builder();
xla::PrimitiveType type = shape.element_type();
using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
switch (type) {
case xla::F32:
case xla::F64:
return xla::UniformFloatingPointDistribution(key, counter, generator,
minval, maxval, shape);
case xla::S32:
case xla::S64:
case xla::U32:
case xla::U64:
return UniformIntDistribution(key, counter, generator, minval, maxval,
shape);
break;
default:
return {builder->ReportError(xla::Unimplemented(
"Types other than F32, S32, S64, U32 and U64 are not "
"implemented by "
"StatelessRngUniformV2; got %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
counter};
}
}
namespace {
xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
xla::XlaOp key, xla::XlaOp counter,
const xla::Shape& shape) {
xla::XlaBuilder* builder = key.builder();
xla::PrimitiveType type = shape.element_type();
xla::RngOutput output = BitGenerator(alg, key, counter, shape);
switch (type) {
case xla::U32:
case xla::U64:
return output;
case xla::S32:
case xla::S64:
return xla::RngOutput{BitcastConvertType(output.value, type),
output.state};
default:
return {
builder->ReportError(xla::Unimplemented(
"Types other than U32, S32, U64 and S64 are not implemented by "
"StatelessRngUniformFullInt; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
output.state};
}
}
Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx,
xla::RandomAlgorithm* alg) {
auto alg_shape = ctx->InputShape(alg_input_idx);
if (alg_shape.dims() != 0) {
return errors::InvalidArgument("algorithm must be of shape [], not ",
alg_shape.DebugString());
}
xla::Literal alg_literal;
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
auto alg_ = Algorithm(alg_literal.Get<int>({}));
*alg = AlgorithmToRandomAlgorithm(alg_);
return Status::OK();
}
xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
TensorShape const& counter_shape,
xla::XlaOp counter) {
auto input_counter_size = counter_shape.dim_size(0);
auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg));
if (input_counter_size > real_counter_size) {
counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
}
return counter;
}
class StatelessRandomUniformOp : public XlaOpKernel {
public:
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* builder = ctx->builder();
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
const int key_input_idx = 1;
const int counter_input_idx = 2;
const int alg_input_idx = 3;
xla::XlaOp key = ctx->Input(key_input_idx);
xla::XlaOp counter = ctx->Input(counter_input_idx);
xla::RandomAlgorithm alg;
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
auto counter_shape = ctx->InputShape(counter_input_idx);
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
ctx->InputShape(key_input_idx),
counter_shape));
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
counter = MaybeSliceCounter(alg, counter_shape, counter);
auto result = StatelessRngUniformV2(
alg, key, counter, xla_shape,
xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
ctx->SetOutput(0, uniform);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
};
REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
.CompileTimeConstantInput("shape")
.CompileTimeConstantInput("alg")
.TypeConstraint("dtype",
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
StatelessRandomUniformOp);
class StatelessRandomUniformIntOp : public XlaOpKernel {
public:
explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
const int key_input_idx = 1;
const int counter_input_idx = 2;
const int alg_input_idx = 3;
xla::XlaOp key = ctx->Input(key_input_idx);
xla::XlaOp counter = ctx->Input(counter_input_idx);
xla::RandomAlgorithm alg;
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
auto counter_shape = ctx->InputShape(counter_input_idx);
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
ctx->InputShape(key_input_idx),
counter_shape));
const int minval_input_idx = 4;
const int maxval_input_idx = 5;
TensorShape minval_shape = ctx->InputShape(minval_input_idx);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
errors::InvalidArgument("minval must be scalar, got shape ",
minval_shape.DebugString()));
TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
errors::InvalidArgument("maxval must be scalar, got shape ",
maxval_shape.DebugString()));
xla::XlaOp minval = ctx->Input(minval_input_idx);
xla::XlaOp maxval = ctx->Input(maxval_input_idx);
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
counter = MaybeSliceCounter(alg, counter_shape, counter);
auto result =
StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
ctx->SetOutput(0, result.value);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
};
REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
.CompileTimeConstantInput("shape")
.CompileTimeConstantInput("alg")
.TypeConstraint("dtype",
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
StatelessRandomUniformIntOp);
class StatelessRandomUniformFullIntOp : public XlaOpKernel {
public:
explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
const int key_input_idx = 1;
const int counter_input_idx = 2;
const int alg_input_idx = 3;
xla::XlaOp key = ctx->Input(key_input_idx);
xla::XlaOp counter = ctx->Input(counter_input_idx);
xla::RandomAlgorithm alg;
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
auto counter_shape = ctx->InputShape(counter_input_idx);
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
ctx->InputShape(key_input_idx),
counter_shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
counter = MaybeSliceCounter(alg, counter_shape, counter);
auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
ctx->SetOutput(0, result.value);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
};
REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
.CompileTimeConstantInput("shape")
.CompileTimeConstantInput("alg")
.TypeConstraint("dtype",
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
StatelessRandomUniformFullIntOp);
class StatelessRandomNormalOp : public XlaOpKernel {
public:
explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
const int key_input_idx = 1;
const int counter_input_idx = 2;
const int alg_input_idx = 3;
xla::XlaOp key = ctx->Input(key_input_idx);
xla::XlaOp counter = ctx->Input(counter_input_idx);
xla::RandomAlgorithm alg;
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
auto counter_shape = ctx->InputShape(counter_input_idx);
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
ctx->InputShape(key_input_idx),
counter_shape));
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
using std::placeholders::_1;
using std::placeholders::_2;
using std::placeholders::_3;
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
counter = MaybeSliceCounter(alg, counter_shape, counter);
auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
xla_shape);
auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
ctx->SetOutput(0, normal);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
};
REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
.CompileTimeConstantInput("shape")
.CompileTimeConstantInput("alg")
.TypeConstraint("dtype",
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
StatelessRandomNormalOp);
class StatelessTruncatedNormalOp : public XlaOpKernel {
public:
explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
const int key_input_idx = 1;
const int counter_input_idx = 2;
const int alg_input_idx = 3;
xla::XlaOp key = ctx->Input(key_input_idx);
xla::XlaOp counter = ctx->Input(counter_input_idx);
xla::RandomAlgorithm alg;
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
auto counter_shape = ctx->InputShape(counter_input_idx);
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
ctx->InputShape(key_input_idx),
counter_shape));
xla::XlaBuilder* builder = ctx->builder();
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
counter = MaybeSliceCounter(alg, counter_shape, counter);
auto result = StatelessRngUniformV2(
alg, key, counter, xla_shape,
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
xla::One(builder, xla_shape.element_type()));
xla::XlaOp truncated_normal = TruncatedNormal(result.value);
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
ctx->SetOutput(0, truncated_normal);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
};
REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
.CompileTimeConstantInput("shape")
.CompileTimeConstantInput("alg")
.TypeConstraint("dtype",
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
StatelessTruncatedNormalOp);
class GetKeyCounterAlgOp : public XlaOpKernel {
public:
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
device_type_string_(ctx->device_type().type_string()) {}
void Compile(XlaOpKernelContext* ctx) override {
TensorShape seed_shape = ctx->InputShape(0);
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
xla::XlaOp seed = ctx->Input(0);
xla::XlaBuilder* builder = seed.builder();
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key);
key = std::get<0>(key_counter_alg);
auto counter = std::get<1>(key_counter_alg);
auto alg = std::get<2>(key_counter_alg);
key = xla::Reshape(key, {RNG_KEY_SIZE});
ctx->SetOutput(0, key);
ctx->SetOutput(1, counter);
ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
}
private:
string device_type_string_;
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
};
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
} // namespace
} // namespace tensorflow

View File

@ -83,8 +83,6 @@ CreateResourceOpInfoMap() {
add("ResourceScatterSub" , kReadWrite, kVariable);
add("ResourceScatterUpdate" , kReadWrite, kVariable);
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
add("RngReadAndSkip" , kReadWrite, kVariable);
add("RngSkip" , kReadWrite, kVariable);
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
add("StatefulTruncatedNormal" , kReadWrite, kVariable);
add("StatefulUniform" , kReadWrite, kVariable);

View File

@ -487,10 +487,6 @@ std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
} // namespace
XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
}
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape) {
PrimitiveType type = shape.element_type();

View File

@ -89,9 +89,6 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
absl::Span<const xla::XlaOp> scalars);
// Increases Philox counter (an uint128) by a delta (an uint64).
xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_

View File

@ -488,7 +488,6 @@ tf_cuda_library(
"//tensorflow/core/framework:register_types_traits.h",
"//tensorflow/core/framework:resource_mgr.h",
"//tensorflow/core/framework:resource_op_kernel.h",
"//tensorflow/core/framework:rng_alg.h",
"//tensorflow/core/framework:selective_registration.h",
"//tensorflow/core/framework:session_state.h",
"//tensorflow/core/framework:shape_inference.h",
@ -653,7 +652,6 @@ tf_gen_op_libs(
"spectral_ops",
"state_ops",
"stateless_random_ops",
"stateless_random_ops_v2",
"summary_ops",
"training_ops",
],
@ -873,7 +871,6 @@ cc_library(
":spectral_ops_op_lib",
":state_ops_op_lib",
":stateless_random_ops_op_lib",
":stateless_random_ops_v2_op_lib",
":string_ops_op_lib",
":training_ops_op_lib",
":user_ops_op_lib",

View File

@ -1,35 +0,0 @@
op {
graph_op_name: "RngReadAndSkip"
visibility: HIDDEN
in_arg {
name: "resource"
description: <<END
The handle of the resource variable that stores the state of the RNG.
END
}
in_arg {
name: "alg"
description: <<END
The RNG algorithm.
END
}
in_arg {
name: "delta"
description: <<END
The amount of advancement.
END
}
out_arg {
name: "value"
description: <<END
The old value of the resource variable, before incrementing. Since state size is algorithm-dependent, this output will be right-padded with zeros to reach shape int64[3] (the current maximal state size among algorithms).
END
}
summary: "Advance the counter of a counter-based RNG."
description: <<END
The state of the RNG after
`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
(or any other distribution). The actual increment added to the
counter is an unspecified implementation choice.
END
}

View File

@ -1,32 +0,0 @@
op {
graph_op_name: "StatelessRandomGetKeyCounterAlg"
visibility: HIDDEN
in_arg {
name: "seed"
description: <<END
2 seeds (shape [2]).
END
}
out_arg {
name: "key"
description: <<END
Key for the counter-based RNG algorithm (shape uint64[1]).
END
}
out_arg {
name: "counter"
description: <<END
Counter for the counter-based RNG algorithm. Since counter size is algorithm-dependent, this output will be right-padded with zeros to reach shape uint64[2] (the current maximal counter size among algorithms).
END
}
out_arg {
name: "alg"
description: <<END
The RNG algorithm (shape int32[]).
END
}
summary: "Picks the best algorithm based on device, and scrambles seed into key and counter."
description: <<END
This op picks the best counter-based RNG algorithm based on device, and scrambles a shape-[2] seed into a key and a counter, both needed by the counter-based algorithm. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
END
}

View File

@ -1,46 +0,0 @@
op {
graph_op_name: "StatelessRandomNormalV2"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "key"
description: <<END
Key for the counter-based RNG algorithm (shape uint64[1]).
END
}
in_arg {
name: "counter"
description: <<END
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
END
}
in_arg {
name: "alg"
description: <<END
The RNG algorithm (shape int32[]).
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom values from a normal distribution."
description: <<END
The generated values will have mean 0 and standard deviation 1.
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
END
}

View File

@ -1,46 +0,0 @@
op {
graph_op_name: "StatelessRandomUniformFullIntV2"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "key"
description: <<END
Key for the counter-based RNG algorithm (shape uint64[1]).
END
}
in_arg {
name: "counter"
description: <<END
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
END
}
in_arg {
name: "alg"
description: <<END
The RNG algorithm (shape int32[]).
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
description: <<END
The generated values are uniform integers covering the whole range of `dtype`.
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
END
}

View File

@ -1,58 +0,0 @@
op {
graph_op_name: "StatelessRandomUniformIntV2"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "key"
description: <<END
Key for the counter-based RNG algorithm (shape uint64[1]).
END
}
in_arg {
name: "counter"
description: <<END
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
END
}
in_arg {
name: "alg"
description: <<END
The RNG algorithm (shape int32[]).
END
}
in_arg {
name: "minval"
description: <<END
Minimum value (inclusive, scalar).
END
}
in_arg {
name: "maxval"
description: <<END
Maximum value (exclusive, scalar).
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
description: <<END
The generated values follow a uniform distribution in the range `[minval, maxval)`.
The outputs are a deterministic function of `shape`, `key`, `counter`, `alg`, `minval` and `maxval`.
END
}

View File

@ -1,47 +0,0 @@
op {
graph_op_name: "StatelessRandomUniformV2"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "key"
description: <<END
Key for the counter-based RNG algorithm (shape uint64[1]).
END
}
in_arg {
name: "counter"
description: <<END
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
END
}
in_arg {
name: "alg"
description: <<END
The RNG algorithm (shape int32[]).
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom random values from a uniform distribution."
description: <<END
The generated values follow a uniform distribution in the range `[0, 1)`. The
lower bound 0 is included in the range, while the upper bound 1 is excluded.
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
END
}

View File

@ -1,48 +0,0 @@
op {
graph_op_name: "StatelessTruncatedNormalV2"
visibility: HIDDEN
in_arg {
name: "shape"
description: <<END
The shape of the output tensor.
END
}
in_arg {
name: "key"
description: <<END
Key for the counter-based RNG algorithm (shape uint64[1]).
END
}
in_arg {
name: "counter"
description: <<END
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
END
}
in_arg {
name: "alg"
description: <<END
The RNG algorithm (shape int32[]).
END
}
out_arg {
name: "output"
description: <<END
Random values with specified shape.
END
}
attr {
name: "dtype"
description: <<END
The type of the output.
END
}
summary: "Outputs deterministic pseudorandom values from a truncated normal distribution."
description: <<END
The generated values follow a normal distribution with mean 0 and standard
deviation 1, except that values whose magnitude is more than 2 standard
deviations from the mean are dropped and re-picked.
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
END
}

View File

@ -68,7 +68,6 @@ exports_files(
"resource_mgr.h",
"resource_op_kernel.h",
"resource_var.h",
"rng_alg.h",
"run_handler.h",
"run_handler_util.h",
"session_state.h",
@ -386,7 +385,6 @@ filegroup(
"resource_mgr.h",
"resource_op_kernel.h",
"resource_var.h",
"rng_alg.h",
"run_handler.cc",
"run_handler.h",
"run_handler_util.cc",

View File

@ -1,34 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
#define TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
namespace tensorflow {
enum Algorithm { RNG_ALG_PHILOX = 1, RNG_ALG_THREEFRY = 2 };
static constexpr int RNG_KEY_SIZE = 1;
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
inline int GetCounterSize(Algorithm alg) {
if (alg == RNG_ALG_PHILOX) {
return 2;
}
return 1;
}
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_

View File

@ -4457,22 +4457,12 @@ cc_library(
],
)
cc_library(
name = "stateless_random_ops_v2_header",
hdrs = ["stateless_random_ops_v2.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_kernel_library(
name = "stateful_random_ops",
prefix = "stateful_random_ops",
deps = [
":bounds_check",
":dense_update_functor",
":fill_functor",
":gather_functor",
":mutex_ops",
":random_op",
@ -5347,7 +5337,7 @@ tf_kernel_library(
prefix = "random_binomial_op",
deps = [
":cwise_op",
":random_op",
":random_ops",
":resource_variable_ops",
":stateful_random_ops",
":stateless_random_ops",
@ -6169,7 +6159,6 @@ filegroup(
"ragged_tensor_to_tensor_op.cc",
"random_op.cc",
"random_op_cpu.h",
"random_ops_util.h",
"random_poisson_op.cc",
"reduce_join_op.cc",
"reduction_ops_all.cc",

View File

@ -66,9 +66,8 @@ struct MultinomialFunctor<GPUDevice, T, OutputType> {
typename TTypes<OutputType>::Matrix output) {
// Uniform, [0, 1).
typedef random::UniformDistribution<random::PhiloxRandom, float> Dist;
functor::FillPhiloxRandom<GPUDevice, Dist>()(
ctx, d, /*key=*/nullptr, /*counter=*/nullptr, gen, noises.data(),
noises.size(), Dist());
functor::FillPhiloxRandom<GPUDevice, Dist>()(ctx, d, gen, noises.data(),
noises.size(), Dist());
#if defined(EIGEN_HAS_INDEX_LIST)
Eigen::IndexList<int, int, int> bsc;

View File

@ -30,10 +30,8 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/rng_alg.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/random_ops_util.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
@ -377,7 +375,7 @@ class RandomBinomialOp : public OpKernel {
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
errors::InvalidArgument("algorithm must be of shape [], not ",
alg_tensor.shape().DebugString()));
Algorithm alg = Algorithm(alg_tensor.flat<int64>()(0));
Algorithm alg = alg_tensor.flat<Algorithm>()(0);
int64 samples_per_batch = 1;
const int64 num_sample_dims =

View File

@ -74,7 +74,7 @@ class PhiloxRandomOp : public OpKernel {
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<T>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
ctx, ctx->eigen_device<Device>(),
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
// it just here.
generator_.ReserveRandomOutputs(output_flat.size(), 256),
@ -123,7 +123,7 @@ class RandomUniformIntOp : public OpKernel {
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
ctx, ctx->eigen_device<Device>(),
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
// it just here.
generator_.ReserveRandomOutputs(output_flat.size(), 256),

View File

@ -34,14 +34,10 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// NOTE: Due to inlining done by the compiler, you may need to add
// explicit instantiation of the functor in random_op.cc. See example
// functor::FillPhiloxRandom<CPUDevice, random::UniformDistribution>.
//
// This functor can take the PhiloxRandom input from either device memory `key`
// and `counter` or a stack value `gen`. If both `key` and `counter` are not
// nullptr, they provide the input; otherwise `gen` provides the input.
template <class Distribution>
struct FillPhiloxRandom<CPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const CPUDevice& d, const uint64* key,
const uint64* counter, random::PhiloxRandom gen,
void operator()(OpKernelContext* ctx, const CPUDevice& d,
random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist);
};
@ -51,13 +47,14 @@ typedef Eigen::GpuDevice GPUDevice;
// Declares the partially GPU-specialized functor struct.
template <class Distribution>
struct FillPhiloxRandom<GPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const GPUDevice& d, const uint64* key,
const uint64* counter, random::PhiloxRandom gen,
void operator()(OpKernelContext* ctx, const GPUDevice& d,
random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist);
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace functor
} // namespace tensorflow

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/kernels/random_ops_util.h"
#include "tensorflow/core/lib/hash/crc32c.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/simple_philox.h"
@ -60,9 +59,8 @@ using random::SingleSampleAdapter;
template <typename Device, class Distribution>
struct FillPhiloxRandom {
typedef typename Distribution::ResultElementType T;
void operator()(OpKernelContext* ctx, const Device&, const uint64* key,
const uint64* counter, random::PhiloxRandom gen, T* data,
int64 size, Distribution dist) {
void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen,
T* data, int64 size, Distribution dist) {
OP_REQUIRES(
ctx, false,
errors::Internal(
@ -156,24 +154,18 @@ struct FillPhiloxRandomTask<Distribution, true> {
// It splits the work into several tasks and run them in parallel
template <class Distribution>
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
OpKernelContext* ctx, const CPUDevice&, const uint64* key,
const uint64* counter, random::PhiloxRandom gen,
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
const int kGroupCost =
random::PhiloxRandom::kResultElementCount *
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
if (key != nullptr && counter != nullptr) {
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
}
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
kGroupCost,
[&gen, data, size, dist](int64 start_group, int64 limit_group) {

View File

@ -19,7 +19,6 @@ limitations under the License.
#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/kernels/random_ops_util.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
@ -34,16 +33,14 @@ struct FillPhiloxRandomKernel;
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, false> {
typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
random::PhiloxRandom gen, T* data, int64 size,
PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size,
Distribution dist);
};
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, true> {
typedef typename Distribution::ResultElementType T;
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
random::PhiloxRandom base_gen, T* data,
PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data,
int64 size, Distribution dist);
};
@ -139,16 +136,12 @@ class SampleCopier<int64, 2> {
// distribution. Each output takes a fixed number of samples.
template <class Distribution>
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
const uint64* key, const uint64* counter, random::PhiloxRandom gen, T* data,
int64 size, Distribution dist) {
random::PhiloxRandom gen, T* data, int64 size, Distribution dist) {
const int kGroupSize = Distribution::kResultElementCount;
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int32 total_thread_count = gridDim.x * blockDim.x;
int32 offset = thread_id * kGroupSize;
if (key != nullptr && counter != nullptr) {
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
}
gen.Skip(thread_id);
const SampleCopier<T, kGroupSize> copier;
@ -174,8 +167,8 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
// distribution. Each output takes a variable number of samples.
template <class Distribution>
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
const uint64* key, const uint64* counter, random::PhiloxRandom base_gen,
T* data, int64 size, Distribution dist) {
const random::PhiloxRandom& base_gen, T* data, int64 size,
Distribution dist) {
using random::PhiloxRandom;
using random::SingleSampleAdapter;
@ -190,9 +183,6 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
int64 group_index = thread_id;
int64 offset = group_index * kGroupSize;
if (key != nullptr && counter != nullptr) {
base_gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
}
while (offset < size) {
// Since each output takes a variable number of samples, we need to
// realign the generator to the beginning for the current output group
@ -218,20 +208,18 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
// A simple launch pad to call the correct function templates to fill the data
template <class Distribution>
__global__ void __launch_bounds__(1024)
FillPhiloxRandomKernelLaunch(const uint64* key, const uint64* counter,
random::PhiloxRandom base_gen,
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
typename Distribution::ResultElementType* data,
int64 size, Distribution dist) {
FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>()
.Run(key, counter, base_gen, data, size, dist);
.Run(base_gen, data, size, dist);
}
// Partial specialization for GPU
template <class Distribution>
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
OpKernelContext*, const GPUDevice& d, const uint64* key,
const uint64* counter, random::PhiloxRandom gen,
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
typename Distribution::ResultElementType* data, int64 size,
Distribution dist) {
const int32 block_size = d.maxGpuThreadsPerBlock();
@ -240,8 +228,8 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
block_size;
TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
num_blocks, block_size, 0, d.stream(), key,
counter, gen, data, size, dist));
num_blocks, block_size, 0, d.stream(), gen, data,
size, dist));
}
} // namespace functor

View File

@ -1,72 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
#define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
#include "tensorflow/core/lib/random/philox_random.h"
namespace tensorflow {
using random::PhiloxRandom;
// The following 2 functions use the contract "lower 32 bits for the first
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
// unlike a direct memory copy `memcpy(output, &input, 8)`.
PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1,
uint32* output2) {
*output1 = static_cast<uint32>(input);
*output2 = static_cast<uint32>(input >> 32);
}
PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) {
auto u64_1 = static_cast<uint64>(input1);
auto u64_2 = static_cast<uint64>(input2);
return u64_1 | (u64_2 << 32);
}
PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem(
uint64 const* ptr) {
PhiloxRandom::ResultType counter;
Uint64ToUint32s(ptr[0], &counter[0], &counter[1]);
Uint64ToUint32s(ptr[1], &counter[2], &counter[3]);
return counter;
}
PHILOX_DEVICE_INLINE void WriteCounterToMem(
PhiloxRandom::ResultType const& counter, uint64* ptr) {
ptr[0] = Uint32sToUint64(counter[0], counter[1]);
ptr[1] = Uint32sToUint64(counter[2], counter[3]);
}
PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) {
PhiloxRandom::Key key;
Uint64ToUint32s(ptr[0], &key[0], &key[1]);
return key;
}
PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key,
uint64* ptr) {
*ptr = Uint32sToUint64(key[0], key[1]);
}
PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem(
uint64 const* counter_ptr, uint64 const* key_ptr) {
return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr));
}
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_

View File

@ -15,9 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/rng_alg.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/random_op_cpu.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
@ -25,8 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace functor {
template <typename Distribution>
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
void operator()(OpKernelContext* ctx, const CPUDevice& device,
@ -46,13 +42,10 @@ struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
// No longer needs the lock.
state_var_guard->Release();
functor::FillPhiloxRandom<CPUDevice, Distribution>()(
ctx, device, /*key=*/nullptr, /*counter=*/nullptr, philox, output_data,
output_size, dist);
ctx, device, philox, output_data, output_size, dist);
}
};
} // end namespace functor
Status CheckState(const Tensor& state) {
if (state.dtype() != STATE_ELEMENT_DTYPE) {
return errors::InvalidArgument("dtype of RNG state variable must be ",
@ -71,12 +64,11 @@ Status CheckPhiloxState(const Tensor& state, int64 alg_tag_skip = 0) {
"StateElementType must be int64");
static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
"PhiloxRandom::ResultElementType must be uint32");
auto min_size = alg_tag_skip + PHILOX_MIN_STATE_SIZE;
if (state.NumElements() < min_size) {
if (state.NumElements() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
return errors::InvalidArgument(
"For the Philox algorithm, the size of state"
" must be at least ",
min_size, "; got ", state.NumElements());
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ", state.NumElements());
}
return Status::OK();
}
@ -103,7 +95,7 @@ Status UpdateVariableAndFill(
if (var_tensor_flat.size() < 1) {
return errors::InvalidArgument("Size of tensor must be at least 1");
}
alg = Algorithm(var_tensor_flat(0));
alg = var_tensor_flat(0);
}
if (alg == RNG_ALG_PHILOX) {
TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip));
@ -115,7 +107,7 @@ Status UpdateVariableAndFill(
arg.alg_tag_skip = alg_tag_skip;
arg.not_used = &state_var_guard;
arg.state_tensor = var_tensor;
functor::UpdateVariableAndFill_Philox<Device, Distribution>()(
UpdateVariableAndFill_Philox<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(), dist, &arg, output_data);
return Status::OK();
} else {
@ -146,8 +138,7 @@ class StatefulRandomOp : public OpKernel {
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true,
RNG_ALG_PHILOX /*dummy*/);
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true, 0);
}
};
@ -168,14 +159,6 @@ Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
return Status::OK();
}
template <typename AlgEnumType>
Status GetAlg(OpKernelContext* ctx, int input_idx, Algorithm* alg) {
AlgEnumType alg_id;
TF_RETURN_IF_ERROR(GetScalar(ctx->input(input_idx), input_idx, &alg_id));
*alg = Algorithm(alg_id);
return Status::OK();
}
template <typename Device, class Distribution>
class StatefulRandomOpV2 : public OpKernel {
public:
@ -183,7 +166,7 @@ class StatefulRandomOpV2 : public OpKernel {
void Compute(OpKernelContext* ctx) override {
Algorithm alg;
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
/*shape_input_idx=*/2,
/*read_alg_from_state=*/false, alg);
@ -197,7 +180,7 @@ class StatefulUniformIntOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
Algorithm alg;
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
const Tensor& minval = ctx->input(3);
const Tensor& maxval = ctx->input(4);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
@ -234,7 +217,7 @@ class StatefulUniformFullIntOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
Algorithm alg;
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
StatefulRandomCompute<Device>(
ctx,
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
@ -243,66 +226,38 @@ class StatefulUniformFullIntOp : public OpKernel {
}
};
namespace functor {
template <>
struct RngSkip_Philox<CPUDevice> {
void operator()(const CPUDevice& device, const StateElementType* in_data,
uint64 delta, StateElementType* out_data) {
void operator()(const CPUDevice& device, int64 delta, Tensor* state_tensor) {
auto state_data = state_tensor->flat<StateElementType>().data();
// Delegates to PhiloxRandom to do the actual increasing.
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
auto philox = GetPhiloxRandomFromMem(state_data);
UpdateMemWithPhiloxRandom(philox, delta, state_data);
}
};
} // end namespace functor
template <typename Device, typename AlgEnumType = int64,
typename DeltaType = int64, bool read_old_value = false>
template <typename Device>
class RngSkipOp : public OpKernel {
public:
explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
auto state_input_idx = 0;
auto alg_input_idx = 1;
auto delta_input_idx = 2;
Algorithm alg;
OP_REQUIRES_OK(ctx, GetAlg<AlgEnumType>(ctx, alg_input_idx, &alg));
DeltaType delta_;
OP_REQUIRES_OK(
ctx, GetScalar(ctx->input(delta_input_idx), delta_input_idx, &delta_));
uint64 delta = static_cast<uint64>(delta_);
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
int64 delta;
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(2), 2, &delta));
Var* var = nullptr;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
ScopedUnlockUnrefVar state_var_guard(var);
Tensor* var_tensor = var->tensor();
OP_REQUIRES_OK(ctx, CheckState(*var_tensor));
using T = StateElementType;
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, T>(
ctx, var_tensor, var->copy_on_read_mode.load()));
if (read_old_value) {
Tensor* output;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, {RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE},
&output));
auto output_flat = output->flat<T>();
if (RNG_MAX_COUNTER_SIZE > GetCounterSize(alg)) {
functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
output_flat);
}
functor::DenseUpdate<Device, T, ASSIGN>()(
ctx->eigen_device<Device>(), output_flat,
const_cast<const Tensor*>(var_tensor)->flat<T>());
}
if (alg == RNG_ALG_PHILOX) {
OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor));
// var_tensor layout is counter+key, so var_tensor data is also counter
// data.
auto counter_data = var_tensor->flat<T>().data();
functor::RngSkip_Philox<Device>()(ctx->eigen_device<Device>(),
counter_data, delta, counter_data);
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
ctx, var_tensor, var->copy_on_read_mode.load()));
RngSkip_Philox<Device>()(ctx->eigen_device<Device>(), delta, var_tensor);
} else {
OP_REQUIRES(ctx, false,
errors::InvalidArgument("Unsupported algorithm id: ", alg));
@ -438,20 +393,13 @@ TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU);
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
// TODO(wangpeng): Remove `HostMemory("delta")` for RngReadAndSkip
#define REGISTER_RngSkip(DEVICE) \
REGISTER_KERNEL_BUILDER(Name("RngSkip") \
.Device(DEVICE_##DEVICE) \
.HostMemory("resource") \
.HostMemory("algorithm") \
.HostMemory("delta"), \
RngSkipOp<DEVICE##Device>); \
REGISTER_KERNEL_BUILDER(Name("RngReadAndSkip") \
.Device(DEVICE_##DEVICE) \
.HostMemory("resource") \
.HostMemory("alg") \
.HostMemory("delta"), \
RngSkipOp<DEVICE##Device, int32, uint64, true>);
RngSkipOp<DEVICE##Device>);
REGISTER_RngSkip(CPU);

View File

@ -22,12 +22,15 @@ limitations under the License.
namespace tensorflow {
// 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained
// in b/111604096 and cl/171681867), so we use signed int here. We choose int64
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU, and
// because of the "int32 problem".
// in b/111604096 and cl/171681867), so I use signed int here. I choose int64
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU.
using StateElementType = int64;
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
using Algorithm = StateElementType;
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
static constexpr Algorithm RNG_ALG_PHILOX = 1;
static constexpr Algorithm RNG_ALG_THREEFRY = 2;
using random::PhiloxRandom;

View File

@ -17,51 +17,59 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/kernels/random_ops_util.h"
#include "tensorflow/core/kernels/stateful_random_ops.h"
namespace tensorflow {
// The following 5 functions are made templates to avoid duplicate symbols when
// linking.
// The following 2 functions use the contract "lower 32 bits for the first
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
// unlike a direct memory copy `memcpy(output, &input, 8)`.
PHILOX_DEVICE_INLINE void Int64ToUint32s(int64 input, uint32* output1,
uint32* output2) {
auto u64 = static_cast<uint64>(input);
*output1 = static_cast<uint32>(u64);
*output2 = static_cast<uint32>(u64 >> 32);
}
PHILOX_DEVICE_INLINE int64 Uint32sToInt64(uint32 input1, uint32 input2) {
auto u64_1 = static_cast<uint64>(input1);
auto u64_2 = static_cast<uint64>(input2);
return static_cast<int64>(u64_1 | (u64_2 << 32));
}
PHILOX_DEVICE_INLINE PhiloxRandom
GetPhiloxRandomFromMem(StateElementType const* ptr) {
auto ptr_ = reinterpret_cast<uint64 const*>(ptr);
return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2);
PhiloxRandom::ResultType counter;
PhiloxRandom::Key key;
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
Int64ToUint32s(ptr[2], &key[0], &key[1]);
return PhiloxRandom(counter, key);
}
PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
StateElementType* ptr) {
auto ptr_ = reinterpret_cast<uint64*>(ptr);
WriteCounterToMem(philox.counter(), ptr_);
WriteKeyToMem(philox.key(), ptr_ + 2);
}
PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox,
uint64 output_size) {
auto new_philox = philox;
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it
// just here.
auto delta = output_size * 256;
new_philox.Skip(delta); // do the actual increasing
return new_philox;
PhiloxRandom::ResultType const& counter = philox.counter();
PhiloxRandom::Key const& key = philox.key();
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
ptr[2] = Uint32sToInt64(key[0], key[1]);
}
PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
uint64 output_size,
int64 output_size,
StateElementType* ptr) {
auto new_philox = SkipPhiloxRandom(philox, output_size);
auto new_philox = philox;
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
// it just here.
auto delta = output_size * 256;
new_philox.Skip(delta); // do the actual increasing
WritePhiloxRandomToMem(new_philox, ptr);
}
PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom(
PhiloxRandom::ResultType const& counter, uint64 output_size,
StateElementType* ptr) {
auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/);
auto new_philox = SkipPhiloxRandom(philox, output_size);
WriteCounterToMem(new_philox.counter(), reinterpret_cast<uint64*>(ptr));
}
namespace functor {
// A per-device helper function that does the actual work for
// `UpdateVariableAndFill`.
// Reason to use functor: C++ doesn't allow function-template partial
@ -72,8 +80,6 @@ struct UpdateVariableAndFill_Philox;
template <typename Device>
struct RngSkip_Philox;
} // end namespace functor
using CPUDevice = Eigen::ThreadPoolDevice;
struct UpdateVariableAndFill_Philox_Arg {
@ -87,8 +93,6 @@ struct UpdateVariableAndFill_Philox_Arg {
using GPUDevice = Eigen::GpuDevice;
namespace functor {
// Declares the partially GPU-specialized functor structs.
// must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
template <typename Distribution>
@ -100,12 +104,9 @@ struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
template <>
struct RngSkip_Philox<GPUDevice> {
void operator()(const GPUDevice& device, const StateElementType* in_data,
uint64 delta, StateElementType* out_data);
void operator()(const GPUDevice& device, int64 delta, Tensor* state_tensor);
};
} // end namespace functor
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // end namespace tensorflow

View File

@ -31,8 +31,6 @@ __device__ int tensorflow_philox_thread_counter;
namespace tensorflow {
namespace functor {
using random::PhiloxRandom;
template <typename Distribution>
@ -50,8 +48,7 @@ __global__ void FillKernel(
__syncthreads();
functor::FillPhiloxRandomKernel<Distribution,
Distribution::kVariableSamplesPerOutput>()
.Run(/*key=*/nullptr, /*counter=*/nullptr, *philox, output_data,
output_size, dist);
.Run(*philox, output_data, output_size, dist);
// The last thread updates the state.
auto total_thread_count = gridDim.x * blockDim.x;
auto old_counter_value = atomicAdd(&tensorflow_philox_thread_counter, 1);
@ -99,19 +96,16 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
}
// Precondition: there is only 1 block and 1 thread.
__global__ void SkipKernel(const StateElementType* __restrict__ in_data,
uint64 delta,
StateElementType* __restrict__ out_data) {
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
__global__ void SkipKernel(int64 delta,
StateElementType* __restrict__ state_data) {
auto philox = GetPhiloxRandomFromMem(state_data);
UpdateMemWithPhiloxRandom(philox, delta, state_data);
}
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d,
const StateElementType* in_data,
uint64 delta,
StateElementType* out_data) {
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), in_data, delta,
out_data));
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d, int64 delta,
Tensor* state_tensor) {
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), delta,
state_tensor->flat<StateElementType>().data()));
}
// Explicit instantiation of the GPU distributions functors.
@ -160,7 +154,6 @@ template struct UpdateVariableAndFill_Philox<
random::PhiloxRandom, uint64> >;
// clang-format on
} // end namespace functor
} // end namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -121,8 +121,8 @@ class StatelessRandomOp : public StatelessRandomOpBase {
auto flat = output->flat<T>();
// Reuse the compute kernels from the stateful random ops
functor::FillPhiloxRandom<Device, Distribution>()(
context, context->eigen_device<Device>(), /*key=*/nullptr,
/*counter=*/nullptr, random, flat.data(), flat.size(), Distribution());
context, context->eigen_device<Device>(), random, flat.data(),
flat.size(), Distribution());
}
};
@ -158,8 +158,8 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
auto flat = output->flat<IntType>();
// Reuse the compute kernels from the stateful random ops
functor::FillPhiloxRandom<Device, Distribution>()(
context, context->eigen_device<Device>(), /*key=*/nullptr,
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
context, context->eigen_device<Device>(), random, flat.data(),
flat.size(), dist);
}
};
@ -178,8 +178,8 @@ class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
auto flat = output->flat<IntType>();
// Reuse the compute kernels from the stateful random ops
functor::FillPhiloxRandom<Device, Distribution>()(
context, context->eigen_device<Device>(), /*key=*/nullptr,
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
context, context->eigen_device<Device>(), random, flat.data(),
flat.size(), dist);
}
};

View File

@ -1,330 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/rng_alg.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/kernels/random_ops_util.h"
#include "tensorflow/core/kernels/random_poisson_op.h"
#include "tensorflow/core/kernels/stateless_random_ops.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
#define DISABLE_FLOAT_EQUALITY_WARNING \
_Pragma("GCC diagnostic push") \
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
#else
#define DISABLE_FLOAT_EQUALITY_WARNING
#define ENABLE_FLOAT_EQUALITY_WARNING
#endif
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
namespace {
template <typename T>
Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
auto dtype = DataTypeToEnum<T>::v();
if (tensor.dims() != 0) {
return errors::InvalidArgument("input ", std::to_string(input_idx),
" (0-based) must have shape [], not ",
tensor.shape().DebugString());
}
if (tensor.dtype() != dtype) {
return errors::InvalidArgument("dtype of input ", std::to_string(input_idx),
" (0-based) must be ", DataTypeString(dtype),
", not ", DataTypeString(tensor.dtype()));
}
*result = tensor.flat<T>()(0);
return Status::OK();
}
class StatelessRandomOpBase : public OpKernel {
public:
explicit StatelessRandomOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// Sanitize input
const Tensor& shape_t = ctx->input(0);
const Tensor& key_t = ctx->input(1);
const Tensor& counter_t = ctx->input(2);
const int alg_input_idx = 3;
const Tensor& alg_t = ctx->input(alg_input_idx);
int alg_id;
OP_REQUIRES_OK(ctx, GetScalar(alg_t, alg_input_idx, &alg_id));
Algorithm alg = Algorithm(alg_id);
TensorShape shape;
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
OP_REQUIRES_OK(ctx,
CheckKeyCounterShape(alg, key_t.shape(), counter_t.shape()));
// Allocate output
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
if (shape.num_elements() == 0) {
return;
}
// Fill in the random numbers
Fill(ctx, alg, key_t, counter_t, output);
}
// The part of Compute that depends on device, type, and distribution
virtual void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
const Tensor& counter, Tensor* output) = 0;
};
template <typename Device, typename Distribution>
class StatelessRandomOp : public StatelessRandomOpBase {
public:
using StatelessRandomOpBase::StatelessRandomOpBase;
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
const Tensor& counter, Tensor* output) override {
typedef typename Distribution::ResultElementType T;
auto flat = output->flat<T>();
if (alg == RNG_ALG_PHILOX) {
// Reuse the compute kernels from the stateful random ops
auto key_data = key.flat<uint64>().data();
auto counter_data = counter.flat<uint64>().data();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(),
Distribution());
} else {
OP_REQUIRES(ctx, false,
errors::InvalidArgument("Unsupported algorithm id: ", alg));
}
}
};
template <typename Device, typename IntType>
class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
public:
using StatelessRandomOpBase::StatelessRandomOpBase;
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
const Tensor& counter, Tensor* output) override {
const Tensor& minval = ctx->input(4);
const Tensor& maxval = ctx->input(5);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
errors::InvalidArgument("minval must be 0-D, got shape ",
minval.shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
// Verify that minval < maxval. Note that we'll never reach this point for
// empty output. Zero impossible things are fine.
const auto lo = minval.scalar<IntType>()();
const auto hi = maxval.scalar<IntType>()();
OP_REQUIRES(
ctx, lo < hi,
errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
// Build distribution
typedef random::UniformDistribution<random::PhiloxRandom, IntType>
Distribution;
Distribution dist(lo, hi);
auto flat = output->flat<IntType>();
if (alg == RNG_ALG_PHILOX) {
// Reuse the compute kernels from the stateful random ops
auto key_data = key.flat<uint64>().data();
auto counter_data = counter.flat<uint64>().data();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
} else {
OP_REQUIRES(ctx, false,
errors::InvalidArgument("Unsupported algorithm id: ", alg));
}
}
};
template <typename Device, typename IntType>
class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
public:
using StatelessRandomOpBase::StatelessRandomOpBase;
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
const Tensor& counter, Tensor* output) override {
// Build distribution
typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
Distribution;
Distribution dist;
auto flat = output->flat<IntType>();
if (alg == RNG_ALG_PHILOX) {
// Reuse the compute kernels from the stateful random ops
auto key_data = key.flat<uint64>().data();
auto counter_data = counter.flat<uint64>().data();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
} else {
OP_REQUIRES(ctx, false,
errors::InvalidArgument("Unsupported algorithm id: ", alg));
}
}
};
class GetKeyCounterAlgOp : public OpKernel {
public:
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& seed_t = ctx->input(0);
OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_t.shape().DebugString()));
// Allocate outputs
Tensor* key_output;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
Tensor* counter_output;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
&counter_output));
Tensor* alg_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &alg_output));
random::PhiloxRandom::Key key;
random::PhiloxRandom::ResultType counter;
OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
WriteKeyToMem(key, key_output->flat<uint64>().data());
WriteCounterToMem(counter, counter_output->flat<uint64>().data());
alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
}
};
#define REGISTER(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniformV2") \
.Device(DEVICE_##DEVICE) \
.HostMemory("shape") \
.HostMemory("alg") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
random::PhiloxRandom, TYPE> >); \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomNormalV2") \
.Device(DEVICE_##DEVICE) \
.HostMemory("shape") \
.HostMemory("alg") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
random::PhiloxRandom, TYPE> >); \
REGISTER_KERNEL_BUILDER( \
Name("StatelessTruncatedNormalV2") \
.Device(DEVICE_##DEVICE) \
.HostMemory("shape") \
.HostMemory("alg") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomOp< \
DEVICE##Device, \
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
#define REGISTER_FULL_INT(DEVICE, TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("StatelessRandomUniformFullIntV2") \
.Device(DEVICE_##DEVICE) \
.HostMemory("shape") \
.HostMemory("alg") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
#define REGISTER_INT(DEVICE, TYPE) \
REGISTER_FULL_INT(DEVICE, TYPE); \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformIntV2") \
.Device(DEVICE_##DEVICE) \
.HostMemory("shape") \
.HostMemory("alg") \
.HostMemory("minval") \
.HostMemory("maxval") \
.TypeConstraint<TYPE>("dtype"), \
StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
TF_CALL_half(REGISTER_CPU);
TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_int32(REGISTER_INT_CPU);
TF_CALL_int64(REGISTER_INT_CPU);
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
#define REGISTER_GET_KCA(DEVICE) \
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
.Device(DEVICE_##DEVICE) \
.HostMemory("seed") \
.HostMemory("key") \
.HostMemory("counter") \
.HostMemory("alg"), \
GetKeyCounterAlgOp)
REGISTER_GET_KCA(CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_half(REGISTER_GPU);
TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
TF_CALL_int32(REGISTER_INT_GPU);
TF_CALL_int64(REGISTER_INT_GPU);
TF_CALL_uint32(REGISTER_FULL_INT_GPU);
TF_CALL_uint64(REGISTER_FULL_INT_GPU);
REGISTER_GET_KCA(GPU);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER
#undef REGISTER_INT
#undef REGISTER_CPU
#undef REGISTER_GPU
#undef REGISTER_INT_CPU
#undef REGISTER_INT_GPU
#undef REGISTER_FULL_INT_CPU
#undef REGISTER_FULL_INT_GPU
#undef REGISTER_GET_KCA
} // namespace
} // namespace tensorflow

View File

@ -1,46 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
#include "tensorflow/core/framework/rng_alg.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
inline Status CheckKeyCounterShape(Algorithm const& alg,
TensorShape const& key_shape,
TensorShape const& counter_shape) {
if (!(key_shape.dims() == 1 && key_shape.dim_size(0) == RNG_KEY_SIZE)) {
return errors::InvalidArgument(
"key must have shape [", RNG_KEY_SIZE, "], not ",
key_shape.DebugString(),
". (Note that batched keys are not supported yet.)");
}
auto counter_size = GetCounterSize(alg);
if (!(counter_shape.dims() == 1 &&
counter_shape.dim_size(0) >= counter_size)) {
return errors::InvalidArgument(
"counter must be a vector with length at least ", counter_size,
"; got shape: ", counter_shape.DebugString(),
". (Note that batched counters are not supported yet.)");
}
return Status::OK();
}
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/rng_alg.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
@ -91,19 +90,6 @@ REGISTER_OP("RngSkip")
return Status::OK();
});
REGISTER_OP("RngReadAndSkip")
.Input("resource: resource")
.Input("alg: int32")
.Input("delta: uint64")
.Output("value: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
c->set_output(0, c->MakeShape({RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE}));
return Status::OK();
});
REGISTER_OP("NonDeterministicInts")
.Input("shape: shape_dtype")
.SetIsStateful()

View File

@ -1,119 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/rng_alg.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
static Status StatelessShapeV2(InferenceContext* c) {
// Check key and counter shapes
ShapeHandle key;
ShapeHandle counter;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &key));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &counter));
shape_inference::ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), RNG_KEY_SIZE, &unused));
// Set output shape
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
c->set_output(0, out);
return Status::OK();
}
#define REGISTER_STATELESS_OP(name) \
REGISTER_OP(name) \
.Input("shape: Tshape") \
.Input("key: uint64") \
.Input("counter: uint64") \
.Input("alg: int32") \
.Output("output: dtype") \
.Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
.Attr("Tshape: {int32, int64} = DT_INT32") \
.SetShapeFn(StatelessShapeV2)
REGISTER_STATELESS_OP("StatelessRandomUniformV2");
REGISTER_STATELESS_OP("StatelessRandomNormalV2");
REGISTER_STATELESS_OP("StatelessTruncatedNormalV2");
#undef REGISTER_STATELESS_OP
REGISTER_OP("StatelessRandomUniformIntV2")
.Input("shape: Tshape")
.Input("key: uint64")
.Input("counter: uint64")
.Input("alg: int32")
.Input("minval: dtype")
.Input("maxval: dtype")
.Output("output: dtype")
.Attr("dtype: {int32, int64, uint32, uint64}")
.Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
Status s = c->WithRank(c->input(4), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"minval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(4)));
}
s = c->WithRank(c->input(5), 0, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"maxval must be a scalar; got a tensor of shape ",
c->DebugString(c->input(5)));
}
return StatelessShapeV2(c);
});
REGISTER_OP("StatelessRandomUniformFullIntV2")
.Input("shape: Tshape")
.Input("key: uint64")
.Input("counter: uint64")
.Input("alg: int32")
.Output("output: dtype")
.Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
.Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn(StatelessShapeV2);
REGISTER_OP("StatelessRandomGetKeyCounterAlg")
.Input("seed: Tseed")
.Output("key: uint64")
.Output("counter: uint64")
.Output("alg: int32")
.Attr("Tseed: {int32, int64} = DT_INT64")
.SetIsStateful() // because outputs depend on device
.SetShapeFn([](InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shapes
c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
c->set_output(2, c->MakeShape({}));
return Status::OK();
});
} // namespace tensorflow

View File

@ -3183,10 +3183,6 @@ tf_gen_op_wrapper_private_py(
],
)
tf_gen_op_wrapper_private_py(
name = "stateless_random_ops_v2_gen",
)
tf_gen_op_wrapper_private_py(
name = "list_ops_gen",
)
@ -4444,7 +4440,6 @@ py_library(
":framework_ops",
":math_ops",
":stateful_random_ops_gen",
":stateless_random_ops_v2_gen",
":variables",
"//third_party/py/numpy",
],
@ -4479,7 +4474,6 @@ py_library(
":math_ops",
":random_ops",
":stateless_random_ops_gen",
":stateless_random_ops_v2_gen",
],
)

View File

@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 357> a = {{
static std::array<OpIndexInfo, 352> a = {{
{"Acosh"},
{"AllToAll", 1, {0}},
{"ApproximateEqual"},
@ -332,16 +332,11 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
{"StatelessRandomBinomial"},
{"StatelessRandomGammaV2", 1, {1}},
{"StatelessRandomNormal"},
{"StatelessRandomNormalV2"},
{"StatelessRandomPoisson"},
{"StatelessRandomUniform"},
{"StatelessRandomUniformFullInt"},
{"StatelessRandomUniformFullIntV2"},
{"StatelessRandomUniformInt"},
{"StatelessRandomUniformIntV2"},
{"StatelessRandomUniformV2"},
{"StatelessTruncatedNormal"},
{"StatelessTruncatedNormalV2"},
{"StopGradient"},
{"StridedSliceGrad", 2, {0, 4}},
{"StringSplit"},
@ -420,7 +415,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 473> a = {{
static std::array<OpIndexInfo, 468> a = {{
{"Abs"},
{"AccumulateNV2"},
{"Acos"},
@ -798,16 +793,11 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
{"StatelessMultinomial"},
{"StatelessRandomBinomial"},
{"StatelessRandomNormal"},
{"StatelessRandomNormalV2"},
{"StatelessRandomPoisson"},
{"StatelessRandomUniform"},
{"StatelessRandomUniformFullInt"},
{"StatelessRandomUniformFullIntV2"},
{"StatelessRandomUniformInt"},
{"StatelessRandomUniformIntV2"},
{"StatelessRandomUniformV2"},
{"StatelessTruncatedNormal"},
{"StatelessTruncatedNormalV2"},
{"StopGradient"},
{"StridedSlice"},
{"StridedSliceGrad"},

View File

@ -30,7 +30,6 @@ from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import image_ops_impl as image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@ -1238,14 +1237,11 @@ class RandomHeightTest(keras_parameterized.TestCase):
mock_factor = 0
with test.mock.patch.object(
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
with test.mock.patch.object(
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
return_value=mock_factor):
with testing_utils.use_gpu():
img = np.random.random((12, 5, 8, 3))
layer = image_preprocessing.RandomHeight(.4)
img_out = layer(img, training=True)
self.assertEqual(img_out.shape[1], 3)
with testing_utils.use_gpu():
img = np.random.random((12, 5, 8, 3))
layer = image_preprocessing.RandomHeight(.4)
img_out = layer(img, training=True)
self.assertEqual(img_out.shape[1], 3)
def test_random_height_longer_numeric(self):
for dtype in (np.int64, np.float32):
@ -1332,14 +1328,11 @@ class RandomWidthTest(keras_parameterized.TestCase):
mock_factor = 0
with test.mock.patch.object(
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
with test.mock.patch.object(
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
return_value=mock_factor):
with testing_utils.use_gpu():
img = np.random.random((12, 8, 5, 3))
layer = image_preprocessing.RandomWidth(.4)
img_out = layer(img, training=True)
self.assertEqual(img_out.shape[2], 3)
with testing_utils.use_gpu():
img = np.random.random((12, 8, 5, 3))
layer = image_preprocessing.RandomWidth(.4)
img_out = layer(img, training=True)
self.assertEqual(img_out.shape[2], 3)
def test_random_width_longer_numeric(self):
for dtype in (np.int64, np.float32):

View File

@ -22,13 +22,10 @@ import functools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@ -38,31 +35,6 @@ from tensorflow.python.ops import stateless_random_ops as stateless
from tensorflow.python.platform import test
# Note that in theory each test will reset the eager context and may choose to
# hide some devices, so we shouldn't cache this transient info. Tests in this
# file don't make those config changes, so caching is fine. It provides a good
# speed-up.
_cached_device = None
def get_device():
global _cached_device
if _cached_device is not None:
return _cached_device
# Precedence from high to low
for device_type in ('XLA_GPU', 'GPU', 'XLA_CPU', 'CPU'):
devices = config.list_logical_devices(device_type)
if devices:
_cached_device = devices[0]
return _cached_device
raise ValueError('Cannot find any suitable device. Available devices: %s' %
config.list_logical_devices())
BEFORE_EXPIRE = (2020, 10, 24)
AFTER_EXPIRE = (2020, 10, 26)
def invert_philox(key, value):
"""Invert the Philox bijection."""
key = np.array(key, dtype=np.uint32)
@ -87,71 +59,47 @@ SEED_TYPES = [dtypes.int32, dtypes.int64]
def float_cases(shape_dtypes=(None,)):
cases = (
# Uniform distribution, with and without range
('uniform', stateless.stateless_random_uniform, random_ops.random_uniform,
{}),
('uniform2', stateless.stateless_random_uniform,
random_ops.random_uniform, dict(minval=2.2, maxval=7.1)),
(stateless.stateless_random_uniform, random_ops.random_uniform, {}),
(stateless.stateless_random_uniform, random_ops.random_uniform,
dict(minval=2.2, maxval=7.1)),
# Normal distribution, with and without mean+stddev
('normal', stateless.stateless_random_normal, random_ops.random_normal,
{}),
('normal2', stateless.stateless_random_normal, random_ops.random_normal,
(stateless.stateless_random_normal, random_ops.random_normal, {}),
(stateless.stateless_random_normal, random_ops.random_normal,
dict(mean=2, stddev=3)),
# Truncated normal distribution, with and without mean+stddev
('trnorm', stateless.stateless_truncated_normal,
random_ops.truncated_normal, {}),
('trnorm2', stateless.stateless_truncated_normal,
random_ops.truncated_normal, dict(mean=3, stddev=4)),
(stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
(stateless.stateless_truncated_normal, random_ops.truncated_normal,
dict(mean=3, stddev=4)),
)
# Explicitly passing in params because capturing cell variable from loop is
# problematic in Python
def wrap(op, dtype, shape, shape_dtype, kwds, seed):
device_type = get_device().device_type
# Some dtypes are not supported on some devices
if (dtype == dtypes.float16 and device_type in ('XLA_GPU', 'XLA_CPU') or
dtype == dtypes.bfloat16 and device_type == 'GPU'):
dtype = dtypes.float32
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
if shape_dtype is not None else shape)
return op(seed=seed, shape=shape_, dtype=dtype, **kwds)
def _name(a):
if hasattr(a, 'name'):
return a.name
else:
return a
for dtype in dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64:
for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
for shape_dtype in shape_dtypes:
for shape in (), (3,), (2, 5):
for name, stateless_op, stateful_op, kwds in cases:
yield (('%s_%s_%s_%s' %
(name, _name(dtype), shape, _name(shape_dtype))).replace(
' ', ''),
functools.partial(wrap, stateless_op, dtype, shape,
for stateless_op, stateful_op, kwds in cases:
yield (functools.partial(wrap, stateless_op, dtype, shape,
shape_dtype, kwds),
functools.partial(wrap, stateful_op, dtype, shape, shape_dtype,
kwds))
functools.partial(wrap, stateful_op, dtype, shape,
shape_dtype, kwds))
def int_cases(shape_dtypes=(None,), minval_maxval=None):
def wrap(op, minval, maxval, shape, shape_dtype, dtype, seed):
def int_cases(shape_dtypes=(None,)):
def wrap(op, shape, shape_dtype, dtype, seed):
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
if shape_dtype is not None else shape)
return op(
seed=seed, shape=shape_, minval=minval, maxval=maxval, dtype=dtype)
if minval_maxval is None:
minval_maxval = ((2, 11111),)
for minval, maxval in minval_maxval:
for shape_dtype in shape_dtypes:
for shape in (), (3,), (2, 5):
for dtype in dtypes.int32, dtypes.int64:
yield ('uniform_%s_%s' % (minval, maxval),
functools.partial(wrap, stateless.stateless_random_uniform,
minval, maxval, shape, shape_dtype, dtype),
functools.partial(wrap, random_ops.random_uniform, minval,
maxval, shape, shape_dtype, dtype))
return op(seed=seed, shape=shape_, minval=2, maxval=11111,
dtype=dtype)
for shape_dtype in shape_dtypes:
for shape in (), (3,), (2, 5):
for dtype in dtypes.int32, dtypes.int64:
yield (functools.partial(wrap, stateless.stateless_random_uniform,
shape, shape_dtype, dtype),
functools.partial(wrap, random_ops.random_uniform,
shape, shape_dtype, dtype))
def multinomial_cases():
@ -164,8 +112,7 @@ def multinomial_cases():
for output_dtype in dtypes.int32, dtypes.int64:
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
[0.25, 0.75]]):
yield ('multinomial',
functools.partial(wrap, stateless.stateless_multinomial, logits,
yield (functools.partial(wrap, stateless.stateless_multinomial, logits,
logits_dtype, output_dtype),
functools.partial(wrap, random_ops.multinomial, logits,
logits_dtype, output_dtype))
@ -177,11 +124,10 @@ def gamma_cases():
alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
for dtype in np.float16, np.float32, np.float64:
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
yield ('gamma',
functools.partial(wrap, stateless.stateless_random_gamma, alpha,
yield (functools.partial(wrap, stateless.stateless_random_gamma, alpha,
dtype, (10,) + tuple(np.shape(alpha))),
functools.partial(wrap, random_ops.random_gamma, alpha, dtype,
(10,)))
functools.partial(wrap, random_ops.random_gamma, alpha,
dtype, (10,)))
def poisson_cases():
@ -192,8 +138,7 @@ def poisson_cases():
for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
yield ('poisson',
functools.partial(wrap, stateless.stateless_random_poisson, lam,
yield (functools.partial(wrap, stateless.stateless_random_poisson, lam,
lam_dtype, out_dtype,
(10,) + tuple(np.shape(lam))),
functools.partial(wrap, random_ops.random_poisson, lam,
@ -208,28 +153,22 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
key = 0x3ec8f720, 0x02461e29
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
preseed = preseed[::2] | preseed[1::2] << 32
with ops.device(get_device().name):
_, stateless_op, stateful_op = case
random_seed.set_random_seed(seed[0])
random_seed.set_random_seed(seed[0])
with test_util.use_gpu():
stateless_op, stateful_op = case
if context.executing_eagerly():
# Call set_random_seed in order to clear kernel cache, to prevent
# kernel reusing for the stateful op
random_seed.set_random_seed(seed[0])
stateful = stateful_op(seed=seed[1])
pure = stateless_op(seed=preseed)
self.assertAllEqual(stateful, pure)
def _test_old_and_new_stateless_match(self, case, seed):
"""Tests that the new stateless ops match the old stateless ones."""
with ops.device(get_device().name):
_, stateless_op, _ = case
with compat.forward_compatibility_horizon(*BEFORE_EXPIRE):
old = stateless_op(seed=seed)
with compat.forward_compatibility_horizon(*AFTER_EXPIRE):
new = stateless_op(seed=seed)
self.assertAllClose(old, new)
def _test_determinism(self, case, seed_type):
# Stateless values should be equal iff the seeds are equal (roughly)
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
with self.test_session(use_gpu=True), ops.device(get_device().name):
_, stateless_op, _ = case
with self.test_session(use_gpu=True), test_util.use_gpu():
stateless_op, _ = case
if context.executing_eagerly():
values = [
(seed, stateless_op(seed=constant_op.constant(seed, seed_type)))
@ -244,172 +183,88 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
]
for s0, v0 in values:
for s1, v1 in values:
if dtypes.as_dtype(v0.dtype) != dtypes.bfloat16:
self.assertEqual(s0 == s1, np.all(v0 == v1))
elif s0 == s1:
# Skip the s0 != s1 case because v0 and v1 can be either equal or
# unequal in that case due to bfloat16's low precision
self.assertAllEqual(v0, v1)
self.assertEqual(s0 == s1, np.all(v0 == v1))
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(float_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testMatchFloat(self, case, seed):
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
'seeds needed by this test.')
self._test_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(int_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testMatchInt(self, case, seed):
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
'seeds needed by this test.')
self._test_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(multinomial_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testMatchMultinomial(self, case, seed):
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking XLA kernel')
self._test_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(gamma_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testMatchGamma(self, case, seed):
if get_device().device_type == 'GPU':
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking GPU kernel')
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking XLA kernel')
self._test_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(poisson_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testMatchPoisson(self, case, seed):
if get_device().device_type == 'GPU':
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking GPU kernel')
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking XLA kernel')
self._test_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(float_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testOldAndNewStatelessMatchFloat(self, case, seed):
self._test_old_and_new_stateless_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
for seed_id, seed in enumerate(SEEDS)
for case_id, case in enumerate(
int_cases(minval_maxval=((2, 11111), (None, None)))))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testOldAndNewStatelessMatchInt(self, case, seed):
self._test_old_and_new_stateless_match(case, seed)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
for seed_type in SEED_TYPES
for case_id, case in enumerate(
float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
for type_id, seed_type in enumerate(SEED_TYPES)
for case_id, case in enumerate(float_cases(
shape_dtypes=(dtypes.int32, dtypes.int64))))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testDeterminismFloat(self, case, seed_type):
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest(
'Skip on XLA because XLA kernels do not support int64 seeds.')
self._test_determinism(case, seed_type)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
for seed_type in SEED_TYPES
for case_id, case in enumerate(
int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
for type_id, seed_type in enumerate(SEED_TYPES)
for case_id, case in enumerate(int_cases(
shape_dtypes=(dtypes.int32, dtypes.int64))))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testDeterminismInt(self, case, seed_type):
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest(
'Skip on XLA because XLA kernels do not support int64 seeds.')
self._test_determinism(case, seed_type)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
for seed_type in SEED_TYPES
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
for type_id, seed_type in enumerate(SEED_TYPES)
for case_id, case in enumerate(multinomial_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testDeterminismMultinomial(self, case, seed_type):
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking XLA kernel')
self._test_determinism(case, seed_type)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
for seed_type in SEED_TYPES
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
for type_id, seed_type in enumerate(SEED_TYPES)
for case_id, case in enumerate(gamma_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testDeterminismGamma(self, case, seed_type):
if get_device().device_type == 'GPU':
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking GPU kernel')
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking XLA kernel')
self._test_determinism(case, seed_type)
@parameterized.named_parameters(
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
for seed_type in SEED_TYPES
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
for type_id, seed_type in enumerate(SEED_TYPES)
for case_id, case in enumerate(poisson_cases()))
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
def testDeterminismPoisson(self, case, seed_type):
if get_device().device_type == 'GPU':
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking GPU kernel')
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
# This test was passing before because soft placement silently picked the
# CPU kernels.
self.skipTest('Lacking XLA kernel')
self._test_determinism(case, seed_type)
def assertDTypeEqual(self, a, b):
@ -472,6 +327,4 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__':
config.set_soft_device_placement(False)
context.context().enable_xla_devices()
test.main()

View File

@ -23,7 +23,6 @@ import enum # pylint: disable=g-bad-import-order
import numpy as np
import six
from tensorflow.python.compat import compat
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
@ -33,14 +32,12 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util.tf_export import tf_export
# A seed for random ops (stateful and stateless) will always be 1024
# bits, all of which will be sent to the C++ code. The actual C++
# implementation of some algorithms may only use a lower part of the bits.
@ -139,15 +136,6 @@ def _make_1d_state(state_size, seed):
return seed
def _get_counter_size(alg):
if alg == RNG_ALG_PHILOX:
return 2
elif alg == RNG_ALG_THREEFRY:
return 1
else:
raise ValueError("Unsupported algorithm id: %s" % alg)
def _get_state_size(alg):
if alg == RNG_ALG_PHILOX:
return PHILOX_STATE_SIZE
@ -572,10 +560,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
return self._alg
def _standard_normal(self, shape, dtype):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
return gen_stateful_random_ops.stateful_standard_normal_v2(
self.state.handle, self.algorithm, shape, dtype=dtype)
@ -602,8 +586,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
else:
raise ValueError("Unsupported algorithm id: %s" % alg)
# TODO(wangpeng): Add "Returns" section to docstring once new version kicks in
# pylint: disable=g-doc-return-or-yield
def skip(self, delta):
"""Advance the counter of a counter-based RNG.
@ -613,24 +595,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
(or any other distribution). The actual increment added to the
counter is an unspecified implementation detail.
"""
if compat.forward_compatible(2020, 10, 25):
return gen_stateful_random_ops.rng_read_and_skip(
self.state.handle,
alg=math_ops.cast(self.algorithm, dtypes.int32),
delta=math_ops.cast(delta, dtypes.uint64))
gen_stateful_random_ops.rng_skip(
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
math_ops.cast(delta, dtypes.int64))
# pylint: enable=g-doc-return-or-yield
def _prepare_key_counter(self, shape):
delta = math_ops.reduce_prod(shape)
counter_key = self.skip(delta)
counter_size = _get_counter_size(self.algorithm)
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
dtypes.uint64)
return key, counter
gen_stateful_random_ops.rng_skip(self.state.handle, self.algorithm, delta)
# The following functions return a tensor and as a side effect update
# self._state_var.
@ -659,14 +624,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
return math_ops.add(rnd * stddev, mean, name=name)
def _truncated_normal(self, shape, dtype):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm)
return gen_stateful_random_ops.stateful_truncated_normal(
self.state.handle, self.algorithm, shape, dtype=dtype)
@ -705,27 +662,10 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
return math_ops.add(mul, mean_tensor, name=name)
def _uniform(self, shape, dtype):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm)
return gen_stateful_random_ops.stateful_uniform(
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
def _uniform_full_int(self, shape, dtype, name=None):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm,
name=name)
return gen_stateful_random_ops.stateful_uniform_full_int(
self.state.handle, self.algorithm, shape=shape,
dtype=dtype, name=name)
@ -789,16 +729,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
shape=shape,
key=key,
counter=counter,
minval=minval,
maxval=maxval,
alg=self.algorithm,
name=name)
return gen_stateful_random_ops.stateful_uniform_int(
self.state.handle, self.algorithm, shape=shape,
minval=minval, maxval=maxval, name=name)

View File

@ -274,7 +274,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
with self.cached_session() as sess:
gen1 = random.Generator.from_seed(seed)
gen2 = random.Generator.from_non_deterministic_state()
sess.run((gen1.state.initializer, gen2.state.initializer))
sess.run((gen1._state_var.initializer, gen2._state_var.initializer))
r1 = gen1.normal(shape, dtype=dtypes.float32)
r2 = gen2.normal(shape, dtype=dtypes.float32)
def f():
@ -372,7 +372,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
delta = 432
gen.skip(delta)
new_counter = gen.state[0]
new_counter = gen._state_var[0]
self.assertAllEqual(counter + delta * 256, new_counter)
def _sameAsOldRandomOps(self, device, floats):
@ -394,7 +394,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
with ops.device(device):
return new(dtype, gen)
for _ in range(5):
for _ in range(100):
self.assertAllEqual(run_old(), run_new())
shape = constant_op.constant([4, 7])
@ -582,11 +582,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
@test_util.run_v2_only
def testGetGlobalGeneratorWithXla(self):
"""Demonstrates using the global generator with XLA."""
# This test was passing before because soft placement silently picked the
# CPU kernel.
# TODO(wangpeng): Remove this skip
self.skipTest("NonDeterministicInts lacks XLA kernel.")
if not config.list_physical_devices("XLA_CPU"):
self.skipTest("No XLA_CPU device available.")
@ -680,16 +675,17 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
@test_util.run_v2_only
def testCreateOutsideMirroredStrat(self):
@test_util.run_cuda_only
def testMirroredStratSeq(self):
"""Tests RNG/MirrorStrategy interaction #1.
If an RNG is created outside a DS scope, all replicas will access the
If an RNG is created outside strategy.scope(), all replicas will access the
same RNG object, and accesses are serialized.
"""
shape = [3, 4]
dtype = dtypes.int32
gen = random.Generator.from_seed(1234)
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
with strat.scope():
def f():
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
@ -766,5 +762,4 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
if __name__ == "__main__":
config.set_soft_device_placement(False)
test.main()

View File

@ -20,19 +20,16 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateless_random_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("StatelessMultinomial")
ops.NotDifferentiable("StatelessRandomBinomial")
ops.NotDifferentiable("StatelessRandomNormal")
@ -43,13 +40,6 @@ ops.NotDifferentiable("StatelessRandomUniformFullInt")
ops.NotDifferentiable("StatelessTruncatedNormal")
ops.NotDifferentiable("StatelessRandomNormalV2")
ops.NotDifferentiable("StatelessRandomUniformV2")
ops.NotDifferentiable("StatelessRandomUniformIntV2")
ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
ops.NotDifferentiable("StatelessTruncatedNormalV2")
@tf_export("random.experimental.stateless_split")
@dispatch.add_dispatch_support
def split(seed, num=2):
@ -123,10 +113,6 @@ def fold_in(seed, data):
return array_ops.stack([seed1, data])
_get_key_counter_alg = (gen_stateless_random_ops_v2
.stateless_random_get_key_counter_alg)
@tf_export("random.stateless_uniform")
@dispatch.add_dispatch_support
def stateless_random_uniform(shape,
@ -206,35 +192,17 @@ def stateless_random_uniform(shape,
[shape, seed, minval, maxval]) as name:
shape = tensor_util.shape_tensor(shape)
if dtype.is_integer and minval is None:
if compat.forward_compatible(2020, 10, 25):
key, counter, alg = _get_key_counter_alg(seed)
result = (gen_stateless_random_ops_v2
.stateless_random_uniform_full_int_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg,
name=name))
else:
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
shape, seed=seed, dtype=dtype, name=name)
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
shape, seed=seed, dtype=dtype, name=name)
else:
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
if compat.forward_compatible(2020, 10, 25):
key, counter, alg = _get_key_counter_alg(seed)
result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
shape, key=key, counter=counter, minval=minval, maxval=maxval,
alg=alg, name=name)
else:
result = gen_stateless_random_ops.stateless_random_uniform_int(
shape, seed=seed, minval=minval, maxval=maxval, name=name)
result = gen_stateless_random_ops.stateless_random_uniform_int(
shape, seed=seed, minval=minval, maxval=maxval, name=name)
else:
if compat.forward_compatible(2020, 10, 25):
key, counter, alg = _get_key_counter_alg(seed)
rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg)
else:
rnd = gen_stateless_random_ops.stateless_random_uniform(
shape, seed=seed, dtype=dtype)
rnd = gen_stateless_random_ops.stateless_random_uniform(
shape, seed=seed, dtype=dtype)
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
tensor_util.maybe_set_static_shape(result, shape)
return result
@ -508,12 +476,7 @@ def stateless_random_normal(shape,
shape = tensor_util.shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
if compat.forward_compatible(2020, 10, 25):
key, counter, alg = _get_key_counter_alg(seed)
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg)
else:
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
result = math_ops.add(rnd * stddev, mean, name=name)
tensor_util.maybe_set_static_shape(result, shape)
return result
@ -558,13 +521,8 @@ def stateless_truncated_normal(shape,
shape = tensor_util.shape_tensor(shape)
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
if compat.forward_compatible(2020, 10, 25):
key, counter, alg = _get_key_counter_alg(seed)
rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=alg)
else:
rnd = gen_stateless_random_ops.stateless_truncated_normal(
shape, seed, dtype)
rnd = gen_stateless_random_ops.stateless_truncated_normal(
shape, seed, dtype)
result = math_ops.add(rnd * stddev, mean, name=name)
tensor_util.maybe_set_static_shape(result, shape)
return result

View File

@ -3792,10 +3792,6 @@ tf_module {
name: "Rint"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RngReadAndSkip"
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RngSkip"
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4548,18 +4544,10 @@ tf_module {
name: "StatelessRandomGammaV2"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetKeyCounterAlg"
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomNormalV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomPoisson"
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4572,22 +4560,10 @@ tf_module {
name: "StatelessRandomUniformFullInt"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformFullIntV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformInt"
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomUniformIntV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomUniformV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessSampleDistortedBoundingBox"
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
@ -4596,10 +4572,6 @@ tf_module {
name: "StatelessTruncatedNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessTruncatedNormalV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessWhile"
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "

View File

@ -3792,10 +3792,6 @@ tf_module {
name: "Rint"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RngReadAndSkip"
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "RngSkip"
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4548,18 +4544,10 @@ tf_module {
name: "StatelessRandomGammaV2"
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomGetKeyCounterAlg"
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomNormalV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomPoisson"
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -4572,22 +4560,10 @@ tf_module {
name: "StatelessRandomUniformFullInt"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformFullIntV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
}
member_method {
name: "StatelessRandomUniformInt"
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomUniformIntV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "StatelessRandomUniformV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessSampleDistortedBoundingBox"
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
@ -4596,10 +4572,6 @@ tf_module {
name: "StatelessTruncatedNormal"
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessTruncatedNormalV2"
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "StatelessWhile"
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "