Modify the XLA Uniform sampler to use cast instead of bitcasts. We didn't strictly need a bitcast because we are ignoring the exponent bits anyway.

Before and after logic is equivalent. However, performance could have an impact.

PiperOrigin-RevId: 304057589
Change-Id: I2ad9e923b1c966f46eba91ae47e0e632b74cff72
This commit is contained in:
Anudhyan Boral 2020-03-31 15:19:16 -07:00 committed by TensorFlower Gardener
parent 0a704c08b6
commit 21813af360
2 changed files with 46 additions and 8 deletions

View File

@ -26,6 +26,7 @@ from tensorflow.python.kernel_tests.random import util as \
random_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import stateless_random_ops as stateless
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -132,5 +133,38 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)
class StatelessRandomOpsBenchmark(test.Benchmark):
"""Microbenchmarks for the stateless random ops."""
def _benchmarkUniform(self, name, dtype, use_xla_jit):
def BuilderFn():
shape = (10, 1000, 1000)
seed_var = variables.Variable((312, 456),
dtype=dtypes.int32,
name='input')
random_t = stateless.stateless_random_uniform(
shape, seed=seed_var, dtype=dtype)
return '%s.shape%s' % (name, shape), [random_t]
xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu')
def benchmarkUniformF32(self):
self._benchmarkUniform(
'uniform_f32', dtype=dtypes.float32, use_xla_jit=False)
def benchmarkUniformF64(self):
self._benchmarkUniform(
'uniform_f64', dtype=dtypes.float64, use_xla_jit=False)
def benchmarkUniformF32XLA(self):
self._benchmarkUniform(
'uniform_f32', dtype=dtypes.float32, use_xla_jit=True)
def benchmarkUniformF64XLA(self):
self._benchmarkUniform(
'uniform_f64', dtype=dtypes.float64, use_xla_jit=True)
if __name__ == '__main__':
test.main()

View File

@ -434,17 +434,21 @@ XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
(value_type == F64 && bit_type == U64));
// Form random mantissa bits for float/double, with a leading 1 bit.
int float_bits = primitive_util::BitWidth(value_type);
int num_float_bits = primitive_util::BitWidth(value_type);
// Subtract one as SignificandWidth includes the leading 1 bit.
int mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
bits = ShiftRightLogical(bits, ScalarLike(bits, float_bits - mantissa_bits)) |
BitcastConvertType(ScalarLike(minval, 1.0), bit_type);
XlaOp values = BitcastConvertType(bits, value_type);
// Ignore the exponent bits and convert the mantissa bits to the floating
// point type.
bits = ShiftRightLogical(
bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
// We have a floating point number in the range [1.0, 2.0).
// Subtract 1.0f to shift to the range [0.0, 1.0)
values = values - ScalarLike(values, 1.0);
// We have an integer-valued floating point number in the range
// [0, 2**{num_mantissa_bits}).
XlaOp values = ConvertElementType(bits, value_type);
// Divide by 2**{-num_mantissa_bits} to get a number in the range [0.0, 1.0).
values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
// Multiply and add to shift to the range [minval, maxval).
return values * (maxval - minval) + minval;