Remove redundant reshape from BiasOp
PiperOrigin-RevId: 275401597 Change-Id: I940422b3581f74dcf47b23b0ec9c28fe0a87b16b
This commit is contained in:
parent
be536832d9
commit
d758182a15
@ -31,20 +31,19 @@ struct Bias {
|
|||||||
typename TTypes<T>::ConstVec bias,
|
typename TTypes<T>::ConstVec bias,
|
||||||
typename TTypes<T, Dims>::Tensor output) {
|
typename TTypes<T, Dims>::Tensor output) {
|
||||||
if (input.size() >= INT_MAX) {
|
if (input.size() >= INT_MAX) {
|
||||||
const int64_t bias_size = bias.dimension(0);
|
const Eigen::Index bias_size = bias.dimension(0);
|
||||||
const int64_t rest_size = input.size() / bias_size;
|
const Eigen::Index rest_size = input.size() / bias_size;
|
||||||
Eigen::DSizes<int64_t, 1> one_d(input.size());
|
Eigen::DSizes<Eigen::Index, 1> one_d(input.size());
|
||||||
Eigen::DSizes<int64_t, 1> bcast(rest_size);
|
Eigen::DSizes<Eigen::Index, 1> bcast(rest_size);
|
||||||
output.reshape(one_d).device(d) =
|
output.reshape(one_d).device(d) =
|
||||||
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
|
input.reshape(one_d) + bias.broadcast(bcast);
|
||||||
} else {
|
} else {
|
||||||
const int bias_size = bias.dimension(0);
|
const int bias_size = bias.dimension(0);
|
||||||
const int rest_size = input.size() / bias_size;
|
const int rest_size = input.size() / bias_size;
|
||||||
Eigen::DSizes<int, 1> one_d(input.size());
|
Eigen::DSizes<int, 1> one_d(input.size());
|
||||||
Eigen::DSizes<int, 1> bcast(rest_size);
|
Eigen::DSizes<int, 1> bcast(rest_size);
|
||||||
To32Bit(output).reshape(one_d).device(d) =
|
To32Bit(output).reshape(one_d).device(d) =
|
||||||
To32Bit(input).reshape(one_d) +
|
To32Bit(input).reshape(one_d) + To32Bit(bias).broadcast(bcast);
|
||||||
To32Bit(bias).broadcast(bcast).reshape(one_d);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -13,16 +13,28 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/bias_op.h"
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/kernels/bias_op.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
static Graph* BiasAdd(int d0, int d1, int d2, int d3) {
|
||||||
|
auto* g = new Graph(OpRegistry::Global());
|
||||||
|
Tensor input(DT_FLOAT, TensorShape({d0, d1, d2, d3}));
|
||||||
|
Tensor bias(DT_FLOAT, TensorShape({d3}));
|
||||||
|
input.flat<float>().setRandom();
|
||||||
|
bias.flat<float>().setRandom();
|
||||||
|
test::graph::Binary(g, "BiasAdd", test::graph::Constant(g, input),
|
||||||
|
test::graph::Constant(g, bias));
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
static Graph* BiasAddGrad(int d0, int d1, int d2, int d3) {
|
static Graph* BiasAddGrad(int d0, int d1, int d2, int d3) {
|
||||||
auto* g = new Graph(OpRegistry::Global());
|
auto* g = new Graph(OpRegistry::Global());
|
||||||
Tensor out_backprop(DT_FLOAT, TensorShape({d0, d1, d2, d3}));
|
Tensor out_backprop(DT_FLOAT, TensorShape({d0, d1, d2, d3}));
|
||||||
@ -31,6 +43,14 @@ static Graph* BiasAddGrad(int d0, int d1, int d2, int d3) {
|
|||||||
return g;
|
return g;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define BM_BiasAddNHWC(N, W, H, C, DEVICE) \
|
||||||
|
static void BM_BiasAddNHWC##_##N##_##H##_##W##_##C##_##DEVICE(int iters) { \
|
||||||
|
testing::UseRealTime(); \
|
||||||
|
testing::ItemsProcessed(static_cast<int64>(iters) * N * H * W * C); \
|
||||||
|
test::Benchmark(#DEVICE, BiasAdd(N, H, W, C)).Run(iters); \
|
||||||
|
} \
|
||||||
|
BENCHMARK(BM_BiasAddNHWC##_##N##_##H##_##W##_##C##_##DEVICE);
|
||||||
|
|
||||||
#define BM_BiasAddGradNHWC(N, W, H, C, DEVICE) \
|
#define BM_BiasAddGradNHWC(N, W, H, C, DEVICE) \
|
||||||
static void BM_BiasAddGradNHWC##_##N##_##H##_##W##_##C##_##DEVICE( \
|
static void BM_BiasAddGradNHWC##_##N##_##H##_##W##_##C##_##DEVICE( \
|
||||||
int iters) { \
|
int iters) { \
|
||||||
@ -41,6 +61,16 @@ static Graph* BiasAddGrad(int d0, int d1, int d2, int d3) {
|
|||||||
BENCHMARK(BM_BiasAddGradNHWC##_##N##_##H##_##W##_##C##_##DEVICE);
|
BENCHMARK(BM_BiasAddGradNHWC##_##N##_##H##_##W##_##C##_##DEVICE);
|
||||||
|
|
||||||
// CPU
|
// CPU
|
||||||
|
BM_BiasAddNHWC(32, 32, 32, 128, cpu);
|
||||||
|
BM_BiasAddNHWC(32, 32, 32, 256, cpu);
|
||||||
|
BM_BiasAddNHWC(32, 32, 32, 512, cpu);
|
||||||
|
BM_BiasAddNHWC(32, 32, 32, 1024, cpu);
|
||||||
|
|
||||||
|
BM_BiasAddNHWC(32, 64, 64, 128, cpu);
|
||||||
|
BM_BiasAddNHWC(32, 64, 64, 256, cpu);
|
||||||
|
BM_BiasAddNHWC(32, 64, 64, 512, cpu);
|
||||||
|
BM_BiasAddNHWC(32, 64, 64, 1024, cpu);
|
||||||
|
|
||||||
BM_BiasAddGradNHWC(32, 32, 32, 128, cpu);
|
BM_BiasAddGradNHWC(32, 32, 32, 128, cpu);
|
||||||
BM_BiasAddGradNHWC(32, 32, 32, 256, cpu);
|
BM_BiasAddGradNHWC(32, 32, 32, 256, cpu);
|
||||||
BM_BiasAddGradNHWC(32, 32, 32, 512, cpu);
|
BM_BiasAddGradNHWC(32, 32, 32, 512, cpu);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user