From 21813af36014ef706f622b4757f4c28b928e018f Mon Sep 17 00:00:00 2001 From: Anudhyan Boral Date: Tue, 31 Mar 2020 15:19:16 -0700 Subject: [PATCH] Modify the XLA Uniform sampler to use cast instead of bitcasts. We didn't strictly need a bitcast because we are ignoring the exponent bits anyway. Before and after logic is equivalent. However, performance could have an impact. PiperOrigin-RevId: 304057589 Change-Id: I2ad9e923b1c966f46eba91ae47e0e632b74cff72 --- .../tests/stateless_random_ops_test.py | 34 +++++++++++++++++++ tensorflow/compiler/xla/client/lib/prng.cc | 20 ++++++----- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 14b062e5cba..f9d792806b0 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.kernel_tests.random import util as \ random_test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import stateless_random_ops as stateless +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -132,5 +133,38 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3) +class StatelessRandomOpsBenchmark(test.Benchmark): + """Microbenchmarks for the stateless random ops.""" + + def _benchmarkUniform(self, name, dtype, use_xla_jit): + + def BuilderFn(): + shape = (10, 1000, 1000) + seed_var = variables.Variable((312, 456), + dtype=dtypes.int32, + name='input') + random_t = stateless.stateless_random_uniform( + shape, seed=seed_var, dtype=dtype) + return '%s.shape%s' % (name, shape), [random_t] + + xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu') + + def benchmarkUniformF32(self): + self._benchmarkUniform( + 'uniform_f32', dtype=dtypes.float32, use_xla_jit=False) + + def benchmarkUniformF64(self): + self._benchmarkUniform( + 'uniform_f64', dtype=dtypes.float64, use_xla_jit=False) + + def benchmarkUniformF32XLA(self): + self._benchmarkUniform( + 'uniform_f32', dtype=dtypes.float32, use_xla_jit=True) + + def benchmarkUniformF64XLA(self): + self._benchmarkUniform( + 'uniform_f64', dtype=dtypes.float64, use_xla_jit=True) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 17fb4c3c369..044a742eddd 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -434,17 +434,21 @@ XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, (value_type == F64 && bit_type == U64)); // Form random mantissa bits for float/double, with a leading 1 bit. - int float_bits = primitive_util::BitWidth(value_type); + int num_float_bits = primitive_util::BitWidth(value_type); // Subtract one as SignificandWidth includes the leading 1 bit. - int mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; + int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; - bits = ShiftRightLogical(bits, ScalarLike(bits, float_bits - mantissa_bits)) | - BitcastConvertType(ScalarLike(minval, 1.0), bit_type); - XlaOp values = BitcastConvertType(bits, value_type); + // Ignore the exponent bits and convert the mantissa bits to the floating + // point type. + bits = ShiftRightLogical( + bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); - // We have a floating point number in the range [1.0, 2.0). - // Subtract 1.0f to shift to the range [0.0, 1.0) - values = values - ScalarLike(values, 1.0); + // We have an integer-valued floating point number in the range + // [0, 2**{num_mantissa_bits}). + XlaOp values = ConvertElementType(bits, value_type); + + // Divide by 2**{-num_mantissa_bits} to get a number in the range [0.0, 1.0). + values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); // Multiply and add to shift to the range [minval, maxval). return values * (maxval - minval) + minval;