Modify the XLA Uniform sampler to use cast instead of bitcasts. We didn't strictly need a bitcast because we are ignoring the exponent bits anyway.
Before and after logic is equivalent. However, performance could have an impact. PiperOrigin-RevId: 304057589 Change-Id: I2ad9e923b1c966f46eba91ae47e0e632b74cff72
This commit is contained in:
parent
0a704c08b6
commit
21813af360
@ -26,6 +26,7 @@ from tensorflow.python.kernel_tests.random import util as \
|
||||
random_test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import stateless_random_ops as stateless
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -132,5 +133,38 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
|
||||
variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
||||
|
||||
|
||||
class StatelessRandomOpsBenchmark(test.Benchmark):
|
||||
"""Microbenchmarks for the stateless random ops."""
|
||||
|
||||
def _benchmarkUniform(self, name, dtype, use_xla_jit):
|
||||
|
||||
def BuilderFn():
|
||||
shape = (10, 1000, 1000)
|
||||
seed_var = variables.Variable((312, 456),
|
||||
dtype=dtypes.int32,
|
||||
name='input')
|
||||
random_t = stateless.stateless_random_uniform(
|
||||
shape, seed=seed_var, dtype=dtype)
|
||||
return '%s.shape%s' % (name, shape), [random_t]
|
||||
|
||||
xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu')
|
||||
|
||||
def benchmarkUniformF32(self):
|
||||
self._benchmarkUniform(
|
||||
'uniform_f32', dtype=dtypes.float32, use_xla_jit=False)
|
||||
|
||||
def benchmarkUniformF64(self):
|
||||
self._benchmarkUniform(
|
||||
'uniform_f64', dtype=dtypes.float64, use_xla_jit=False)
|
||||
|
||||
def benchmarkUniformF32XLA(self):
|
||||
self._benchmarkUniform(
|
||||
'uniform_f32', dtype=dtypes.float32, use_xla_jit=True)
|
||||
|
||||
def benchmarkUniformF64XLA(self):
|
||||
self._benchmarkUniform(
|
||||
'uniform_f64', dtype=dtypes.float64, use_xla_jit=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -434,17 +434,21 @@ XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
|
||||
(value_type == F64 && bit_type == U64));
|
||||
|
||||
// Form random mantissa bits for float/double, with a leading 1 bit.
|
||||
int float_bits = primitive_util::BitWidth(value_type);
|
||||
int num_float_bits = primitive_util::BitWidth(value_type);
|
||||
// Subtract one as SignificandWidth includes the leading 1 bit.
|
||||
int mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
|
||||
int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
|
||||
|
||||
bits = ShiftRightLogical(bits, ScalarLike(bits, float_bits - mantissa_bits)) |
|
||||
BitcastConvertType(ScalarLike(minval, 1.0), bit_type);
|
||||
XlaOp values = BitcastConvertType(bits, value_type);
|
||||
// Ignore the exponent bits and convert the mantissa bits to the floating
|
||||
// point type.
|
||||
bits = ShiftRightLogical(
|
||||
bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
|
||||
|
||||
// We have a floating point number in the range [1.0, 2.0).
|
||||
// Subtract 1.0f to shift to the range [0.0, 1.0)
|
||||
values = values - ScalarLike(values, 1.0);
|
||||
// We have an integer-valued floating point number in the range
|
||||
// [0, 2**{num_mantissa_bits}).
|
||||
XlaOp values = ConvertElementType(bits, value_type);
|
||||
|
||||
// Divide by 2**{-num_mantissa_bits} to get a number in the range [0.0, 1.0).
|
||||
values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
|
||||
|
||||
// Multiply and add to shift to the range [minval, maxval).
|
||||
return values * (maxval - minval) + minval;
|
||||
|
Loading…
Reference in New Issue
Block a user