Fix handling of negative seeds in random number generator op kernels for XLA
Casting negative s32 number to u64 directly will have leading 1s in the representation which is not what we want to get a single u64 out of two s32 seeds. Fixed this by first getting unsigned number of the same bit-width. PiperOrigin-RevId: 345902167 Change-Id: I4f2f6d5415a82ac49db197a64216f951cf1b059d
This commit is contained in:
parent
f31efcb824
commit
a5d5a36e4c
@ -518,11 +518,19 @@ void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand,
|
||||
}
|
||||
|
||||
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getOperand().getType() == getResult().getType()) return getOperand();
|
||||
auto operand_ty = getOperand().getType().cast<TensorType>();
|
||||
auto result_ty = getResult().getType().cast<TensorType>();
|
||||
if (operand_ty == result_ty) return getOperand();
|
||||
|
||||
// If the result has non-static shape, a convert op is necessary to go from
|
||||
// static shape to non-static shape.
|
||||
if (!getResult().getType().cast<TensorType>().hasStaticShape()) return {};
|
||||
if (!result_ty.hasStaticShape()) return {};
|
||||
|
||||
// TODO(hinsu): Handle unsigned types.
|
||||
if (operand_ty.getElementType().isUnsignedInteger() ||
|
||||
result_ty.getElementType().isUnsignedInteger()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// If the operand is constant, we can do the conversion now.
|
||||
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
||||
|
@ -104,7 +104,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||
with self.session(), self.test_scope():
|
||||
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||
seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension
|
||||
seeds = [(x, y) for x in range(-2, 3) for y in range(-2, 3)] * 3 # pylint: disable=g-complex-comprehension
|
||||
for stateless_op in [
|
||||
stateless.stateless_random_uniform, stateless.stateless_random_normal
|
||||
]:
|
||||
|
@ -76,6 +76,7 @@ tf_kernel_library(
|
||||
"qr_op.cc",
|
||||
"quantize_and_dequantize_op.cc",
|
||||
"random_ops.cc",
|
||||
"random_ops_util.cc",
|
||||
"random_ops_util.h",
|
||||
"reduce_window_op.cc",
|
||||
"reduction_ops.cc",
|
||||
|
33
tensorflow/compiler/tf2xla/kernels/random_ops_util.cc
Normal file
33
tensorflow/compiler/tf2xla/kernels/random_ops_util.cc
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1) {
|
||||
// Here, the seeds are cast to unsigned type of the same width to have leading
|
||||
// zeros in the 64 bit representation.
|
||||
xla::XlaOp u64_seed0 =
|
||||
ConvertElementType(ConvertElementType(seed0, xla::U32), xla::U64);
|
||||
xla::XlaOp u64_seed1 =
|
||||
ConvertElementType(ConvertElementType(seed1, xla::U32), xla::U64);
|
||||
return u64_seed0 |
|
||||
(u64_seed1 << ConstantR0WithType(seed0.builder(), xla::U64, 32));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -37,6 +37,9 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
|
||||
// addition, the distribution near the limit is not uniform.
|
||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype);
|
||||
|
||||
// Combines two signed 32-bit seeds into a single unsigned 64 bit seed.
|
||||
xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_
|
||||
|
@ -82,9 +82,7 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string,
|
||||
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
@ -120,9 +118,7 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string,
|
||||
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
xla::RngOutput output =
|
||||
@ -307,9 +303,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
|
||||
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
|
||||
xla::XlaOp normal =
|
||||
xla::NormalFloatingPointDistribution(
|
||||
key, initial_state, GetBitGeneratorForDevice(device_type_string_),
|
||||
|
@ -478,9 +478,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
|
||||
xla::XlaBuilder* builder = seed.builder();
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
|
||||
auto key_counter = GetKeyCounter(device_type_string_, key);
|
||||
key = std::get<0>(key_counter);
|
||||
auto counter = std::get<1>(key_counter);
|
||||
@ -497,6 +495,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp);
|
||||
};
|
||||
|
||||
// TODO(hinsu): Dis-allow unsupported int64 seed types.
|
||||
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp);
|
||||
|
||||
class GetKeyCounterOp : public XlaOpKernel {
|
||||
@ -512,12 +511,9 @@ class GetKeyCounterOp : public XlaOpKernel {
|
||||
seed_shape.DebugString()));
|
||||
xla::XlaOp seed = ctx->Input(0);
|
||||
|
||||
xla::XlaBuilder* builder = seed.builder();
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1);
|
||||
auto key_counter = GetKeyCounter(device_type_string_, key);
|
||||
key = std::get<0>(key_counter);
|
||||
auto counter = std::get<1>(key_counter);
|
||||
@ -532,6 +528,7 @@ class GetKeyCounterOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp);
|
||||
};
|
||||
|
||||
// TODO(hinsu): Dis-allow unsupported int64 seed types.
|
||||
REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp);
|
||||
|
||||
class GetAlgOp : public XlaOpKernel {
|
||||
|
Loading…
x
Reference in New Issue
Block a user