Internal change
PiperOrigin-RevId: 332379487 Change-Id: Ie43ff74a010bcc9893f6c358b2e56b27ef67fcac
This commit is contained in:
parent
ea83ee9a82
commit
6d605dde0a
@ -1995,6 +1995,8 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"ResourceScatterNdUpdate",
|
"ResourceScatterNdUpdate",
|
||||||
"ResourceScatterSub",
|
"ResourceScatterSub",
|
||||||
"ResourceScatterUpdate",
|
"ResourceScatterUpdate",
|
||||||
|
"RngReadAndSkip",
|
||||||
|
"RngSkip",
|
||||||
"Roll",
|
"Roll",
|
||||||
"ScatterNd",
|
"ScatterNd",
|
||||||
"SelfAdjointEigV2",
|
"SelfAdjointEigV2",
|
||||||
@ -2017,11 +2019,17 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"StatelessCase",
|
"StatelessCase",
|
||||||
"StatelessIf",
|
"StatelessIf",
|
||||||
"StatelessMultinomial",
|
"StatelessMultinomial",
|
||||||
|
"StatelessRandomGetKeyCounterAlg",
|
||||||
"StatelessRandomNormal",
|
"StatelessRandomNormal",
|
||||||
|
"StatelessRandomNormalV2",
|
||||||
"StatelessRandomUniform",
|
"StatelessRandomUniform",
|
||||||
|
"StatelessRandomUniformV2",
|
||||||
"StatelessRandomUniformInt",
|
"StatelessRandomUniformInt",
|
||||||
|
"StatelessRandomUniformIntV2",
|
||||||
"StatelessRandomUniformFullInt",
|
"StatelessRandomUniformFullInt",
|
||||||
|
"StatelessRandomUniformFullIntV2",
|
||||||
"StatelessTruncatedNormal",
|
"StatelessTruncatedNormal",
|
||||||
|
"StatelessTruncatedNormalV2",
|
||||||
"StatelessWhile",
|
"StatelessWhile",
|
||||||
"Svd",
|
"Svd",
|
||||||
"SymbolicGradient",
|
"SymbolicGradient",
|
||||||
|
@ -25,7 +25,9 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.client import device_lib
|
from tensorflow.python.client import device_lib
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -156,6 +158,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
def testNewStateThreeFry(self):
|
def testNewStateThreeFry(self):
|
||||||
"""Tests that the new state is correct (for ThreeFry).
|
"""Tests that the new state is correct (for ThreeFry).
|
||||||
"""
|
"""
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
self.skipTest("The expected values in this test is inconsistent with "
|
||||||
|
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
|
||||||
|
"new states for the new version.")
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
counter = 57
|
counter = 57
|
||||||
key = 0x1234
|
key = 0x1234
|
||||||
@ -171,6 +177,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
def testNewStatePhilox(self):
|
def testNewStatePhilox(self):
|
||||||
"""Tests that the new state is correct (for Philox).
|
"""Tests that the new state is correct (for Philox).
|
||||||
"""
|
"""
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
self.skipTest("The expected values in this test is inconsistent with "
|
||||||
|
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
|
||||||
|
"new states for the new version.")
|
||||||
with ops.device(xla_device_name()):
|
with ops.device(xla_device_name()):
|
||||||
counter_low = 57
|
counter_low = 57
|
||||||
counter_high = 283
|
counter_high = 283
|
||||||
@ -204,13 +214,39 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
"""Tests that XLA and CPU kernels generate the same integers."""
|
"""Tests that XLA and CPU kernels generate the same integers."""
|
||||||
seed = 1234
|
seed = 1234
|
||||||
shape = [315, 49]
|
shape = [315, 49]
|
||||||
with ops.device("/device:CPU:0"):
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
with ops.device("/device:CPU:0"):
|
||||||
.uniform_full_int(shape=shape, dtype=dtype))
|
cpu_gen = random.Generator.from_seed(
|
||||||
with ops.device(xla_device_name()):
|
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||||
xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
with ops.device(xla_device_name()):
|
||||||
.uniform_full_int(shape=shape, dtype=dtype))
|
xla_gen = random.Generator.from_seed(
|
||||||
self.assertAllEqual(cpu, xla)
|
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||||
|
# Repeat multiple times to make sure that the state after
|
||||||
|
# number-generation are the same between CPU and XLA.
|
||||||
|
for _ in range(5):
|
||||||
|
with ops.device("/device:CPU:0"):
|
||||||
|
# Test both number-generation and skip
|
||||||
|
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
cpu_gen.skip(100)
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
xla_gen.skip(100)
|
||||||
|
self.assertAllEqual(cpu, xla)
|
||||||
|
self.assertAllEqual(cpu_gen.state, xla_gen.state)
|
||||||
|
else:
|
||||||
|
# The old version doesn't guarantee that CPU and XLA are in the same state
|
||||||
|
# after number-generation, which is a bug.
|
||||||
|
with ops.device("/device:CPU:0"):
|
||||||
|
cpu = (
|
||||||
|
random.Generator.from_seed(
|
||||||
|
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
|
||||||
|
shape=shape, dtype=dtype))
|
||||||
|
with ops.device(xla_device_name()):
|
||||||
|
xla = (
|
||||||
|
random.Generator.from_seed(
|
||||||
|
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
|
||||||
|
shape=shape, dtype=dtype))
|
||||||
|
self.assertAllEqual(cpu, xla)
|
||||||
|
|
||||||
def _testRngIsNotConstant(self, rng, dtype):
|
def _testRngIsNotConstant(self, rng, dtype):
|
||||||
# Tests that 'rng' does not always return the same value.
|
# Tests that 'rng' does not always return the same value.
|
||||||
@ -364,4 +400,5 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -21,7 +21,11 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.compiler.xla import xla
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.kernel_tests.random import util as \
|
from tensorflow.python.kernel_tests.random import util as \
|
||||||
random_test_util
|
random_test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -39,6 +43,26 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
|||||||
allowed_types.update({dtypes.int32, dtypes.int64})
|
allowed_types.update({dtypes.int32, dtypes.int64})
|
||||||
return self.all_tf_types & allowed_types
|
return self.all_tf_types & allowed_types
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testForcedCompile(self):
|
||||||
|
"""Tests whole-function forced-compilation.
|
||||||
|
|
||||||
|
This test checks that stateless_random_* can be used in forced-compilation
|
||||||
|
scenarios (e.g. TPU). The new version of stateless_random_* requires the
|
||||||
|
intermediate tensor `alg` to be compile-time constant, so we need to check
|
||||||
|
that this requirement is met. We use xla.compile instead of tf.function's
|
||||||
|
experimental_compile because the latter doesn't throw an error even if the
|
||||||
|
compile-time-constant constraint is not met.
|
||||||
|
"""
|
||||||
|
if config.list_logical_devices('TPU'):
|
||||||
|
self.skipTest('To accommodate OSS, xla.compile support for TPU is not '
|
||||||
|
'linked in.')
|
||||||
|
@def_function.function
|
||||||
|
def f(x):
|
||||||
|
return xla.compile(
|
||||||
|
lambda x: stateless.stateless_random_normal([], seed=x), [x])
|
||||||
|
f([1, 2])
|
||||||
|
|
||||||
def testDeterminism(self):
|
def testDeterminism(self):
|
||||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||||
with self.session(), self.test_scope():
|
with self.session(), self.test_scope():
|
||||||
@ -138,7 +162,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
|||||||
|
|
||||||
def _benchmarkUniform(self, name, dtype, use_xla_jit):
|
def _benchmarkUniform(self, name, dtype, use_xla_jit):
|
||||||
|
|
||||||
def BuilderFn():
|
def builder_fn():
|
||||||
shape = (10, 1000, 1000)
|
shape = (10, 1000, 1000)
|
||||||
seed_var = variables.Variable((312, 456),
|
seed_var = variables.Variable((312, 456),
|
||||||
dtype=dtypes.int32,
|
dtype=dtypes.int32,
|
||||||
@ -147,7 +171,7 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
|||||||
shape, seed=seed_var, dtype=dtype)
|
shape, seed=seed_var, dtype=dtype)
|
||||||
return '%s.shape%s' % (name, shape), [random_t]
|
return '%s.shape%s' % (name, shape), [random_t]
|
||||||
|
|
||||||
xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu')
|
xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu')
|
||||||
|
|
||||||
def benchmarkUniformF32(self):
|
def benchmarkUniformF32(self):
|
||||||
self._benchmarkUniform(
|
self._benchmarkUniform(
|
||||||
@ -167,4 +191,5 @@ class StatelessRandomOpsBenchmark(test.Benchmark):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -108,6 +108,7 @@ tf_kernel_library(
|
|||||||
"stack_ops.cc",
|
"stack_ops.cc",
|
||||||
"stateful_random_ops.cc",
|
"stateful_random_ops.cc",
|
||||||
"stateless_random_ops.cc",
|
"stateless_random_ops.cc",
|
||||||
|
"stateless_random_ops_v2.cc",
|
||||||
"strided_slice_op.cc",
|
"strided_slice_op.cc",
|
||||||
"tensor_array_ops.cc",
|
"tensor_array_ops.cc",
|
||||||
"tensor_list_ops.cc",
|
"tensor_list_ops.cc",
|
||||||
@ -187,6 +188,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/kernels:stateful_random_ops_header",
|
"//tensorflow/core/kernels:stateful_random_ops_header",
|
||||||
|
"//tensorflow/core/kernels:stateless_random_ops_v2_header",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/lib/math/math_util.h"
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
@ -180,7 +181,7 @@ Status CompileImpl(
|
|||||||
}
|
}
|
||||||
xla::Literal alg_literal;
|
xla::Literal alg_literal;
|
||||||
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||||
auto alg = alg_literal.Get<Algorithm>({});
|
Algorithm alg = Algorithm(alg_literal.Get<int>({}));
|
||||||
if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
|
if (!(alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX)) {
|
||||||
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
|
return errors::InvalidArgument("Unsupported algorithm id: ", alg);
|
||||||
}
|
}
|
||||||
@ -407,5 +408,80 @@ REGISTER_XLA_OP(Name("StatefulUniformFullInt")
|
|||||||
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
{DT_INT32, DT_UINT32, DT_INT64, DT_UINT64}),
|
||||||
StatefulUniformFullIntOp);
|
StatefulUniformFullIntOp);
|
||||||
|
|
||||||
|
xla::XlaOp IncreaseCounter(Algorithm const& alg, xla::XlaOp counter,
|
||||||
|
xla::XlaOp delta) {
|
||||||
|
// Multiplying 256 to be consistent with the CPU/GPU kernels
|
||||||
|
delta = delta * ConstantR0WithType(delta.builder(), xla::U64, 256);
|
||||||
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
|
return xla::PhiloxIncreaseCounter(counter, delta);
|
||||||
|
} else {
|
||||||
|
return counter + delta;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::XlaOp PadRight(xla::XlaOp a, int n) {
|
||||||
|
return xla::Pad(a, xla::ScalarLike(a, 0),
|
||||||
|
xla::MakeEdgePaddingConfig({{0, n}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AlgEnumType = int64, bool read_old_value = false>
|
||||||
|
class RngSkipOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit RngSkipOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
const int state_input_idx = 0;
|
||||||
|
const int alg_input_idx = 1;
|
||||||
|
const int delta_input_idx = 2;
|
||||||
|
xla::XlaOp var;
|
||||||
|
TensorShape var_shape;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->ReadVariableInput(state_input_idx, STATE_ELEMENT_DTYPE,
|
||||||
|
&var_shape, &var));
|
||||||
|
xla::Literal alg_literal;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||||
|
Algorithm alg = Algorithm(alg_literal.Get<AlgEnumType>({}));
|
||||||
|
OP_REQUIRES(ctx, alg == RNG_ALG_THREEFRY || alg == RNG_ALG_PHILOX,
|
||||||
|
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||||
|
OP_REQUIRES_OK(ctx, CheckStateShape(alg, var_shape));
|
||||||
|
if (read_old_value) {
|
||||||
|
auto counter_size = GetCounterSize(alg);
|
||||||
|
xla::XlaOp output = var;
|
||||||
|
if (RNG_MAX_COUNTER_SIZE > counter_size) {
|
||||||
|
// Because the size of `var` depends on the algorithm while we want the
|
||||||
|
// output to have a fixed size (to help shape inference), we fix the
|
||||||
|
// output size to be the maximal state size among algorithms, and right-
|
||||||
|
// pad it with zeros if var's size is smaller than that.
|
||||||
|
output = PadRight(output, RNG_MAX_COUNTER_SIZE - counter_size);
|
||||||
|
}
|
||||||
|
ctx->SetOutput(0, output);
|
||||||
|
}
|
||||||
|
xla::XlaOp counter;
|
||||||
|
xla::XlaOp key;
|
||||||
|
std::tie(counter, key) = StateAndKeyFromVariable(alg, var);
|
||||||
|
xla::XlaOp delta = ctx->Input(delta_input_idx);
|
||||||
|
delta = BitcastConvertType(delta, xla::U64);
|
||||||
|
auto new_counter = IncreaseCounter(alg, counter, delta);
|
||||||
|
var = StateAndKeyToVariable(alg, new_counter, key);
|
||||||
|
xla::PrimitiveType state_element_type;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||||
|
var = BitcastConvertType(var, state_element_type);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->AssignVariable(state_input_idx, STATE_ELEMENT_DTYPE, var));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(RngSkipOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"),
|
||||||
|
RngSkipOp<>);
|
||||||
|
|
||||||
|
using RngReadAndSkipOp = RngSkipOp<int32, true>;
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"),
|
||||||
|
RngReadAndSkipOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -111,6 +111,8 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
||||||
xla::XlaOp seeds,
|
xla::XlaOp seeds,
|
||||||
const xla::Shape& shape) {
|
const xla::Shape& shape) {
|
||||||
@ -140,8 +142,6 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
class StatelessRandomUniformOp : public XlaOpKernel {
|
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||||
|
485
tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc
Normal file
485
tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc
Normal file
@ -0,0 +1,485 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline xla::RandomAlgorithm AlgorithmToRandomAlgorithm(Algorithm const& alg) {
|
||||||
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
|
return xla::RandomAlgorithm::RNG_PHILOX;
|
||||||
|
}
|
||||||
|
return xla::RandomAlgorithm::RNG_THREE_FRY;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Algorithm RandomAlgorithmToAlgorithm(xla::RandomAlgorithm const& alg) {
|
||||||
|
if (alg == xla::RandomAlgorithm::RNG_PHILOX) {
|
||||||
|
return RNG_ALG_PHILOX;
|
||||||
|
}
|
||||||
|
return RNG_ALG_THREEFRY;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::XlaOp GetCounter(xla::RandomAlgorithm const& alg, xla::XlaOp state) {
|
||||||
|
Algorithm alg_ = RandomAlgorithmToAlgorithm(alg);
|
||||||
|
return xla::Slice(state, {RNG_KEY_SIZE},
|
||||||
|
{RNG_KEY_SIZE + GetCounterSize(alg_)}, {1});
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key,
|
||||||
|
xla::XlaOp counter, const xla::Shape& shape) {
|
||||||
|
key = BitcastConvertType(key, xla::U64);
|
||||||
|
counter = BitcastConvertType(counter, xla::U64);
|
||||||
|
xla::XlaOp state = xla::ConcatInDim(key.builder(), {key, counter}, 0);
|
||||||
|
xla::XlaOp result = xla::RngBitGenerator(alg, state, shape);
|
||||||
|
auto new_counter = GetCounter(alg, xla::GetTupleElement(result, 0));
|
||||||
|
new_counter = BitcastConvertType(new_counter, xla::S64);
|
||||||
|
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
|
||||||
|
/*state=*/new_counter};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<xla::XlaOp, xla::XlaOp, Algorithm> GetKeyCounterAlg(
|
||||||
|
absl::string_view device_type_string, xla::XlaOp key) {
|
||||||
|
// The Philox algorithm may cause performance regression on other devices.
|
||||||
|
// Turn on the Philox algorithm for the CPU and GPU backends only.
|
||||||
|
if (device_type_string == DEVICE_GPU_XLA_JIT ||
|
||||||
|
device_type_string == DEVICE_CPU_XLA_JIT) {
|
||||||
|
auto counter_key = xla::ScramblePhiloxKey(key);
|
||||||
|
return std::make_tuple(counter_key.second, counter_key.first,
|
||||||
|
RNG_ALG_PHILOX);
|
||||||
|
} else {
|
||||||
|
auto counter_shape =
|
||||||
|
xla::ShapeUtil::MakeShape(xla::U64, {RNG_MAX_COUNTER_SIZE});
|
||||||
|
auto counter = xla::Zeros(key.builder(), counter_shape);
|
||||||
|
return std::make_tuple(key, counter, RNG_ALG_THREEFRY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg,
|
||||||
|
xla::XlaOp key, xla::XlaOp counter,
|
||||||
|
const xla::Shape& shape, xla::XlaOp minval,
|
||||||
|
xla::XlaOp maxval) {
|
||||||
|
xla::XlaBuilder* builder = key.builder();
|
||||||
|
xla::PrimitiveType type = shape.element_type();
|
||||||
|
using std::placeholders::_1;
|
||||||
|
using std::placeholders::_2;
|
||||||
|
using std::placeholders::_3;
|
||||||
|
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
|
||||||
|
switch (type) {
|
||||||
|
case xla::F32:
|
||||||
|
case xla::F64:
|
||||||
|
return xla::UniformFloatingPointDistribution(key, counter, generator,
|
||||||
|
minval, maxval, shape);
|
||||||
|
case xla::S32:
|
||||||
|
case xla::S64:
|
||||||
|
case xla::U32:
|
||||||
|
case xla::U64:
|
||||||
|
return UniformIntDistribution(key, counter, generator, minval, maxval,
|
||||||
|
shape);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return {builder->ReportError(xla::Unimplemented(
|
||||||
|
"Types other than F32, S32, S64, U32 and U64 are not "
|
||||||
|
"implemented by "
|
||||||
|
"StatelessRngUniformV2; got %s",
|
||||||
|
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||||
|
counter};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
xla::RngOutput StatelessRngUniformFullInt(xla::RandomAlgorithm const& alg,
|
||||||
|
xla::XlaOp key, xla::XlaOp counter,
|
||||||
|
const xla::Shape& shape) {
|
||||||
|
xla::XlaBuilder* builder = key.builder();
|
||||||
|
|
||||||
|
xla::PrimitiveType type = shape.element_type();
|
||||||
|
xla::RngOutput output = BitGenerator(alg, key, counter, shape);
|
||||||
|
switch (type) {
|
||||||
|
case xla::U32:
|
||||||
|
case xla::U64:
|
||||||
|
return output;
|
||||||
|
case xla::S32:
|
||||||
|
case xla::S64:
|
||||||
|
return xla::RngOutput{BitcastConvertType(output.value, type),
|
||||||
|
output.state};
|
||||||
|
default:
|
||||||
|
return {
|
||||||
|
builder->ReportError(xla::Unimplemented(
|
||||||
|
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||||
|
"StatelessRngUniformFullInt; got: %s",
|
||||||
|
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||||
|
output.state};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetAlgorithm(XlaOpKernelContext* ctx, int alg_input_idx,
|
||||||
|
xla::RandomAlgorithm* alg) {
|
||||||
|
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||||
|
if (alg_shape.dims() != 0) {
|
||||||
|
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||||
|
alg_shape.DebugString());
|
||||||
|
}
|
||||||
|
xla::Literal alg_literal;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->ConstantInput(alg_input_idx, &alg_literal));
|
||||||
|
auto alg_ = Algorithm(alg_literal.Get<int>({}));
|
||||||
|
*alg = AlgorithmToRandomAlgorithm(alg_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg,
|
||||||
|
TensorShape const& counter_shape,
|
||||||
|
xla::XlaOp counter) {
|
||||||
|
auto input_counter_size = counter_shape.dim_size(0);
|
||||||
|
auto real_counter_size = GetCounterSize(RandomAlgorithmToAlgorithm(alg));
|
||||||
|
if (input_counter_size > real_counter_size) {
|
||||||
|
counter = xla::Slice(counter, {0}, {real_counter_size}, {1});
|
||||||
|
}
|
||||||
|
return counter;
|
||||||
|
}
|
||||||
|
|
||||||
|
class StatelessRandomUniformOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
xla::XlaBuilder* builder = ctx->builder();
|
||||||
|
|
||||||
|
TensorShape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||||
|
|
||||||
|
const int key_input_idx = 1;
|
||||||
|
const int counter_input_idx = 2;
|
||||||
|
const int alg_input_idx = 3;
|
||||||
|
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||||
|
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||||
|
|
||||||
|
xla::RandomAlgorithm alg;
|
||||||
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||||
|
|
||||||
|
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||||
|
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||||
|
ctx->InputShape(key_input_idx),
|
||||||
|
counter_shape));
|
||||||
|
|
||||||
|
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||||
|
xla::PrimitiveType rng_primitive_type = xla_shape.element_type();
|
||||||
|
|
||||||
|
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||||
|
|
||||||
|
auto result = StatelessRngUniformV2(
|
||||||
|
alg, key, counter, xla_shape,
|
||||||
|
xla::ConstantR0WithType(builder, rng_primitive_type, 0.0),
|
||||||
|
xla::ConstantR0WithType(builder, rng_primitive_type, 1.0));
|
||||||
|
auto uniform = MaybeConvertF32ToBF16(result.value, dtype_);
|
||||||
|
ctx->SetOutput(0, uniform);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessRandomUniformV2")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.CompileTimeConstantInput("alg")
|
||||||
|
.TypeConstraint("dtype",
|
||||||
|
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||||
|
StatelessRandomUniformOp);
|
||||||
|
|
||||||
|
class StatelessRandomUniformIntOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
TensorShape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||||
|
|
||||||
|
const int key_input_idx = 1;
|
||||||
|
const int counter_input_idx = 2;
|
||||||
|
const int alg_input_idx = 3;
|
||||||
|
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||||
|
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||||
|
|
||||||
|
xla::RandomAlgorithm alg;
|
||||||
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||||
|
|
||||||
|
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||||
|
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||||
|
ctx->InputShape(key_input_idx),
|
||||||
|
counter_shape));
|
||||||
|
|
||||||
|
const int minval_input_idx = 4;
|
||||||
|
const int maxval_input_idx = 5;
|
||||||
|
TensorShape minval_shape = ctx->InputShape(minval_input_idx);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
|
||||||
|
errors::InvalidArgument("minval must be scalar, got shape ",
|
||||||
|
minval_shape.DebugString()));
|
||||||
|
TensorShape maxval_shape = ctx->InputShape(maxval_input_idx);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
|
||||||
|
errors::InvalidArgument("maxval must be scalar, got shape ",
|
||||||
|
maxval_shape.DebugString()));
|
||||||
|
|
||||||
|
xla::XlaOp minval = ctx->Input(minval_input_idx);
|
||||||
|
xla::XlaOp maxval = ctx->Input(maxval_input_idx);
|
||||||
|
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||||
|
|
||||||
|
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||||
|
auto result =
|
||||||
|
StatelessRngUniformV2(alg, key, counter, xla_shape, minval, maxval);
|
||||||
|
ctx->SetOutput(0, result.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessRandomUniformIntV2")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.CompileTimeConstantInput("alg")
|
||||||
|
.TypeConstraint("dtype",
|
||||||
|
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
|
||||||
|
StatelessRandomUniformIntOp);
|
||||||
|
|
||||||
|
class StatelessRandomUniformFullIntOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatelessRandomUniformFullIntOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
TensorShape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||||
|
|
||||||
|
const int key_input_idx = 1;
|
||||||
|
const int counter_input_idx = 2;
|
||||||
|
const int alg_input_idx = 3;
|
||||||
|
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||||
|
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||||
|
|
||||||
|
xla::RandomAlgorithm alg;
|
||||||
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||||
|
|
||||||
|
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||||
|
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||||
|
ctx->InputShape(key_input_idx),
|
||||||
|
counter_shape));
|
||||||
|
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||||
|
|
||||||
|
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||||
|
auto result = StatelessRngUniformFullInt(alg, key, counter, xla_shape);
|
||||||
|
ctx->SetOutput(0, result.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformFullIntOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessRandomUniformFullIntV2")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.CompileTimeConstantInput("alg")
|
||||||
|
.TypeConstraint("dtype",
|
||||||
|
{DT_INT32, DT_INT64, DT_UINT32, DT_UINT64}),
|
||||||
|
StatelessRandomUniformFullIntOp);
|
||||||
|
|
||||||
|
class StatelessRandomNormalOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
TensorShape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||||
|
|
||||||
|
const int key_input_idx = 1;
|
||||||
|
const int counter_input_idx = 2;
|
||||||
|
const int alg_input_idx = 3;
|
||||||
|
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||||
|
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||||
|
|
||||||
|
xla::RandomAlgorithm alg;
|
||||||
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||||
|
|
||||||
|
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||||
|
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||||
|
ctx->InputShape(key_input_idx),
|
||||||
|
counter_shape));
|
||||||
|
|
||||||
|
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||||
|
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||||
|
|
||||||
|
using std::placeholders::_1;
|
||||||
|
using std::placeholders::_2;
|
||||||
|
using std::placeholders::_3;
|
||||||
|
auto generator = std::bind(BitGenerator, alg, _1, _2, _3);
|
||||||
|
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||||
|
auto result = xla::NormalFloatingPointDistribution(key, counter, generator,
|
||||||
|
xla_shape);
|
||||||
|
auto normal = MaybeConvertF32ToBF16(result.value, dtype_);
|
||||||
|
ctx->SetOutput(0, normal);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessRandomNormalV2")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.CompileTimeConstantInput("alg")
|
||||||
|
.TypeConstraint("dtype",
|
||||||
|
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||||
|
StatelessRandomNormalOp);
|
||||||
|
|
||||||
|
class StatelessTruncatedNormalOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
TensorShape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
|
||||||
|
|
||||||
|
const int key_input_idx = 1;
|
||||||
|
const int counter_input_idx = 2;
|
||||||
|
const int alg_input_idx = 3;
|
||||||
|
xla::XlaOp key = ctx->Input(key_input_idx);
|
||||||
|
xla::XlaOp counter = ctx->Input(counter_input_idx);
|
||||||
|
|
||||||
|
xla::RandomAlgorithm alg;
|
||||||
|
OP_REQUIRES_OK(ctx, GetAlgorithm(ctx, alg_input_idx, &alg));
|
||||||
|
|
||||||
|
auto counter_shape = ctx->InputShape(counter_input_idx);
|
||||||
|
OP_REQUIRES_OK(ctx, CheckKeyCounterShape(RandomAlgorithmToAlgorithm(alg),
|
||||||
|
ctx->InputShape(key_input_idx),
|
||||||
|
counter_shape));
|
||||||
|
|
||||||
|
xla::XlaBuilder* builder = ctx->builder();
|
||||||
|
|
||||||
|
DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT;
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape));
|
||||||
|
|
||||||
|
counter = MaybeSliceCounter(alg, counter_shape, counter);
|
||||||
|
auto result = StatelessRngUniformV2(
|
||||||
|
alg, key, counter, xla_shape,
|
||||||
|
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||||
|
xla::One(builder, xla_shape.element_type()));
|
||||||
|
xla::XlaOp truncated_normal = TruncatedNormal(result.value);
|
||||||
|
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||||
|
ctx->SetOutput(0, truncated_normal);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2")
|
||||||
|
.CompileTimeConstantInput("shape")
|
||||||
|
.CompileTimeConstantInput("alg")
|
||||||
|
.TypeConstraint("dtype",
|
||||||
|
{DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}),
|
||||||
|
StatelessTruncatedNormalOp);
|
||||||
|
|
||||||
|
class GetKeyCounterAlgOp : public XlaOpKernel {
|
||||||
|
public:
|
||||||
|
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx)
|
||||||
|
: XlaOpKernel(ctx),
|
||||||
|
device_type_string_(ctx->device_type().type_string()) {}
|
||||||
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
TensorShape seed_shape = ctx->InputShape(0);
|
||||||
|
OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
|
||||||
|
errors::InvalidArgument("seed must have shape [2], not ",
|
||||||
|
seed_shape.DebugString()));
|
||||||
|
xla::XlaOp seed = ctx->Input(0);
|
||||||
|
|
||||||
|
xla::XlaBuilder* builder = seed.builder();
|
||||||
|
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||||
|
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||||
|
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||||
|
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||||
|
ConstantR0WithType(builder, xla::U64, 32));
|
||||||
|
auto key_counter_alg = GetKeyCounterAlg(device_type_string_, key);
|
||||||
|
key = std::get<0>(key_counter_alg);
|
||||||
|
auto counter = std::get<1>(key_counter_alg);
|
||||||
|
auto alg = std::get<2>(key_counter_alg);
|
||||||
|
key = xla::Reshape(key, {RNG_KEY_SIZE});
|
||||||
|
ctx->SetOutput(0, key);
|
||||||
|
ctx->SetOutput(1, counter);
|
||||||
|
ctx->SetOutput(2, ConstantR0(builder, static_cast<int>(alg)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
string device_type_string_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -83,6 +83,8 @@ CreateResourceOpInfoMap() {
|
|||||||
add("ResourceScatterSub" , kReadWrite, kVariable);
|
add("ResourceScatterSub" , kReadWrite, kVariable);
|
||||||
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
add("ResourceScatterUpdate" , kReadWrite, kVariable);
|
||||||
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
add("ResourceStridedSliceAssign" , kReadWrite, kVariable);
|
||||||
|
add("RngReadAndSkip" , kReadWrite, kVariable);
|
||||||
|
add("RngSkip" , kReadWrite, kVariable);
|
||||||
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
|
add("StatefulStandardNormalV2" , kReadWrite, kVariable);
|
||||||
add("StatefulTruncatedNormal" , kReadWrite, kVariable);
|
add("StatefulTruncatedNormal" , kReadWrite, kVariable);
|
||||||
add("StatefulUniform" , kReadWrite, kVariable);
|
add("StatefulUniform" , kReadWrite, kVariable);
|
||||||
|
@ -487,6 +487,10 @@ std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
|
||||||
|
return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
|
||||||
|
}
|
||||||
|
|
||||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
PrimitiveType type = shape.element_type();
|
PrimitiveType type = shape.element_type();
|
||||||
|
@ -89,6 +89,9 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
|
|||||||
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||||
absl::Span<const xla::XlaOp> scalars);
|
absl::Span<const xla::XlaOp> scalars);
|
||||||
|
|
||||||
|
// Increases Philox counter (an uint128) by a delta (an uint64).
|
||||||
|
xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_
|
||||||
|
@ -488,6 +488,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/framework:register_types_traits.h",
|
"//tensorflow/core/framework:register_types_traits.h",
|
||||||
"//tensorflow/core/framework:resource_mgr.h",
|
"//tensorflow/core/framework:resource_mgr.h",
|
||||||
"//tensorflow/core/framework:resource_op_kernel.h",
|
"//tensorflow/core/framework:resource_op_kernel.h",
|
||||||
|
"//tensorflow/core/framework:rng_alg.h",
|
||||||
"//tensorflow/core/framework:selective_registration.h",
|
"//tensorflow/core/framework:selective_registration.h",
|
||||||
"//tensorflow/core/framework:session_state.h",
|
"//tensorflow/core/framework:session_state.h",
|
||||||
"//tensorflow/core/framework:shape_inference.h",
|
"//tensorflow/core/framework:shape_inference.h",
|
||||||
@ -652,6 +653,7 @@ tf_gen_op_libs(
|
|||||||
"spectral_ops",
|
"spectral_ops",
|
||||||
"state_ops",
|
"state_ops",
|
||||||
"stateless_random_ops",
|
"stateless_random_ops",
|
||||||
|
"stateless_random_ops_v2",
|
||||||
"summary_ops",
|
"summary_ops",
|
||||||
"training_ops",
|
"training_ops",
|
||||||
],
|
],
|
||||||
@ -871,6 +873,7 @@ cc_library(
|
|||||||
":spectral_ops_op_lib",
|
":spectral_ops_op_lib",
|
||||||
":state_ops_op_lib",
|
":state_ops_op_lib",
|
||||||
":stateless_random_ops_op_lib",
|
":stateless_random_ops_op_lib",
|
||||||
|
":stateless_random_ops_v2_op_lib",
|
||||||
":string_ops_op_lib",
|
":string_ops_op_lib",
|
||||||
":training_ops_op_lib",
|
":training_ops_op_lib",
|
||||||
":user_ops_op_lib",
|
":user_ops_op_lib",
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "RngReadAndSkip"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "resource"
|
||||||
|
description: <<END
|
||||||
|
The handle of the resource variable that stores the state of the RNG.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "delta"
|
||||||
|
description: <<END
|
||||||
|
The amount of advancement.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "value"
|
||||||
|
description: <<END
|
||||||
|
The old value of the resource variable, before incrementing. Since state size is algorithm-dependent, this output will be right-padded with zeros to reach shape int64[3] (the current maximal state size among algorithms).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Advance the counter of a counter-based RNG."
|
||||||
|
description: <<END
|
||||||
|
The state of the RNG after
|
||||||
|
`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
|
||||||
|
(or any other distribution). The actual increment added to the
|
||||||
|
counter is an unspecified implementation choice.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessRandomGetKeyCounterAlg"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "seed"
|
||||||
|
description: <<END
|
||||||
|
2 seeds (shape [2]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "counter"
|
||||||
|
description: <<END
|
||||||
|
Counter for the counter-based RNG algorithm. Since counter size is algorithm-dependent, this output will be right-padded with zeros to reach shape uint64[2] (the current maximal counter size among algorithms).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm (shape int32[]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Picks the best algorithm based on device, and scrambles seed into key and counter."
|
||||||
|
description: <<END
|
||||||
|
This op picks the best counter-based RNG algorithm based on device, and scrambles a shape-[2] seed into a key and a counter, both needed by the counter-based algorithm. The scrambling is opaque but approximately satisfies the property that different seed results in different key/counter pair (which will in turn result in different random numbers).
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessRandomNormalV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "counter"
|
||||||
|
description: <<END
|
||||||
|
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm (shape int32[]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Random values with specified shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Outputs deterministic pseudorandom values from a normal distribution."
|
||||||
|
description: <<END
|
||||||
|
The generated values will have mean 0 and standard deviation 1.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessRandomUniformFullIntV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "counter"
|
||||||
|
description: <<END
|
||||||
|
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm (shape int32[]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Random values with specified shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
|
||||||
|
description: <<END
|
||||||
|
The generated values are uniform integers covering the whole range of `dtype`.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessRandomUniformIntV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "counter"
|
||||||
|
description: <<END
|
||||||
|
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm (shape int32[]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "minval"
|
||||||
|
description: <<END
|
||||||
|
Minimum value (inclusive, scalar).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "maxval"
|
||||||
|
description: <<END
|
||||||
|
Maximum value (exclusive, scalar).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Random values with specified shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
|
||||||
|
description: <<END
|
||||||
|
The generated values follow a uniform distribution in the range `[minval, maxval)`.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `key`, `counter`, `alg`, `minval` and `maxval`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessRandomUniformV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "counter"
|
||||||
|
description: <<END
|
||||||
|
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm (shape int32[]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Random values with specified shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Outputs deterministic pseudorandom random values from a uniform distribution."
|
||||||
|
description: <<END
|
||||||
|
The generated values follow a uniform distribution in the range `[0, 1)`. The
|
||||||
|
lower bound 0 is included in the range, while the upper bound 1 is excluded.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "StatelessTruncatedNormalV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "shape"
|
||||||
|
description: <<END
|
||||||
|
The shape of the output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "key"
|
||||||
|
description: <<END
|
||||||
|
Key for the counter-based RNG algorithm (shape uint64[1]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "counter"
|
||||||
|
description: <<END
|
||||||
|
Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "alg"
|
||||||
|
description: <<END
|
||||||
|
The RNG algorithm (shape int32[]).
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Random values with specified shape.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "dtype"
|
||||||
|
description: <<END
|
||||||
|
The type of the output.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Outputs deterministic pseudorandom values from a truncated normal distribution."
|
||||||
|
description: <<END
|
||||||
|
The generated values follow a normal distribution with mean 0 and standard
|
||||||
|
deviation 1, except that values whose magnitude is more than 2 standard
|
||||||
|
deviations from the mean are dropped and re-picked.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||||
|
END
|
||||||
|
}
|
@ -68,6 +68,7 @@ exports_files(
|
|||||||
"resource_mgr.h",
|
"resource_mgr.h",
|
||||||
"resource_op_kernel.h",
|
"resource_op_kernel.h",
|
||||||
"resource_var.h",
|
"resource_var.h",
|
||||||
|
"rng_alg.h",
|
||||||
"run_handler.h",
|
"run_handler.h",
|
||||||
"run_handler_util.h",
|
"run_handler_util.h",
|
||||||
"session_state.h",
|
"session_state.h",
|
||||||
@ -385,6 +386,7 @@ filegroup(
|
|||||||
"resource_mgr.h",
|
"resource_mgr.h",
|
||||||
"resource_op_kernel.h",
|
"resource_op_kernel.h",
|
||||||
"resource_var.h",
|
"resource_var.h",
|
||||||
|
"rng_alg.h",
|
||||||
"run_handler.cc",
|
"run_handler.cc",
|
||||||
"run_handler.h",
|
"run_handler.h",
|
||||||
"run_handler_util.cc",
|
"run_handler_util.cc",
|
||||||
|
34
tensorflow/core/framework/rng_alg.h
Normal file
34
tensorflow/core/framework/rng_alg.h
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
||||||
|
#define TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
enum Algorithm { RNG_ALG_PHILOX = 1, RNG_ALG_THREEFRY = 2 };
|
||||||
|
|
||||||
|
static constexpr int RNG_KEY_SIZE = 1;
|
||||||
|
static constexpr int RNG_MAX_COUNTER_SIZE = 2;
|
||||||
|
inline int GetCounterSize(Algorithm alg) {
|
||||||
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_
|
@ -4457,12 +4457,22 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "stateless_random_ops_v2_header",
|
||||||
|
hdrs = ["stateless_random_ops_v2.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "stateful_random_ops",
|
name = "stateful_random_ops",
|
||||||
prefix = "stateful_random_ops",
|
prefix = "stateful_random_ops",
|
||||||
deps = [
|
deps = [
|
||||||
":bounds_check",
|
":bounds_check",
|
||||||
":dense_update_functor",
|
":dense_update_functor",
|
||||||
|
":fill_functor",
|
||||||
":gather_functor",
|
":gather_functor",
|
||||||
":mutex_ops",
|
":mutex_ops",
|
||||||
":random_op",
|
":random_op",
|
||||||
@ -5337,7 +5347,7 @@ tf_kernel_library(
|
|||||||
prefix = "random_binomial_op",
|
prefix = "random_binomial_op",
|
||||||
deps = [
|
deps = [
|
||||||
":cwise_op",
|
":cwise_op",
|
||||||
":random_ops",
|
":random_op",
|
||||||
":resource_variable_ops",
|
":resource_variable_ops",
|
||||||
":stateful_random_ops",
|
":stateful_random_ops",
|
||||||
":stateless_random_ops",
|
":stateless_random_ops",
|
||||||
@ -6159,6 +6169,7 @@ filegroup(
|
|||||||
"ragged_tensor_to_tensor_op.cc",
|
"ragged_tensor_to_tensor_op.cc",
|
||||||
"random_op.cc",
|
"random_op.cc",
|
||||||
"random_op_cpu.h",
|
"random_op_cpu.h",
|
||||||
|
"random_ops_util.h",
|
||||||
"random_poisson_op.cc",
|
"random_poisson_op.cc",
|
||||||
"reduce_join_op.cc",
|
"reduce_join_op.cc",
|
||||||
"reduction_ops_all.cc",
|
"reduction_ops_all.cc",
|
||||||
|
@ -66,8 +66,9 @@ struct MultinomialFunctor<GPUDevice, T, OutputType> {
|
|||||||
typename TTypes<OutputType>::Matrix output) {
|
typename TTypes<OutputType>::Matrix output) {
|
||||||
// Uniform, [0, 1).
|
// Uniform, [0, 1).
|
||||||
typedef random::UniformDistribution<random::PhiloxRandom, float> Dist;
|
typedef random::UniformDistribution<random::PhiloxRandom, float> Dist;
|
||||||
functor::FillPhiloxRandom<GPUDevice, Dist>()(ctx, d, gen, noises.data(),
|
functor::FillPhiloxRandom<GPUDevice, Dist>()(
|
||||||
noises.size(), Dist());
|
ctx, d, /*key=*/nullptr, /*counter=*/nullptr, gen, noises.data(),
|
||||||
|
noises.size(), Dist());
|
||||||
|
|
||||||
#if defined(EIGEN_HAS_INDEX_LIST)
|
#if defined(EIGEN_HAS_INDEX_LIST)
|
||||||
Eigen::IndexList<int, int, int> bsc;
|
Eigen::IndexList<int, int, int> bsc;
|
||||||
|
@ -30,8 +30,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||||
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
@ -375,7 +377,7 @@ class RandomBinomialOp : public OpKernel {
|
|||||||
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
|
||||||
errors::InvalidArgument("algorithm must be of shape [], not ",
|
errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||||
alg_tensor.shape().DebugString()));
|
alg_tensor.shape().DebugString()));
|
||||||
Algorithm alg = alg_tensor.flat<Algorithm>()(0);
|
Algorithm alg = Algorithm(alg_tensor.flat<int64>()(0));
|
||||||
|
|
||||||
int64 samples_per_batch = 1;
|
int64 samples_per_batch = 1;
|
||||||
const int64 num_sample_dims =
|
const int64 num_sample_dims =
|
||||||
|
@ -74,7 +74,7 @@ class PhiloxRandomOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
||||||
auto output_flat = output->flat<T>();
|
auto output_flat = output->flat<T>();
|
||||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
ctx, ctx->eigen_device<Device>(),
|
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
|
||||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
||||||
// it just here.
|
// it just here.
|
||||||
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
||||||
@ -123,7 +123,7 @@ class RandomUniformIntOp : public OpKernel {
|
|||||||
|
|
||||||
auto output_flat = output->flat<IntType>();
|
auto output_flat = output->flat<IntType>();
|
||||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
ctx, ctx->eigen_device<Device>(),
|
ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
|
||||||
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
|
||||||
// it just here.
|
// it just here.
|
||||||
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
generator_.ReserveRandomOutputs(output_flat.size(), 256),
|
||||||
|
@ -34,10 +34,14 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
|||||||
// NOTE: Due to inlining done by the compiler, you may need to add
|
// NOTE: Due to inlining done by the compiler, you may need to add
|
||||||
// explicit instantiation of the functor in random_op.cc. See example
|
// explicit instantiation of the functor in random_op.cc. See example
|
||||||
// functor::FillPhiloxRandom<CPUDevice, random::UniformDistribution>.
|
// functor::FillPhiloxRandom<CPUDevice, random::UniformDistribution>.
|
||||||
|
//
|
||||||
|
// This functor can take the PhiloxRandom input from either device memory `key`
|
||||||
|
// and `counter` or a stack value `gen`. If both `key` and `counter` are not
|
||||||
|
// nullptr, they provide the input; otherwise `gen` provides the input.
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
struct FillPhiloxRandom<CPUDevice, Distribution> {
|
struct FillPhiloxRandom<CPUDevice, Distribution> {
|
||||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
void operator()(OpKernelContext* ctx, const CPUDevice& d, const uint64* key,
|
||||||
random::PhiloxRandom gen,
|
const uint64* counter, random::PhiloxRandom gen,
|
||||||
typename Distribution::ResultElementType* data, int64 size,
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
Distribution dist);
|
Distribution dist);
|
||||||
};
|
};
|
||||||
@ -47,14 +51,13 @@ typedef Eigen::GpuDevice GPUDevice;
|
|||||||
// Declares the partially GPU-specialized functor struct.
|
// Declares the partially GPU-specialized functor struct.
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
struct FillPhiloxRandom<GPUDevice, Distribution> {
|
struct FillPhiloxRandom<GPUDevice, Distribution> {
|
||||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
void operator()(OpKernelContext* ctx, const GPUDevice& d, const uint64* key,
|
||||||
random::PhiloxRandom gen,
|
const uint64* counter, random::PhiloxRandom gen,
|
||||||
typename Distribution::ResultElementType* data, int64 size,
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
Distribution dist);
|
Distribution dist);
|
||||||
};
|
};
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/random_op.h"
|
#include "tensorflow/core/kernels/random_op.h"
|
||||||
|
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||||
#include "tensorflow/core/lib/hash/crc32c.h"
|
#include "tensorflow/core/lib/hash/crc32c.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
@ -59,8 +60,9 @@ using random::SingleSampleAdapter;
|
|||||||
template <typename Device, class Distribution>
|
template <typename Device, class Distribution>
|
||||||
struct FillPhiloxRandom {
|
struct FillPhiloxRandom {
|
||||||
typedef typename Distribution::ResultElementType T;
|
typedef typename Distribution::ResultElementType T;
|
||||||
void operator()(OpKernelContext* ctx, const Device&, random::PhiloxRandom gen,
|
void operator()(OpKernelContext* ctx, const Device&, const uint64* key,
|
||||||
T* data, int64 size, Distribution dist) {
|
const uint64* counter, random::PhiloxRandom gen, T* data,
|
||||||
|
int64 size, Distribution dist) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, false,
|
ctx, false,
|
||||||
errors::Internal(
|
errors::Internal(
|
||||||
@ -154,18 +156,24 @@ struct FillPhiloxRandomTask<Distribution, true> {
|
|||||||
// It splits the work into several tasks and run them in parallel
|
// It splits the work into several tasks and run them in parallel
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
|
void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
|
||||||
OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
|
OpKernelContext* ctx, const CPUDevice&, const uint64* key,
|
||||||
|
const uint64* counter, random::PhiloxRandom gen,
|
||||||
typename Distribution::ResultElementType* data, int64 size,
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
Distribution dist) {
|
Distribution dist) {
|
||||||
const int kGroupSize = Distribution::kResultElementCount;
|
const int kGroupSize = Distribution::kResultElementCount;
|
||||||
|
|
||||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||||
|
|
||||||
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
|
int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
|
||||||
|
|
||||||
const int kGroupCost =
|
const int kGroupCost =
|
||||||
random::PhiloxRandom::kResultElementCount *
|
random::PhiloxRandom::kResultElementCount *
|
||||||
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
|
(random::PhiloxRandom::kElementCost + Distribution::kElementCost);
|
||||||
|
|
||||||
|
if (key != nullptr && counter != nullptr) {
|
||||||
|
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||||
|
}
|
||||||
|
|
||||||
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
|
Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
|
||||||
kGroupCost,
|
kGroupCost,
|
||||||
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
|
[&gen, data, size, dist](int64 start_group, int64 limit_group) {
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/random_op.h"
|
#include "tensorflow/core/kernels/random_op.h"
|
||||||
|
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||||
#include "tensorflow/core/lib/random/philox_random.h"
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
@ -33,14 +34,16 @@ struct FillPhiloxRandomKernel;
|
|||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
struct FillPhiloxRandomKernel<Distribution, false> {
|
struct FillPhiloxRandomKernel<Distribution, false> {
|
||||||
typedef typename Distribution::ResultElementType T;
|
typedef typename Distribution::ResultElementType T;
|
||||||
PHILOX_DEVICE_INLINE void Run(random::PhiloxRandom gen, T* data, int64 size,
|
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
|
||||||
|
random::PhiloxRandom gen, T* data, int64 size,
|
||||||
Distribution dist);
|
Distribution dist);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
struct FillPhiloxRandomKernel<Distribution, true> {
|
struct FillPhiloxRandomKernel<Distribution, true> {
|
||||||
typedef typename Distribution::ResultElementType T;
|
typedef typename Distribution::ResultElementType T;
|
||||||
PHILOX_DEVICE_INLINE void Run(const random::PhiloxRandom& base_gen, T* data,
|
PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter,
|
||||||
|
random::PhiloxRandom base_gen, T* data,
|
||||||
int64 size, Distribution dist);
|
int64 size, Distribution dist);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -136,12 +139,16 @@ class SampleCopier<int64, 2> {
|
|||||||
// distribution. Each output takes a fixed number of samples.
|
// distribution. Each output takes a fixed number of samples.
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
||||||
random::PhiloxRandom gen, T* data, int64 size, Distribution dist) {
|
const uint64* key, const uint64* counter, random::PhiloxRandom gen, T* data,
|
||||||
|
int64 size, Distribution dist) {
|
||||||
const int kGroupSize = Distribution::kResultElementCount;
|
const int kGroupSize = Distribution::kResultElementCount;
|
||||||
|
|
||||||
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const int32 total_thread_count = gridDim.x * blockDim.x;
|
const int32 total_thread_count = gridDim.x * blockDim.x;
|
||||||
int32 offset = thread_id * kGroupSize;
|
int32 offset = thread_id * kGroupSize;
|
||||||
|
if (key != nullptr && counter != nullptr) {
|
||||||
|
gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||||
|
}
|
||||||
gen.Skip(thread_id);
|
gen.Skip(thread_id);
|
||||||
|
|
||||||
const SampleCopier<T, kGroupSize> copier;
|
const SampleCopier<T, kGroupSize> copier;
|
||||||
@ -167,8 +174,8 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, false>::Run(
|
|||||||
// distribution. Each output takes a variable number of samples.
|
// distribution. Each output takes a variable number of samples.
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
||||||
const random::PhiloxRandom& base_gen, T* data, int64 size,
|
const uint64* key, const uint64* counter, random::PhiloxRandom base_gen,
|
||||||
Distribution dist) {
|
T* data, int64 size, Distribution dist) {
|
||||||
using random::PhiloxRandom;
|
using random::PhiloxRandom;
|
||||||
using random::SingleSampleAdapter;
|
using random::SingleSampleAdapter;
|
||||||
|
|
||||||
@ -183,6 +190,9 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
|||||||
int64 group_index = thread_id;
|
int64 group_index = thread_id;
|
||||||
int64 offset = group_index * kGroupSize;
|
int64 offset = group_index * kGroupSize;
|
||||||
|
|
||||||
|
if (key != nullptr && counter != nullptr) {
|
||||||
|
base_gen = GetPhiloxRandomFromCounterKeyMem(counter, key);
|
||||||
|
}
|
||||||
while (offset < size) {
|
while (offset < size) {
|
||||||
// Since each output takes a variable number of samples, we need to
|
// Since each output takes a variable number of samples, we need to
|
||||||
// realign the generator to the beginning for the current output group
|
// realign the generator to the beginning for the current output group
|
||||||
@ -208,18 +218,20 @@ PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel<Distribution, true>::Run(
|
|||||||
// A simple launch pad to call the correct function templates to fill the data
|
// A simple launch pad to call the correct function templates to fill the data
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
__global__ void __launch_bounds__(1024)
|
__global__ void __launch_bounds__(1024)
|
||||||
FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
|
FillPhiloxRandomKernelLaunch(const uint64* key, const uint64* counter,
|
||||||
|
random::PhiloxRandom base_gen,
|
||||||
typename Distribution::ResultElementType* data,
|
typename Distribution::ResultElementType* data,
|
||||||
int64 size, Distribution dist) {
|
int64 size, Distribution dist) {
|
||||||
FillPhiloxRandomKernel<Distribution,
|
FillPhiloxRandomKernel<Distribution,
|
||||||
Distribution::kVariableSamplesPerOutput>()
|
Distribution::kVariableSamplesPerOutput>()
|
||||||
.Run(base_gen, data, size, dist);
|
.Run(key, counter, base_gen, data, size, dist);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Partial specialization for GPU
|
// Partial specialization for GPU
|
||||||
template <class Distribution>
|
template <class Distribution>
|
||||||
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
||||||
OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
|
OpKernelContext*, const GPUDevice& d, const uint64* key,
|
||||||
|
const uint64* counter, random::PhiloxRandom gen,
|
||||||
typename Distribution::ResultElementType* data, int64 size,
|
typename Distribution::ResultElementType* data, int64 size,
|
||||||
Distribution dist) {
|
Distribution dist) {
|
||||||
const int32 block_size = d.maxGpuThreadsPerBlock();
|
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||||
@ -228,8 +240,8 @@ void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
|
|||||||
block_size;
|
block_size;
|
||||||
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch<Distribution>,
|
||||||
num_blocks, block_size, 0, d.stream(), gen, data,
|
num_blocks, block_size, 0, d.stream(), key,
|
||||||
size, dist));
|
counter, gen, data, size, dist));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
72
tensorflow/core/kernels/random_ops_util.h
Normal file
72
tensorflow/core/kernels/random_ops_util.h
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/random/philox_random.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using random::PhiloxRandom;
|
||||||
|
|
||||||
|
// The following 2 functions use the contract "lower 32 bits for the first
|
||||||
|
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
||||||
|
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
||||||
|
PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1,
|
||||||
|
uint32* output2) {
|
||||||
|
*output1 = static_cast<uint32>(input);
|
||||||
|
*output2 = static_cast<uint32>(input >> 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) {
|
||||||
|
auto u64_1 = static_cast<uint64>(input1);
|
||||||
|
auto u64_2 = static_cast<uint64>(input2);
|
||||||
|
return u64_1 | (u64_2 << 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem(
|
||||||
|
uint64 const* ptr) {
|
||||||
|
PhiloxRandom::ResultType counter;
|
||||||
|
Uint64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
||||||
|
Uint64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
||||||
|
return counter;
|
||||||
|
}
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE void WriteCounterToMem(
|
||||||
|
PhiloxRandom::ResultType const& counter, uint64* ptr) {
|
||||||
|
ptr[0] = Uint32sToUint64(counter[0], counter[1]);
|
||||||
|
ptr[1] = Uint32sToUint64(counter[2], counter[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) {
|
||||||
|
PhiloxRandom::Key key;
|
||||||
|
Uint64ToUint32s(ptr[0], &key[0], &key[1]);
|
||||||
|
return key;
|
||||||
|
}
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key,
|
||||||
|
uint64* ptr) {
|
||||||
|
*ptr = Uint32sToUint64(key[0], key[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem(
|
||||||
|
uint64 const* counter_ptr, uint64 const* key_ptr) {
|
||||||
|
return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_
|
@ -15,7 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
#include "tensorflow/core/framework/tensor_util.h"
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/kernels/random_op_cpu.h"
|
#include "tensorflow/core/kernels/random_op_cpu.h"
|
||||||
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
|
||||||
#include "tensorflow/core/kernels/training_op_helpers.h"
|
#include "tensorflow/core/kernels/training_op_helpers.h"
|
||||||
@ -23,6 +25,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
||||||
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
void operator()(OpKernelContext* ctx, const CPUDevice& device,
|
||||||
@ -42,10 +46,13 @@ struct UpdateVariableAndFill_Philox<CPUDevice, Distribution> {
|
|||||||
// No longer needs the lock.
|
// No longer needs the lock.
|
||||||
state_var_guard->Release();
|
state_var_guard->Release();
|
||||||
functor::FillPhiloxRandom<CPUDevice, Distribution>()(
|
functor::FillPhiloxRandom<CPUDevice, Distribution>()(
|
||||||
ctx, device, philox, output_data, output_size, dist);
|
ctx, device, /*key=*/nullptr, /*counter=*/nullptr, philox, output_data,
|
||||||
|
output_size, dist);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
|
|
||||||
Status CheckState(const Tensor& state) {
|
Status CheckState(const Tensor& state) {
|
||||||
if (state.dtype() != STATE_ELEMENT_DTYPE) {
|
if (state.dtype() != STATE_ELEMENT_DTYPE) {
|
||||||
return errors::InvalidArgument("dtype of RNG state variable must be ",
|
return errors::InvalidArgument("dtype of RNG state variable must be ",
|
||||||
@ -64,11 +71,12 @@ Status CheckPhiloxState(const Tensor& state, int64 alg_tag_skip = 0) {
|
|||||||
"StateElementType must be int64");
|
"StateElementType must be int64");
|
||||||
static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
|
static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
|
||||||
"PhiloxRandom::ResultElementType must be uint32");
|
"PhiloxRandom::ResultElementType must be uint32");
|
||||||
if (state.NumElements() < alg_tag_skip + PHILOX_MIN_STATE_SIZE) {
|
auto min_size = alg_tag_skip + PHILOX_MIN_STATE_SIZE;
|
||||||
|
if (state.NumElements() < min_size) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"For the Philox algorithm, the size of state"
|
"For the Philox algorithm, the size of state"
|
||||||
" must be at least ",
|
" must be at least ",
|
||||||
alg_tag_skip + PHILOX_MIN_STATE_SIZE, "; got ", state.NumElements());
|
min_size, "; got ", state.NumElements());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -95,7 +103,7 @@ Status UpdateVariableAndFill(
|
|||||||
if (var_tensor_flat.size() < 1) {
|
if (var_tensor_flat.size() < 1) {
|
||||||
return errors::InvalidArgument("Size of tensor must be at least 1");
|
return errors::InvalidArgument("Size of tensor must be at least 1");
|
||||||
}
|
}
|
||||||
alg = var_tensor_flat(0);
|
alg = Algorithm(var_tensor_flat(0));
|
||||||
}
|
}
|
||||||
if (alg == RNG_ALG_PHILOX) {
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip));
|
TF_RETURN_IF_ERROR(CheckPhiloxState(*var_tensor, alg_tag_skip));
|
||||||
@ -107,7 +115,7 @@ Status UpdateVariableAndFill(
|
|||||||
arg.alg_tag_skip = alg_tag_skip;
|
arg.alg_tag_skip = alg_tag_skip;
|
||||||
arg.not_used = &state_var_guard;
|
arg.not_used = &state_var_guard;
|
||||||
arg.state_tensor = var_tensor;
|
arg.state_tensor = var_tensor;
|
||||||
UpdateVariableAndFill_Philox<Device, Distribution>()(
|
functor::UpdateVariableAndFill_Philox<Device, Distribution>()(
|
||||||
ctx, ctx->eigen_device<Device>(), dist, &arg, output_data);
|
ctx, ctx->eigen_device<Device>(), dist, &arg, output_data);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
@ -138,7 +146,8 @@ class StatefulRandomOp : public OpKernel {
|
|||||||
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit StatefulRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true, 0);
|
StatefulRandomCompute<Device>(ctx, Distribution(), 0, 1, true,
|
||||||
|
RNG_ALG_PHILOX /*dummy*/);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -159,6 +168,14 @@ Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename AlgEnumType>
|
||||||
|
Status GetAlg(OpKernelContext* ctx, int input_idx, Algorithm* alg) {
|
||||||
|
AlgEnumType alg_id;
|
||||||
|
TF_RETURN_IF_ERROR(GetScalar(ctx->input(input_idx), input_idx, &alg_id));
|
||||||
|
*alg = Algorithm(alg_id);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Device, class Distribution>
|
template <typename Device, class Distribution>
|
||||||
class StatefulRandomOpV2 : public OpKernel {
|
class StatefulRandomOpV2 : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@ -166,7 +183,7 @@ class StatefulRandomOpV2 : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
Algorithm alg;
|
Algorithm alg;
|
||||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||||
StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
|
StatefulRandomCompute<Device>(ctx, Distribution(), /*state_input_idx=*/0,
|
||||||
/*shape_input_idx=*/2,
|
/*shape_input_idx=*/2,
|
||||||
/*read_alg_from_state=*/false, alg);
|
/*read_alg_from_state=*/false, alg);
|
||||||
@ -180,7 +197,7 @@ class StatefulUniformIntOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
Algorithm alg;
|
Algorithm alg;
|
||||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||||
const Tensor& minval = ctx->input(3);
|
const Tensor& minval = ctx->input(3);
|
||||||
const Tensor& maxval = ctx->input(4);
|
const Tensor& maxval = ctx->input(4);
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
||||||
@ -217,7 +234,7 @@ class StatefulUniformFullIntOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
Algorithm alg;
|
Algorithm alg;
|
||||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
OP_REQUIRES_OK(ctx, GetAlg<int64>(ctx, 1, &alg));
|
||||||
StatefulRandomCompute<Device>(
|
StatefulRandomCompute<Device>(
|
||||||
ctx,
|
ctx,
|
||||||
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
|
random::UniformFullIntDistribution<random::PhiloxRandom, IntType>(),
|
||||||
@ -226,38 +243,66 @@ class StatefulUniformFullIntOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct RngSkip_Philox<CPUDevice> {
|
struct RngSkip_Philox<CPUDevice> {
|
||||||
void operator()(const CPUDevice& device, int64 delta, Tensor* state_tensor) {
|
void operator()(const CPUDevice& device, const StateElementType* in_data,
|
||||||
auto state_data = state_tensor->flat<StateElementType>().data();
|
uint64 delta, StateElementType* out_data) {
|
||||||
// Delegates to PhiloxRandom to do the actual increasing.
|
// Delegates to PhiloxRandom to do the actual increasing.
|
||||||
auto philox = GetPhiloxRandomFromMem(state_data);
|
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
|
||||||
UpdateMemWithPhiloxRandom(philox, delta, state_data);
|
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device>
|
} // end namespace functor
|
||||||
|
|
||||||
|
template <typename Device, typename AlgEnumType = int64,
|
||||||
|
typename DeltaType = int64, bool read_old_value = false>
|
||||||
class RngSkipOp : public OpKernel {
|
class RngSkipOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit RngSkipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
auto state_input_idx = 0;
|
auto state_input_idx = 0;
|
||||||
|
auto alg_input_idx = 1;
|
||||||
|
auto delta_input_idx = 2;
|
||||||
Algorithm alg;
|
Algorithm alg;
|
||||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(1), 1, &alg));
|
OP_REQUIRES_OK(ctx, GetAlg<AlgEnumType>(ctx, alg_input_idx, &alg));
|
||||||
int64 delta;
|
DeltaType delta_;
|
||||||
OP_REQUIRES_OK(ctx, GetScalar(ctx->input(2), 2, &delta));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, GetScalar(ctx->input(delta_input_idx), delta_input_idx, &delta_));
|
||||||
|
uint64 delta = static_cast<uint64>(delta_);
|
||||||
Var* var = nullptr;
|
Var* var = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
|
ctx, LookupResource(ctx, HandleFromInput(ctx, state_input_idx), &var));
|
||||||
ScopedUnlockUnrefVar state_var_guard(var);
|
ScopedUnlockUnrefVar state_var_guard(var);
|
||||||
Tensor* var_tensor = var->tensor();
|
Tensor* var_tensor = var->tensor();
|
||||||
OP_REQUIRES_OK(ctx, CheckState(*var_tensor));
|
OP_REQUIRES_OK(ctx, CheckState(*var_tensor));
|
||||||
|
using T = StateElementType;
|
||||||
|
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, T>(
|
||||||
|
ctx, var_tensor, var->copy_on_read_mode.load()));
|
||||||
|
if (read_old_value) {
|
||||||
|
Tensor* output;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, {RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE},
|
||||||
|
&output));
|
||||||
|
auto output_flat = output->flat<T>();
|
||||||
|
if (RNG_MAX_COUNTER_SIZE > GetCounterSize(alg)) {
|
||||||
|
functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
|
||||||
|
output_flat);
|
||||||
|
}
|
||||||
|
functor::DenseUpdate<Device, T, ASSIGN>()(
|
||||||
|
ctx->eigen_device<Device>(), output_flat,
|
||||||
|
const_cast<const Tensor*>(var_tensor)->flat<T>());
|
||||||
|
}
|
||||||
if (alg == RNG_ALG_PHILOX) {
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor));
|
OP_REQUIRES_OK(ctx, CheckPhiloxState(*var_tensor));
|
||||||
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
|
// var_tensor layout is counter+key, so var_tensor data is also counter
|
||||||
ctx, var_tensor, var->copy_on_read_mode.load()));
|
// data.
|
||||||
RngSkip_Philox<Device>()(ctx->eigen_device<Device>(), delta, var_tensor);
|
auto counter_data = var_tensor->flat<T>().data();
|
||||||
|
functor::RngSkip_Philox<Device>()(ctx->eigen_device<Device>(),
|
||||||
|
counter_data, delta, counter_data);
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES(ctx, false,
|
OP_REQUIRES(ctx, false,
|
||||||
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||||
@ -393,13 +438,20 @@ TF_CALL_int64(REGISTER_StatefulUniformFullInt_CPU);
|
|||||||
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
|
TF_CALL_uint32(REGISTER_StatefulUniformFullInt_CPU);
|
||||||
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
|
TF_CALL_uint64(REGISTER_StatefulUniformFullInt_CPU);
|
||||||
|
|
||||||
|
// TODO(wangpeng): Remove `HostMemory("delta")` for RngReadAndSkip
|
||||||
#define REGISTER_RngSkip(DEVICE) \
|
#define REGISTER_RngSkip(DEVICE) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("RngSkip") \
|
REGISTER_KERNEL_BUILDER(Name("RngSkip") \
|
||||||
.Device(DEVICE_##DEVICE) \
|
.Device(DEVICE_##DEVICE) \
|
||||||
.HostMemory("resource") \
|
.HostMemory("resource") \
|
||||||
.HostMemory("algorithm") \
|
.HostMemory("algorithm") \
|
||||||
.HostMemory("delta"), \
|
.HostMemory("delta"), \
|
||||||
RngSkipOp<DEVICE##Device>);
|
RngSkipOp<DEVICE##Device>); \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RngReadAndSkip") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("resource") \
|
||||||
|
.HostMemory("alg") \
|
||||||
|
.HostMemory("delta"), \
|
||||||
|
RngSkipOp<DEVICE##Device, int32, uint64, true>);
|
||||||
|
|
||||||
REGISTER_RngSkip(CPU);
|
REGISTER_RngSkip(CPU);
|
||||||
|
|
||||||
|
@ -22,15 +22,12 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained
|
// 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained
|
||||||
// in b/111604096 and cl/171681867), so I use signed int here. I choose int64
|
// in b/111604096 and cl/171681867), so we use signed int here. We choose int64
|
||||||
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU.
|
// instead of int32 because `VarHandleOp` doesn't support int32 on GPU, and
|
||||||
|
// because of the "int32 problem".
|
||||||
using StateElementType = int64;
|
using StateElementType = int64;
|
||||||
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
|
static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64;
|
||||||
|
|
||||||
using Algorithm = StateElementType;
|
|
||||||
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
|
static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE;
|
||||||
static constexpr Algorithm RNG_ALG_PHILOX = 1;
|
|
||||||
static constexpr Algorithm RNG_ALG_THREEFRY = 2;
|
|
||||||
|
|
||||||
using random::PhiloxRandom;
|
using random::PhiloxRandom;
|
||||||
|
|
||||||
|
@ -17,59 +17,51 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/resource_var.h"
|
#include "tensorflow/core/framework/resource_var.h"
|
||||||
|
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||||
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
#include "tensorflow/core/kernels/stateful_random_ops.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// The following 5 functions are made templates to avoid duplicate symbols when
|
|
||||||
// linking.
|
|
||||||
|
|
||||||
// The following 2 functions use the contract "lower 32 bits for the first
|
|
||||||
// uint32, higher 32 bits for the second". Note that this is endian-neutral,
|
|
||||||
// unlike a direct memory copy `memcpy(output, &input, 8)`.
|
|
||||||
PHILOX_DEVICE_INLINE void Int64ToUint32s(int64 input, uint32* output1,
|
|
||||||
uint32* output2) {
|
|
||||||
auto u64 = static_cast<uint64>(input);
|
|
||||||
*output1 = static_cast<uint32>(u64);
|
|
||||||
*output2 = static_cast<uint32>(u64 >> 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
PHILOX_DEVICE_INLINE int64 Uint32sToInt64(uint32 input1, uint32 input2) {
|
|
||||||
auto u64_1 = static_cast<uint64>(input1);
|
|
||||||
auto u64_2 = static_cast<uint64>(input2);
|
|
||||||
return static_cast<int64>(u64_1 | (u64_2 << 32));
|
|
||||||
}
|
|
||||||
|
|
||||||
PHILOX_DEVICE_INLINE PhiloxRandom
|
PHILOX_DEVICE_INLINE PhiloxRandom
|
||||||
GetPhiloxRandomFromMem(StateElementType const* ptr) {
|
GetPhiloxRandomFromMem(StateElementType const* ptr) {
|
||||||
PhiloxRandom::ResultType counter;
|
auto ptr_ = reinterpret_cast<uint64 const*>(ptr);
|
||||||
PhiloxRandom::Key key;
|
return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2);
|
||||||
Int64ToUint32s(ptr[0], &counter[0], &counter[1]);
|
|
||||||
Int64ToUint32s(ptr[1], &counter[2], &counter[3]);
|
|
||||||
Int64ToUint32s(ptr[2], &key[0], &key[1]);
|
|
||||||
return PhiloxRandom(counter, key);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
|
PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox,
|
||||||
StateElementType* ptr) {
|
StateElementType* ptr) {
|
||||||
PhiloxRandom::ResultType const& counter = philox.counter();
|
auto ptr_ = reinterpret_cast<uint64*>(ptr);
|
||||||
PhiloxRandom::Key const& key = philox.key();
|
WriteCounterToMem(philox.counter(), ptr_);
|
||||||
ptr[0] = Uint32sToInt64(counter[0], counter[1]);
|
WriteKeyToMem(philox.key(), ptr_ + 2);
|
||||||
ptr[1] = Uint32sToInt64(counter[2], counter[3]);
|
}
|
||||||
ptr[2] = Uint32sToInt64(key[0], key[1]);
|
|
||||||
|
PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox,
|
||||||
|
uint64 output_size) {
|
||||||
|
auto new_philox = philox;
|
||||||
|
// Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it
|
||||||
|
// just here.
|
||||||
|
auto delta = output_size * 256;
|
||||||
|
new_philox.Skip(delta); // do the actual increasing
|
||||||
|
return new_philox;
|
||||||
}
|
}
|
||||||
|
|
||||||
PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
|
PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox,
|
||||||
int64 output_size,
|
uint64 output_size,
|
||||||
StateElementType* ptr) {
|
StateElementType* ptr) {
|
||||||
auto new_philox = philox;
|
auto new_philox = SkipPhiloxRandom(philox, output_size);
|
||||||
// Multiplier 256 is the same as in `FillPhiloxRandomTask`; do not change
|
|
||||||
// it just here.
|
|
||||||
auto delta = output_size * 256;
|
|
||||||
new_philox.Skip(delta); // do the actual increasing
|
|
||||||
WritePhiloxRandomToMem(new_philox, ptr);
|
WritePhiloxRandomToMem(new_philox, ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom(
|
||||||
|
PhiloxRandom::ResultType const& counter, uint64 output_size,
|
||||||
|
StateElementType* ptr) {
|
||||||
|
auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/);
|
||||||
|
auto new_philox = SkipPhiloxRandom(philox, output_size);
|
||||||
|
WriteCounterToMem(new_philox.counter(), reinterpret_cast<uint64*>(ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
// A per-device helper function that does the actual work for
|
// A per-device helper function that does the actual work for
|
||||||
// `UpdateVariableAndFill`.
|
// `UpdateVariableAndFill`.
|
||||||
// Reason to use functor: C++ doesn't allow function-template partial
|
// Reason to use functor: C++ doesn't allow function-template partial
|
||||||
@ -80,6 +72,8 @@ struct UpdateVariableAndFill_Philox;
|
|||||||
template <typename Device>
|
template <typename Device>
|
||||||
struct RngSkip_Philox;
|
struct RngSkip_Philox;
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
|
|
||||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||||
|
|
||||||
struct UpdateVariableAndFill_Philox_Arg {
|
struct UpdateVariableAndFill_Philox_Arg {
|
||||||
@ -93,6 +87,8 @@ struct UpdateVariableAndFill_Philox_Arg {
|
|||||||
|
|
||||||
using GPUDevice = Eigen::GpuDevice;
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
// Declares the partially GPU-specialized functor structs.
|
// Declares the partially GPU-specialized functor structs.
|
||||||
// must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
|
// must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug
|
||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
@ -104,9 +100,12 @@ struct UpdateVariableAndFill_Philox<GPUDevice, Distribution> {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct RngSkip_Philox<GPUDevice> {
|
struct RngSkip_Philox<GPUDevice> {
|
||||||
void operator()(const GPUDevice& device, int64 delta, Tensor* state_tensor);
|
void operator()(const GPUDevice& device, const StateElementType* in_data,
|
||||||
|
uint64 delta, StateElementType* out_data);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -31,6 +31,8 @@ __device__ int tensorflow_philox_thread_counter;
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
using random::PhiloxRandom;
|
using random::PhiloxRandom;
|
||||||
|
|
||||||
template <typename Distribution>
|
template <typename Distribution>
|
||||||
@ -48,7 +50,8 @@ __global__ void FillKernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
functor::FillPhiloxRandomKernel<Distribution,
|
functor::FillPhiloxRandomKernel<Distribution,
|
||||||
Distribution::kVariableSamplesPerOutput>()
|
Distribution::kVariableSamplesPerOutput>()
|
||||||
.Run(*philox, output_data, output_size, dist);
|
.Run(/*key=*/nullptr, /*counter=*/nullptr, *philox, output_data,
|
||||||
|
output_size, dist);
|
||||||
// The last thread updates the state.
|
// The last thread updates the state.
|
||||||
auto total_thread_count = gridDim.x * blockDim.x;
|
auto total_thread_count = gridDim.x * blockDim.x;
|
||||||
auto old_counter_value = atomicAdd(&tensorflow_philox_thread_counter, 1);
|
auto old_counter_value = atomicAdd(&tensorflow_philox_thread_counter, 1);
|
||||||
@ -96,16 +99,19 @@ void UpdateVariableAndFill_Philox<GPUDevice, Distribution>::operator()(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Precondition: there is only 1 block and 1 thread.
|
// Precondition: there is only 1 block and 1 thread.
|
||||||
__global__ void SkipKernel(int64 delta,
|
__global__ void SkipKernel(const StateElementType* __restrict__ in_data,
|
||||||
StateElementType* __restrict__ state_data) {
|
uint64 delta,
|
||||||
auto philox = GetPhiloxRandomFromMem(state_data);
|
StateElementType* __restrict__ out_data) {
|
||||||
UpdateMemWithPhiloxRandom(philox, delta, state_data);
|
auto counter = GetCounterFromMem(reinterpret_cast<const uint64*>(in_data));
|
||||||
|
UpdateCounterMemWithPhiloxRandom(counter, delta, out_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d, int64 delta,
|
void RngSkip_Philox<GPUDevice>::operator()(const GPUDevice& d,
|
||||||
Tensor* state_tensor) {
|
const StateElementType* in_data,
|
||||||
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), delta,
|
uint64 delta,
|
||||||
state_tensor->flat<StateElementType>().data()));
|
StateElementType* out_data) {
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(SkipKernel, 1, 1, 0, d.stream(), in_data, delta,
|
||||||
|
out_data));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Explicit instantiation of the GPU distributions functors.
|
// Explicit instantiation of the GPU distributions functors.
|
||||||
@ -154,6 +160,7 @@ template struct UpdateVariableAndFill_Philox<
|
|||||||
random::PhiloxRandom, uint64> >;
|
random::PhiloxRandom, uint64> >;
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -121,8 +121,8 @@ class StatelessRandomOp : public StatelessRandomOpBase {
|
|||||||
auto flat = output->flat<T>();
|
auto flat = output->flat<T>();
|
||||||
// Reuse the compute kernels from the stateful random ops
|
// Reuse the compute kernels from the stateful random ops
|
||||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
context, context->eigen_device<Device>(), random, flat.data(),
|
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||||
flat.size(), Distribution());
|
/*counter=*/nullptr, random, flat.data(), flat.size(), Distribution());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -158,8 +158,8 @@ class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
|||||||
auto flat = output->flat<IntType>();
|
auto flat = output->flat<IntType>();
|
||||||
// Reuse the compute kernels from the stateful random ops
|
// Reuse the compute kernels from the stateful random ops
|
||||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
context, context->eigen_device<Device>(), random, flat.data(),
|
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||||
flat.size(), dist);
|
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -178,8 +178,8 @@ class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
|
|||||||
auto flat = output->flat<IntType>();
|
auto flat = output->flat<IntType>();
|
||||||
// Reuse the compute kernels from the stateful random ops
|
// Reuse the compute kernels from the stateful random ops
|
||||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
context, context->eigen_device<Device>(), random, flat.data(),
|
context, context->eigen_device<Device>(), /*key=*/nullptr,
|
||||||
flat.size(), dist);
|
/*counter=*/nullptr, random, flat.data(), flat.size(), dist);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
330
tensorflow/core/kernels/stateless_random_ops_v2.cc
Normal file
330
tensorflow/core/kernels/stateless_random_ops_v2.cc
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/stateless_random_ops_v2.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
|
#include "tensorflow/core/kernels/random_op.h"
|
||||||
|
#include "tensorflow/core/kernels/random_ops_util.h"
|
||||||
|
#include "tensorflow/core/kernels/random_poisson_op.h"
|
||||||
|
#include "tensorflow/core/kernels/stateless_random_ops.h"
|
||||||
|
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#if EIGEN_COMP_GNUC && __cplusplus > 199711L
|
||||||
|
#define DISABLE_FLOAT_EQUALITY_WARNING \
|
||||||
|
_Pragma("GCC diagnostic push") \
|
||||||
|
_Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
|
||||||
|
#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
|
||||||
|
#else
|
||||||
|
#define DISABLE_FLOAT_EQUALITY_WARNING
|
||||||
|
#define ENABLE_FLOAT_EQUALITY_WARNING
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||||
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status GetScalar(const Tensor& tensor, int input_idx, T* result) {
|
||||||
|
auto dtype = DataTypeToEnum<T>::v();
|
||||||
|
if (tensor.dims() != 0) {
|
||||||
|
return errors::InvalidArgument("input ", std::to_string(input_idx),
|
||||||
|
" (0-based) must have shape [], not ",
|
||||||
|
tensor.shape().DebugString());
|
||||||
|
}
|
||||||
|
if (tensor.dtype() != dtype) {
|
||||||
|
return errors::InvalidArgument("dtype of input ", std::to_string(input_idx),
|
||||||
|
" (0-based) must be ", DataTypeString(dtype),
|
||||||
|
", not ", DataTypeString(tensor.dtype()));
|
||||||
|
}
|
||||||
|
*result = tensor.flat<T>()(0);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
class StatelessRandomOpBase : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit StatelessRandomOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
// Sanitize input
|
||||||
|
const Tensor& shape_t = ctx->input(0);
|
||||||
|
const Tensor& key_t = ctx->input(1);
|
||||||
|
const Tensor& counter_t = ctx->input(2);
|
||||||
|
const int alg_input_idx = 3;
|
||||||
|
const Tensor& alg_t = ctx->input(alg_input_idx);
|
||||||
|
|
||||||
|
int alg_id;
|
||||||
|
OP_REQUIRES_OK(ctx, GetScalar(alg_t, alg_input_idx, &alg_id));
|
||||||
|
Algorithm alg = Algorithm(alg_id);
|
||||||
|
|
||||||
|
TensorShape shape;
|
||||||
|
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
CheckKeyCounterShape(alg, key_t.shape(), counter_t.shape()));
|
||||||
|
|
||||||
|
// Allocate output
|
||||||
|
Tensor* output;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
|
||||||
|
if (shape.num_elements() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill in the random numbers
|
||||||
|
Fill(ctx, alg, key_t, counter_t, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The part of Compute that depends on device, type, and distribution
|
||||||
|
virtual void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||||
|
const Tensor& counter, Tensor* output) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename Distribution>
|
||||||
|
class StatelessRandomOp : public StatelessRandomOpBase {
|
||||||
|
public:
|
||||||
|
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||||
|
|
||||||
|
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||||
|
const Tensor& counter, Tensor* output) override {
|
||||||
|
typedef typename Distribution::ResultElementType T;
|
||||||
|
auto flat = output->flat<T>();
|
||||||
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
|
// Reuse the compute kernels from the stateful random ops
|
||||||
|
auto key_data = key.flat<uint64>().data();
|
||||||
|
auto counter_data = counter.flat<uint64>().data();
|
||||||
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
|
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||||
|
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(),
|
||||||
|
Distribution());
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES(ctx, false,
|
||||||
|
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename IntType>
|
||||||
|
class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
|
||||||
|
public:
|
||||||
|
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||||
|
|
||||||
|
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||||
|
const Tensor& counter, Tensor* output) override {
|
||||||
|
const Tensor& minval = ctx->input(4);
|
||||||
|
const Tensor& maxval = ctx->input(5);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
|
||||||
|
errors::InvalidArgument("minval must be 0-D, got shape ",
|
||||||
|
minval.shape().DebugString()));
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
|
||||||
|
errors::InvalidArgument("maxval must be 0-D, got shape ",
|
||||||
|
maxval.shape().DebugString()));
|
||||||
|
|
||||||
|
// Verify that minval < maxval. Note that we'll never reach this point for
|
||||||
|
// empty output. Zero impossible things are fine.
|
||||||
|
const auto lo = minval.scalar<IntType>()();
|
||||||
|
const auto hi = maxval.scalar<IntType>()();
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, lo < hi,
|
||||||
|
errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
|
||||||
|
|
||||||
|
// Build distribution
|
||||||
|
typedef random::UniformDistribution<random::PhiloxRandom, IntType>
|
||||||
|
Distribution;
|
||||||
|
Distribution dist(lo, hi);
|
||||||
|
|
||||||
|
auto flat = output->flat<IntType>();
|
||||||
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
|
// Reuse the compute kernels from the stateful random ops
|
||||||
|
auto key_data = key.flat<uint64>().data();
|
||||||
|
auto counter_data = counter.flat<uint64>().data();
|
||||||
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
|
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||||
|
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES(ctx, false,
|
||||||
|
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename IntType>
|
||||||
|
class StatelessRandomUniformFullIntOp : public StatelessRandomOpBase {
|
||||||
|
public:
|
||||||
|
using StatelessRandomOpBase::StatelessRandomOpBase;
|
||||||
|
|
||||||
|
void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key,
|
||||||
|
const Tensor& counter, Tensor* output) override {
|
||||||
|
// Build distribution
|
||||||
|
typedef random::UniformFullIntDistribution<random::PhiloxRandom, IntType>
|
||||||
|
Distribution;
|
||||||
|
Distribution dist;
|
||||||
|
|
||||||
|
auto flat = output->flat<IntType>();
|
||||||
|
if (alg == RNG_ALG_PHILOX) {
|
||||||
|
// Reuse the compute kernels from the stateful random ops
|
||||||
|
auto key_data = key.flat<uint64>().data();
|
||||||
|
auto counter_data = counter.flat<uint64>().data();
|
||||||
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
|
ctx, ctx->eigen_device<Device>(), key_data, counter_data,
|
||||||
|
random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES(ctx, false,
|
||||||
|
errors::InvalidArgument("Unsupported algorithm id: ", alg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class GetKeyCounterAlgOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit GetKeyCounterAlgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& seed_t = ctx->input(0);
|
||||||
|
OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
|
||||||
|
errors::InvalidArgument("seed must have shape [2], not ",
|
||||||
|
seed_t.shape().DebugString()));
|
||||||
|
// Allocate outputs
|
||||||
|
Tensor* key_output;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, TensorShape({RNG_KEY_SIZE}), &key_output));
|
||||||
|
Tensor* counter_output;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(1, TensorShape({RNG_MAX_COUNTER_SIZE}),
|
||||||
|
&counter_output));
|
||||||
|
Tensor* alg_output;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &alg_output));
|
||||||
|
|
||||||
|
random::PhiloxRandom::Key key;
|
||||||
|
random::PhiloxRandom::ResultType counter;
|
||||||
|
OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
|
||||||
|
WriteKeyToMem(key, key_output->flat<uint64>().data());
|
||||||
|
WriteCounterToMem(counter, counter_output->flat<uint64>().data());
|
||||||
|
alg_output->flat<int>()(0) = RNG_ALG_PHILOX;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(DEVICE, TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("StatelessRandomUniformV2") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.HostMemory("alg") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
|
||||||
|
random::PhiloxRandom, TYPE> >); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("StatelessRandomNormalV2") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.HostMemory("alg") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
|
||||||
|
random::PhiloxRandom, TYPE> >); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("StatelessTruncatedNormalV2") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.HostMemory("alg") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatelessRandomOp< \
|
||||||
|
DEVICE##Device, \
|
||||||
|
random::TruncatedNormalDistribution< \
|
||||||
|
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
|
||||||
|
|
||||||
|
#define REGISTER_FULL_INT(DEVICE, TYPE) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("StatelessRandomUniformFullIntV2") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.HostMemory("alg") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatelessRandomUniformFullIntOp<DEVICE##Device, TYPE>)
|
||||||
|
|
||||||
|
#define REGISTER_INT(DEVICE, TYPE) \
|
||||||
|
REGISTER_FULL_INT(DEVICE, TYPE); \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformIntV2") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("shape") \
|
||||||
|
.HostMemory("alg") \
|
||||||
|
.HostMemory("minval") \
|
||||||
|
.HostMemory("maxval") \
|
||||||
|
.TypeConstraint<TYPE>("dtype"), \
|
||||||
|
StatelessRandomUniformIntOp<DEVICE##Device, TYPE>)
|
||||||
|
|
||||||
|
#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
|
||||||
|
#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
|
||||||
|
#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
|
||||||
|
#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
|
||||||
|
#define REGISTER_FULL_INT_CPU(TYPE) REGISTER_FULL_INT(CPU, TYPE)
|
||||||
|
#define REGISTER_FULL_INT_GPU(TYPE) REGISTER_FULL_INT(GPU, TYPE)
|
||||||
|
|
||||||
|
TF_CALL_half(REGISTER_CPU);
|
||||||
|
TF_CALL_bfloat16(REGISTER_CPU);
|
||||||
|
TF_CALL_float(REGISTER_CPU);
|
||||||
|
TF_CALL_double(REGISTER_CPU);
|
||||||
|
TF_CALL_int32(REGISTER_INT_CPU);
|
||||||
|
TF_CALL_int64(REGISTER_INT_CPU);
|
||||||
|
TF_CALL_uint32(REGISTER_FULL_INT_CPU);
|
||||||
|
TF_CALL_uint64(REGISTER_FULL_INT_CPU);
|
||||||
|
|
||||||
|
#define REGISTER_GET_KCA(DEVICE) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("StatelessRandomGetKeyCounterAlg") \
|
||||||
|
.Device(DEVICE_##DEVICE) \
|
||||||
|
.HostMemory("seed") \
|
||||||
|
.HostMemory("key") \
|
||||||
|
.HostMemory("counter") \
|
||||||
|
.HostMemory("alg"), \
|
||||||
|
GetKeyCounterAlgOp)
|
||||||
|
|
||||||
|
REGISTER_GET_KCA(CPU);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
TF_CALL_half(REGISTER_GPU);
|
||||||
|
TF_CALL_float(REGISTER_GPU);
|
||||||
|
TF_CALL_double(REGISTER_GPU);
|
||||||
|
TF_CALL_int32(REGISTER_INT_GPU);
|
||||||
|
TF_CALL_int64(REGISTER_INT_GPU);
|
||||||
|
TF_CALL_uint32(REGISTER_FULL_INT_GPU);
|
||||||
|
TF_CALL_uint64(REGISTER_FULL_INT_GPU);
|
||||||
|
|
||||||
|
REGISTER_GET_KCA(GPU);
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
#undef REGISTER_INT
|
||||||
|
#undef REGISTER_CPU
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
#undef REGISTER_INT_CPU
|
||||||
|
#undef REGISTER_INT_GPU
|
||||||
|
#undef REGISTER_FULL_INT_CPU
|
||||||
|
#undef REGISTER_FULL_INT_GPU
|
||||||
|
|
||||||
|
#undef REGISTER_GET_KCA
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
46
tensorflow/core/kernels/stateless_random_ops_v2.h
Normal file
46
tensorflow/core/kernels/stateless_random_ops_v2.h
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
inline Status CheckKeyCounterShape(Algorithm const& alg,
|
||||||
|
TensorShape const& key_shape,
|
||||||
|
TensorShape const& counter_shape) {
|
||||||
|
if (!(key_shape.dims() == 1 && key_shape.dim_size(0) == RNG_KEY_SIZE)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"key must have shape [", RNG_KEY_SIZE, "], not ",
|
||||||
|
key_shape.DebugString(),
|
||||||
|
". (Note that batched keys are not supported yet.)");
|
||||||
|
}
|
||||||
|
auto counter_size = GetCounterSize(alg);
|
||||||
|
if (!(counter_shape.dims() == 1 &&
|
||||||
|
counter_shape.dim_size(0) >= counter_size)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"counter must be a vector with length at least ", counter_size,
|
||||||
|
"; got shape: ", counter_shape.DebugString(),
|
||||||
|
". (Note that batched counters are not supported yet.)");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -90,6 +91,19 @@ REGISTER_OP("RngSkip")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("RngReadAndSkip")
|
||||||
|
.Input("resource: resource")
|
||||||
|
.Input("alg: int32")
|
||||||
|
.Input("delta: uint64")
|
||||||
|
.Output("value: int64")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||||
|
c->set_output(0, c->MakeShape({RNG_MAX_COUNTER_SIZE + RNG_KEY_SIZE}));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("NonDeterministicInts")
|
REGISTER_OP("NonDeterministicInts")
|
||||||
.Input("shape: shape_dtype")
|
.Input("shape: shape_dtype")
|
||||||
.SetIsStateful()
|
.SetIsStateful()
|
||||||
|
119
tensorflow/core/ops/stateless_random_ops_v2.cc
Normal file
119
tensorflow/core/ops/stateless_random_ops_v2.cc
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/rng_alg.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using shape_inference::DimensionHandle;
|
||||||
|
using shape_inference::InferenceContext;
|
||||||
|
using shape_inference::ShapeHandle;
|
||||||
|
|
||||||
|
static Status StatelessShapeV2(InferenceContext* c) {
|
||||||
|
// Check key and counter shapes
|
||||||
|
ShapeHandle key;
|
||||||
|
ShapeHandle counter;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &key));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &counter));
|
||||||
|
shape_inference::ShapeHandle unused_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
|
||||||
|
DimensionHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(key, 0), RNG_KEY_SIZE, &unused));
|
||||||
|
|
||||||
|
// Set output shape
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REGISTER_STATELESS_OP(name) \
|
||||||
|
REGISTER_OP(name) \
|
||||||
|
.Input("shape: Tshape") \
|
||||||
|
.Input("key: uint64") \
|
||||||
|
.Input("counter: uint64") \
|
||||||
|
.Input("alg: int32") \
|
||||||
|
.Output("output: dtype") \
|
||||||
|
.Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
|
||||||
|
.Attr("Tshape: {int32, int64} = DT_INT32") \
|
||||||
|
.SetShapeFn(StatelessShapeV2)
|
||||||
|
|
||||||
|
REGISTER_STATELESS_OP("StatelessRandomUniformV2");
|
||||||
|
REGISTER_STATELESS_OP("StatelessRandomNormalV2");
|
||||||
|
REGISTER_STATELESS_OP("StatelessTruncatedNormalV2");
|
||||||
|
|
||||||
|
#undef REGISTER_STATELESS_OP
|
||||||
|
|
||||||
|
REGISTER_OP("StatelessRandomUniformIntV2")
|
||||||
|
.Input("shape: Tshape")
|
||||||
|
.Input("key: uint64")
|
||||||
|
.Input("counter: uint64")
|
||||||
|
.Input("alg: int32")
|
||||||
|
.Input("minval: dtype")
|
||||||
|
.Input("maxval: dtype")
|
||||||
|
.Output("output: dtype")
|
||||||
|
.Attr("dtype: {int32, int64, uint32, uint64}")
|
||||||
|
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
Status s = c->WithRank(c->input(4), 0, &unused);
|
||||||
|
if (!s.ok()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"minval must be a scalar; got a tensor of shape ",
|
||||||
|
c->DebugString(c->input(4)));
|
||||||
|
}
|
||||||
|
s = c->WithRank(c->input(5), 0, &unused);
|
||||||
|
if (!s.ok()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"maxval must be a scalar; got a tensor of shape ",
|
||||||
|
c->DebugString(c->input(5)));
|
||||||
|
}
|
||||||
|
return StatelessShapeV2(c);
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("StatelessRandomUniformFullIntV2")
|
||||||
|
.Input("shape: Tshape")
|
||||||
|
.Input("key: uint64")
|
||||||
|
.Input("counter: uint64")
|
||||||
|
.Input("alg: int32")
|
||||||
|
.Output("output: dtype")
|
||||||
|
.Attr("dtype: {int32, int64, uint32, uint64} = DT_UINT64")
|
||||||
|
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||||
|
.SetShapeFn(StatelessShapeV2);
|
||||||
|
|
||||||
|
REGISTER_OP("StatelessRandomGetKeyCounterAlg")
|
||||||
|
.Input("seed: Tseed")
|
||||||
|
.Output("key: uint64")
|
||||||
|
.Output("counter: uint64")
|
||||||
|
.Output("alg: int32")
|
||||||
|
.Attr("Tseed: {int32, int64} = DT_INT64")
|
||||||
|
.SetIsStateful() // because outputs depend on device
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
// Check seed shape
|
||||||
|
ShapeHandle seed;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &seed));
|
||||||
|
DimensionHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
|
||||||
|
|
||||||
|
// Set output shapes
|
||||||
|
c->set_output(0, c->MakeShape({RNG_KEY_SIZE}));
|
||||||
|
c->set_output(1, c->MakeShape({RNG_MAX_COUNTER_SIZE}));
|
||||||
|
c->set_output(2, c->MakeShape({}));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -3183,6 +3183,10 @@ tf_gen_op_wrapper_private_py(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_private_py(
|
||||||
|
name = "stateless_random_ops_v2_gen",
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrapper_private_py(
|
tf_gen_op_wrapper_private_py(
|
||||||
name = "list_ops_gen",
|
name = "list_ops_gen",
|
||||||
)
|
)
|
||||||
@ -4440,6 +4444,7 @@ py_library(
|
|||||||
":framework_ops",
|
":framework_ops",
|
||||||
":math_ops",
|
":math_ops",
|
||||||
":stateful_random_ops_gen",
|
":stateful_random_ops_gen",
|
||||||
|
":stateless_random_ops_v2_gen",
|
||||||
":variables",
|
":variables",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
@ -4474,6 +4479,7 @@ py_library(
|
|||||||
":math_ops",
|
":math_ops",
|
||||||
":random_ops",
|
":random_ops",
|
||||||
":stateless_random_ops_gen",
|
":stateless_random_ops_gen",
|
||||||
|
":stateless_random_ops_v2_gen",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) {
|
|||||||
|
|
||||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||||
const tensorflow::string &op_name) {
|
const tensorflow::string &op_name) {
|
||||||
static std::array<OpIndexInfo, 352> a = {{
|
static std::array<OpIndexInfo, 357> a = {{
|
||||||
{"Acosh"},
|
{"Acosh"},
|
||||||
{"AllToAll", 1, {0}},
|
{"AllToAll", 1, {0}},
|
||||||
{"ApproximateEqual"},
|
{"ApproximateEqual"},
|
||||||
@ -332,11 +332,16 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
|||||||
{"StatelessRandomBinomial"},
|
{"StatelessRandomBinomial"},
|
||||||
{"StatelessRandomGammaV2", 1, {1}},
|
{"StatelessRandomGammaV2", 1, {1}},
|
||||||
{"StatelessRandomNormal"},
|
{"StatelessRandomNormal"},
|
||||||
|
{"StatelessRandomNormalV2"},
|
||||||
{"StatelessRandomPoisson"},
|
{"StatelessRandomPoisson"},
|
||||||
{"StatelessRandomUniform"},
|
{"StatelessRandomUniform"},
|
||||||
{"StatelessRandomUniformFullInt"},
|
{"StatelessRandomUniformFullInt"},
|
||||||
|
{"StatelessRandomUniformFullIntV2"},
|
||||||
{"StatelessRandomUniformInt"},
|
{"StatelessRandomUniformInt"},
|
||||||
|
{"StatelessRandomUniformIntV2"},
|
||||||
|
{"StatelessRandomUniformV2"},
|
||||||
{"StatelessTruncatedNormal"},
|
{"StatelessTruncatedNormal"},
|
||||||
|
{"StatelessTruncatedNormalV2"},
|
||||||
{"StopGradient"},
|
{"StopGradient"},
|
||||||
{"StridedSliceGrad", 2, {0, 4}},
|
{"StridedSliceGrad", 2, {0, 4}},
|
||||||
{"StringSplit"},
|
{"StringSplit"},
|
||||||
@ -415,7 +420,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
|||||||
|
|
||||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||||
const tensorflow::string &op_name) {
|
const tensorflow::string &op_name) {
|
||||||
static std::array<OpIndexInfo, 468> a = {{
|
static std::array<OpIndexInfo, 473> a = {{
|
||||||
{"Abs"},
|
{"Abs"},
|
||||||
{"AccumulateNV2"},
|
{"AccumulateNV2"},
|
||||||
{"Acos"},
|
{"Acos"},
|
||||||
@ -793,11 +798,16 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
|||||||
{"StatelessMultinomial"},
|
{"StatelessMultinomial"},
|
||||||
{"StatelessRandomBinomial"},
|
{"StatelessRandomBinomial"},
|
||||||
{"StatelessRandomNormal"},
|
{"StatelessRandomNormal"},
|
||||||
|
{"StatelessRandomNormalV2"},
|
||||||
{"StatelessRandomPoisson"},
|
{"StatelessRandomPoisson"},
|
||||||
{"StatelessRandomUniform"},
|
{"StatelessRandomUniform"},
|
||||||
{"StatelessRandomUniformFullInt"},
|
{"StatelessRandomUniformFullInt"},
|
||||||
|
{"StatelessRandomUniformFullIntV2"},
|
||||||
{"StatelessRandomUniformInt"},
|
{"StatelessRandomUniformInt"},
|
||||||
|
{"StatelessRandomUniformIntV2"},
|
||||||
|
{"StatelessRandomUniformV2"},
|
||||||
{"StatelessTruncatedNormal"},
|
{"StatelessTruncatedNormal"},
|
||||||
|
{"StatelessTruncatedNormalV2"},
|
||||||
{"StopGradient"},
|
{"StopGradient"},
|
||||||
{"StridedSlice"},
|
{"StridedSlice"},
|
||||||
{"StridedSliceGrad"},
|
{"StridedSliceGrad"},
|
||||||
|
@ -30,6 +30,7 @@ from tensorflow.python.keras.engine import sequential
|
|||||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||||
from tensorflow.python.ops import image_ops_impl as image_ops
|
from tensorflow.python.ops import image_ops_impl as image_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
@ -1237,11 +1238,14 @@ class RandomHeightTest(keras_parameterized.TestCase):
|
|||||||
mock_factor = 0
|
mock_factor = 0
|
||||||
with test.mock.patch.object(
|
with test.mock.patch.object(
|
||||||
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
||||||
with testing_utils.use_gpu():
|
with test.mock.patch.object(
|
||||||
img = np.random.random((12, 5, 8, 3))
|
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
|
||||||
layer = image_preprocessing.RandomHeight(.4)
|
return_value=mock_factor):
|
||||||
img_out = layer(img, training=True)
|
with testing_utils.use_gpu():
|
||||||
self.assertEqual(img_out.shape[1], 3)
|
img = np.random.random((12, 5, 8, 3))
|
||||||
|
layer = image_preprocessing.RandomHeight(.4)
|
||||||
|
img_out = layer(img, training=True)
|
||||||
|
self.assertEqual(img_out.shape[1], 3)
|
||||||
|
|
||||||
def test_random_height_longer_numeric(self):
|
def test_random_height_longer_numeric(self):
|
||||||
for dtype in (np.int64, np.float32):
|
for dtype in (np.int64, np.float32):
|
||||||
@ -1328,11 +1332,14 @@ class RandomWidthTest(keras_parameterized.TestCase):
|
|||||||
mock_factor = 0
|
mock_factor = 0
|
||||||
with test.mock.patch.object(
|
with test.mock.patch.object(
|
||||||
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
gen_stateful_random_ops, 'stateful_uniform', return_value=mock_factor):
|
||||||
with testing_utils.use_gpu():
|
with test.mock.patch.object(
|
||||||
img = np.random.random((12, 8, 5, 3))
|
gen_stateless_random_ops_v2, 'stateless_random_uniform_v2',
|
||||||
layer = image_preprocessing.RandomWidth(.4)
|
return_value=mock_factor):
|
||||||
img_out = layer(img, training=True)
|
with testing_utils.use_gpu():
|
||||||
self.assertEqual(img_out.shape[2], 3)
|
img = np.random.random((12, 8, 5, 3))
|
||||||
|
layer = image_preprocessing.RandomWidth(.4)
|
||||||
|
img_out = layer(img, training=True)
|
||||||
|
self.assertEqual(img_out.shape[2], 3)
|
||||||
|
|
||||||
def test_random_width_longer_numeric(self):
|
def test_random_width_longer_numeric(self):
|
||||||
for dtype in (np.int64, np.float32):
|
for dtype in (np.int64, np.float32):
|
||||||
|
@ -22,10 +22,13 @@ import functools
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -35,6 +38,31 @@ from tensorflow.python.ops import stateless_random_ops as stateless
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
# Note that in theory each test will reset the eager context and may choose to
|
||||||
|
# hide some devices, so we shouldn't cache this transient info. Tests in this
|
||||||
|
# file don't make those config changes, so caching is fine. It provides a good
|
||||||
|
# speed-up.
|
||||||
|
_cached_device = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
global _cached_device
|
||||||
|
if _cached_device is not None:
|
||||||
|
return _cached_device
|
||||||
|
# Precedence from high to low
|
||||||
|
for device_type in ('XLA_GPU', 'GPU', 'XLA_CPU', 'CPU'):
|
||||||
|
devices = config.list_logical_devices(device_type)
|
||||||
|
if devices:
|
||||||
|
_cached_device = devices[0]
|
||||||
|
return _cached_device
|
||||||
|
raise ValueError('Cannot find any suitable device. Available devices: %s' %
|
||||||
|
config.list_logical_devices())
|
||||||
|
|
||||||
|
|
||||||
|
BEFORE_EXPIRE = (2020, 10, 24)
|
||||||
|
AFTER_EXPIRE = (2020, 10, 26)
|
||||||
|
|
||||||
|
|
||||||
def invert_philox(key, value):
|
def invert_philox(key, value):
|
||||||
"""Invert the Philox bijection."""
|
"""Invert the Philox bijection."""
|
||||||
key = np.array(key, dtype=np.uint32)
|
key = np.array(key, dtype=np.uint32)
|
||||||
@ -59,47 +87,71 @@ SEED_TYPES = [dtypes.int32, dtypes.int64]
|
|||||||
def float_cases(shape_dtypes=(None,)):
|
def float_cases(shape_dtypes=(None,)):
|
||||||
cases = (
|
cases = (
|
||||||
# Uniform distribution, with and without range
|
# Uniform distribution, with and without range
|
||||||
(stateless.stateless_random_uniform, random_ops.random_uniform, {}),
|
('uniform', stateless.stateless_random_uniform, random_ops.random_uniform,
|
||||||
(stateless.stateless_random_uniform, random_ops.random_uniform,
|
{}),
|
||||||
dict(minval=2.2, maxval=7.1)),
|
('uniform2', stateless.stateless_random_uniform,
|
||||||
|
random_ops.random_uniform, dict(minval=2.2, maxval=7.1)),
|
||||||
# Normal distribution, with and without mean+stddev
|
# Normal distribution, with and without mean+stddev
|
||||||
(stateless.stateless_random_normal, random_ops.random_normal, {}),
|
('normal', stateless.stateless_random_normal, random_ops.random_normal,
|
||||||
(stateless.stateless_random_normal, random_ops.random_normal,
|
{}),
|
||||||
|
('normal2', stateless.stateless_random_normal, random_ops.random_normal,
|
||||||
dict(mean=2, stddev=3)),
|
dict(mean=2, stddev=3)),
|
||||||
# Truncated normal distribution, with and without mean+stddev
|
# Truncated normal distribution, with and without mean+stddev
|
||||||
(stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
|
('trnorm', stateless.stateless_truncated_normal,
|
||||||
(stateless.stateless_truncated_normal, random_ops.truncated_normal,
|
random_ops.truncated_normal, {}),
|
||||||
dict(mean=3, stddev=4)),
|
('trnorm2', stateless.stateless_truncated_normal,
|
||||||
|
random_ops.truncated_normal, dict(mean=3, stddev=4)),
|
||||||
)
|
)
|
||||||
# Explicitly passing in params because capturing cell variable from loop is
|
# Explicitly passing in params because capturing cell variable from loop is
|
||||||
# problematic in Python
|
# problematic in Python
|
||||||
def wrap(op, dtype, shape, shape_dtype, kwds, seed):
|
def wrap(op, dtype, shape, shape_dtype, kwds, seed):
|
||||||
|
device_type = get_device().device_type
|
||||||
|
# Some dtypes are not supported on some devices
|
||||||
|
if (dtype == dtypes.float16 and device_type in ('XLA_GPU', 'XLA_CPU') or
|
||||||
|
dtype == dtypes.bfloat16 and device_type == 'GPU'):
|
||||||
|
dtype = dtypes.float32
|
||||||
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
||||||
if shape_dtype is not None else shape)
|
if shape_dtype is not None else shape)
|
||||||
return op(seed=seed, shape=shape_, dtype=dtype, **kwds)
|
return op(seed=seed, shape=shape_, dtype=dtype, **kwds)
|
||||||
for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
|
|
||||||
|
def _name(a):
|
||||||
|
if hasattr(a, 'name'):
|
||||||
|
return a.name
|
||||||
|
else:
|
||||||
|
return a
|
||||||
|
|
||||||
|
for dtype in dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64:
|
||||||
for shape_dtype in shape_dtypes:
|
for shape_dtype in shape_dtypes:
|
||||||
for shape in (), (3,), (2, 5):
|
for shape in (), (3,), (2, 5):
|
||||||
for stateless_op, stateful_op, kwds in cases:
|
for name, stateless_op, stateful_op, kwds in cases:
|
||||||
yield (functools.partial(wrap, stateless_op, dtype, shape,
|
yield (('%s_%s_%s_%s' %
|
||||||
|
(name, _name(dtype), shape, _name(shape_dtype))).replace(
|
||||||
|
' ', ''),
|
||||||
|
functools.partial(wrap, stateless_op, dtype, shape,
|
||||||
shape_dtype, kwds),
|
shape_dtype, kwds),
|
||||||
functools.partial(wrap, stateful_op, dtype, shape,
|
functools.partial(wrap, stateful_op, dtype, shape, shape_dtype,
|
||||||
shape_dtype, kwds))
|
kwds))
|
||||||
|
|
||||||
|
|
||||||
def int_cases(shape_dtypes=(None,)):
|
def int_cases(shape_dtypes=(None,), minval_maxval=None):
|
||||||
def wrap(op, shape, shape_dtype, dtype, seed):
|
|
||||||
|
def wrap(op, minval, maxval, shape, shape_dtype, dtype, seed):
|
||||||
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
shape_ = (constant_op.constant(shape, dtype=shape_dtype)
|
||||||
if shape_dtype is not None else shape)
|
if shape_dtype is not None else shape)
|
||||||
return op(seed=seed, shape=shape_, minval=2, maxval=11111,
|
return op(
|
||||||
dtype=dtype)
|
seed=seed, shape=shape_, minval=minval, maxval=maxval, dtype=dtype)
|
||||||
for shape_dtype in shape_dtypes:
|
|
||||||
for shape in (), (3,), (2, 5):
|
if minval_maxval is None:
|
||||||
for dtype in dtypes.int32, dtypes.int64:
|
minval_maxval = ((2, 11111),)
|
||||||
yield (functools.partial(wrap, stateless.stateless_random_uniform,
|
for minval, maxval in minval_maxval:
|
||||||
shape, shape_dtype, dtype),
|
for shape_dtype in shape_dtypes:
|
||||||
functools.partial(wrap, random_ops.random_uniform,
|
for shape in (), (3,), (2, 5):
|
||||||
shape, shape_dtype, dtype))
|
for dtype in dtypes.int32, dtypes.int64:
|
||||||
|
yield ('uniform_%s_%s' % (minval, maxval),
|
||||||
|
functools.partial(wrap, stateless.stateless_random_uniform,
|
||||||
|
minval, maxval, shape, shape_dtype, dtype),
|
||||||
|
functools.partial(wrap, random_ops.random_uniform, minval,
|
||||||
|
maxval, shape, shape_dtype, dtype))
|
||||||
|
|
||||||
|
|
||||||
def multinomial_cases():
|
def multinomial_cases():
|
||||||
@ -112,7 +164,8 @@ def multinomial_cases():
|
|||||||
for output_dtype in dtypes.int32, dtypes.int64:
|
for output_dtype in dtypes.int32, dtypes.int64:
|
||||||
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
|
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
|
||||||
[0.25, 0.75]]):
|
[0.25, 0.75]]):
|
||||||
yield (functools.partial(wrap, stateless.stateless_multinomial, logits,
|
yield ('multinomial',
|
||||||
|
functools.partial(wrap, stateless.stateless_multinomial, logits,
|
||||||
logits_dtype, output_dtype),
|
logits_dtype, output_dtype),
|
||||||
functools.partial(wrap, random_ops.multinomial, logits,
|
functools.partial(wrap, random_ops.multinomial, logits,
|
||||||
logits_dtype, output_dtype))
|
logits_dtype, output_dtype))
|
||||||
@ -124,10 +177,11 @@ def gamma_cases():
|
|||||||
alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
|
alpha=constant_op.constant(alpha, dtype=dtype), dtype=dtype)
|
||||||
for dtype in np.float16, np.float32, np.float64:
|
for dtype in np.float16, np.float32, np.float64:
|
||||||
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
|
for alpha in ([[.5, 1., 2.]], [[0.5, 0.5], [0.8, 0.2], [0.25, 0.75]]):
|
||||||
yield (functools.partial(wrap, stateless.stateless_random_gamma, alpha,
|
yield ('gamma',
|
||||||
|
functools.partial(wrap, stateless.stateless_random_gamma, alpha,
|
||||||
dtype, (10,) + tuple(np.shape(alpha))),
|
dtype, (10,) + tuple(np.shape(alpha))),
|
||||||
functools.partial(wrap, random_ops.random_gamma, alpha,
|
functools.partial(wrap, random_ops.random_gamma, alpha, dtype,
|
||||||
dtype, (10,)))
|
(10,)))
|
||||||
|
|
||||||
|
|
||||||
def poisson_cases():
|
def poisson_cases():
|
||||||
@ -138,7 +192,8 @@ def poisson_cases():
|
|||||||
for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
for lam_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||||
for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
for out_dtype in np.float16, np.float32, np.float64, np.int32, np.int64:
|
||||||
for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
|
for lam in ([[5.5, 1., 2.]], [[7.5, 10.5], [3.8, 8.2], [1.25, 9.75]]):
|
||||||
yield (functools.partial(wrap, stateless.stateless_random_poisson, lam,
|
yield ('poisson',
|
||||||
|
functools.partial(wrap, stateless.stateless_random_poisson, lam,
|
||||||
lam_dtype, out_dtype,
|
lam_dtype, out_dtype,
|
||||||
(10,) + tuple(np.shape(lam))),
|
(10,) + tuple(np.shape(lam))),
|
||||||
functools.partial(wrap, random_ops.random_poisson, lam,
|
functools.partial(wrap, random_ops.random_poisson, lam,
|
||||||
@ -153,22 +208,28 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
key = 0x3ec8f720, 0x02461e29
|
key = 0x3ec8f720, 0x02461e29
|
||||||
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
|
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
|
||||||
preseed = preseed[::2] | preseed[1::2] << 32
|
preseed = preseed[::2] | preseed[1::2] << 32
|
||||||
random_seed.set_random_seed(seed[0])
|
with ops.device(get_device().name):
|
||||||
with test_util.use_gpu():
|
_, stateless_op, stateful_op = case
|
||||||
stateless_op, stateful_op = case
|
random_seed.set_random_seed(seed[0])
|
||||||
if context.executing_eagerly():
|
|
||||||
# Call set_random_seed in order to clear kernel cache, to prevent
|
|
||||||
# kernel reusing for the stateful op
|
|
||||||
random_seed.set_random_seed(seed[0])
|
|
||||||
stateful = stateful_op(seed=seed[1])
|
stateful = stateful_op(seed=seed[1])
|
||||||
pure = stateless_op(seed=preseed)
|
pure = stateless_op(seed=preseed)
|
||||||
self.assertAllEqual(stateful, pure)
|
self.assertAllEqual(stateful, pure)
|
||||||
|
|
||||||
|
def _test_old_and_new_stateless_match(self, case, seed):
|
||||||
|
"""Tests that the new stateless ops match the old stateless ones."""
|
||||||
|
with ops.device(get_device().name):
|
||||||
|
_, stateless_op, _ = case
|
||||||
|
with compat.forward_compatibility_horizon(*BEFORE_EXPIRE):
|
||||||
|
old = stateless_op(seed=seed)
|
||||||
|
with compat.forward_compatibility_horizon(*AFTER_EXPIRE):
|
||||||
|
new = stateless_op(seed=seed)
|
||||||
|
self.assertAllClose(old, new)
|
||||||
|
|
||||||
def _test_determinism(self, case, seed_type):
|
def _test_determinism(self, case, seed_type):
|
||||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||||
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
|
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
|
||||||
with self.test_session(use_gpu=True), test_util.use_gpu():
|
with self.test_session(use_gpu=True), ops.device(get_device().name):
|
||||||
stateless_op, _ = case
|
_, stateless_op, _ = case
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
values = [
|
values = [
|
||||||
(seed, stateless_op(seed=constant_op.constant(seed, seed_type)))
|
(seed, stateless_op(seed=constant_op.constant(seed, seed_type)))
|
||||||
@ -183,88 +244,172 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
]
|
]
|
||||||
for s0, v0 in values:
|
for s0, v0 in values:
|
||||||
for s1, v1 in values:
|
for s1, v1 in values:
|
||||||
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
if dtypes.as_dtype(v0.dtype) != dtypes.bfloat16:
|
||||||
|
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||||
|
elif s0 == s1:
|
||||||
|
# Skip the s0 != s1 case because v0 and v1 can be either equal or
|
||||||
|
# unequal in that case due to bfloat16's low precision
|
||||||
|
self.assertAllEqual(v0, v1)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||||
for seed_id, seed in enumerate(SEEDS)
|
for seed_id, seed in enumerate(SEEDS)
|
||||||
for case_id, case in enumerate(float_cases()))
|
for case_id, case in enumerate(float_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testMatchFloat(self, case, seed):
|
def testMatchFloat(self, case, seed):
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
|
||||||
|
'seeds needed by this test.')
|
||||||
self._test_match(case, seed)
|
self._test_match(case, seed)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||||
for seed_id, seed in enumerate(SEEDS)
|
for seed_id, seed in enumerate(SEEDS)
|
||||||
for case_id, case in enumerate(int_cases()))
|
for case_id, case in enumerate(int_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testMatchInt(self, case, seed):
|
def testMatchInt(self, case, seed):
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Skip on XLA because XLA kernels do not support int64 '
|
||||||
|
'seeds needed by this test.')
|
||||||
self._test_match(case, seed)
|
self._test_match(case, seed)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||||
for seed_id, seed in enumerate(SEEDS)
|
for seed_id, seed in enumerate(SEEDS)
|
||||||
for case_id, case in enumerate(multinomial_cases()))
|
for case_id, case in enumerate(multinomial_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testMatchMultinomial(self, case, seed):
|
def testMatchMultinomial(self, case, seed):
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking XLA kernel')
|
||||||
self._test_match(case, seed)
|
self._test_match(case, seed)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||||
for seed_id, seed in enumerate(SEEDS)
|
for seed_id, seed in enumerate(SEEDS)
|
||||||
for case_id, case in enumerate(gamma_cases()))
|
for case_id, case in enumerate(gamma_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testMatchGamma(self, case, seed):
|
def testMatchGamma(self, case, seed):
|
||||||
|
if get_device().device_type == 'GPU':
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking GPU kernel')
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking XLA kernel')
|
||||||
self._test_match(case, seed)
|
self._test_match(case, seed)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||||
for seed_id, seed in enumerate(SEEDS)
|
for seed_id, seed in enumerate(SEEDS)
|
||||||
for case_id, case in enumerate(poisson_cases()))
|
for case_id, case in enumerate(poisson_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testMatchPoisson(self, case, seed):
|
def testMatchPoisson(self, case, seed):
|
||||||
|
if get_device().device_type == 'GPU':
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking GPU kernel')
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking XLA kernel')
|
||||||
self._test_match(case, seed)
|
self._test_match(case, seed)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
for seed_id, seed in enumerate(SEEDS)
|
||||||
for case_id, case in enumerate(float_cases(
|
for case_id, case in enumerate(float_cases()))
|
||||||
shape_dtypes=(dtypes.int32, dtypes.int64))))
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
|
def testOldAndNewStatelessMatchFloat(self, case, seed):
|
||||||
|
self._test_old_and_new_stateless_match(case, seed)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('_%s_%s_%s' % (case[0], case_id, seed_id), case, seed) # pylint: disable=g-complex-comprehension
|
||||||
|
for seed_id, seed in enumerate(SEEDS)
|
||||||
|
for case_id, case in enumerate(
|
||||||
|
int_cases(minval_maxval=((2, 11111), (None, None)))))
|
||||||
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
|
def testOldAndNewStatelessMatchInt(self, case, seed):
|
||||||
|
self._test_old_and_new_stateless_match(case, seed)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||||
|
for seed_type in SEED_TYPES
|
||||||
|
for case_id, case in enumerate(
|
||||||
|
float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testDeterminismFloat(self, case, seed_type):
|
def testDeterminismFloat(self, case, seed_type):
|
||||||
|
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
|
||||||
|
'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest(
|
||||||
|
'Skip on XLA because XLA kernels do not support int64 seeds.')
|
||||||
self._test_determinism(case, seed_type)
|
self._test_determinism(case, seed_type)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
for seed_type in SEED_TYPES
|
||||||
for case_id, case in enumerate(int_cases(
|
for case_id, case in enumerate(
|
||||||
shape_dtypes=(dtypes.int32, dtypes.int64))))
|
int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testDeterminismInt(self, case, seed_type):
|
def testDeterminismInt(self, case, seed_type):
|
||||||
|
if seed_type == dtypes.int64 and get_device().device_type in ('XLA_GPU',
|
||||||
|
'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest(
|
||||||
|
'Skip on XLA because XLA kernels do not support int64 seeds.')
|
||||||
self._test_determinism(case, seed_type)
|
self._test_determinism(case, seed_type)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
for seed_type in SEED_TYPES
|
||||||
for case_id, case in enumerate(multinomial_cases()))
|
for case_id, case in enumerate(multinomial_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testDeterminismMultinomial(self, case, seed_type):
|
def testDeterminismMultinomial(self, case, seed_type):
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking XLA kernel')
|
||||||
self._test_determinism(case, seed_type)
|
self._test_determinism(case, seed_type)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
for seed_type in SEED_TYPES
|
||||||
for case_id, case in enumerate(gamma_cases()))
|
for case_id, case in enumerate(gamma_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testDeterminismGamma(self, case, seed_type):
|
def testDeterminismGamma(self, case, seed_type):
|
||||||
|
if get_device().device_type == 'GPU':
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking GPU kernel')
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking XLA kernel')
|
||||||
self._test_determinism(case, seed_type)
|
self._test_determinism(case, seed_type)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('_%s_%s' % (case_id, type_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
('_%s_%s_%s' % (case[0], seed_type.name, case_id), case, seed_type) # pylint: disable=g-complex-comprehension
|
||||||
for type_id, seed_type in enumerate(SEED_TYPES)
|
for seed_type in SEED_TYPES
|
||||||
for case_id, case in enumerate(poisson_cases()))
|
for case_id, case in enumerate(poisson_cases()))
|
||||||
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
@test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396')
|
||||||
def testDeterminismPoisson(self, case, seed_type):
|
def testDeterminismPoisson(self, case, seed_type):
|
||||||
|
if get_device().device_type == 'GPU':
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking GPU kernel')
|
||||||
|
if get_device().device_type in ('XLA_GPU', 'XLA_CPU'):
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernels.
|
||||||
|
self.skipTest('Lacking XLA kernel')
|
||||||
self._test_determinism(case, seed_type)
|
self._test_determinism(case, seed_type)
|
||||||
|
|
||||||
def assertDTypeEqual(self, a, b):
|
def assertDTypeEqual(self, a, b):
|
||||||
@ -327,4 +472,6 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
|
context.context().enable_xla_devices()
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -23,6 +23,7 @@ import enum # pylint: disable=g-bad-import-order
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import composite_tensor
|
from tensorflow.python.framework import composite_tensor
|
||||||
@ -32,12 +33,14 @@ from tensorflow.python.framework import tensor_spec
|
|||||||
from tensorflow.python.framework import type_spec
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# A seed for random ops (stateful and stateless) will always be 1024
|
# A seed for random ops (stateful and stateless) will always be 1024
|
||||||
# bits, all of which will be sent to the C++ code. The actual C++
|
# bits, all of which will be sent to the C++ code. The actual C++
|
||||||
# implementation of some algorithms may only use a lower part of the bits.
|
# implementation of some algorithms may only use a lower part of the bits.
|
||||||
@ -136,6 +139,15 @@ def _make_1d_state(state_size, seed):
|
|||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
def _get_counter_size(alg):
|
||||||
|
if alg == RNG_ALG_PHILOX:
|
||||||
|
return 2
|
||||||
|
elif alg == RNG_ALG_THREEFRY:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported algorithm id: %s" % alg)
|
||||||
|
|
||||||
|
|
||||||
def _get_state_size(alg):
|
def _get_state_size(alg):
|
||||||
if alg == RNG_ALG_PHILOX:
|
if alg == RNG_ALG_PHILOX:
|
||||||
return PHILOX_STATE_SIZE
|
return PHILOX_STATE_SIZE
|
||||||
@ -560,6 +572,10 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
return self._alg
|
return self._alg
|
||||||
|
|
||||||
def _standard_normal(self, shape, dtype):
|
def _standard_normal(self, shape, dtype):
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
key, counter = self._prepare_key_counter(shape)
|
||||||
|
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||||
|
shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
|
||||||
return gen_stateful_random_ops.stateful_standard_normal_v2(
|
return gen_stateful_random_ops.stateful_standard_normal_v2(
|
||||||
self.state.handle, self.algorithm, shape, dtype=dtype)
|
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||||
|
|
||||||
@ -586,6 +602,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported algorithm id: %s" % alg)
|
raise ValueError("Unsupported algorithm id: %s" % alg)
|
||||||
|
|
||||||
|
# TODO(wangpeng): Add "Returns" section to docstring once new version kicks in
|
||||||
|
# pylint: disable=g-doc-return-or-yield
|
||||||
def skip(self, delta):
|
def skip(self, delta):
|
||||||
"""Advance the counter of a counter-based RNG.
|
"""Advance the counter of a counter-based RNG.
|
||||||
|
|
||||||
@ -595,7 +613,24 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
(or any other distribution). The actual increment added to the
|
(or any other distribution). The actual increment added to the
|
||||||
counter is an unspecified implementation detail.
|
counter is an unspecified implementation detail.
|
||||||
"""
|
"""
|
||||||
gen_stateful_random_ops.rng_skip(self.state.handle, self.algorithm, delta)
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
return gen_stateful_random_ops.rng_read_and_skip(
|
||||||
|
self.state.handle,
|
||||||
|
alg=math_ops.cast(self.algorithm, dtypes.int32),
|
||||||
|
delta=math_ops.cast(delta, dtypes.uint64))
|
||||||
|
gen_stateful_random_ops.rng_skip(
|
||||||
|
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
|
||||||
|
math_ops.cast(delta, dtypes.int64))
|
||||||
|
# pylint: enable=g-doc-return-or-yield
|
||||||
|
|
||||||
|
def _prepare_key_counter(self, shape):
|
||||||
|
delta = math_ops.reduce_prod(shape)
|
||||||
|
counter_key = self.skip(delta)
|
||||||
|
counter_size = _get_counter_size(self.algorithm)
|
||||||
|
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
|
||||||
|
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
|
||||||
|
dtypes.uint64)
|
||||||
|
return key, counter
|
||||||
|
|
||||||
# The following functions return a tensor and as a side effect update
|
# The following functions return a tensor and as a side effect update
|
||||||
# self._state_var.
|
# self._state_var.
|
||||||
@ -624,6 +659,14 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
return math_ops.add(rnd * stddev, mean, name=name)
|
return math_ops.add(rnd * stddev, mean, name=name)
|
||||||
|
|
||||||
def _truncated_normal(self, shape, dtype):
|
def _truncated_normal(self, shape, dtype):
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
key, counter = self._prepare_key_counter(shape)
|
||||||
|
return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
|
||||||
|
shape=shape,
|
||||||
|
key=key,
|
||||||
|
counter=counter,
|
||||||
|
dtype=dtype,
|
||||||
|
alg=self.algorithm)
|
||||||
return gen_stateful_random_ops.stateful_truncated_normal(
|
return gen_stateful_random_ops.stateful_truncated_normal(
|
||||||
self.state.handle, self.algorithm, shape, dtype=dtype)
|
self.state.handle, self.algorithm, shape, dtype=dtype)
|
||||||
|
|
||||||
@ -662,10 +705,27 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
return math_ops.add(mul, mean_tensor, name=name)
|
return math_ops.add(mul, mean_tensor, name=name)
|
||||||
|
|
||||||
def _uniform(self, shape, dtype):
|
def _uniform(self, shape, dtype):
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
key, counter = self._prepare_key_counter(shape)
|
||||||
|
return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
|
||||||
|
shape=shape,
|
||||||
|
key=key,
|
||||||
|
counter=counter,
|
||||||
|
dtype=dtype,
|
||||||
|
alg=self.algorithm)
|
||||||
return gen_stateful_random_ops.stateful_uniform(
|
return gen_stateful_random_ops.stateful_uniform(
|
||||||
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
|
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
|
||||||
|
|
||||||
def _uniform_full_int(self, shape, dtype, name=None):
|
def _uniform_full_int(self, shape, dtype, name=None):
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
key, counter = self._prepare_key_counter(shape)
|
||||||
|
return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
|
||||||
|
shape=shape,
|
||||||
|
key=key,
|
||||||
|
counter=counter,
|
||||||
|
dtype=dtype,
|
||||||
|
alg=self.algorithm,
|
||||||
|
name=name)
|
||||||
return gen_stateful_random_ops.stateful_uniform_full_int(
|
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||||
self.state.handle, self.algorithm, shape=shape,
|
self.state.handle, self.algorithm, shape=shape,
|
||||||
dtype=dtype, name=name)
|
dtype=dtype, name=name)
|
||||||
@ -729,6 +789,16 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||||
if dtype.is_integer:
|
if dtype.is_integer:
|
||||||
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
key, counter = self._prepare_key_counter(shape)
|
||||||
|
return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
|
||||||
|
shape=shape,
|
||||||
|
key=key,
|
||||||
|
counter=counter,
|
||||||
|
minval=minval,
|
||||||
|
maxval=maxval,
|
||||||
|
alg=self.algorithm,
|
||||||
|
name=name)
|
||||||
return gen_stateful_random_ops.stateful_uniform_int(
|
return gen_stateful_random_ops.stateful_uniform_int(
|
||||||
self.state.handle, self.algorithm, shape=shape,
|
self.state.handle, self.algorithm, shape=shape,
|
||||||
minval=minval, maxval=maxval, name=name)
|
minval=minval, maxval=maxval, name=name)
|
||||||
|
@ -274,7 +274,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
gen1 = random.Generator.from_seed(seed)
|
gen1 = random.Generator.from_seed(seed)
|
||||||
gen2 = random.Generator.from_non_deterministic_state()
|
gen2 = random.Generator.from_non_deterministic_state()
|
||||||
sess.run((gen1._state_var.initializer, gen2._state_var.initializer))
|
sess.run((gen1.state.initializer, gen2.state.initializer))
|
||||||
r1 = gen1.normal(shape, dtype=dtypes.float32)
|
r1 = gen1.normal(shape, dtype=dtypes.float32)
|
||||||
r2 = gen2.normal(shape, dtype=dtypes.float32)
|
r2 = gen2.normal(shape, dtype=dtypes.float32)
|
||||||
def f():
|
def f():
|
||||||
@ -372,7 +372,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
|
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
|
||||||
delta = 432
|
delta = 432
|
||||||
gen.skip(delta)
|
gen.skip(delta)
|
||||||
new_counter = gen._state_var[0]
|
new_counter = gen.state[0]
|
||||||
self.assertAllEqual(counter + delta * 256, new_counter)
|
self.assertAllEqual(counter + delta * 256, new_counter)
|
||||||
|
|
||||||
def _sameAsOldRandomOps(self, device, floats):
|
def _sameAsOldRandomOps(self, device, floats):
|
||||||
@ -394,7 +394,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device(device):
|
with ops.device(device):
|
||||||
return new(dtype, gen)
|
return new(dtype, gen)
|
||||||
|
|
||||||
for _ in range(100):
|
for _ in range(5):
|
||||||
self.assertAllEqual(run_old(), run_new())
|
self.assertAllEqual(run_old(), run_new())
|
||||||
|
|
||||||
shape = constant_op.constant([4, 7])
|
shape = constant_op.constant([4, 7])
|
||||||
@ -582,6 +582,11 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testGetGlobalGeneratorWithXla(self):
|
def testGetGlobalGeneratorWithXla(self):
|
||||||
"""Demonstrates using the global generator with XLA."""
|
"""Demonstrates using the global generator with XLA."""
|
||||||
|
# This test was passing before because soft placement silently picked the
|
||||||
|
# CPU kernel.
|
||||||
|
# TODO(wangpeng): Remove this skip
|
||||||
|
self.skipTest("NonDeterministicInts lacks XLA kernel.")
|
||||||
|
|
||||||
if not config.list_physical_devices("XLA_CPU"):
|
if not config.list_physical_devices("XLA_CPU"):
|
||||||
self.skipTest("No XLA_CPU device available.")
|
self.skipTest("No XLA_CPU device available.")
|
||||||
|
|
||||||
@ -675,17 +680,16 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
|
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
@test_util.run_cuda_only
|
def testCreateOutsideMirroredStrat(self):
|
||||||
def testMirroredStratSeq(self):
|
|
||||||
"""Tests RNG/MirrorStrategy interaction #1.
|
"""Tests RNG/MirrorStrategy interaction #1.
|
||||||
|
|
||||||
If an RNG is created outside strategy.scope(), all replicas will access the
|
If an RNG is created outside a DS scope, all replicas will access the
|
||||||
same RNG object, and accesses are serialized.
|
same RNG object, and accesses are serialized.
|
||||||
"""
|
"""
|
||||||
shape = [3, 4]
|
shape = [3, 4]
|
||||||
dtype = dtypes.int32
|
dtype = dtypes.int32
|
||||||
gen = random.Generator.from_seed(1234)
|
gen = random.Generator.from_seed(1234)
|
||||||
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
|
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
||||||
with strat.scope():
|
with strat.scope():
|
||||||
def f():
|
def f():
|
||||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
@ -762,4 +766,5 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
config.set_soft_device_placement(False)
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -20,16 +20,19 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_stateless_random_ops
|
from tensorflow.python.ops import gen_stateless_random_ops
|
||||||
|
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import dispatch
|
from tensorflow.python.util import dispatch
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
ops.NotDifferentiable("StatelessMultinomial")
|
ops.NotDifferentiable("StatelessMultinomial")
|
||||||
ops.NotDifferentiable("StatelessRandomBinomial")
|
ops.NotDifferentiable("StatelessRandomBinomial")
|
||||||
ops.NotDifferentiable("StatelessRandomNormal")
|
ops.NotDifferentiable("StatelessRandomNormal")
|
||||||
@ -40,6 +43,13 @@ ops.NotDifferentiable("StatelessRandomUniformFullInt")
|
|||||||
ops.NotDifferentiable("StatelessTruncatedNormal")
|
ops.NotDifferentiable("StatelessTruncatedNormal")
|
||||||
|
|
||||||
|
|
||||||
|
ops.NotDifferentiable("StatelessRandomNormalV2")
|
||||||
|
ops.NotDifferentiable("StatelessRandomUniformV2")
|
||||||
|
ops.NotDifferentiable("StatelessRandomUniformIntV2")
|
||||||
|
ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
|
||||||
|
ops.NotDifferentiable("StatelessTruncatedNormalV2")
|
||||||
|
|
||||||
|
|
||||||
@tf_export("random.experimental.stateless_split")
|
@tf_export("random.experimental.stateless_split")
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
def split(seed, num=2):
|
def split(seed, num=2):
|
||||||
@ -113,6 +123,10 @@ def fold_in(seed, data):
|
|||||||
return array_ops.stack([seed1, data])
|
return array_ops.stack([seed1, data])
|
||||||
|
|
||||||
|
|
||||||
|
_get_key_counter_alg = (gen_stateless_random_ops_v2
|
||||||
|
.stateless_random_get_key_counter_alg)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("random.stateless_uniform")
|
@tf_export("random.stateless_uniform")
|
||||||
@dispatch.add_dispatch_support
|
@dispatch.add_dispatch_support
|
||||||
def stateless_random_uniform(shape,
|
def stateless_random_uniform(shape,
|
||||||
@ -192,17 +206,35 @@ def stateless_random_uniform(shape,
|
|||||||
[shape, seed, minval, maxval]) as name:
|
[shape, seed, minval, maxval]) as name:
|
||||||
shape = tensor_util.shape_tensor(shape)
|
shape = tensor_util.shape_tensor(shape)
|
||||||
if dtype.is_integer and minval is None:
|
if dtype.is_integer and minval is None:
|
||||||
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
shape, seed=seed, dtype=dtype, name=name)
|
key, counter, alg = _get_key_counter_alg(seed)
|
||||||
|
result = (gen_stateless_random_ops_v2
|
||||||
|
.stateless_random_uniform_full_int_v2(
|
||||||
|
shape, key=key, counter=counter, dtype=dtype, alg=alg,
|
||||||
|
name=name))
|
||||||
|
else:
|
||||||
|
result = gen_stateless_random_ops.stateless_random_uniform_full_int(
|
||||||
|
shape, seed=seed, dtype=dtype, name=name)
|
||||||
else:
|
else:
|
||||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||||
if dtype.is_integer:
|
if dtype.is_integer:
|
||||||
result = gen_stateless_random_ops.stateless_random_uniform_int(
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
shape, seed=seed, minval=minval, maxval=maxval, name=name)
|
key, counter, alg = _get_key_counter_alg(seed)
|
||||||
|
result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
|
||||||
|
shape, key=key, counter=counter, minval=minval, maxval=maxval,
|
||||||
|
alg=alg, name=name)
|
||||||
|
else:
|
||||||
|
result = gen_stateless_random_ops.stateless_random_uniform_int(
|
||||||
|
shape, seed=seed, minval=minval, maxval=maxval, name=name)
|
||||||
else:
|
else:
|
||||||
rnd = gen_stateless_random_ops.stateless_random_uniform(
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
shape, seed=seed, dtype=dtype)
|
key, counter, alg = _get_key_counter_alg(seed)
|
||||||
|
rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
|
||||||
|
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||||
|
else:
|
||||||
|
rnd = gen_stateless_random_ops.stateless_random_uniform(
|
||||||
|
shape, seed=seed, dtype=dtype)
|
||||||
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
|
result = math_ops.add(rnd * (maxval - minval), minval, name=name)
|
||||||
tensor_util.maybe_set_static_shape(result, shape)
|
tensor_util.maybe_set_static_shape(result, shape)
|
||||||
return result
|
return result
|
||||||
@ -476,7 +508,12 @@ def stateless_random_normal(shape,
|
|||||||
shape = tensor_util.shape_tensor(shape)
|
shape = tensor_util.shape_tensor(shape)
|
||||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||||
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
|
key, counter, alg = _get_key_counter_alg(seed)
|
||||||
|
rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
|
||||||
|
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||||
|
else:
|
||||||
|
rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
|
||||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||||
tensor_util.maybe_set_static_shape(result, shape)
|
tensor_util.maybe_set_static_shape(result, shape)
|
||||||
return result
|
return result
|
||||||
@ -521,8 +558,13 @@ def stateless_truncated_normal(shape,
|
|||||||
shape = tensor_util.shape_tensor(shape)
|
shape = tensor_util.shape_tensor(shape)
|
||||||
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
|
||||||
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
|
||||||
rnd = gen_stateless_random_ops.stateless_truncated_normal(
|
if compat.forward_compatible(2020, 10, 25):
|
||||||
shape, seed, dtype)
|
key, counter, alg = _get_key_counter_alg(seed)
|
||||||
|
rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
|
||||||
|
shape, key=key, counter=counter, dtype=dtype, alg=alg)
|
||||||
|
else:
|
||||||
|
rnd = gen_stateless_random_ops.stateless_truncated_normal(
|
||||||
|
shape, seed, dtype)
|
||||||
result = math_ops.add(rnd * stddev, mean, name=name)
|
result = math_ops.add(rnd * stddev, mean, name=name)
|
||||||
tensor_util.maybe_set_static_shape(result, shape)
|
tensor_util.maybe_set_static_shape(result, shape)
|
||||||
return result
|
return result
|
||||||
|
@ -3792,6 +3792,10 @@ tf_module {
|
|||||||
name: "Rint"
|
name: "Rint"
|
||||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RngReadAndSkip"
|
||||||
|
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RngSkip"
|
name: "RngSkip"
|
||||||
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -4544,10 +4548,18 @@ tf_module {
|
|||||||
name: "StatelessRandomGammaV2"
|
name: "StatelessRandomGammaV2"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomGetKeyCounterAlg"
|
||||||
|
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomNormal"
|
name: "StatelessRandomNormal"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomNormalV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomPoisson"
|
name: "StatelessRandomPoisson"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -4560,10 +4572,22 @@ tf_module {
|
|||||||
name: "StatelessRandomUniformFullInt"
|
name: "StatelessRandomUniformFullInt"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomUniformFullIntV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomUniformInt"
|
name: "StatelessRandomUniformInt"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomUniformIntV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomUniformV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessSampleDistortedBoundingBox"
|
name: "StatelessSampleDistortedBoundingBox"
|
||||||
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
||||||
@ -4572,6 +4596,10 @@ tf_module {
|
|||||||
name: "StatelessTruncatedNormal"
|
name: "StatelessTruncatedNormal"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessTruncatedNormalV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessWhile"
|
name: "StatelessWhile"
|
||||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||||
|
@ -3792,6 +3792,10 @@ tf_module {
|
|||||||
name: "Rint"
|
name: "Rint"
|
||||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RngReadAndSkip"
|
||||||
|
argspec: "args=[\'resource\', \'alg\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RngSkip"
|
name: "RngSkip"
|
||||||
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'resource\', \'algorithm\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -4544,10 +4548,18 @@ tf_module {
|
|||||||
name: "StatelessRandomGammaV2"
|
name: "StatelessRandomGammaV2"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomGetKeyCounterAlg"
|
||||||
|
argspec: "args=[\'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomNormal"
|
name: "StatelessRandomNormal"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomNormalV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomPoisson"
|
name: "StatelessRandomPoisson"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'lam\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -4560,10 +4572,22 @@ tf_module {
|
|||||||
name: "StatelessRandomUniformFullInt"
|
name: "StatelessRandomUniformFullInt"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomUniformFullIntV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'uint64\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessRandomUniformInt"
|
name: "StatelessRandomUniformInt"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomUniformIntV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'minval\', \'maxval\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessRandomUniformV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessSampleDistortedBoundingBox"
|
name: "StatelessSampleDistortedBoundingBox"
|
||||||
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
argspec: "args=[\'image_size\', \'bounding_boxes\', \'min_object_covered\', \'seed\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'[0.75, 1.33]\', \'[0.05, 1]\', \'100\', \'False\', \'None\'], "
|
||||||
@ -4572,6 +4596,10 @@ tf_module {
|
|||||||
name: "StatelessTruncatedNormal"
|
name: "StatelessTruncatedNormal"
|
||||||
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
argspec: "args=[\'shape\', \'seed\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "StatelessTruncatedNormalV2"
|
||||||
|
argspec: "args=[\'shape\', \'key\', \'counter\', \'alg\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "StatelessWhile"
|
name: "StatelessWhile"
|
||||||
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
argspec: "args=[\'input\', \'cond\', \'body\', \'output_shapes\', \'parallel_iterations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'10\', \'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user