Remove redundant reshape from BiasOp
PiperOrigin-RevId: 275401597 Change-Id: I940422b3581f74dcf47b23b0ec9c28fe0a87b16b
This commit is contained in:
parent
be536832d9
commit
d758182a15
tensorflow/core/kernels
@ -31,20 +31,19 @@ struct Bias {
|
||||
typename TTypes<T>::ConstVec bias,
|
||||
typename TTypes<T, Dims>::Tensor output) {
|
||||
if (input.size() >= INT_MAX) {
|
||||
const int64_t bias_size = bias.dimension(0);
|
||||
const int64_t rest_size = input.size() / bias_size;
|
||||
Eigen::DSizes<int64_t, 1> one_d(input.size());
|
||||
Eigen::DSizes<int64_t, 1> bcast(rest_size);
|
||||
const Eigen::Index bias_size = bias.dimension(0);
|
||||
const Eigen::Index rest_size = input.size() / bias_size;
|
||||
Eigen::DSizes<Eigen::Index, 1> one_d(input.size());
|
||||
Eigen::DSizes<Eigen::Index, 1> bcast(rest_size);
|
||||
output.reshape(one_d).device(d) =
|
||||
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
|
||||
input.reshape(one_d) + bias.broadcast(bcast);
|
||||
} else {
|
||||
const int bias_size = bias.dimension(0);
|
||||
const int rest_size = input.size() / bias_size;
|
||||
Eigen::DSizes<int, 1> one_d(input.size());
|
||||
Eigen::DSizes<int, 1> bcast(rest_size);
|
||||
To32Bit(output).reshape(one_d).device(d) =
|
||||
To32Bit(input).reshape(one_d) +
|
||||
To32Bit(bias).broadcast(bcast).reshape(one_d);
|
||||
To32Bit(input).reshape(one_d) + To32Bit(bias).broadcast(bcast);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -13,16 +13,28 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/bias_op.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/bias_op.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static Graph* BiasAdd(int d0, int d1, int d2, int d3) {
|
||||
auto* g = new Graph(OpRegistry::Global());
|
||||
Tensor input(DT_FLOAT, TensorShape({d0, d1, d2, d3}));
|
||||
Tensor bias(DT_FLOAT, TensorShape({d3}));
|
||||
input.flat<float>().setRandom();
|
||||
bias.flat<float>().setRandom();
|
||||
test::graph::Binary(g, "BiasAdd", test::graph::Constant(g, input),
|
||||
test::graph::Constant(g, bias));
|
||||
return g;
|
||||
}
|
||||
|
||||
static Graph* BiasAddGrad(int d0, int d1, int d2, int d3) {
|
||||
auto* g = new Graph(OpRegistry::Global());
|
||||
Tensor out_backprop(DT_FLOAT, TensorShape({d0, d1, d2, d3}));
|
||||
@ -31,6 +43,14 @@ static Graph* BiasAddGrad(int d0, int d1, int d2, int d3) {
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_BiasAddNHWC(N, W, H, C, DEVICE) \
|
||||
static void BM_BiasAddNHWC##_##N##_##H##_##W##_##C##_##DEVICE(int iters) { \
|
||||
testing::UseRealTime(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * N * H * W * C); \
|
||||
test::Benchmark(#DEVICE, BiasAdd(N, H, W, C)).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_BiasAddNHWC##_##N##_##H##_##W##_##C##_##DEVICE);
|
||||
|
||||
#define BM_BiasAddGradNHWC(N, W, H, C, DEVICE) \
|
||||
static void BM_BiasAddGradNHWC##_##N##_##H##_##W##_##C##_##DEVICE( \
|
||||
int iters) { \
|
||||
@ -41,6 +61,16 @@ static Graph* BiasAddGrad(int d0, int d1, int d2, int d3) {
|
||||
BENCHMARK(BM_BiasAddGradNHWC##_##N##_##H##_##W##_##C##_##DEVICE);
|
||||
|
||||
// CPU
|
||||
BM_BiasAddNHWC(32, 32, 32, 128, cpu);
|
||||
BM_BiasAddNHWC(32, 32, 32, 256, cpu);
|
||||
BM_BiasAddNHWC(32, 32, 32, 512, cpu);
|
||||
BM_BiasAddNHWC(32, 32, 32, 1024, cpu);
|
||||
|
||||
BM_BiasAddNHWC(32, 64, 64, 128, cpu);
|
||||
BM_BiasAddNHWC(32, 64, 64, 256, cpu);
|
||||
BM_BiasAddNHWC(32, 64, 64, 512, cpu);
|
||||
BM_BiasAddNHWC(32, 64, 64, 1024, cpu);
|
||||
|
||||
BM_BiasAddGradNHWC(32, 32, 32, 128, cpu);
|
||||
BM_BiasAddGradNHWC(32, 32, 32, 256, cpu);
|
||||
BM_BiasAddGradNHWC(32, 32, 32, 512, cpu);
|
||||
|
Loading…
Reference in New Issue
Block a user