From a5d5a36e4c93a8d4ef31fb41d0dafe0b474b3752 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sat, 5 Dec 2020 18:54:37 -0800 Subject: [PATCH] Fix handling of negative seeds in random number generator op kernels for XLA Casting negative s32 number to u64 directly will have leading 1s in the representation which is not what we want to get a single u64 out of two s32 seeds. Fixed this by first getting unsigned number of the same bit-width. PiperOrigin-RevId: 345902167 Change-Id: I4f2f6d5415a82ac49db197a64216f951cf1b059d --- .../mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc | 12 +++++-- .../tests/stateless_random_ops_test.py | 2 +- tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../tf2xla/kernels/random_ops_util.cc | 33 +++++++++++++++++++ .../compiler/tf2xla/kernels/random_ops_util.h | 3 ++ .../tf2xla/kernels/stateless_random_ops.cc | 13 +++----- .../tf2xla/kernels/stateless_random_ops_v2.cc | 11 +++---- 7 files changed, 56 insertions(+), 19 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/kernels/random_ops_util.cc diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index e25538e9aa5..6b7b235573f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -518,11 +518,19 @@ void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, } OpFoldResult ConvertOp::fold(ArrayRef operands) { - if (getOperand().getType() == getResult().getType()) return getOperand(); + auto operand_ty = getOperand().getType().cast(); + auto result_ty = getResult().getType().cast(); + if (operand_ty == result_ty) return getOperand(); // If the result has non-static shape, a convert op is necessary to go from // static shape to non-static shape. - if (!getResult().getType().cast().hasStaticShape()) return {}; + if (!result_ty.hasStaticShape()) return {}; + + // TODO(hinsu): Handle unsigned types. + if (operand_ty.getElementType().isUnsignedInteger() || + result_ty.getElementType().isUnsignedInteger()) { + return {}; + } // If the operand is constant, we can do the conversion now. if (auto elementsAttr = operands.front().dyn_cast_or_null()) { diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index a1ce7c016ec..c86b36d846a 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -104,7 +104,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Stateless values should be equal iff the seeds are equal (roughly) with self.session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - seeds = [(x, y) for x in range(5) for y in range(5)] * 3 # pylint: disable=g-complex-comprehension + seeds = [(x, y) for x in range(-2, 3) for y in range(-2, 3)] * 3 # pylint: disable=g-complex-comprehension for stateless_op in [ stateless.stateless_random_uniform, stateless.stateless_random_normal ]: diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index dd724b3d453..b4f897016d0 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -76,6 +76,7 @@ tf_kernel_library( "qr_op.cc", "quantize_and_dequantize_op.cc", "random_ops.cc", + "random_ops_util.cc", "random_ops_util.h", "reduce_window_op.cc", "reduction_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc new file mode 100644 index 00000000000..d9886b49532 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" + +namespace tensorflow { + +xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1) { + // Here, the seeds are cast to unsigned type of the same width to have leading + // zeros in the 64 bit representation. + xla::XlaOp u64_seed0 = + ConvertElementType(ConvertElementType(seed0, xla::U32), xla::U64); + xla::XlaOp u64_seed1 = + ConvertElementType(ConvertElementType(seed1, xla::U32), xla::U64); + return u64_seed0 | + (u64_seed1 << ConstantR0WithType(seed0.builder(), xla::U64, 32)); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h index 9a6dc37e2c9..799e215ef76 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -37,6 +37,9 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, // addition, the distribution near the limit is not uniform. xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype); +// Combines two signed 32-bit seeds into a single unsigned 64 bit seed. +xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index e606812bc4e..ecf8eda9a5f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -82,9 +82,7 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {}); xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {}); - xla::XlaOp key = ConvertElementType(seed0, xla::U64) | - ShiftLeft(ConvertElementType(seed1, xla::U64), - ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1); xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); xla::PrimitiveType type = shape.element_type(); switch (type) { @@ -120,9 +118,7 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string, xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {}); xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {}); - xla::XlaOp key = ConvertElementType(seed0, xla::U64) | - ShiftLeft(ConvertElementType(seed1, xla::U64), - ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1); xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); xla::PrimitiveType type = shape.element_type(); xla::RngOutput output = @@ -307,9 +303,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); - xla::XlaOp key = ConvertElementType(seed0, xla::U64) | - ShiftLeft(ConvertElementType(seed1, xla::U64), - ConstantR0WithType(builder, xla::U64, 32)); + + xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1); xla::XlaOp normal = xla::NormalFloatingPointDistribution( key, initial_state, GetBitGeneratorForDevice(device_type_string_), diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index 0c9f1f92a0b..2d38dbfaaad 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -478,9 +478,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel { xla::XlaBuilder* builder = seed.builder(); xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - xla::XlaOp key = ConvertElementType(seed0, xla::U64) | - ShiftLeft(ConvertElementType(seed1, xla::U64), - ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1); auto key_counter = GetKeyCounter(device_type_string_, key); key = std::get<0>(key_counter); auto counter = std::get<1>(key_counter); @@ -497,6 +495,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterAlgOp); }; +// TODO(hinsu): Dis-allow unsupported int64 seed types. REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounterAlg"), GetKeyCounterAlgOp); class GetKeyCounterOp : public XlaOpKernel { @@ -512,12 +511,9 @@ class GetKeyCounterOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(0); - xla::XlaBuilder* builder = seed.builder(); xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - xla::XlaOp key = ConvertElementType(seed0, xla::U64) | - ShiftLeft(ConvertElementType(seed1, xla::U64), - ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp key = GetU64FromS32Seeds(seed0, seed1); auto key_counter = GetKeyCounter(device_type_string_, key); key = std::get<0>(key_counter); auto counter = std::get<1>(key_counter); @@ -532,6 +528,7 @@ class GetKeyCounterOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(GetKeyCounterOp); }; +// TODO(hinsu): Dis-allow unsupported int64 seed types. REGISTER_XLA_OP(Name("StatelessRandomGetKeyCounter"), GetKeyCounterOp); class GetAlgOp : public XlaOpKernel {