[XLA] Broadcast ParameterizedTruncatedNormal parameters to the uniform's shape

PiperOrigin-RevId: 307231768
Change-Id: I5b06be798fe4be5ae4447e3a0060706ac0e08a26
This commit is contained in:
David Majnemer 2020-04-18 17:43:57 -07:00 committed by TensorFlower Gardener
parent d80c47cd14
commit f31621823a
2 changed files with 33 additions and 6 deletions

View File

@ -190,6 +190,25 @@ class RandomOpsTest(xla_test.XLATestCase):
self._checkTruncatedNormalIsInRange(
x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=stat_test)
def testParameterizedTruncatedNormalBroadcasting(self):
for dtype in self._random_types() & {np.float32, np.float64}:
with self.session():
with self.test_scope():
a = -1.
b = 1.
mu = 0.
sigma = 1.
count = 10000000
x = random_ops.parameterized_truncated_normal(
shape=[1, count],
dtype=dtype,
means=mu,
stddevs=sigma,
minvals=[a],
maxvals=[b])
self._checkTruncatedNormalIsInRange(
x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=True)
def testParameterizedTruncatedNormalIsInRangeCenter(self):
count = 10000000
self._implParameterizedTruncatedNormalIsInRange(

View File

@ -18,6 +18,7 @@ limitations under the License.
// TODO(misard,phawkins): add tests.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@ -337,13 +338,20 @@ class ParameterizedTruncatedNormalOp : public XlaOpKernel {
"reproducible behavior is desired.";
xla::XlaOp uniform = xla::RngUniform(min_positive, one, xla_shape);
xla::XlaOp means = ctx->Input(1);
xla::XlaOp stddevs = ctx->Input(2);
xla::XlaOp minvals = ctx->Input(3);
xla::XlaOp maxvals = ctx->Input(4);
auto result = b->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
TF_ASSIGN_OR_RETURN(xla::XlaOp means,
BroadcastTo(ctx->Input(1), shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(xla::XlaOp stddevs,
BroadcastTo(ctx->Input(2), shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(xla::XlaOp minvals,
BroadcastTo(ctx->Input(3), shape.dim_sizes()));
TF_ASSIGN_OR_RETURN(xla::XlaOp maxvals,
BroadcastTo(ctx->Input(4), shape.dim_sizes()));
return ParameterizedTruncatedNormal(uniform, means, stddevs, minvals,
maxvals);
});
ctx->SetOutput(0, ParameterizedTruncatedNormal(uniform, means, stddevs,
minvals, maxvals));
ctx->SetOutput(0, result);
}
};