Allow RandomBinomial op to broadcast parameters.

- Add multiple parameter broadcasting support for BCast. This will allow it to be used in multiparameter broadcasting contexts. This is specifically for ternary ops, but will be used to make other samplers like ParameterizedTruncatedNormal broadcast.

- Add batch index methods for generating a list of batch indices when the input vectors are flattened. This is used to get broadcasting on flattened inputs (which is used in the RandomBinomial sampler).

- Shard on the number of outputs. This allows us to scale better to Tensor inputs.

PiperOrigin-RevId: 281202841
Change-Id: I0b276e983bf31056677a67b4d5ce8ebc98d77930
This commit is contained in:
Srinivas Vasudevan 2019-11-18 18:49:25 -08:00 committed by TensorFlower Gardener
parent b39b1ed24b
commit 5396e7a3cd
10 changed files with 828 additions and 367 deletions

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/bcast.h"
#include "tensorflow/core/util/guarded_philox_random.h"
#include "tensorflow/core/util/work_sharder.h"
@ -86,7 +87,7 @@ double binomial_inversion(double count, double prob,
return num_geom;
}
double stirling_approx_tail(double k) {
inline double stirling_approx_tail(double k) {
static double kTailValues[] = {0.0810614667953272, 0.0413406959554092,
0.0276779256849983, 0.02079067210376509,
0.0166446911898211, 0.0138761288230707,
@ -102,7 +103,7 @@ double stirling_approx_tail(double k) {
// We use a transformation-rejection algorithm from
// pairs of uniform random variables due to Hormann.
// https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
double btrs(double count, double prob, random::PhiloxRandom* gen) {
inline double btrs(double count, double prob, random::PhiloxRandom* gen) {
using Eigen::numext::abs;
using Eigen::numext::floor;
using Eigen::numext::log;
@ -119,6 +120,9 @@ double btrs(double count, double prob, random::PhiloxRandom* gen) {
const double v_r = 0.92 - 4.2 / b;
const double r = prob / (1 - prob);
const double alpha = (2.83 + 5.1 / b) * stddev;
const double m = floor((count + 1) * prob);
Uniform uniform;
typename Uniform::ResultType uniform_result;
int16 uniform_remaining = 0;
@ -143,8 +147,6 @@ double btrs(double count, double prob, random::PhiloxRandom* gen) {
continue;
}
double alpha = (2.83 + 5.1 / b) * stddev;
double m = floor((count + 1) * prob);
// This deviates from Hormann's BRTS algorithm, as there is a log missing.
// For all (u, v) pairs outside of the bounding box, this calculates the
// transformed-reject ratio.
@ -169,66 +171,83 @@ template <typename T, typename U>
struct RandomBinomialFunctor<CPUDevice, T, U> {
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
typename TTypes<T>::ConstFlat counts,
const BCast& bcast, typename TTypes<T>::ConstFlat counts,
typename TTypes<T>::ConstFlat probs,
const random::PhiloxRandom& gen,
typename TTypes<U>::Flat output) {
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
auto DoWork = [samples_per_batch, num_elements, &counts, &probs, &gen,
&output](int start_batch, int limit_batch) {
// Capturing "gen" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
// Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
// us using the same state in different batches.
// The sample from each iteration uses 2 random numbers.
gen_copy.Skip(start_batch * 2 * 3 * (samples_per_batch + 3) / 4);
// The output layout is [B1, ... Bk, H1, ... Hm]. We have [B1, ... Bk] for
// the sample shape and [H1, ... Hm] for the batch shape of the samples.
// We have B1 * ... * Bk samples per batch member we need.
auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
&gen, &output](int start_output, int limit_output) {
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
Eigen::array<T, 4> z;
Eigen::array<T, 4> g;
const bool should_bcast = bcast.IsBroadcastingRequired();
const auto& counts_batch_indices = bcast.x_batch_indices();
const auto& probs_batch_indices = bcast.y_batch_indices();
auto output_flat = output.data();
for (int64 b = start_batch; b < limit_batch; ++b) {
// We are passed a flat array for each of the parameter tensors.
// The input is either a scalar broadcasted to all batches or a vector
// with length num_batches, but the scalar becomes an array of length 1.
T count = counts((counts.dimension(0) == 1) ? 0 : b);
T prob = probs((probs.dimension(0) == 1) ? 0 : b);
// The last batch can be short, if we adjusted num_batches and
// samples_per_batch.
const int64 limit_sample =
std::min((b + 1) * samples_per_batch, num_elements);
int64 sample = b * samples_per_batch;
// We partition work across batches (count, prob) and then across samples
// per batch member, to avoid extra work.
for (int64 output_idx = start_output; output_idx < limit_output;
// output_idx is incremented with the inner loops below.
) {
int64 batch_idx = output_idx / samples_per_batch;
U* const output_batch_offset = output_flat + batch_idx;
// Generate batch counts from BCast, as it has the right indices to loop
// over.
T count, prob;
if (should_bcast) {
count = counts(counts_batch_indices[batch_idx]);
prob = probs(probs_batch_indices[batch_idx]);
} else {
count = counts(batch_idx);
prob = probs(batch_idx);
}
// Calculate normalized samples, then convert them.
// Determine the method to use.
double dcount = static_cast<double>(count);
if (dcount <= 0.0 || prob <= T(0.0)) {
while (sample < limit_sample) {
output(sample) = static_cast<U>(0.0);
sample++;
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;
++sample_idx, ++output_idx) {
output_batch_offset[sample_idx * num_batches] = static_cast<U>(0.0);
}
} else if (prob >= T(1.0)) {
while (sample < limit_sample) {
output(sample) = static_cast<U>(dcount);
sample++;
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;
++sample_idx, ++output_idx) {
output_batch_offset[sample_idx * num_batches] =
static_cast<U>(dcount);
}
} else if (prob <= T(0.5)) {
double dp = static_cast<double>(prob);
if (count * prob >= T(10)) {
while (sample < limit_sample) {
output(sample) = static_cast<U>(btrs(dcount, dp, &gen_copy));
sample++;
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;
++sample_idx, ++output_idx) {
random::PhiloxRandom gen_copy = gen;
gen_copy.Skip(256 * output_idx);
output_batch_offset[sample_idx * num_batches] =
static_cast<U>(btrs(dcount, dp, &gen_copy));
}
} else {
while (sample < limit_sample) {
output(sample) =
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;
++sample_idx, ++output_idx) {
random::PhiloxRandom gen_copy = gen;
// For binomial inversion, we have mean <= 10, variance <= 10.
// This means on average we need at most 10 number of samples,
// and for 10 standard deviations, we need 42 samples. We reserve
// that much.
gen_copy.Skip(42 * output_idx);
output_batch_offset[sample_idx * num_batches] =
static_cast<U>(binomial_inversion(dcount, dp, &gen_copy));
sample++;
}
}
} else if (prob > T(0.5)) {
@ -236,45 +255,41 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
double dcount = static_cast<double>(count);
double dq = static_cast<double>(q);
if (count * q >= T(10)) {
while (sample < limit_sample) {
output(sample) =
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;
++sample_idx, ++output_idx) {
random::PhiloxRandom gen_copy = gen;
gen_copy.Skip(256 * output_idx);
output_batch_offset[sample_idx * num_batches] =
static_cast<U>(dcount - btrs(dcount, dq, &gen_copy));
sample++;
}
} else {
while (sample < limit_sample) {
output(sample) = static_cast<U>(
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;
++sample_idx, ++output_idx) {
random::PhiloxRandom gen_copy = gen;
// For binomial inversion, we have mean <= 10, variance <= 10.
// This means on average we need at most 10 number of samples,
// and for 10 standard deviations, we need 42 samples. We reserve
// that much.
gen_copy.Skip(42 * output_idx);
output_batch_offset[sample_idx * num_batches] = static_cast<U>(
dcount - binomial_inversion(dcount, dq, &gen_copy));
sample++;
}
}
} else { // prob is NaN
// TODO(srvasude): What should happen if prob is NaN but the output
// type is an integer (which doesn't have a sentinel for NaN)? Fail
// the whole batch sample? Return a specialized sentinel like -1?
while (sample < limit_sample) {
output(sample) = static_cast<U>(NAN);
sample++;
for (int64 sample_idx = output_idx % samples_per_batch;
sample_idx < samples_per_batch && output_idx < limit_output;
++sample_idx, ++output_idx) {
output_batch_offset[sample_idx * num_batches] = static_cast<U>(NAN);
}
}
}
};
const int64 batch_init_cost =
// normMin, normMax
(Eigen::TensorOpCost::AddCost<T>() +
Eigen::TensorOpCost::MulCost<T>()) *
2
// sqrtFactor
+ Eigen::TensorOpCost::AddCost<T>() +
Eigen::TensorOpCost::MulCost<T>() +
Eigen::internal::functor_traits<
Eigen::internal::scalar_sqrt_op<T>>::Cost
// cutoff
+ Eigen::TensorOpCost::MulCost<T>() * 4 +
Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
// diff
+ Eigen::TensorOpCost::AddCost<T>();
// This will depend on count * p (or count * q).
// For n * p < 10, on average, O(n * p) calls to uniform are
// needed, with that
@ -290,17 +305,15 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
// 2 uniform generations along with 5 other ops at 3-6 cycles each.
// ~15 / .89 = ~16
//
// In total this should be ~529 + 2 * Uniform::kElementCost.
// In total this (rate >= 10) should be ~329 + 2 * Uniform::kElementCost.
// We assume that half the tensor has rate < 10, so on average 6
// uniform's
// will be needed. We will upper bound the other op cost by the one for
// rate > 10.
static const int kElementCost = 529 + 6 * Uniform::kElementCost +
static const int kElementCost = 329 + 6 * Uniform::kElementCost +
6 * random::PhiloxRandom::kElementCost;
// Assume we use uniform sampling, and accept the 2nd sample on average.
const int64 batch_cost = batch_init_cost + kElementCost * samples_per_batch;
Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
batch_cost, DoWork);
Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
kElementCost, DoWork);
}
};
@ -324,72 +337,60 @@ class RandomBinomialOp : public OpKernel {
const Tensor& counts_tensor = ctx->input(3);
const Tensor& probs_tensor = ctx->input(4);
tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(),
probs_tensor.shape().dim_sizes(),
/*fewer_dims_optimization=*/false,
/*return_flattened_batch_indices=*/true);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
"counts and probs must have compatible batch dimensions: ",
counts_tensor.shape().DebugString(), " vs. ",
probs_tensor.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
errors::InvalidArgument("Input shape should be a vector, got shape: ",
shape_tensor.shape().DebugString()));
OP_REQUIRES(ctx,
(shape_tensor.dtype() == DataType::DT_INT32 ||
shape_tensor.dtype() == DataType::DT_INT64),
errors::InvalidArgument(
"Input shape should have dtype {int32, int64}."));
// Let's check that the shape tensor dominates the broadcasted tensor.
TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
TensorShape output_shape;
if (shape_tensor.dtype() == DataType::DT_INT32) {
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
&output_shape));
} else {
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int64>(),
&output_shape));
}
OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
errors::InvalidArgument(
"Shape passed in must end with broadcasted shape."));
// Now that we have a guarantee, we can get the additional dimensions added
// by sampling.
OP_REQUIRES(ctx, alg_tensor.dims() == 0,
errors::InvalidArgument("algorithm must be of shape [], not ",
alg_tensor.shape().DebugString()));
Algorithm alg = alg_tensor.flat<Algorithm>()(0);
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
errors::InvalidArgument("Input shape should be a vector, got shape: ",
shape_tensor.shape().DebugString()));
int32 num_batches = shape_tensor.flat<int32>()(0);
int32 samples_per_batch = 1;
const int32 num_dims = shape_tensor.dim_size(0);
for (int32 i = 1; i < num_dims; i++) {
int64 samples_per_batch = 1;
const int64 num_sample_dims =
(shape_tensor.dim_size(0) - bcast.output_shape().size());
for (int64 i = 0; i < num_sample_dims; ++i) {
samples_per_batch *= shape_tensor.flat<int32>()(i);
}
const int32 num_elements = num_batches * samples_per_batch;
// Allocate the output before fudging num_batches and samples_per_batch.
auto shape_vec = shape_tensor.flat<int32>();
TensorShape tensor_shape;
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
shape_vec.data(), shape_vec.size(), &tensor_shape));
Tensor* samples_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor));
// Parameters must be 0-d or 1-d.
OP_REQUIRES(ctx, counts_tensor.dims() <= 1,
errors::InvalidArgument(
"Input counts should be a scalar or vector, got shape: ",
counts_tensor.shape().DebugString()));
OP_REQUIRES(ctx, probs_tensor.dims() <= 1,
errors::InvalidArgument(
"Input probs should be a scalar or vector, got shape: ",
probs_tensor.shape().DebugString()));
if ((counts_tensor.dims() == 0 || counts_tensor.dim_size(0) == 1) &&
(probs_tensor.dims() == 0 || probs_tensor.dim_size(0) == 1)) {
// All batches have the same parameters, so we can update the batch size
// to a reasonable value to improve parallelism (ensure enough batches,
// and no very small batches which have high overhead).
int32 size = num_batches * samples_per_batch;
int32 adjusted_samples = kDesiredBatchSize;
// Ensure adjusted_batches * adjusted_samples >= size.
int32 adjusted_batches = Eigen::divup(size, adjusted_samples);
num_batches = adjusted_batches;
samples_per_batch = adjusted_samples;
} else {
// Parameters must be broadcastable to the shape [num_batches].
OP_REQUIRES(
ctx,
TensorShapeUtils::IsScalar(counts_tensor.shape()) ||
counts_tensor.dim_size(0) == 1 ||
counts_tensor.dim_size(0) == num_batches,
errors::InvalidArgument(
"Input counts should have length 1 or shape[0], got shape: ",
counts_tensor.shape().DebugString()));
OP_REQUIRES(
ctx,
TensorShapeUtils::IsScalar(probs_tensor.shape()) ||
probs_tensor.dim_size(0) == 1 ||
probs_tensor.dim_size(0) == num_batches,
errors::InvalidArgument(
"Input probs should have length 1 or shape[0], got shape: ",
probs_tensor.shape().DebugString()));
int64 num_batches = 1;
for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
num_batches *= shape_tensor.flat<int32>()(i);
}
const int64 num_elements = num_batches * samples_per_batch;
Tensor* samples_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
core::RefCountPtr<Var> var;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
@ -415,7 +416,6 @@ class RandomBinomialOp : public OpKernel {
"For Philox algorithm, the size of state must be at least ",
PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size()));
// Each worker has the fudge factor for samples_per_batch, so use it here.
OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
ctx, var_tensor, var->copy_on_read_mode.load()));
auto var_data = var_tensor_flat.data();
@ -425,8 +425,9 @@ class RandomBinomialOp : public OpKernel {
auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
samples_per_batch, num_elements, counts_tensor.flat<T>(),
probs_tensor.flat<T>(), philox, samples_tensor->flat<U>());
samples_per_batch, num_elements, bcast,
counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox,
samples_tensor->flat<U>());
}
private:

View File

@ -107,10 +107,6 @@ REGISTER_OP("StatefulRandomBinomial")
.SetShapeFn([](shape_inference::InferenceContext* c) {
using shape_inference::ShapeHandle;
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out));
c->set_output(0, out);

View File

@ -18,153 +18,16 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
/* static */
void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); }
BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) {
if (sx == sy && TF_PREDICT_TRUE(fewer_dims_optimization)) {
// Fast path for common case of identical shapes for sx and sy
int64 elements = 1;
const int n = sx.size();
output_.resize(n);
for (int i = 0; i < n; i++) {
const int64 dim = sx[i];
elements *= dim;
output_[i] = dim;
}
result_.push_back(elements);
x_reshape_.push_back(elements);
y_reshape_.push_back(elements);
x_bcast_.push_back(1);
y_bcast_.push_back(1);
// grad_x_reduce_ and grad_y_reduce_ are left as empty
} else {
// Reverse the shape of x and y for convenience.
// After the reverse, 0-th is the inner-most dimension.
Vec x = sx;
Vec y = sy;
Reverse(&x);
Reverse(&y);
// 1-extend and align x and y so that they are the same size.
if (x.size() > y.size()) {
y.resize(x.size(), 1);
} else {
x.resize(y.size(), 1);
}
// Going through each dimension starting from the inner-most
// dimension, compares dimension of x and y. They are compatible if
// they are equal or either is 1.
enum State {
UNKNOWN,
SAME,
X_ONE,
Y_ONE,
};
State prev = UNKNOWN;
const int64 n = x.size();
for (int i = 0; i < n; ++i) {
// Output shape.
State curr = UNKNOWN;
const int64 x_i = x[i]; // i-th dimension of x.
const int64 y_i = y[i]; // i-th dimension of y.
int64 o_i; // i-th dimension of the output.
int64 bx_i; // i-th broadcast for x.
int64 by_i; // i-th broadcast for y.
// Invariant:
// o_i = x_i * bx_i = y_i * by_i
if (x_i == y_i) {
// No broadcast.
o_i = x_i;
bx_i = 1;
by_i = 1;
curr = SAME;
} else if (x_i == 1) {
// x broadcast to y on this dimension.
o_i = y_i;
bx_i = y_i;
by_i = 1;
grad_x_reduce_idx_.push_back(n - 1 - i);
curr = X_ONE;
} else if (y_i == 1) {
// y broadcast to x on this dimension.
o_i = x_i;
bx_i = 1;
by_i = x_i;
grad_y_reduce_idx_.push_back(n - 1 - i);
curr = Y_ONE;
} else {
valid_ = false;
return;
}
output_.push_back(o_i);
// Reshape/broadcast.
// Invariant:
// result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i]
if (curr == SAME && x_i == 1) {
// Both side are 1s.
grad_x_reduce_idx_.push_back(n - 1 - i);
grad_y_reduce_idx_.push_back(n - 1 - i);
if (!TF_PREDICT_TRUE(fewer_dims_optimization)) {
result_.push_back(o_i);
x_reshape_.push_back(x_i);
x_bcast_.push_back(bx_i);
y_reshape_.push_back(y_i);
y_bcast_.push_back(by_i);
}
continue;
} else if (TF_PREDICT_TRUE(fewer_dims_optimization) && prev == curr) {
// It is a run of the same cases(no broadcast, x broadcast to y, y
// broadcast to x). We can reshape the input so that fewer dimensions
// are involved in the intermediate computation.
result_.back() *= o_i;
x_reshape_.back() *= x_i;
x_bcast_.back() *= bx_i;
y_reshape_.back() *= y_i;
y_bcast_.back() *= by_i;
} else {
result_.push_back(o_i);
x_reshape_.push_back(x_i);
x_bcast_.push_back(bx_i);
y_reshape_.push_back(y_i);
y_bcast_.push_back(by_i);
}
prev = curr;
}
if (result_.empty()) {
// Can happen when both x and y are effectively scalar.
result_.push_back(1);
x_reshape_.push_back(1);
x_bcast_.push_back(1);
y_reshape_.push_back(1);
y_bcast_.push_back(1);
}
// Reverse all vectors since x and y were reversed at very
// beginning.
Reverse(&x_reshape_);
Reverse(&x_bcast_);
Reverse(&y_reshape_);
Reverse(&y_bcast_);
Reverse(&result_);
Reverse(&output_);
Reverse(&grad_x_reduce_idx_);
Reverse(&grad_y_reduce_idx_);
}
}
BCast::Vec BCast::FromShape(const TensorShape& shape) {
const int N = shape.dims();
BCast::Vec ret(N);
BCastList::Vec ret(N);
for (int i = 0; i < N; ++i) {
ret[i] = shape.dim_size(i);
}
return ret;
}
TensorShape BCast::ToShape(const BCast::Vec& vec) {
TensorShape BCast::ToShape(const BCastList::Vec& vec) {
TensorShape shape(vec);
return shape;
}

View File

@ -25,6 +25,284 @@ limitations under the License.
namespace tensorflow {
// Returns the mapping from the output batch indices to the corresponding
// input's batch indices, given the input's "reshape" and "bcast" shapes as
// returned by the BCastList helper class. The i'th element denotes the
// (flattened) batch index of the input that must be used to compute the i'th
// batch output.
//
inline void ComputeBatchIndices(const int64 output_batch_size,
const gtl::InlinedVector<int64, 4>& reshape,
const gtl::InlinedVector<int64, 4>& bcast,
std::vector<int64>* out_indices) {
// Populates the mapping in out_indices. This algorithm is identical to
// the following steps:
// - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
// - Broadcast to the output shape.
// - Reshape back to a flat 1D vector.
out_indices->resize(output_batch_size);
int64 num_output_elements = 1;
int64 num_input_elements = 1;
for (int64 i = reshape.size() - 1; i >= 0; --i) {
// Replicate the already populated mapping an additional (dim - 1) times.
// If we are broadcasting, just copy the existing mapping.
// Otherwise, add another dimension from the input shape.
const int64 dim = std::max(reshape[i], bcast[i]);
const int64 incr = bcast[i] > 1 ? 0 : num_input_elements;
for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) {
(*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
}
num_output_elements *= dim;
num_input_elements *= reshape[i];
}
}
template <int N>
class BCastList {
public:
// A vector of int64 representing the shape of tensor. The 0-th
// element is the outer-most dimension and the last element is the
// inner-most dimension. Note that we do not use TensorShape since
// it's more convenient to manipulate Vec directly for this module.
typedef gtl::InlinedVector<int64, 4> Vec;
// Constructs all helper shapes, following the aforementioned rules.
//
// If "fewer_dims_optimization" is set to true (the default), the
// implementation tries to reduce intermediate dimensions needed to be more
// efficient. This is transparent to the caller.
//
// If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
// the same number of dimensions as the larger of the two inputs.
//
// If return_flattened_batch_indices is true, the implementation will compute
// for each output member of the flattened output, which batch indicies of
// each input correspond to it. This is disabled by default.
explicit BCastList(const Vec (&x)[N],
const bool fewer_dims_optimization = true,
const bool return_flattened_batch_indices = false);
~BCastList() {}
// Returns true iff two operands are compatible according to the
// broadcasting rule.
bool IsValid() const { return valid_; }
bool IsBroadcastingRequired() const { return broadcasting_required_; }
// If and only if IsValid(), the following fields can be used in
// implementing a broadcasted binary tensor operation according to
// the broadcasting rule.
const Vec& reshape(int i) const { return reshape_[i]; }
const Vec& bcast(int i) const { return bcast_[i]; }
const Vec& result_shape() const { return result_; }
const Vec& output_shape() const { return output_; }
const Vec& grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; }
const int64 output_batch_size() const { return output_batch_size_; }
// Returns the mapping from the flattened output batch indices to x's
// flattened batch indices. The result is a vector of length
// output_batch_size(). To compute the i'th batch output, a binary matmul-like
// operation should use the `x_batch_indices()[i]`th batch index of `x`.
// Note: Returns an empty vector if broadcasting is not required. Callers
// should only use this when IsBroadcastingRequired() returns true.
const std::vector<int64>& batch_indices(int i) const {
return batch_indices_[i];
}
protected:
bool valid_ = true;
bool broadcasting_required_ = true;
Vec reshape_[N];
Vec bcast_[N];
Vec result_;
Vec output_;
Vec grad_reduce_idx_[N];
int64 output_batch_size_;
std::vector<int64> batch_indices_[N];
static void Reverse(Vec* shape) {
std::reverse(shape->begin(), shape->end());
}
TF_DISALLOW_COPY_AND_ASSIGN(BCastList);
};
template <int N>
BCastList<N>::BCastList(const BCastList::Vec (&x)[N],
const bool fewer_dims_optimization,
const bool return_flattened_batch_indices) {
typedef BCastList::Vec Vec;
bool all_equal = true;
int largest_rank = 0;
output_batch_size_ = 1;
for (int i = 0; i < N; ++i) {
if (x[i] != x[0]) {
all_equal = false;
}
if (x[i].size() > largest_rank) {
largest_rank = x[i].size();
}
}
if (all_equal) {
broadcasting_required_ = false;
}
if (all_equal && TF_PREDICT_TRUE(fewer_dims_optimization)) {
// Fast path for common case of identical shapes.
int64 elements = 1;
const int rank = x[0].size();
output_.resize(rank);
for (int i = 0; i < rank; i++) {
const int64 dim = x[0][i];
elements *= dim;
output_[i] = dim;
}
result_.push_back(elements);
output_batch_size_ = elements;
for (int i = 0; i < N; ++i) {
reshape_[i].push_back(elements);
bcast_[i].push_back(1);
}
// grad_reduce_ is left as empty
return;
}
// Reverse all the shapes for convenience
// After the reverse, 0-th is the inner-most dimension.
Vec copy[N];
for (int i = 0; i < N; ++i) {
copy[i] = x[i];
Reverse(&copy[i]);
}
// 1-extend and align all vectors.
for (int i = 0; i < N; ++i) {
if (copy[i].size() < largest_rank) {
copy[i].resize(largest_rank, 1);
}
}
// Going through each dimension starting from the inner-most
// dimension, compares dimension of x and y. They are compatible if
// they are equal or either is 1.
// indices of j-th component of each input.
bool prev_is_one[N];
bool current_is_one[N];
for (int i = 0; i < N; ++i) {
prev_is_one[i] = false;
current_is_one[i] = false;
}
Vec output;
bool output_dim_set = false;
int output_dim = -1;
bool none_is_one = true;
bool set_one = false;
for (int j = 0; j < largest_rank; ++j) {
output_dim = -1;
output_dim_set = false;
none_is_one = true;
// Find which indices are 1.
for (int i = 0; i < N; ++i) {
// Keep track of which indices are 1.
if (copy[i][j] == 1) {
current_is_one[i] = true;
none_is_one = false;
} else {
current_is_one[i] = false;
if (!output_dim_set || copy[i][j] == output_dim) {
output_dim = copy[i][j];
output_dim_set = true;
} else {
valid_ = false;
return;
}
}
}
output_.push_back(output_dim_set ? output_dim : 1);
output_batch_size_ *= output_.back();
// All dimensions are 1.
if (!output_dim_set) {
if (!TF_PREDICT_TRUE(fewer_dims_optimization)) {
for (int i = 0; i < N; ++i) {
bcast_[i].push_back(1);
reshape_[i].push_back(1);
}
result_.push_back(1);
}
for (int i = 0; i < N; ++i) {
grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
}
// This will skip updating the previous state to the current one. We'll
// explain why this is safe below.
// Consider the previous state P, current state C and the next state N.
// In the case where N also is all ones (N == C), we'll do the same
// optimization here (push back one dimensions if we need to), which is
// safe and is expected.
//
// When N != C, we'll continue as usual. However, we might trigger the
// next block if N == P (because we didn't update the previous state).
// We trigger the next block if `fewer_dims_optimization` is true.
// This means that we did not modify and broadcast / rehshapes in this
// block (we skipped updating, since the one dimensions can be ignored).
// In essence, we only need to check whether the previous non-one state is
// equal to the current non-one state.
continue;
} else if (TF_PREDICT_TRUE(fewer_dims_optimization) &&
std::equal(current_is_one, current_is_one + N, prev_is_one) &&
set_one) {
// It is a run of the same broadcasting case as last time.
// We can reshape the input so that fewer dimensions
// are involved in the intermediate computation.
result_.back() *= output_dim;
for (int i = 0; i < N; ++i) {
reshape_[i].back() *= copy[i][j];
bcast_[i].back() *= current_is_one[i] ? output_dim : 1;
if (current_is_one[i] && !none_is_one) {
grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
}
}
} else {
result_.push_back(output_dim);
for (int i = 0; i < N; ++i) {
reshape_[i].push_back(copy[i][j]);
bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
if (current_is_one[i] && !none_is_one) {
grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
}
}
}
set_one = true;
for (int i = 0; i < N; ++i) {
prev_is_one[i] = current_is_one[i];
}
}
if (result_.empty()) {
result_.push_back(1);
for (int i = 0; i < N; ++i) {
reshape_[i].push_back(1);
bcast_[i].push_back(1);
}
}
// Do something about batches.
for (int i = 0; i < N; ++i) {
Reverse(&reshape_[i]);
Reverse(&bcast_[i]);
Reverse(&grad_reduce_idx_[i]);
}
Reverse(&result_);
Reverse(&output_);
// Only compute batch indices when we need broadcasting, and we aren't doing
// needless work (when the output size is 0 or the
// return_flattened_batch_indices isn't enabled).
if (return_flattened_batch_indices && broadcasting_required_ &&
output_batch_size_ > 0) {
for (int i = 0; i < N; ++i) {
ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i],
&batch_indices_[i]);
}
}
}
// BCast is a helper for broadcasting binary tensor operation.
// TensorFlow's broadcasting rule follows that of numpy (See
// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
@ -64,16 +342,8 @@ namespace tensorflow {
//
// The multiplication in the grad * backprop_x itself is also
// broadcasting following the same rule.
//
// TODO(zhifengc): Adds support for n-ary (n >= 2).
class BCast {
class BCast : public BCastList<2> {
public:
// A vector of int64 representing the shape of tensor. The 0-th
// element is the outer-most dimension and the last element is the
// inner-most dimension. Note that we do not use TensorShape since
// it's more convenient to manipulate Vec directly for this module.
typedef gtl::InlinedVector<int64, 4> Vec;
// Constructs all helper shapes, following the aforementioned rules.
//
// If "fewer_dims_optimization" is set to true (the default), the
@ -82,28 +352,43 @@ class BCast {
//
// If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
// the same number of dimensions as the larger of the two inputs.
BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true);
~BCast() {}
typedef gtl::InlinedVector<int64, 4> Vec;
// Returns true iff two operands are compatible according to the
// broadcasting rule.
bool IsValid() const { return valid_; }
BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true,
const bool return_flattened_batch_indices = false)
: BCastList<2>({x, y}, fewer_dims_optimization,
return_flattened_batch_indices) {}
~BCast() {}
// If and only if IsValid(), the following fields can be used in
// implementing a broadcasted binary tensor operation according to
// the broadcasting rule.
const Vec& x_reshape() const { return x_reshape_; }
const Vec& x_bcast() const { return x_bcast_; }
const Vec& y_reshape() const { return y_reshape_; }
const Vec& y_bcast() const { return y_bcast_; }
const Vec& x_reshape() const { return reshape_[0]; }
const Vec& x_bcast() const { return bcast_[0]; }
const Vec& y_reshape() const { return reshape_[1]; }
const Vec& y_bcast() const { return bcast_[1]; }
const Vec& result_shape() const { return result_; }
const Vec& output_shape() const { return output_; }
const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }
const Vec& grad_x_reduce_idx() const { return grad_reduce_idx_[0]; }
const Vec& grad_y_reduce_idx() const { return grad_reduce_idx_[1]; }
// Static helpers.
static Vec FromShape(const TensorShape& shape);
static TensorShape ToShape(const BCast::Vec& vec);
// Returns the mapping from the flattened output batch indices to x's
// flattened batch indices. The result is a vector of length
// output_batch_size(). To compute the i'th batch output, a binary matmul-like
// operation should use the `x_batch_indices()[i]`th batch index of `x`.
// Note: Returns an empty vector if broadcasting is not required. Callers
// should only use this when IsBroadcastingRequired() returns true.
const std::vector<int64>& x_batch_indices() const {
return batch_indices_[0];
}
// Returns the mapping from the flattened output batch indices to y's
// flattened batch indices. Similar to x_batch_indices().
// Note: Returns an empty vector if broadcasting is not required. Callers
// should only use this when IsBroadcastingRequired() returns true.
const std::vector<int64>& y_batch_indices() const {
return batch_indices_[1];
}
template <typename IndexType, int NDIMS>
static Eigen::array<IndexType, NDIMS> ToIndexArrayType(
@ -120,19 +405,11 @@ class BCast {
return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec);
}
// Static helpers.
static Vec FromShape(const TensorShape& shape);
static TensorShape ToShape(const Vec& vec);
private:
bool valid_ = true;
Vec x_reshape_;
Vec x_bcast_;
Vec y_reshape_;
Vec y_bcast_;
Vec result_;
Vec output_;
Vec grad_x_reduce_idx_;
Vec grad_y_reduce_idx_;
static void Reverse(Vec* shape);
TF_DISALLOW_COPY_AND_ASSIGN(BCast);
};

View File

@ -41,6 +41,40 @@ string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y,
return ret;
}
string BCastBatchIndices(const tensorflow::BCast::Vec& x,
const tensorflow::BCast::Vec& y,
const bool fewer_dims_optimization = true) {
tensorflow::BCast b(x, y, fewer_dims_optimization,
/*return_flattened_batch_indices=*/true);
string ret;
strings::StrAppend(&ret, "[", absl::StrJoin(b.x_batch_indices(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.y_batch_indices(), ","), "]");
return ret;
}
string BCastList3(const tensorflow::BCast::Vec& x,
const tensorflow::BCast::Vec& y,
const tensorflow::BCast::Vec& z,
const bool fewer_dims_optimization = true) {
tensorflow::BCastList<3> b({x, y, z}, fewer_dims_optimization);
if (!b.IsValid()) {
return "invalid";
}
string ret;
strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(0), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(0), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(1), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(1), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(2), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(2), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.result_shape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.output_shape(), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(0), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(1), ","), "]");
strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(2), ","), "]");
return ret;
}
TEST(BCastTest, Invalid) {
for (const bool use_optimization : {true, false}) {
EXPECT_EQ("invalid", BCast({5, 3, 2}, {3}, use_optimization));
@ -51,6 +85,26 @@ TEST(BCastTest, Invalid) {
}
}
TEST(BCastListTest, Invalid) {
for (const bool use_optimization : {true, false}) {
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {3}, {1}, use_optimization));
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {2, 2}, {1}, use_optimization));
EXPECT_EQ("invalid",
BCastList3({5, 3, 2}, {10, 1, 1}, {1}, use_optimization));
EXPECT_EQ("invalid", BCastList3({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1}, {1},
use_optimization));
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {1}, {3}, use_optimization));
EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {1}, {2, 2}, use_optimization));
EXPECT_EQ("invalid",
BCastList3({5, 3, 2}, {1}, {10, 1, 1}, use_optimization));
EXPECT_EQ("invalid", BCastList3({1}, {5, 3, 2}, {3}, use_optimization));
EXPECT_EQ("invalid", BCastList3({1}, {5, 3, 2}, {2, 2}, use_optimization));
EXPECT_EQ("invalid",
BCastList3({1}, {5, 3, 2}, {10, 1, 1}, use_optimization));
}
}
TEST(BCastTest, Basic_SameShape) {
// Effectively no broadcast needed.
EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}),
@ -66,6 +120,22 @@ TEST(BCastTest, Basic_SameShape) {
"[][]");
}
TEST(BCastListTest, Basic_SameShape) {
// Effectively no broadcast needed.
EXPECT_EQ(BCastList3({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}),
"[2310][1][2310][1][2310][1]"
"[2310]"
"[11,7,5,3,2]"
"[][][]");
EXPECT_EQ(
BCastList3({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, false),
"[11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[][][]");
}
TEST(BCastTest, Basic_SameShapeWithZeroDim) {
// Effectively no broadcast needed.
EXPECT_EQ(BCast({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}),
@ -81,9 +151,32 @@ TEST(BCastTest, Basic_SameShapeWithZeroDim) {
"[][]");
}
TEST(BCastListTest, Basic_SameShapeWithZeroDim) {
// Effectively no broadcast needed.
EXPECT_EQ(BCastList3({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}),
"[0][1][0][1][0][1]"
"[0]"
"[11,7,0,3,2]"
"[][][]");
EXPECT_EQ(
BCastList3({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, false),
"[11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1]"
"[11,7,0,3,2]"
"[11,7,0,3,2]"
"[][][]");
}
TEST(BCastTest, Basic_Scalar_Scalar) {
// Effectively it's a scalar and a scalar.
// [1, 1] [1]
//
EXPECT_EQ(BCast({1, 1}, {}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({1, 1}, {1}),
"[1][1][1][1]"
"[1]"
@ -110,6 +203,151 @@ TEST(BCastTest, Basic_Scalar_Scalar) {
"[0,1][0,1]");
}
TEST(BCastTest, Basic_TrueScalar_Scalar) {
// [] []
EXPECT_EQ(BCast({}, {}),
"[1][1][1][1]"
"[1]"
"[]"
"[][]");
// [] [1]
EXPECT_EQ(BCast({}, {1}),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
EXPECT_EQ(BCast({}, {1}, false),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
// [] [1, 1]
EXPECT_EQ(BCast({}, {1, 1}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({}, {1, 1}, false),
"[1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1]");
// [1] []
EXPECT_EQ(BCast({1}, {}),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
EXPECT_EQ(BCast({1}, {}, false),
"[1][1][1][1]"
"[1]"
"[1]"
"[0][0]");
// [1, 1] []
EXPECT_EQ(BCast({1, 1}, {}),
"[1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1]");
EXPECT_EQ(BCast({1, 1}, {}, false),
"[1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1]");
}
TEST(BCastListTest, Basic_Scalar_Scalar_Scalar) {
// Effectively it's a scalar and a scalar.
// [1, 1] [1] [1]
EXPECT_EQ(BCastList3({1, 1}, {1}, {1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1, 1}, {1}, {1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [1] [1, 1] [1]
EXPECT_EQ(BCastList3({1}, {1, 1}, {1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1}, {1, 1}, {1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [1] [1] [1, 1]
EXPECT_EQ(BCastList3({1}, {1}, {1, 1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1}, {1}, {1, 1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
}
TEST(BCastListTest, Basic_TrueScalar_Scalar_Scalar) {
// Effectively it's a scalar and a scalar.
// [1, 1] [1] []
EXPECT_EQ(BCastList3({1, 1}, {1}, {}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1, 1}, {1}, {}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [] [1, 1] [1]
EXPECT_EQ(BCastList3({}, {1, 1}, {1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({}, {1, 1}, {1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
// [1] [] [1, 1]
EXPECT_EQ(BCastList3({1}, {}, {1, 1}),
"[1][1][1][1][1][1]"
"[1]"
"[1,1]"
"[0,1][0,1][0,1]");
EXPECT_EQ(BCastList3({1}, {}, {1, 1}, false),
"[1,1][1,1][1,1][1,1][1,1][1,1]"
"[1,1]"
"[1,1]"
"[0,1][0,1][0,1]");
}
TEST(BCastTest, Basic_Tensor_Scalar) {
// Effectively it's a tensor and a scalar.
// [11, 7, 5, 3, 2] [1]
@ -327,6 +565,30 @@ TEST(BCastTest, Complex_BCast_To_Each_Other) {
EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}, false), truth);
}
TEST(BCastListTest, Complex_BCast_To_Each_Other) {
// Rare cases. x, y and z broadcast to each other. x,y and z are of
// different ranks.
// Can be verified in numpy as:
// import numpy as np
// x = np.arange(0,22).reshape([11,1,1,1,2])
// y = np.arange(0,21).reshape([7,1,3,1])
// z = np.arange(0,5).reshape([5,1,1])
// np.shape(x + y + z)
// Out[.]: (11, 7, 5, 3, 2)
//
string truth =
"[11,1,1,1,2][1,7,5,3,1]"
"[1,7,1,3,1][11,1,5,1,2]"
"[1,1,5,1,1][11,7,1,3,2]"
"[11,7,5,3,2]"
"[11,7,5,3,2]"
"[1,2,3][0,2,4][0,1,3,4]";
EXPECT_EQ(BCastList3({11, 1, 1, 1, 2}, {7, 1, 3, 1}, {5, 1, 1}), truth);
EXPECT_EQ(BCastList3({11, 1, 1, 1, 2}, {7, 1, 3, 1}, {5, 1, 1}, false),
truth);
}
TEST(BCastTest, TestZeroDimensionShape) {
// (2,0,5) and (5) in both orders
EXPECT_EQ(BCast({2, 0, 5}, {5}),
@ -398,6 +660,19 @@ TEST(BCastTest, TestZeroDimensionShape) {
"[0,1,3][]");
}
TEST(BCastTest, BatchIndices) {
EXPECT_EQ("[0,0,0,0][0,1,2,3]", BCastBatchIndices({1}, {4}));
// Invalid broadcast.
EXPECT_EQ("[][]", BCastBatchIndices({5}, {7}));
// Same shape, no batch indices.
EXPECT_EQ("[][]", BCastBatchIndices({2, 4, 6}, {2, 4, 6}));
// More complicated broadcasts.
EXPECT_EQ("[0,0,0,0,1,1,1,1,2,2,2,2][0,1,2,3,0,1,2,3,0,1,2,3]",
BCastBatchIndices({3, 1}, {1, 4}));
EXPECT_EQ("[0,0,1,1,2,2,0,0,1,1,2,2][0,1,0,1,0,1,2,3,2,3,2,3]",
BCastBatchIndices({3, 1}, {2, 1, 2}));
}
static void BM_BCastSetup(int iters, int same_shape) {
if (same_shape) {
testing::SetLabel("same_shapes");

View File

@ -16,40 +16,6 @@ limitations under the License.
#include "tensorflow/core/util/matmul_bcast.h"
namespace tensorflow {
namespace {
// Returns the mapping from the output batch indices to the corresponding
// input's batch indices, given the input's "reshape" and "bcast" shapes as
// returned by the BCast helper class. The i'th element denotes the (flattened)
// batch index of the input that must be used to compute the i'th batch output.
void ComputeBatchIndices(const int64 output_batch_size,
const MatMulBCast::Vec& reshape,
const MatMulBCast::Vec& bcast,
std::vector<int64>* out_indices) {
// Populates the mapping in out_indices. This algorithm is identical to
// the following steps:
// - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
// - Broadcast to the output shape.
// - Reshape back to a flat 1D vector.
out_indices->resize(output_batch_size);
int64 num_output_elements = 1;
int64 num_input_elements = 1;
for (int64 i = reshape.size() - 1; i >= 0; --i) {
// Replicate the already populated mapping an additional (dim - 1) times.
// If we are broadcasting, just copy the existing mapping.
// Otherwise, add another dimension from the input shape.
const int64 dim = std::max(reshape[i], bcast[i]);
const int64 incr = bcast[i] > 1 ? 0 : num_input_elements;
for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) {
(*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
}
num_output_elements *= dim;
num_input_elements *= reshape[i];
}
}
} // namespace
MatMulBCast::MatMulBCast(Vec x, Vec y) {
if (x.size() < 2 || y.size() < 2) return;
x.resize(x.size() - 2);

View File

@ -161,6 +161,7 @@ tf_py_test(
"//tensorflow/python:platform",
"//tensorflow/python:stateful_random_ops",
],
shard_count = 3,
tags = ["no_oss"],
)

View File

@ -35,14 +35,14 @@ _SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64,
class RandomBinomialTest(test.TestCase):
"""This is a large test due to the moments computation taking some time."""
def _Sampler(self, num, counts, probs, dtype, seed=None):
def _Sampler(
self, num, counts, probs, dtype, gen=None, sample_shape=None, seed=None):
def func():
rng = stateful_random_ops.Generator.from_seed(seed).binomial(
shape=[10 * num], counts=counts, probs=probs, dtype=dtype)
ret = array_ops.reshape(rng, [10, num])
ret = self.evaluate(ret)
return ret
shape = [10 * num] if sample_shape is None else sample_shape
generator = gen if gen is not None else (
stateful_random_ops.Generator.from_seed(seed))
return generator.binomial(
shape=shape, counts=counts, probs=probs, dtype=dtype)
return func
@ -57,15 +57,16 @@ class RandomBinomialTest(test.TestCase):
# we want to tolerate. Since the z-test approximates a unit normal
# distribution, it should almost definitely never exceed 6.
z_limit = 6.0
gen = stateful_random_ops.Generator.from_seed(seed=12345)
for dt in _SUPPORTED_DTYPES:
# Test when n * p > 10, and n * p < 10
for stride in 0, 4, 10:
for counts in (1., 10., 22., 50.):
for prob in (0.1, 0.5, 0.8):
sampler = self._Sampler(int(1e5), counts, prob, dt, seed=12345)
sampler = self._Sampler(int(5e4), counts, prob, dt, gen=gen)
z_scores = util.test_moment_matching(
# Use float64 samples.
sampler().astype(np.float64),
self.evaluate(sampler()).astype(np.float64),
number_moments=6,
dist=stats.binom(counts, prob),
stride=stride,
@ -77,7 +78,7 @@ class RandomBinomialTest(test.TestCase):
for dt in dtypes.float16, dtypes.float32, dtypes.float64:
sx = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345)
sy = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345)
self.assertAllEqual(sx(), sy())
self.assertAllEqual(self.evaluate(sx()), self.evaluate(sy()))
def testZeroShape(self):
rnd = stateful_random_ops.Generator.from_seed(12345).binomial([0], [], [])
@ -88,6 +89,8 @@ class RandomBinomialTest(test.TestCase):
# Scalar parameters.
rnd = rng.binomial(shape=[10], counts=np.float32(2.), probs=np.float32(0.5))
self.assertEqual([10], rnd.shape.as_list())
rnd = rng.binomial(shape=[], counts=np.float32(2.), probs=np.float32(0.5))
self.assertEqual([], rnd.shape.as_list())
# Vector parameters.
rnd = rng.binomial(
@ -96,10 +99,10 @@ class RandomBinomialTest(test.TestCase):
probs=0.3 * array_ops.ones([10], dtype=np.float32))
self.assertEqual([10], rnd.shape.as_list())
rnd = rng.binomial(
shape=[2, 5],
shape=[5, 2],
counts=array_ops.ones([2], dtype=np.float32),
probs=0.4 * array_ops.ones([2], dtype=np.float32))
self.assertEqual([2, 5], rnd.shape.as_list())
self.assertEqual([5, 2], rnd.shape.as_list())
# Scalar counts, vector probs.
rnd = rng.binomial(
@ -115,6 +118,20 @@ class RandomBinomialTest(test.TestCase):
probs=np.float32(0.9))
self.assertEqual([10], rnd.shape.as_list())
# Tensor parameters
rnd = rng.binomial(
shape=[10, 2, 3],
counts=array_ops.ones([2, 1], dtype=np.float32),
probs=0.9 * array_ops.ones([1, 3], dtype=np.float32))
self.assertEqual([10, 2, 3], rnd.shape.as_list())
# Tensor parameters
rnd = rng.binomial(
shape=[10, 2, 3, 5],
counts=array_ops.ones([2, 1, 5], dtype=np.float32),
probs=0.9 * array_ops.ones([1, 3, 1], dtype=np.float32))
self.assertEqual([10, 2, 3, 5], rnd.shape.as_list())
@test_util.run_v2_only
def testCornerCases(self):
rng = stateful_random_ops.Generator.from_seed(12345)
@ -126,5 +143,61 @@ class RandomBinomialTest(test.TestCase):
shape=[6], counts=counts, probs=probs, dtype=np.float32)
self.assertAllEqual(expected, self.evaluate(result))
@test_util.run_v2_only
def testMomentsForTensorInputs(self):
try:
from scipy import stats # pylint: disable=g-import-not-at-top
except ImportError as e:
tf_logging.warn("Cannot test moments: %s", e)
return
# The moments test is a z-value test. This is the largest z-value
# we want to tolerate. Since the z-test approximates a unit normal
# distribution, it should almost definitely never exceed 6.
z_limit = 6.0
class ScipyBinomialWrapper(object):
"""Wrapper for stats.binom to support broadcasting."""
def __init__(self, counts, probs):
self.counts = counts
self.probs = probs
def moment(self, i):
counts, probs = np.broadcast_arrays(self.counts, self.probs)
broadcast_shape = counts.shape
counts = np.reshape(counts, (-1,))
probs = np.reshape(probs, (-1,))
counts_and_probs = np.stack([counts, probs], axis=-1)
moments = np.fromiter(
(stats.binom(cp[0], cp[1]).moment(i) for cp in counts_and_probs),
dtype=np.float64)
return np.reshape(moments, broadcast_shape)
gen = stateful_random_ops.Generator.from_seed(seed=23455)
for dt in _SUPPORTED_DTYPES:
# Test when n * p > 10, and n * p < 10
for stride in 0, 4, 10:
counts = np.float64(np.random.randint(low=1, high=20, size=(2, 1, 4)))
probs = np.random.uniform(size=(1, 3, 4))
sampler = self._Sampler(
int(5e4),
counts,
probs,
dt,
gen=gen,
sample_shape=[10 * int(5e4), 2, 3, 4])
# Use float64 samples.
samples = self.evaluate(sampler()).astype(np.float64)
z_scores = util.test_moment_matching(
samples,
number_moments=6,
dist=ScipyBinomialWrapper(counts, probs),
stride=stride,
)
self.assertAllLess(z_scores, z_limit)
if __name__ == "__main__":
test.main()

View File

@ -49,10 +49,12 @@ def test_moment_matching(
sample_moments = []
expected_moments = []
variance_sample_moments = []
x = samples.flat
for i in range(1, number_moments + 1):
strided_range = x[::(i - 1) * stride + 1]
sample_moments.append(np.mean(strided_range ** i))
if len(samples.shape) == 2:
strided_range = samples.flat[::(i - 1) * stride + 1]
else:
strided_range = samples[::(i - 1) * stride + 1, ...]
sample_moments.append(np.mean(strided_range**i, axis=0))
expected_moments.append(dist.moment(i))
variance_sample_moments.append(
(dist.moment(2 * i) - dist.moment(i) ** 2) / len(strided_range))
@ -66,8 +68,7 @@ def test_moment_matching(
i * np.finfo(samples.dtype).eps)
tiny = np.finfo(samples.dtype).tiny
assert np.all(total_variance > 0)
if total_variance < tiny:
total_variance = tiny
total_variance = np.where(total_variance < tiny, tiny, total_variance)
# z_test is approximately a unit normal distribution.
z_test_scores.append(abs(
(sample_moments[i - 1] - expected_moments[i - 1]) / np.sqrt(

View File

@ -598,22 +598,30 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
```python
counts = [10., 20.]
# Probability of success.
probs = [0.8, 0.9]
probs = [0.8]
rng = tf.random.experimental.Generator.from_seed(seed=234)
binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
counts = ... # Shape [3, 1, 2]
probs = ... # Shape [1, 4, 2]
shape = [3, 4, 3, 4, 2]
rng = tf.random.experimental.Generator.from_seed(seed=1717)
# Sample shape will be [3, 4, 3, 4, 2]
binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs)
```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
counts: A 0/1-D Tensor or Python value. The counts of the binomial
distribution. Must be broadcastable with the leftmost dimension
defined by `shape`.
probs: A 0/1-D Tensor or Python value. The probability of success for the
binomial distribution. Must be broadcastable with the leftmost
dimension defined by `shape`.
counts: Tensor. The counts of the binomial distribution. Must be
broadcastable with `probs`, and broadcastable with the rightmost
dimensions of `shape`.
probs: Tensor. The probability of success for the
binomial distribution. Must be broadcastable with `counts` and
broadcastable with the rightmost dimensions of `shape`.
dtype: The type of the output. Default: tf.int32
name: A name for the operation (optional).