[XLA] Broadcast ParameterizedTruncatedNormal parameters to the uniform's shape
PiperOrigin-RevId: 307231768 Change-Id: I5b06be798fe4be5ae4447e3a0060706ac0e08a26
This commit is contained in:
parent
d80c47cd14
commit
f31621823a
@ -190,6 +190,25 @@ class RandomOpsTest(xla_test.XLATestCase):
|
||||
self._checkTruncatedNormalIsInRange(
|
||||
x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=stat_test)
|
||||
|
||||
def testParameterizedTruncatedNormalBroadcasting(self):
|
||||
for dtype in self._random_types() & {np.float32, np.float64}:
|
||||
with self.session():
|
||||
with self.test_scope():
|
||||
a = -1.
|
||||
b = 1.
|
||||
mu = 0.
|
||||
sigma = 1.
|
||||
count = 10000000
|
||||
x = random_ops.parameterized_truncated_normal(
|
||||
shape=[1, count],
|
||||
dtype=dtype,
|
||||
means=mu,
|
||||
stddevs=sigma,
|
||||
minvals=[a],
|
||||
maxvals=[b])
|
||||
self._checkTruncatedNormalIsInRange(
|
||||
x, a=a, b=b, mu=mu, sigma=sigma, count=count, stat_test=True)
|
||||
|
||||
def testParameterizedTruncatedNormalIsInRangeCenter(self):
|
||||
count = 10000000
|
||||
self._implParameterizedTruncatedNormalIsInRange(
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
// TODO(misard,phawkins): add tests.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/random.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
@ -337,13 +338,20 @@ class ParameterizedTruncatedNormalOp : public XlaOpKernel {
|
||||
"reproducible behavior is desired.";
|
||||
xla::XlaOp uniform = xla::RngUniform(min_positive, one, xla_shape);
|
||||
|
||||
xla::XlaOp means = ctx->Input(1);
|
||||
xla::XlaOp stddevs = ctx->Input(2);
|
||||
xla::XlaOp minvals = ctx->Input(3);
|
||||
xla::XlaOp maxvals = ctx->Input(4);
|
||||
auto result = b->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaOp means,
|
||||
BroadcastTo(ctx->Input(1), shape.dim_sizes()));
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaOp stddevs,
|
||||
BroadcastTo(ctx->Input(2), shape.dim_sizes()));
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaOp minvals,
|
||||
BroadcastTo(ctx->Input(3), shape.dim_sizes()));
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaOp maxvals,
|
||||
BroadcastTo(ctx->Input(4), shape.dim_sizes()));
|
||||
return ParameterizedTruncatedNormal(uniform, means, stddevs, minvals,
|
||||
maxvals);
|
||||
});
|
||||
|
||||
ctx->SetOutput(0, ParameterizedTruncatedNormal(uniform, means, stddevs,
|
||||
minvals, maxvals));
|
||||
ctx->SetOutput(0, result);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user