Add benchmarks for Conv2D input gradient with strides not equal to one
PiperOrigin-RevId: 282847663 Change-Id: Iaa201f4f8e74a62380e68167dde75a158689baf8
This commit is contained in:
parent
a10dc73356
commit
74faaeb08f
@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
@ -40,7 +39,7 @@ template <typename T>
|
|||||||
static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
||||||
int in_depth, int filter_h, int filter_w,
|
int in_depth, int filter_h, int filter_w,
|
||||||
int out_depth, int stride_h, int stride_w,
|
int out_depth, int stride_h, int stride_w,
|
||||||
TensorFormat data_format) {
|
Padding padding, TensorFormat data_format) {
|
||||||
auto* graph = new Graph(OpRegistry::Global());
|
auto* graph = new Graph(OpRegistry::Global());
|
||||||
|
|
||||||
Tensor input_t = data_format == FORMAT_NHWC
|
Tensor input_t = data_format == FORMAT_NHWC
|
||||||
@ -53,7 +52,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
|||||||
Conv2DParameters params;
|
Conv2DParameters params;
|
||||||
params.dilations = {1, 1, 1, 1};
|
params.dilations = {1, 1, 1, 1};
|
||||||
params.strides = {1, stride_h, stride_w, 1};
|
params.strides = {1, stride_h, stride_w, 1};
|
||||||
params.padding = Padding::SAME;
|
params.padding = padding;
|
||||||
params.data_format = data_format;
|
params.data_format = data_format;
|
||||||
|
|
||||||
Conv2DDimensions conv2d_dims;
|
Conv2DDimensions conv2d_dims;
|
||||||
@ -85,7 +84,9 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
|||||||
.Input(backprop)
|
.Input(backprop)
|
||||||
.Attr("T", DataTypeToEnum<T>::value)
|
.Attr("T", DataTypeToEnum<T>::value)
|
||||||
.Attr("strides", {1, stride_h, stride_w, 1})
|
.Attr("strides", {1, stride_h, stride_w, 1})
|
||||||
.Attr("padding", "SAME")
|
.Attr("padding", padding == Padding::SAME
|
||||||
|
? "SAME"
|
||||||
|
: padding == Padding::VALID ? "VALID" : "N/A")
|
||||||
.Attr("data_format", ToString(data_format))
|
.Attr("data_format", ToString(data_format))
|
||||||
.Finalize(graph, &conv2d));
|
.Finalize(graph, &conv2d));
|
||||||
|
|
||||||
@ -94,7 +95,7 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
|||||||
|
|
||||||
// Macro arguments names: --------------------------------------------------- //
|
// Macro arguments names: --------------------------------------------------- //
|
||||||
// T: data type
|
// T: data type
|
||||||
// FORMAT: data format (NHWC or NCHW)
|
// FMT: data format (NHWC or NCHW)
|
||||||
// N: batch size
|
// N: batch size
|
||||||
// H: height
|
// H: height
|
||||||
// W: width
|
// W: width
|
||||||
@ -107,41 +108,50 @@ static Graph* Conv2DBackpropInput(int batch, int height, int width,
|
|||||||
|
|
||||||
#define BM_CONCAT(a, b) a##_##b
|
#define BM_CONCAT(a, b) a##_##b
|
||||||
|
|
||||||
#define BM_NAME(name, type, T, FORMAT, N, H, W, C, FH, FW, FC, SH, SW) \
|
#define BM_NAME(name, type, T, FMT, N, H, W, C, FH, FW, FC, SH, SW, PADDING) \
|
||||||
BM_CONCAT(name##_##T##_##FORMAT##_##type##_in##N##x##H##x##W##x##C, \
|
BM_CONCAT(name##_##T##_##FMT##_##type##_in##N##x##H##x##W##x##C, \
|
||||||
f##FH##x##FW##x##FC##_##s##SH##x##SW)
|
f##FH##x##FW##x##FC##_##s##SH##x##SW##_##PADDING)
|
||||||
|
|
||||||
#define BM_Conv2DBwdInputFmt(T, FORMAT, N, H, W, C, FW, FH, FC, SH, SW, type) \
|
#define BM_Conv2DBwdInput(T, FMT, N, H, W, C, FW, FH, FC, SH, SW, PADDING, \
|
||||||
static void BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH, \
|
type) \
|
||||||
FW, FC, SH, SW)(int iters) { \
|
static void BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, \
|
||||||
testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * \
|
FW, FC, SH, SW, PADDING)(int iters) { \
|
||||||
(C)); \
|
testing::ItemsProcessed(static_cast<int64>(iters) * (N) * (H) * (W) * \
|
||||||
test::Benchmark(#type, Conv2DBackpropInput<T>(N, H, W, C, FH, FW, FC, SH, \
|
(C)); \
|
||||||
SW, FORMAT_##FORMAT)) \
|
test::Benchmark(#type, Conv2DBackpropInput<T>(N, H, W, C, FH, FW, FC, SH, \
|
||||||
.Run(iters); \
|
SW, PADDING, FORMAT_##FMT)) \
|
||||||
} \
|
.Run(iters); \
|
||||||
BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FORMAT, N, H, W, C, FH, \
|
} \
|
||||||
FW, FC, SH, SW));
|
BENCHMARK(BM_NAME(BM_Conv2DBackpropInput, type, T, FMT, N, H, W, C, FH, FW, \
|
||||||
|
FC, SH, SW, PADDING));
|
||||||
|
|
||||||
using fp32 = float;
|
using fp32 = float;
|
||||||
using fp16 = Eigen::half;
|
using fp16 = Eigen::half;
|
||||||
|
|
||||||
// ResNet50-ish convolutions.
|
// ResNet50-ish convolutions.
|
||||||
#define BENCHMARK_DTYPE(FORMAT, BATCH, T, D) \
|
#define BENCHMARK_DTYPE(FMT, BATCH, T, D) \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, SAME, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, SAME, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, SAME, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, D); \
|
|
||||||
\
|
\
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 112, 112, 64, 2, 2, 64, 2, 2, VALID, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 128, 2, 2, 128, 2, 2, VALID, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 256, 2, 2, 256, 2, 2, VALID, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, D); \
|
|
||||||
\
|
\
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 64, 1, 1, SAME, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 1, 1, 256, 1, 1, SAME, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, D); \
|
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 256, 1, 1, 64, 1, 1, SAME, D); \
|
||||||
BM_Conv2DBwdInputFmt(T, FORMAT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, D);
|
BM_Conv2DBwdInput(T, FMT, BATCH, 56, 56, 64, 3, 3, 64, 1, 1, SAME, D); \
|
||||||
|
\
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 128, 1, 1, SAME, D); \
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 128, 1, 1, 512, 1, 1, SAME, D); \
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 1, 1, 128, 1, 1, SAME, D); \
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 28, 28, 512, 3, 3, 128, 1, 1, SAME, D); \
|
||||||
|
\
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 256, 1, 1, SAME, D); \
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 1, 1, 1024, 1, 1, SAME, D); \
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 1024, 1, 1, 256, 1, 1, SAME, D); \
|
||||||
|
BM_Conv2DBwdInput(T, FMT, BATCH, 14, 14, 256, 3, 3, 256, 1, 1, SAME, D);
|
||||||
|
|
||||||
BENCHMARK_DTYPE(NHWC, 8, fp32, cpu);
|
BENCHMARK_DTYPE(NHWC, 8, fp32, cpu);
|
||||||
BENCHMARK_DTYPE(NHWC, 16, fp32, cpu);
|
BENCHMARK_DTYPE(NHWC, 16, fp32, cpu);
|
||||||
|
Loading…
Reference in New Issue
Block a user