STT-tensorflow/tensorflow/core/util/bcast.cc
Srinivas Vasudevan 5396e7a3cd Allow RandomBinomial op to broadcast parameters.
- Add multiple parameter broadcasting support for BCast. This will allow it to be used in multiparameter broadcasting contexts. This is specifically for ternary ops, but will be used to make other samplers like ParameterizedTruncatedNormal broadcast.

- Add batch index methods for generating a list of batch indices when the input vectors are flattened. This is used to get broadcasting on flattened inputs (which is used in the RandomBinomial sampler).

- Shard on the number of outputs. This allows us to scale better to Tensor inputs.

PiperOrigin-RevId: 281202841
Change-Id: I0b276e983bf31056677a67b4d5ce8ebc98d77930
2019-11-18 20:18:33 -08:00

36 lines
1.1 KiB
C++

/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/bcast.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
BCast::Vec BCast::FromShape(const TensorShape& shape) {
const int N = shape.dims();
BCastList::Vec ret(N);
for (int i = 0; i < N; ++i) {
ret[i] = shape.dim_size(i);
}
return ret;
}
TensorShape BCast::ToShape(const BCastList::Vec& vec) {
TensorShape shape(vec);
return shape;
}
} // end namespace tensorflow