From f31621823a2467daf7cf40e9e2d3a56ce193e695 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Sat, 18 Apr 2020 17:43:57 -0700 Subject: [PATCH] [XLA] Broadcast ParameterizedTruncatedNormal parameters to the uniform's shape PiperOrigin-RevId: 307231768 Change-Id: I5b06be798fe4be5ae4447e3a0060706ac0e08a26 --- tensorflow/compiler/tests/random_ops_test.py | 19 ++++++++++++++++++ .../compiler/tf2xla/kernels/random_ops.cc | 20 +++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 52f47416ed2..2f304d0a96f 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -190,6 +190,25 @@ class RandomOpsTest(xla_test.XLATestCase): self._checkTruncatedNormalIsInRange( x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=stat_test) + def testParameterizedTruncatedNormalBroadcasting(self): + for dtype in self._random_types() & {np.float32, np.float64}: + with self.session(): + with self.test_scope(): + a = -1. + b = 1. + mu = 0. + sigma = 1. + count = 10000000 + x = random_ops.parameterized_truncated_normal( + shape=[1, count], + dtype=dtype, + means=mu, + stddevs=sigma, + minvals=[a], + maxvals=[b]) + self._checkTruncatedNormalIsInRange( + x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=True) + def testParameterizedTruncatedNormalIsInRangeCenter(self): count = 10000000 self._implParameterizedTruncatedNormalIsInRange( diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 1ccf0b4b125..3acb1d3359b 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -18,6 +18,7 @@ limitations under the License. // TODO(misard,phawkins): add tests. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -337,13 +338,20 @@ class ParameterizedTruncatedNormalOp : public XlaOpKernel { "reproducible behavior is desired."; xla::XlaOp uniform = xla::RngUniform(min_positive, one, xla_shape); - xla::XlaOp means = ctx->Input(1); - xla::XlaOp stddevs = ctx->Input(2); - xla::XlaOp minvals = ctx->Input(3); - xla::XlaOp maxvals = ctx->Input(4); + auto result = b->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::XlaOp means, + BroadcastTo(ctx->Input(1), shape.dim_sizes())); + TF_ASSIGN_OR_RETURN(xla::XlaOp stddevs, + BroadcastTo(ctx->Input(2), shape.dim_sizes())); + TF_ASSIGN_OR_RETURN(xla::XlaOp minvals, + BroadcastTo(ctx->Input(3), shape.dim_sizes())); + TF_ASSIGN_OR_RETURN(xla::XlaOp maxvals, + BroadcastTo(ctx->Input(4), shape.dim_sizes())); + return ParameterizedTruncatedNormal(uniform, means, stddevs, minvals, + maxvals); + }); - ctx->SetOutput(0, ParameterizedTruncatedNormal(uniform, means, stddevs, - minvals, maxvals)); + ctx->SetOutput(0, result); } };