Add benchmarks for Conv2D input gradient with strides not equal to one
PiperOrigin-RevId: 282847663 Change-Id: Iaa201f4f8e74a62380e68167dde75a158689baf8
This commit is contained in:
parent
a10dc73356
commit
74faaeb08f
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
@ -40,7 +39,7 @@ template <typename T>
|
||||
static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
||||
int in_depth, int filter_h, int filter_w,
|
||||
int out_depth, int stride_h, int stride_w,
|
||||
TensorFormat data_format) {
|
||||
Padding padding, TensorFormat data_format) {
|
||||
auto* graph = new Graph(OpRegistry::Global());
|
||||
|
||||
Tensor input_t = data_format == FORMAT_NHWC
|
||||
@ -53,7 +52,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
||||
Conv2DParameters params;
|
||||
params.dilations = {1, 1, 1, 1};
|
||||
params.strides = {1, stride_h, stride_w, 1};
|
||||
params.padding = Padding::SAME;
|
||||
params.padding = padding;
|
||||
params.data_format = data_format;
|
||||
|
||||
Conv2DDimensions conv2d_dims;
|
||||
@ -85,7 +84,9 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
||||
.Input(backprop)
|
||||
.Attr("T", DataTypeToEnum<T>::value)
|
||||
.Attr("strides", {1, stride_h, stride_w, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("padding", padding == Padding::SAME
|
||||
? "SAME"
|
||||
: padding == Padding::VALID ? "VALID" : "N/A")
|
||||
.Attr("data_format", ToString(data_format))
|
||||
.Finalize(graph, &conv2d));
|
||||
|
||||
@ -94,7 +95,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
||||
|
||||
// Macro arguments names: --------------------------------------------------- //
|
||||
// T: data type
|
||||
// FORMAT: data format (NHWC or NCHW)
|
||||
// FMT: data format (NHWC or NCHW)
|
||||
// N: batch size
|
||||
// H: height
|
||||
// W: width
|
||||
@ -107,41 +108,50 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
||||
|
||||
#define BM_CONCAT(a, b) a##_##b
|
||||
|
||||
#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW) \
|
||||
BM_CONCAT(name##_##T##_##FORMAT##_##type##_in##N##x##H##x##W##x##C, \
|
||||
f##FH##x##FW##x##FC##_##s##SH##x##SW)
|
||||
#define BM_NAME(name, type, T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING) \
|
||||
BM_CONCAT(name##_##T##_##FMT##_##type##_in##N##x##H##x##W##x##C, \
|
||||
f##FH##x##FW##x##FC##_##s##SH##x##SW##_##PADDING)
|
||||
|
||||
#define BM_Conv2DBwdInputFmt(T, FORMAT, N, H, W, C, FW, FH, FC, SH, SW, type) \
|
||||
static void BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH, \
|
||||
FW, FC, SH, SW)(int iters) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * \
|
||||
(C)); \
|
||||
test::Benchmark(#type, Conv2DBackpropInput<T>(N, H, W, C, FH, FW, FC, SH, \
|
||||
SW, FORMAT_##FORMAT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH, \
|
||||
FW, FC, SH, SW));
|
||||
#define BM_Conv2DBwdInput(T, FMT, N, H, W, C, FW, FH, FC, SH, SW, PADDING, \
|
||||
type) \
|
||||
static void BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, \
|
||||
FW, FC, SH, SW, PADDING)(int iters) { \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * \
|
||||
(C)); \
|
||||
test::Benchmark(#type, Conv2DBackpropInput<T>(N, H, W, C, FH, FW, FC, SH, \
|
||||
SW, PADDING, FORMAT_##FMT)) \
|
||||
.Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, FW, \
|
||||
FC, SH, SW, PADDING));
|
||||
|
||||
using fp32 = float;
|
||||
using fp16 = Eigen::half;
|
||||
|
||||
// ResNet50-ish convolutions.
|
||||
#define BENCHMARK_DTYPE(FORMAT, BATCH, T, D) \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, D); \
|
||||
#define BENCHMARK_DTYPE(FMT, BATCH, T, D) \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, SAME, D); \
|
||||
\
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, VALID, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, VALID, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, VALID, D); \
|
||||
\
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, D); \
|
||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, D);
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, SAME, D); \
|
||||
\
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, SAME, D); \
|
||||
\
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, SAME, D); \
|
||||
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, SAME, D);
|
||||
|
||||
BENCHMARK_DTYPE(NHWC, 8, fp32, cpu);
|
||||
BENCHMARK_DTYPE(NHWC, 16, fp32, cpu);
|
||||
|
Loading…
Reference in New Issue
Block a user