diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc index df27541bb66..f94c5a08ae9 100644 --- a/tensorflow/core/kernels/random_binomial_op.cc +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/bcast.h" #include "tensorflow/core/util/guarded_philox_random.h" #include "tensorflow/core/util/work_sharder.h" @@ -86,7 +87,7 @@ double binomial_inversion(double count, double prob, return num_geom; } -double stirling_approx_tail(double k) { +inline double stirling_approx_tail(double k) { static double kTailValues[] = {0.0810614667953272, 0.0413406959554092, 0.0276779256849983, 0.02079067210376509, 0.0166446911898211, 0.0138761288230707, @@ -102,7 +103,7 @@ double stirling_approx_tail(double k) { // We use a transformation-rejection algorithm from // pairs of uniform random variables due to Hormann. // https://www.tandfonline.com/doi/abs/10.1080/00949659308811496 -double btrs(double count, double prob, random::PhiloxRandom* gen) { +inline double btrs(double count, double prob, random::PhiloxRandom* gen) { using Eigen::numext::abs; using Eigen::numext::floor; using Eigen::numext::log; @@ -119,6 +120,9 @@ double btrs(double count, double prob, random::PhiloxRandom* gen) { const double v_r = 0.92 - 4.2 / b; const double r = prob / (1 - prob); + const double alpha = (2.83 + 5.1 / b) * stddev; + const double m = floor((count + 1) * prob); + Uniform uniform; typename Uniform::ResultType uniform_result; int16 uniform_remaining = 0; @@ -143,8 +147,6 @@ double btrs(double count, double prob, random::PhiloxRandom* gen) { continue; } - double alpha = (2.83 + 5.1 / b) * stddev; - double m = floor((count + 1) * prob); // This deviates from Hormann's BRTS algorithm, as there is a log missing. // For all (u, v) pairs outside of the bounding box, this calculates the // transformed-reject ratio. @@ -169,66 +171,83 @@ template struct RandomBinomialFunctor { void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches, int64 samples_per_batch, int64 num_elements, - typename TTypes::ConstFlat counts, + const BCast& bcast, typename TTypes::ConstFlat counts, typename TTypes::ConstFlat probs, const random::PhiloxRandom& gen, typename TTypes::Flat output) { auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); - auto DoWork = [samples_per_batch, num_elements, &counts, &probs, &gen, - &output](int start_batch, int limit_batch) { - // Capturing "gen" by-value would only make a copy for the _shared_ - // lambda. Since we want to let each worker have its own copy, we pass - // "gen" by reference and explicitly do a copy assignment here. - random::PhiloxRandom gen_copy = gen; - // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to - // us using the same state in different batches. - // The sample from each iteration uses 2 random numbers. - gen_copy.Skip(start_batch * 2 * 3 * (samples_per_batch + 3) / 4); - + // The output layout is [B1, ... Bk, H1, ... Hm]. We have [B1, ... Bk] for + // the sample shape and [H1, ... Hm] for the batch shape of the samples. + // We have B1 * ... * Bk samples per batch member we need. + auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs, + &gen, &output](int start_output, int limit_output) { // Vectorized intermediate calculations for uniform rejection sampling. // We always generate at most 4 samples. Eigen::array z; Eigen::array g; + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& counts_batch_indices = bcast.x_batch_indices(); + const auto& probs_batch_indices = bcast.y_batch_indices(); + auto output_flat = output.data(); - for (int64 b = start_batch; b < limit_batch; ++b) { - // We are passed a flat array for each of the parameter tensors. - // The input is either a scalar broadcasted to all batches or a vector - // with length num_batches, but the scalar becomes an array of length 1. - T count = counts((counts.dimension(0) == 1) ? 0 : b); - T prob = probs((probs.dimension(0) == 1) ? 0 : b); - - // The last batch can be short, if we adjusted num_batches and - // samples_per_batch. - const int64 limit_sample = - std::min((b + 1) * samples_per_batch, num_elements); - int64 sample = b * samples_per_batch; + // We partition work across batches (count, prob) and then across samples + // per batch member, to avoid extra work. + for (int64 output_idx = start_output; output_idx < limit_output; + // output_idx is incremented with the inner loops below. + ) { + int64 batch_idx = output_idx / samples_per_batch; + U* const output_batch_offset = output_flat + batch_idx; + // Generate batch counts from BCast, as it has the right indices to loop + // over. + T count, prob; + if (should_bcast) { + count = counts(counts_batch_indices[batch_idx]); + prob = probs(probs_batch_indices[batch_idx]); + } else { + count = counts(batch_idx); + prob = probs(batch_idx); + } // Calculate normalized samples, then convert them. // Determine the method to use. double dcount = static_cast(count); if (dcount <= 0.0 || prob <= T(0.0)) { - while (sample < limit_sample) { - output(sample) = static_cast(0.0); - sample++; + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output; + ++sample_idx, ++output_idx) { + output_batch_offset[sample_idx * num_batches] = static_cast(0.0); } } else if (prob >= T(1.0)) { - while (sample < limit_sample) { - output(sample) = static_cast(dcount); - sample++; + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output; + ++sample_idx, ++output_idx) { + output_batch_offset[sample_idx * num_batches] = + static_cast(dcount); } } else if (prob <= T(0.5)) { double dp = static_cast(prob); if (count * prob >= T(10)) { - while (sample < limit_sample) { - output(sample) = static_cast(btrs(dcount, dp, &gen_copy)); - sample++; + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output; + ++sample_idx, ++output_idx) { + random::PhiloxRandom gen_copy = gen; + gen_copy.Skip(256 * output_idx); + output_batch_offset[sample_idx * num_batches] = + static_cast(btrs(dcount, dp, &gen_copy)); } } else { - while (sample < limit_sample) { - output(sample) = + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output; + ++sample_idx, ++output_idx) { + random::PhiloxRandom gen_copy = gen; + // For binomial inversion, we have mean <= 10, variance <= 10. + // This means on average we need at most 10 number of samples, + // and for 10 standard deviations, we need 42 samples. We reserve + // that much. + gen_copy.Skip(42 * output_idx); + output_batch_offset[sample_idx * num_batches] = static_cast(binomial_inversion(dcount, dp, &gen_copy)); - sample++; } } } else if (prob > T(0.5)) { @@ -236,45 +255,41 @@ struct RandomBinomialFunctor { double dcount = static_cast(count); double dq = static_cast(q); if (count * q >= T(10)) { - while (sample < limit_sample) { - output(sample) = + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output; + ++sample_idx, ++output_idx) { + random::PhiloxRandom gen_copy = gen; + gen_copy.Skip(256 * output_idx); + output_batch_offset[sample_idx * num_batches] = static_cast(dcount - btrs(dcount, dq, &gen_copy)); - sample++; } } else { - while (sample < limit_sample) { - output(sample) = static_cast( + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output; + ++sample_idx, ++output_idx) { + random::PhiloxRandom gen_copy = gen; + // For binomial inversion, we have mean <= 10, variance <= 10. + // This means on average we need at most 10 number of samples, + // and for 10 standard deviations, we need 42 samples. We reserve + // that much. + gen_copy.Skip(42 * output_idx); + output_batch_offset[sample_idx * num_batches] = static_cast( dcount - binomial_inversion(dcount, dq, &gen_copy)); - sample++; } } } else { // prob is NaN // TODO(srvasude): What should happen if prob is NaN but the output // type is an integer (which doesn't have a sentinel for NaN)? Fail // the whole batch sample? Return a specialized sentinel like -1? - while (sample < limit_sample) { - output(sample) = static_cast(NAN); - sample++; + for (int64 sample_idx = output_idx % samples_per_batch; + sample_idx < samples_per_batch && output_idx < limit_output; + ++sample_idx, ++output_idx) { + output_batch_offset[sample_idx * num_batches] = static_cast(NAN); } } } }; - const int64 batch_init_cost = - // normMin, normMax - (Eigen::TensorOpCost::AddCost() + - Eigen::TensorOpCost::MulCost()) * - 2 - // sqrtFactor - + Eigen::TensorOpCost::AddCost() + - Eigen::TensorOpCost::MulCost() + - Eigen::internal::functor_traits< - Eigen::internal::scalar_sqrt_op>::Cost - // cutoff - + Eigen::TensorOpCost::MulCost() * 4 + - Eigen::internal::functor_traits>::Cost - // diff - + Eigen::TensorOpCost::AddCost(); // This will depend on count * p (or count * q). // For n * p < 10, on average, O(n * p) calls to uniform are // needed, with that @@ -290,17 +305,15 @@ struct RandomBinomialFunctor { // 2 uniform generations along with 5 other ops at 3-6 cycles each. // ~15 / .89 = ~16 // - // In total this should be ~529 + 2 * Uniform::kElementCost. + // In total this (rate >= 10) should be ~329 + 2 * Uniform::kElementCost. // We assume that half the tensor has rate < 10, so on average 6 // uniform's // will be needed. We will upper bound the other op cost by the one for // rate > 10. - static const int kElementCost = 529 + 6 * Uniform::kElementCost + + static const int kElementCost = 329 + 6 * Uniform::kElementCost + 6 * random::PhiloxRandom::kElementCost; - // Assume we use uniform sampling, and accept the 2nd sample on average. - const int64 batch_cost = batch_init_cost + kElementCost * samples_per_batch; - Shard(worker_threads.num_threads, worker_threads.workers, num_batches, - batch_cost, DoWork); + Shard(worker_threads.num_threads, worker_threads.workers, num_elements, + kElementCost, DoWork); } }; @@ -324,72 +337,60 @@ class RandomBinomialOp : public OpKernel { const Tensor& counts_tensor = ctx->input(3); const Tensor& probs_tensor = ctx->input(4); + tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(), + probs_tensor.shape().dim_sizes(), + /*fewer_dims_optimization=*/false, + /*return_flattened_batch_indices=*/true); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "counts and probs must have compatible batch dimensions: ", + counts_tensor.shape().DebugString(), " vs. ", + probs_tensor.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), + errors::InvalidArgument("Input shape should be a vector, got shape: ", + shape_tensor.shape().DebugString())); + OP_REQUIRES(ctx, + (shape_tensor.dtype() == DataType::DT_INT32 || + shape_tensor.dtype() == DataType::DT_INT64), + errors::InvalidArgument( + "Input shape should have dtype {int32, int64}.")); + + // Let's check that the shape tensor dominates the broadcasted tensor. + TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); + TensorShape output_shape; + if (shape_tensor.dtype() == DataType::DT_INT32) { + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec(), + &output_shape)); + } else { + OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec(), + &output_shape)); + } + OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape), + errors::InvalidArgument( + "Shape passed in must end with broadcasted shape.")); + // Now that we have a guarantee, we can get the additional dimensions added + // by sampling. OP_REQUIRES(ctx, alg_tensor.dims() == 0, errors::InvalidArgument("algorithm must be of shape [], not ", alg_tensor.shape().DebugString())); Algorithm alg = alg_tensor.flat()(0); - OP_REQUIRES( - ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), - errors::InvalidArgument("Input shape should be a vector, got shape: ", - shape_tensor.shape().DebugString())); - int32 num_batches = shape_tensor.flat()(0); - - int32 samples_per_batch = 1; - const int32 num_dims = shape_tensor.dim_size(0); - for (int32 i = 1; i < num_dims; i++) { + int64 samples_per_batch = 1; + const int64 num_sample_dims = + (shape_tensor.dim_size(0) - bcast.output_shape().size()); + for (int64 i = 0; i < num_sample_dims; ++i) { samples_per_batch *= shape_tensor.flat()(i); } - const int32 num_elements = num_batches * samples_per_batch; - - // Allocate the output before fudging num_batches and samples_per_batch. - auto shape_vec = shape_tensor.flat(); - TensorShape tensor_shape; - OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( - shape_vec.data(), shape_vec.size(), &tensor_shape)); - Tensor* samples_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor)); - - // Parameters must be 0-d or 1-d. - OP_REQUIRES(ctx, counts_tensor.dims() <= 1, - errors::InvalidArgument( - "Input counts should be a scalar or vector, got shape: ", - counts_tensor.shape().DebugString())); - OP_REQUIRES(ctx, probs_tensor.dims() <= 1, - errors::InvalidArgument( - "Input probs should be a scalar or vector, got shape: ", - probs_tensor.shape().DebugString())); - - if ((counts_tensor.dims() == 0 || counts_tensor.dim_size(0) == 1) && - (probs_tensor.dims() == 0 || probs_tensor.dim_size(0) == 1)) { - // All batches have the same parameters, so we can update the batch size - // to a reasonable value to improve parallelism (ensure enough batches, - // and no very small batches which have high overhead). - int32 size = num_batches * samples_per_batch; - int32 adjusted_samples = kDesiredBatchSize; - // Ensure adjusted_batches * adjusted_samples >= size. - int32 adjusted_batches = Eigen::divup(size, adjusted_samples); - num_batches = adjusted_batches; - samples_per_batch = adjusted_samples; - } else { - // Parameters must be broadcastable to the shape [num_batches]. - OP_REQUIRES( - ctx, - TensorShapeUtils::IsScalar(counts_tensor.shape()) || - counts_tensor.dim_size(0) == 1 || - counts_tensor.dim_size(0) == num_batches, - errors::InvalidArgument( - "Input counts should have length 1 or shape[0], got shape: ", - counts_tensor.shape().DebugString())); - OP_REQUIRES( - ctx, - TensorShapeUtils::IsScalar(probs_tensor.shape()) || - probs_tensor.dim_size(0) == 1 || - probs_tensor.dim_size(0) == num_batches, - errors::InvalidArgument( - "Input probs should have length 1 or shape[0], got shape: ", - probs_tensor.shape().DebugString())); + int64 num_batches = 1; + for (int64 i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) { + num_batches *= shape_tensor.flat()(i); } + const int64 num_elements = num_batches * samples_per_batch; + + Tensor* samples_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor)); + core::RefCountPtr var; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); @@ -415,7 +416,6 @@ class RandomBinomialOp : public OpKernel { "For Philox algorithm, the size of state must be at least ", PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size())); - // Each worker has the fudge factor for samples_per_batch, so use it here. OP_REQUIRES_OK(ctx, PrepareToUpdateVariable( ctx, var_tensor, var->copy_on_read_mode.load())); auto var_data = var_tensor_flat.data(); @@ -425,8 +425,9 @@ class RandomBinomialOp : public OpKernel { auto binomial_functor = functor::RandomBinomialFunctor(); binomial_functor(ctx, ctx->eigen_device(), num_batches, - samples_per_batch, num_elements, counts_tensor.flat(), - probs_tensor.flat(), philox, samples_tensor->flat()); + samples_per_batch, num_elements, bcast, + counts_tensor.flat(), probs_tensor.flat(), philox, + samples_tensor->flat()); } private: diff --git a/tensorflow/core/ops/stateful_random_ops.cc b/tensorflow/core/ops/stateful_random_ops.cc index 9537e614069..ecc570a44f3 100644 --- a/tensorflow/core/ops/stateful_random_ops.cc +++ b/tensorflow/core/ops/stateful_random_ops.cc @@ -107,10 +107,6 @@ REGISTER_OP("StatefulRandomBinomial") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(3), 1, &unused)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); - ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &out)); c->set_output(0, out); diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc index 3a5f1f83af8..eec22899ab1 100644 --- a/tensorflow/core/util/bcast.cc +++ b/tensorflow/core/util/bcast.cc @@ -18,153 +18,16 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace tensorflow { -/* static */ -void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); } - -BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) { - if (sx == sy && TF_PREDICT_TRUE(fewer_dims_optimization)) { - // Fast path for common case of identical shapes for sx and sy - int64 elements = 1; - const int n = sx.size(); - output_.resize(n); - for (int i = 0; i < n; i++) { - const int64 dim = sx[i]; - elements *= dim; - output_[i] = dim; - } - result_.push_back(elements); - x_reshape_.push_back(elements); - y_reshape_.push_back(elements); - x_bcast_.push_back(1); - y_bcast_.push_back(1); - // grad_x_reduce_ and grad_y_reduce_ are left as empty - } else { - // Reverse the shape of x and y for convenience. - // After the reverse, 0-th is the inner-most dimension. - Vec x = sx; - Vec y = sy; - Reverse(&x); - Reverse(&y); - - // 1-extend and align x and y so that they are the same size. - if (x.size() > y.size()) { - y.resize(x.size(), 1); - } else { - x.resize(y.size(), 1); - } - - // Going through each dimension starting from the inner-most - // dimension, compares dimension of x and y. They are compatible if - // they are equal or either is 1. - enum State { - UNKNOWN, - SAME, - X_ONE, - Y_ONE, - }; - State prev = UNKNOWN; - const int64 n = x.size(); - for (int i = 0; i < n; ++i) { - // Output shape. - State curr = UNKNOWN; - const int64 x_i = x[i]; // i-th dimension of x. - const int64 y_i = y[i]; // i-th dimension of y. - int64 o_i; // i-th dimension of the output. - int64 bx_i; // i-th broadcast for x. - int64 by_i; // i-th broadcast for y. - // Invariant: - // o_i = x_i * bx_i = y_i * by_i - if (x_i == y_i) { - // No broadcast. - o_i = x_i; - bx_i = 1; - by_i = 1; - curr = SAME; - } else if (x_i == 1) { - // x broadcast to y on this dimension. - o_i = y_i; - bx_i = y_i; - by_i = 1; - grad_x_reduce_idx_.push_back(n - 1 - i); - curr = X_ONE; - } else if (y_i == 1) { - // y broadcast to x on this dimension. - o_i = x_i; - bx_i = 1; - by_i = x_i; - grad_y_reduce_idx_.push_back(n - 1 - i); - curr = Y_ONE; - } else { - valid_ = false; - return; - } - output_.push_back(o_i); - // Reshape/broadcast. - // Invariant: - // result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i] - if (curr == SAME && x_i == 1) { - // Both side are 1s. - grad_x_reduce_idx_.push_back(n - 1 - i); - grad_y_reduce_idx_.push_back(n - 1 - i); - if (!TF_PREDICT_TRUE(fewer_dims_optimization)) { - result_.push_back(o_i); - x_reshape_.push_back(x_i); - x_bcast_.push_back(bx_i); - y_reshape_.push_back(y_i); - y_bcast_.push_back(by_i); - } - continue; - } else if (TF_PREDICT_TRUE(fewer_dims_optimization) && prev == curr) { - // It is a run of the same cases(no broadcast, x broadcast to y, y - // broadcast to x). We can reshape the input so that fewer dimensions - // are involved in the intermediate computation. - result_.back() *= o_i; - x_reshape_.back() *= x_i; - x_bcast_.back() *= bx_i; - y_reshape_.back() *= y_i; - y_bcast_.back() *= by_i; - } else { - result_.push_back(o_i); - x_reshape_.push_back(x_i); - x_bcast_.push_back(bx_i); - y_reshape_.push_back(y_i); - y_bcast_.push_back(by_i); - } - prev = curr; - } - - if (result_.empty()) { - // Can happen when both x and y are effectively scalar. - result_.push_back(1); - x_reshape_.push_back(1); - x_bcast_.push_back(1); - y_reshape_.push_back(1); - y_bcast_.push_back(1); - } - - // Reverse all vectors since x and y were reversed at very - // beginning. - Reverse(&x_reshape_); - Reverse(&x_bcast_); - Reverse(&y_reshape_); - Reverse(&y_bcast_); - Reverse(&result_); - Reverse(&output_); - Reverse(&grad_x_reduce_idx_); - Reverse(&grad_y_reduce_idx_); - } -} - BCast::Vec BCast::FromShape(const TensorShape& shape) { const int N = shape.dims(); - BCast::Vec ret(N); + BCastList::Vec ret(N); for (int i = 0; i < N; ++i) { ret[i] = shape.dim_size(i); } return ret; } -TensorShape BCast::ToShape(const BCast::Vec& vec) { +TensorShape BCast::ToShape(const BCastList::Vec& vec) { TensorShape shape(vec); return shape; } diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h index 2d647fd8d86..62d3968aadc 100644 --- a/tensorflow/core/util/bcast.h +++ b/tensorflow/core/util/bcast.h @@ -25,6 +25,284 @@ limitations under the License. namespace tensorflow { +// Returns the mapping from the output batch indices to the corresponding +// input's batch indices, given the input's "reshape" and "bcast" shapes as +// returned by the BCastList helper class. The i'th element denotes the +// (flattened) batch index of the input that must be used to compute the i'th +// batch output. +// +inline void ComputeBatchIndices(const int64 output_batch_size, + const gtl::InlinedVector& reshape, + const gtl::InlinedVector& bcast, + std::vector* out_indices) { + // Populates the mapping in out_indices. This algorithm is identical to + // the following steps: + // - Reshape {0, 1, ..., input_batch_size - 1} to the input shape. + // - Broadcast to the output shape. + // - Reshape back to a flat 1D vector. + out_indices->resize(output_batch_size); + int64 num_output_elements = 1; + int64 num_input_elements = 1; + for (int64 i = reshape.size() - 1; i >= 0; --i) { + // Replicate the already populated mapping an additional (dim - 1) times. + // If we are broadcasting, just copy the existing mapping. + // Otherwise, add another dimension from the input shape. + const int64 dim = std::max(reshape[i], bcast[i]); + const int64 incr = bcast[i] > 1 ? 0 : num_input_elements; + for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) { + (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr; + } + num_output_elements *= dim; + num_input_elements *= reshape[i]; + } +} + +template +class BCastList { + public: + // A vector of int64 representing the shape of tensor. The 0-th + // element is the outer-most dimension and the last element is the + // inner-most dimension. Note that we do not use TensorShape since + // it's more convenient to manipulate Vec directly for this module. + typedef gtl::InlinedVector Vec; + + // Constructs all helper shapes, following the aforementioned rules. + // + // If "fewer_dims_optimization" is set to true (the default), the + // implementation tries to reduce intermediate dimensions needed to be more + // efficient. This is transparent to the caller. + // + // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have + // the same number of dimensions as the larger of the two inputs. + // + // If return_flattened_batch_indices is true, the implementation will compute + // for each output member of the flattened output, which batch indicies of + // each input correspond to it. This is disabled by default. + explicit BCastList(const Vec (&x)[N], + const bool fewer_dims_optimization = true, + const bool return_flattened_batch_indices = false); + ~BCastList() {} + + // Returns true iff two operands are compatible according to the + // broadcasting rule. + bool IsValid() const { return valid_; } + bool IsBroadcastingRequired() const { return broadcasting_required_; } + + // If and only if IsValid(), the following fields can be used in + // implementing a broadcasted binary tensor operation according to + // the broadcasting rule. + const Vec& reshape(int i) const { return reshape_[i]; } + const Vec& bcast(int i) const { return bcast_[i]; } + const Vec& result_shape() const { return result_; } + const Vec& output_shape() const { return output_; } + const Vec& grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; } + const int64 output_batch_size() const { return output_batch_size_; } + + // Returns the mapping from the flattened output batch indices to x's + // flattened batch indices. The result is a vector of length + // output_batch_size(). To compute the i'th batch output, a binary matmul-like + // operation should use the `x_batch_indices()[i]`th batch index of `x`. + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& batch_indices(int i) const { + return batch_indices_[i]; + } + + protected: + bool valid_ = true; + bool broadcasting_required_ = true; + Vec reshape_[N]; + Vec bcast_[N]; + Vec result_; + Vec output_; + Vec grad_reduce_idx_[N]; + + int64 output_batch_size_; + std::vector batch_indices_[N]; + + static void Reverse(Vec* shape) { + std::reverse(shape->begin(), shape->end()); + } + + TF_DISALLOW_COPY_AND_ASSIGN(BCastList); +}; + +template +BCastList::BCastList(const BCastList::Vec (&x)[N], + const bool fewer_dims_optimization, + const bool return_flattened_batch_indices) { + typedef BCastList::Vec Vec; + bool all_equal = true; + int largest_rank = 0; + output_batch_size_ = 1; + for (int i = 0; i < N; ++i) { + if (x[i] != x[0]) { + all_equal = false; + } + if (x[i].size() > largest_rank) { + largest_rank = x[i].size(); + } + } + if (all_equal) { + broadcasting_required_ = false; + } + if (all_equal && TF_PREDICT_TRUE(fewer_dims_optimization)) { + // Fast path for common case of identical shapes. + int64 elements = 1; + const int rank = x[0].size(); + output_.resize(rank); + for (int i = 0; i < rank; i++) { + const int64 dim = x[0][i]; + elements *= dim; + output_[i] = dim; + } + result_.push_back(elements); + output_batch_size_ = elements; + for (int i = 0; i < N; ++i) { + reshape_[i].push_back(elements); + bcast_[i].push_back(1); + } + // grad_reduce_ is left as empty + return; + } + + // Reverse all the shapes for convenience + // After the reverse, 0-th is the inner-most dimension. + Vec copy[N]; + for (int i = 0; i < N; ++i) { + copy[i] = x[i]; + Reverse(©[i]); + } + + // 1-extend and align all vectors. + for (int i = 0; i < N; ++i) { + if (copy[i].size() < largest_rank) { + copy[i].resize(largest_rank, 1); + } + } + // Going through each dimension starting from the inner-most + // dimension, compares dimension of x and y. They are compatible if + // they are equal or either is 1. + + // indices of j-th component of each input. + bool prev_is_one[N]; + bool current_is_one[N]; + for (int i = 0; i < N; ++i) { + prev_is_one[i] = false; + current_is_one[i] = false; + } + Vec output; + bool output_dim_set = false; + int output_dim = -1; + bool none_is_one = true; + bool set_one = false; + for (int j = 0; j < largest_rank; ++j) { + output_dim = -1; + output_dim_set = false; + none_is_one = true; + // Find which indices are 1. + for (int i = 0; i < N; ++i) { + // Keep track of which indices are 1. + if (copy[i][j] == 1) { + current_is_one[i] = true; + none_is_one = false; + } else { + current_is_one[i] = false; + if (!output_dim_set || copy[i][j] == output_dim) { + output_dim = copy[i][j]; + output_dim_set = true; + } else { + valid_ = false; + return; + } + } + } + output_.push_back(output_dim_set ? output_dim : 1); + output_batch_size_ *= output_.back(); + // All dimensions are 1. + if (!output_dim_set) { + if (!TF_PREDICT_TRUE(fewer_dims_optimization)) { + for (int i = 0; i < N; ++i) { + bcast_[i].push_back(1); + reshape_[i].push_back(1); + } + result_.push_back(1); + } + for (int i = 0; i < N; ++i) { + grad_reduce_idx_[i].push_back(largest_rank - 1 - j); + } + // This will skip updating the previous state to the current one. We'll + // explain why this is safe below. + // Consider the previous state P, current state C and the next state N. + // In the case where N also is all ones (N == C), we'll do the same + // optimization here (push back one dimensions if we need to), which is + // safe and is expected. + // + // When N != C, we'll continue as usual. However, we might trigger the + // next block if N == P (because we didn't update the previous state). + // We trigger the next block if `fewer_dims_optimization` is true. + // This means that we did not modify and broadcast / rehshapes in this + // block (we skipped updating, since the one dimensions can be ignored). + // In essence, we only need to check whether the previous non-one state is + // equal to the current non-one state. + + continue; + } else if (TF_PREDICT_TRUE(fewer_dims_optimization) && + std::equal(current_is_one, current_is_one + N, prev_is_one) && + set_one) { + // It is a run of the same broadcasting case as last time. + // We can reshape the input so that fewer dimensions + // are involved in the intermediate computation. + result_.back() *= output_dim; + for (int i = 0; i < N; ++i) { + reshape_[i].back() *= copy[i][j]; + bcast_[i].back() *= current_is_one[i] ? output_dim : 1; + if (current_is_one[i] && !none_is_one) { + grad_reduce_idx_[i].push_back(largest_rank - 1 - j); + } + } + } else { + result_.push_back(output_dim); + for (int i = 0; i < N; ++i) { + reshape_[i].push_back(copy[i][j]); + bcast_[i].push_back(current_is_one[i] ? output_dim : 1); + if (current_is_one[i] && !none_is_one) { + grad_reduce_idx_[i].push_back(largest_rank - 1 - j); + } + } + } + set_one = true; + for (int i = 0; i < N; ++i) { + prev_is_one[i] = current_is_one[i]; + } + } + if (result_.empty()) { + result_.push_back(1); + for (int i = 0; i < N; ++i) { + reshape_[i].push_back(1); + bcast_[i].push_back(1); + } + } + // Do something about batches. + for (int i = 0; i < N; ++i) { + Reverse(&reshape_[i]); + Reverse(&bcast_[i]); + Reverse(&grad_reduce_idx_[i]); + } + Reverse(&result_); + Reverse(&output_); + // Only compute batch indices when we need broadcasting, and we aren't doing + // needless work (when the output size is 0 or the + // return_flattened_batch_indices isn't enabled). + if (return_flattened_batch_indices && broadcasting_required_ && + output_batch_size_ > 0) { + for (int i = 0; i < N; ++i) { + ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i], + &batch_indices_[i]); + } + } +} + // BCast is a helper for broadcasting binary tensor operation. // TensorFlow's broadcasting rule follows that of numpy (See // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). @@ -64,16 +342,8 @@ namespace tensorflow { // // The multiplication in the grad * backprop_x itself is also // broadcasting following the same rule. -// -// TODO(zhifengc): Adds support for n-ary (n >= 2). -class BCast { +class BCast : public BCastList<2> { public: - // A vector of int64 representing the shape of tensor. The 0-th - // element is the outer-most dimension and the last element is the - // inner-most dimension. Note that we do not use TensorShape since - // it's more convenient to manipulate Vec directly for this module. - typedef gtl::InlinedVector Vec; - // Constructs all helper shapes, following the aforementioned rules. // // If "fewer_dims_optimization" is set to true (the default), the @@ -82,28 +352,43 @@ class BCast { // // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have // the same number of dimensions as the larger of the two inputs. - BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true); - ~BCast() {} + typedef gtl::InlinedVector Vec; - // Returns true iff two operands are compatible according to the - // broadcasting rule. - bool IsValid() const { return valid_; } + BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true, + const bool return_flattened_batch_indices = false) + : BCastList<2>({x, y}, fewer_dims_optimization, + return_flattened_batch_indices) {} + + ~BCast() {} // If and only if IsValid(), the following fields can be used in // implementing a broadcasted binary tensor operation according to // the broadcasting rule. - const Vec& x_reshape() const { return x_reshape_; } - const Vec& x_bcast() const { return x_bcast_; } - const Vec& y_reshape() const { return y_reshape_; } - const Vec& y_bcast() const { return y_bcast_; } + const Vec& x_reshape() const { return reshape_[0]; } + const Vec& x_bcast() const { return bcast_[0]; } + const Vec& y_reshape() const { return reshape_[1]; } + const Vec& y_bcast() const { return bcast_[1]; } const Vec& result_shape() const { return result_; } const Vec& output_shape() const { return output_; } - const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; } - const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; } + const Vec& grad_x_reduce_idx() const { return grad_reduce_idx_[0]; } + const Vec& grad_y_reduce_idx() const { return grad_reduce_idx_[1]; } - // Static helpers. - static Vec FromShape(const TensorShape& shape); - static TensorShape ToShape(const BCast::Vec& vec); + // Returns the mapping from the flattened output batch indices to x's + // flattened batch indices. The result is a vector of length + // output_batch_size(). To compute the i'th batch output, a binary matmul-like + // operation should use the `x_batch_indices()[i]`th batch index of `x`. + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& x_batch_indices() const { + return batch_indices_[0]; + } + // Returns the mapping from the flattened output batch indices to y's + // flattened batch indices. Similar to x_batch_indices(). + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& y_batch_indices() const { + return batch_indices_[1]; + } template static Eigen::array ToIndexArrayType( @@ -120,19 +405,11 @@ class BCast { return ToIndexArrayType(vec); } + // Static helpers. + static Vec FromShape(const TensorShape& shape); + static TensorShape ToShape(const Vec& vec); + private: - bool valid_ = true; - Vec x_reshape_; - Vec x_bcast_; - Vec y_reshape_; - Vec y_bcast_; - Vec result_; - Vec output_; - Vec grad_x_reduce_idx_; - Vec grad_y_reduce_idx_; - - static void Reverse(Vec* shape); - TF_DISALLOW_COPY_AND_ASSIGN(BCast); }; diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc index c73bb2999f3..b6e8bcd706b 100644 --- a/tensorflow/core/util/bcast_test.cc +++ b/tensorflow/core/util/bcast_test.cc @@ -41,6 +41,40 @@ string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y, return ret; } +string BCastBatchIndices(const tensorflow::BCast::Vec& x, + const tensorflow::BCast::Vec& y, + const bool fewer_dims_optimization = true) { + tensorflow::BCast b(x, y, fewer_dims_optimization, + /*return_flattened_batch_indices=*/true); + string ret; + strings::StrAppend(&ret, "[", absl::StrJoin(b.x_batch_indices(), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.y_batch_indices(), ","), "]"); + return ret; +} + +string BCastList3(const tensorflow::BCast::Vec& x, + const tensorflow::BCast::Vec& y, + const tensorflow::BCast::Vec& z, + const bool fewer_dims_optimization = true) { + tensorflow::BCastList<3> b({x, y, z}, fewer_dims_optimization); + if (!b.IsValid()) { + return "invalid"; + } + string ret; + strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(0), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(0), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(1), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(1), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.reshape(2), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.bcast(2), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.result_shape(), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.output_shape(), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(0), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(1), ","), "]"); + strings::StrAppend(&ret, "[", absl::StrJoin(b.grad_reduce_idx(2), ","), "]"); + return ret; +} + TEST(BCastTest, Invalid) { for (const bool use_optimization : {true, false}) { EXPECT_EQ("invalid", BCast({5, 3, 2}, {3}, use_optimization)); @@ -51,6 +85,26 @@ TEST(BCastTest, Invalid) { } } +TEST(BCastListTest, Invalid) { + for (const bool use_optimization : {true, false}) { + EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {3}, {1}, use_optimization)); + EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {2, 2}, {1}, use_optimization)); + EXPECT_EQ("invalid", + BCastList3({5, 3, 2}, {10, 1, 1}, {1}, use_optimization)); + EXPECT_EQ("invalid", BCastList3({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1}, {1}, + use_optimization)); + EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {1}, {3}, use_optimization)); + EXPECT_EQ("invalid", BCastList3({5, 3, 2}, {1}, {2, 2}, use_optimization)); + EXPECT_EQ("invalid", + BCastList3({5, 3, 2}, {1}, {10, 1, 1}, use_optimization)); + + EXPECT_EQ("invalid", BCastList3({1}, {5, 3, 2}, {3}, use_optimization)); + EXPECT_EQ("invalid", BCastList3({1}, {5, 3, 2}, {2, 2}, use_optimization)); + EXPECT_EQ("invalid", + BCastList3({1}, {5, 3, 2}, {10, 1, 1}, use_optimization)); + } +} + TEST(BCastTest, Basic_SameShape) { // Effectively no broadcast needed. EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}), @@ -66,6 +120,22 @@ TEST(BCastTest, Basic_SameShape) { "[][]"); } +TEST(BCastListTest, Basic_SameShape) { + // Effectively no broadcast needed. + EXPECT_EQ(BCastList3({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}), + "[2310][1][2310][1][2310][1]" + "[2310]" + "[11,7,5,3,2]" + "[][][]"); + + EXPECT_EQ( + BCastList3({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, false), + "[11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[][][]"); +} + TEST(BCastTest, Basic_SameShapeWithZeroDim) { // Effectively no broadcast needed. EXPECT_EQ(BCast({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}), @@ -81,9 +151,32 @@ TEST(BCastTest, Basic_SameShapeWithZeroDim) { "[][]"); } +TEST(BCastListTest, Basic_SameShapeWithZeroDim) { + // Effectively no broadcast needed. + EXPECT_EQ(BCastList3({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}), + "[0][1][0][1][0][1]" + "[0]" + "[11,7,0,3,2]" + "[][][]"); + + EXPECT_EQ( + BCastList3({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, false), + "[11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1]" + "[11,7,0,3,2]" + "[11,7,0,3,2]" + "[][][]"); +} + TEST(BCastTest, Basic_Scalar_Scalar) { // Effectively it's a scalar and a scalar. // [1, 1] [1] + // + EXPECT_EQ(BCast({1, 1}, {}), + "[1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1]"); + EXPECT_EQ(BCast({1, 1}, {1}), "[1][1][1][1]" "[1]" @@ -110,6 +203,151 @@ TEST(BCastTest, Basic_Scalar_Scalar) { "[0,1][0,1]"); } +TEST(BCastTest, Basic_TrueScalar_Scalar) { + // [] [] + EXPECT_EQ(BCast({}, {}), + "[1][1][1][1]" + "[1]" + "[]" + "[][]"); + + // [] [1] + EXPECT_EQ(BCast({}, {1}), + "[1][1][1][1]" + "[1]" + "[1]" + "[0][0]"); + + EXPECT_EQ(BCast({}, {1}, false), + "[1][1][1][1]" + "[1]" + "[1]" + "[0][0]"); + + // [] [1, 1] + EXPECT_EQ(BCast({}, {1, 1}), + "[1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1]"); + + EXPECT_EQ(BCast({}, {1, 1}, false), + "[1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1]"); + + // [1] [] + EXPECT_EQ(BCast({1}, {}), + "[1][1][1][1]" + "[1]" + "[1]" + "[0][0]"); + + EXPECT_EQ(BCast({1}, {}, false), + "[1][1][1][1]" + "[1]" + "[1]" + "[0][0]"); + + // [1, 1] [] + EXPECT_EQ(BCast({1, 1}, {}), + "[1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1]"); + + EXPECT_EQ(BCast({1, 1}, {}, false), + "[1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1]"); +} + +TEST(BCastListTest, Basic_Scalar_Scalar_Scalar) { + // Effectively it's a scalar and a scalar. + // [1, 1] [1] [1] + EXPECT_EQ(BCastList3({1, 1}, {1}, {1}), + "[1][1][1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + EXPECT_EQ(BCastList3({1, 1}, {1}, {1}, false), + "[1,1][1,1][1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + // [1] [1, 1] [1] + EXPECT_EQ(BCastList3({1}, {1, 1}, {1}), + "[1][1][1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + EXPECT_EQ(BCastList3({1}, {1, 1}, {1}, false), + "[1,1][1,1][1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + // [1] [1] [1, 1] + EXPECT_EQ(BCastList3({1}, {1}, {1, 1}), + "[1][1][1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + EXPECT_EQ(BCastList3({1}, {1}, {1, 1}, false), + "[1,1][1,1][1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1][0,1]"); +} + +TEST(BCastListTest, Basic_TrueScalar_Scalar_Scalar) { + // Effectively it's a scalar and a scalar. + // [1, 1] [1] [] + EXPECT_EQ(BCastList3({1, 1}, {1}, {}), + "[1][1][1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + EXPECT_EQ(BCastList3({1, 1}, {1}, {}, false), + "[1,1][1,1][1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + // [] [1, 1] [1] + EXPECT_EQ(BCastList3({}, {1, 1}, {1}), + "[1][1][1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + EXPECT_EQ(BCastList3({}, {1, 1}, {1}, false), + "[1,1][1,1][1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + // [1] [] [1, 1] + EXPECT_EQ(BCastList3({1}, {}, {1, 1}), + "[1][1][1][1][1][1]" + "[1]" + "[1,1]" + "[0,1][0,1][0,1]"); + + EXPECT_EQ(BCastList3({1}, {}, {1, 1}, false), + "[1,1][1,1][1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1][0,1]"); +} + TEST(BCastTest, Basic_Tensor_Scalar) { // Effectively it's a tensor and a scalar. // [11, 7, 5, 3, 2] [1] @@ -327,6 +565,30 @@ TEST(BCastTest, Complex_BCast_To_Each_Other) { EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}, false), truth); } +TEST(BCastListTest, Complex_BCast_To_Each_Other) { + // Rare cases. x, y and z broadcast to each other. x,y and z are of + // different ranks. + // Can be verified in numpy as: + // import numpy as np + // x = np.arange(0,22).reshape([11,1,1,1,2]) + // y = np.arange(0,21).reshape([7,1,3,1]) + // z = np.arange(0,5).reshape([5,1,1]) + // np.shape(x + y + z) + // Out[.]: (11, 7, 5, 3, 2) + // + string truth = + "[11,1,1,1,2][1,7,5,3,1]" + "[1,7,1,3,1][11,1,5,1,2]" + "[1,1,5,1,1][11,7,1,3,2]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[1,2,3][0,2,4][0,1,3,4]"; + + EXPECT_EQ(BCastList3({11, 1, 1, 1, 2}, {7, 1, 3, 1}, {5, 1, 1}), truth); + EXPECT_EQ(BCastList3({11, 1, 1, 1, 2}, {7, 1, 3, 1}, {5, 1, 1}, false), + truth); +} + TEST(BCastTest, TestZeroDimensionShape) { // (2,0,5) and (5) in both orders EXPECT_EQ(BCast({2, 0, 5}, {5}), @@ -398,6 +660,19 @@ TEST(BCastTest, TestZeroDimensionShape) { "[0,1,3][]"); } +TEST(BCastTest, BatchIndices) { + EXPECT_EQ("[0,0,0,0][0,1,2,3]", BCastBatchIndices({1}, {4})); + // Invalid broadcast. + EXPECT_EQ("[][]", BCastBatchIndices({5}, {7})); + // Same shape, no batch indices. + EXPECT_EQ("[][]", BCastBatchIndices({2, 4, 6}, {2, 4, 6})); + // More complicated broadcasts. + EXPECT_EQ("[0,0,0,0,1,1,1,1,2,2,2,2][0,1,2,3,0,1,2,3,0,1,2,3]", + BCastBatchIndices({3, 1}, {1, 4})); + EXPECT_EQ("[0,0,1,1,2,2,0,0,1,1,2,2][0,1,0,1,0,1,2,3,2,3,2,3]", + BCastBatchIndices({3, 1}, {2, 1, 2})); +} + static void BM_BCastSetup(int iters, int same_shape) { if (same_shape) { testing::SetLabel("same_shapes"); diff --git a/tensorflow/core/util/matmul_bcast.cc b/tensorflow/core/util/matmul_bcast.cc index 3e5c5cf1750..8bb03616f87 100644 --- a/tensorflow/core/util/matmul_bcast.cc +++ b/tensorflow/core/util/matmul_bcast.cc @@ -16,40 +16,6 @@ limitations under the License. #include "tensorflow/core/util/matmul_bcast.h" namespace tensorflow { -namespace { - -// Returns the mapping from the output batch indices to the corresponding -// input's batch indices, given the input's "reshape" and "bcast" shapes as -// returned by the BCast helper class. The i'th element denotes the (flattened) -// batch index of the input that must be used to compute the i'th batch output. -void ComputeBatchIndices(const int64 output_batch_size, - const MatMulBCast::Vec& reshape, - const MatMulBCast::Vec& bcast, - std::vector* out_indices) { - // Populates the mapping in out_indices. This algorithm is identical to - // the following steps: - // - Reshape {0, 1, ..., input_batch_size - 1} to the input shape. - // - Broadcast to the output shape. - // - Reshape back to a flat 1D vector. - out_indices->resize(output_batch_size); - int64 num_output_elements = 1; - int64 num_input_elements = 1; - for (int64 i = reshape.size() - 1; i >= 0; --i) { - // Replicate the already populated mapping an additional (dim - 1) times. - // If we are broadcasting, just copy the existing mapping. - // Otherwise, add another dimension from the input shape. - const int64 dim = std::max(reshape[i], bcast[i]); - const int64 incr = bcast[i] > 1 ? 0 : num_input_elements; - for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) { - (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr; - } - num_output_elements *= dim; - num_input_elements *= reshape[i]; - } -} - -} // namespace - MatMulBCast::MatMulBCast(Vec x, Vec y) { if (x.size() < 2 || y.size() < 2) return; x.resize(x.size() - 2); diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index c48864be6d7..9e395370662 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -161,6 +161,7 @@ tf_py_test( "//tensorflow/python:platform", "//tensorflow/python:stateful_random_ops", ], + shard_count = 3, tags = ["no_oss"], ) diff --git a/tensorflow/python/kernel_tests/random/random_binomial_test.py b/tensorflow/python/kernel_tests/random/random_binomial_test.py index 5a17602c2fd..11bfd149c3f 100644 --- a/tensorflow/python/kernel_tests/random/random_binomial_test.py +++ b/tensorflow/python/kernel_tests/random/random_binomial_test.py @@ -35,14 +35,14 @@ _SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64, class RandomBinomialTest(test.TestCase): """This is a large test due to the moments computation taking some time.""" - def _Sampler(self, num, counts, probs, dtype, seed=None): - + def _Sampler( + self, num, counts, probs, dtype, gen=None, sample_shape=None, seed=None): def func(): - rng = stateful_random_ops.Generator.from_seed(seed).binomial( - shape=[10 * num], counts=counts, probs=probs, dtype=dtype) - ret = array_ops.reshape(rng, [10, num]) - ret = self.evaluate(ret) - return ret + shape = [10 * num] if sample_shape is None else sample_shape + generator = gen if gen is not None else ( + stateful_random_ops.Generator.from_seed(seed)) + return generator.binomial( + shape=shape, counts=counts, probs=probs, dtype=dtype) return func @@ -57,15 +57,16 @@ class RandomBinomialTest(test.TestCase): # we want to tolerate. Since the z-test approximates a unit normal # distribution, it should almost definitely never exceed 6. z_limit = 6.0 + gen = stateful_random_ops.Generator.from_seed(seed=12345) for dt in _SUPPORTED_DTYPES: # Test when n * p > 10, and n * p < 10 for stride in 0, 4, 10: for counts in (1., 10., 22., 50.): for prob in (0.1, 0.5, 0.8): - sampler = self._Sampler(int(1e5), counts, prob, dt, seed=12345) + sampler = self._Sampler(int(5e4), counts, prob, dt, gen=gen) z_scores = util.test_moment_matching( # Use float64 samples. - sampler().astype(np.float64), + self.evaluate(sampler()).astype(np.float64), number_moments=6, dist=stats.binom(counts, prob), stride=stride, @@ -77,7 +78,7 @@ class RandomBinomialTest(test.TestCase): for dt in dtypes.float16, dtypes.float32, dtypes.float64: sx = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345) sy = self._Sampler(1000, counts=10., probs=0.4, dtype=dt, seed=345) - self.assertAllEqual(sx(), sy()) + self.assertAllEqual(self.evaluate(sx()), self.evaluate(sy())) def testZeroShape(self): rnd = stateful_random_ops.Generator.from_seed(12345).binomial([0], [], []) @@ -88,6 +89,8 @@ class RandomBinomialTest(test.TestCase): # Scalar parameters. rnd = rng.binomial(shape=[10], counts=np.float32(2.), probs=np.float32(0.5)) self.assertEqual([10], rnd.shape.as_list()) + rnd = rng.binomial(shape=[], counts=np.float32(2.), probs=np.float32(0.5)) + self.assertEqual([], rnd.shape.as_list()) # Vector parameters. rnd = rng.binomial( @@ -96,10 +99,10 @@ class RandomBinomialTest(test.TestCase): probs=0.3 * array_ops.ones([10], dtype=np.float32)) self.assertEqual([10], rnd.shape.as_list()) rnd = rng.binomial( - shape=[2, 5], + shape=[5, 2], counts=array_ops.ones([2], dtype=np.float32), probs=0.4 * array_ops.ones([2], dtype=np.float32)) - self.assertEqual([2, 5], rnd.shape.as_list()) + self.assertEqual([5, 2], rnd.shape.as_list()) # Scalar counts, vector probs. rnd = rng.binomial( @@ -115,6 +118,20 @@ class RandomBinomialTest(test.TestCase): probs=np.float32(0.9)) self.assertEqual([10], rnd.shape.as_list()) + # Tensor parameters + rnd = rng.binomial( + shape=[10, 2, 3], + counts=array_ops.ones([2, 1], dtype=np.float32), + probs=0.9 * array_ops.ones([1, 3], dtype=np.float32)) + self.assertEqual([10, 2, 3], rnd.shape.as_list()) + + # Tensor parameters + rnd = rng.binomial( + shape=[10, 2, 3, 5], + counts=array_ops.ones([2, 1, 5], dtype=np.float32), + probs=0.9 * array_ops.ones([1, 3, 1], dtype=np.float32)) + self.assertEqual([10, 2, 3, 5], rnd.shape.as_list()) + @test_util.run_v2_only def testCornerCases(self): rng = stateful_random_ops.Generator.from_seed(12345) @@ -126,5 +143,61 @@ class RandomBinomialTest(test.TestCase): shape=[6], counts=counts, probs=probs, dtype=np.float32) self.assertAllEqual(expected, self.evaluate(result)) + @test_util.run_v2_only + def testMomentsForTensorInputs(self): + try: + from scipy import stats # pylint: disable=g-import-not-at-top + except ImportError as e: + tf_logging.warn("Cannot test moments: %s", e) + return + # The moments test is a z-value test. This is the largest z-value + # we want to tolerate. Since the z-test approximates a unit normal + # distribution, it should almost definitely never exceed 6. + z_limit = 6.0 + + class ScipyBinomialWrapper(object): + """Wrapper for stats.binom to support broadcasting.""" + + def __init__(self, counts, probs): + self.counts = counts + self.probs = probs + + def moment(self, i): + counts, probs = np.broadcast_arrays(self.counts, self.probs) + broadcast_shape = counts.shape + + counts = np.reshape(counts, (-1,)) + probs = np.reshape(probs, (-1,)) + counts_and_probs = np.stack([counts, probs], axis=-1) + moments = np.fromiter( + (stats.binom(cp[0], cp[1]).moment(i) for cp in counts_and_probs), + dtype=np.float64) + return np.reshape(moments, broadcast_shape) + + gen = stateful_random_ops.Generator.from_seed(seed=23455) + for dt in _SUPPORTED_DTYPES: + # Test when n * p > 10, and n * p < 10 + for stride in 0, 4, 10: + counts = np.float64(np.random.randint(low=1, high=20, size=(2, 1, 4))) + probs = np.random.uniform(size=(1, 3, 4)) + + sampler = self._Sampler( + int(5e4), + counts, + probs, + dt, + gen=gen, + sample_shape=[10 * int(5e4), 2, 3, 4]) + # Use float64 samples. + samples = self.evaluate(sampler()).astype(np.float64) + z_scores = util.test_moment_matching( + samples, + number_moments=6, + dist=ScipyBinomialWrapper(counts, probs), + stride=stride, + ) + self.assertAllLess(z_scores, z_limit) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/random/util.py b/tensorflow/python/kernel_tests/random/util.py index 9b29f9c5130..7dffbe494f4 100644 --- a/tensorflow/python/kernel_tests/random/util.py +++ b/tensorflow/python/kernel_tests/random/util.py @@ -49,10 +49,12 @@ def test_moment_matching( sample_moments = [] expected_moments = [] variance_sample_moments = [] - x = samples.flat for i in range(1, number_moments + 1): - strided_range = x[::(i - 1) * stride + 1] - sample_moments.append(np.mean(strided_range ** i)) + if len(samples.shape) == 2: + strided_range = samples.flat[::(i - 1) * stride + 1] + else: + strided_range = samples[::(i - 1) * stride + 1, ...] + sample_moments.append(np.mean(strided_range**i, axis=0)) expected_moments.append(dist.moment(i)) variance_sample_moments.append( (dist.moment(2 * i) - dist.moment(i) ** 2) / len(strided_range)) @@ -66,8 +68,7 @@ def test_moment_matching( i * np.finfo(samples.dtype).eps) tiny = np.finfo(samples.dtype).tiny assert np.all(total_variance > 0) - if total_variance < tiny: - total_variance = tiny + total_variance = np.where(total_variance < tiny, tiny, total_variance) # z_test is approximately a unit normal distribution. z_test_scores.append(abs( (sample_moments[i - 1] - expected_moments[i - 1]) / np.sqrt( diff --git a/tensorflow/python/ops/stateful_random_ops.py b/tensorflow/python/ops/stateful_random_ops.py index b95ca5d7d9e..d0f132d91b0 100644 --- a/tensorflow/python/ops/stateful_random_ops.py +++ b/tensorflow/python/ops/stateful_random_ops.py @@ -598,22 +598,30 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): ```python counts = [10., 20.] # Probability of success. - probs = [0.8, 0.9] + probs = [0.8] rng = tf.random.experimental.Generator.from_seed(seed=234) binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs) + + + counts = ... # Shape [3, 1, 2] + probs = ... # Shape [1, 4, 2] + shape = [3, 4, 3, 4, 2] + rng = tf.random.experimental.Generator.from_seed(seed=1717) + # Sample shape will be [3, 4, 3, 4, 2] + binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs) ``` Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. - counts: A 0/1-D Tensor or Python value. The counts of the binomial - distribution. Must be broadcastable with the leftmost dimension - defined by `shape`. - probs: A 0/1-D Tensor or Python value. The probability of success for the - binomial distribution. Must be broadcastable with the leftmost - dimension defined by `shape`. + counts: Tensor. The counts of the binomial distribution. Must be + broadcastable with `probs`, and broadcastable with the rightmost + dimensions of `shape`. + probs: Tensor. The probability of success for the + binomial distribution. Must be broadcastable with `counts` and + broadcastable with the rightmost dimensions of `shape`. dtype: The type of the output. Default: tf.int32 name: A name for the operation (optional).