Fix handling of negative seeds in random number generator op kernels for XLA

Casting negative s32 number to u64 directly will have leading 1s in the representation which is not what we want to get a single u64 out of two s32 seeds. Fixed this by first getting unsigned number of the same bit-width.

PiperOrigin-RevId: 345902167
Change-Id: I4f2f6d5415a82ac49db197a64216f951cf1b059d
This commit is contained in:
Smit Hinsu 2020-12-05 18:54:37 -08:00 committed by TensorFlower Gardener
parent f31efcb824
commit a5d5a36e4c
7 changed files with 56 additions and 19 deletions

View File

@ -518,11 +518,19 @@ void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
}
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
if (getOperand().getType() == getResult().getType()) return getOperand();
auto operand_ty = getOperand().getType().cast<TensorType>();
auto result_ty = getResult().getType().cast<TensorType>();
if (operand_ty == result_ty) return getOperand();
// If the result has non-static shape, a convert op is necessary to go from
// static shape to non-static shape.
if (!getResult().getType().cast<TensorType>().hasStaticShape()) return {};
if (!result_ty.hasStaticShape()) return {};
// TODO(hinsu): Handle unsigned types.
if (operand_ty.getElementType().isUnsignedInteger() ||
result_ty.getElementType().isUnsignedInteger()) {
return {};
}
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {

View File

@ -104,7 +104,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# Stateless values should be equal iff the seeds are equal (roughly)
with self.session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
seeds = [(x, y) for x in range(-2, 3) for y in range(-2, 3)] * 3 # pylint: disable=g-complex-comprehension
for stateless_op in [
stateless.stateless_random_uniform, stateless.stateless_random_normal
]:

View File

@ -76,6 +76,7 @@ tf_kernel_library(
"qr_op.cc",
"quantize_and_dequantize_op.cc",
"random_ops.cc",
"random_ops_util.cc",
"random_ops_util.h",
"reduce_window_op.cc",
"reduction_ops.cc",

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
namespace tensorflow {
xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1) {
// Here, the seeds are cast to unsigned type of the same width to have leading
// zeros in the 64 bit representation.
xla::XlaOp u64_seed0 =
ConvertElementType(ConvertElementType(seed0, xla::U32), xla::U64);
xla::XlaOp u64_seed1 =
ConvertElementType(ConvertElementType(seed1, xla::U32), xla::U64);
return u64_seed0 |
(u64_seed1 << ConstantR0WithType(seed0.builder(), xla::U64, 32));
}
} // namespace tensorflow

View File

@ -37,6 +37,9 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
// addition, the distribution near the limit is not uniform.
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype);
// Combines two signed 32-bit seeds into a single unsigned 64 bit seed.
xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_

View File

@ -82,9 +82,7 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::PrimitiveType type = shape.element_type();
switch (type) {
@ -120,9 +118,7 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::PrimitiveType type = shape.element_type();
xla::RngOutput output =
@ -307,9 +303,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
xla::XlaOp normal =
xla::NormalFloatingPointDistribution(
key, initial_state, GetBitGeneratorForDevice(device_type_string_),

View File

@ -478,9 +478,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
xla::XlaBuilder* builder = seed.builder();
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
auto key_counter = GetKeyCounter(device_type_string_, key);
key = std::get<0>(key_counter);
auto counter = std::get<1>(key_counter);
@ -497,6 +495,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
};
// TODO(hinsu): Dis-allow unsupported int64 seed types.
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
class GetKeyCounterOp : public XlaOpKernel {
@ -512,12 +511,9 @@ class GetKeyCounterOp : public XlaOpKernel {
seed_shape.DebugString()));
xla::XlaOp seed = ctx->Input(0);
xla::XlaBuilder* builder = seed.builder();
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
auto key_counter = GetKeyCounter(device_type_string_, key);
key = std::get<0>(key_counter);
auto counter = std::get<1>(key_counter);
@ -532,6 +528,7 @@ class GetKeyCounterOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
};
// TODO(hinsu): Dis-allow unsupported int64 seed types.
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
class GetAlgOp : public XlaOpKernel {