Add Conv2D benchmarks for NASNet shapes

PiperOrigin-RevId: 284892240
Change-Id: Ie05cd1a7c3fa04c2df8380b657fd3c0f2c3e4ab4
This commit is contained in:
Eugene Zhulenev 2019-12-10 18:16:23 -08:00 committed by TensorFlower Gardener
parent cdb11818ce
commit 8037ae3e0d

View File

@ -574,7 +574,7 @@ using fp32 = float;
using fp16 = Eigen::half;
// ResNet50-ish convolutions.
#define BENCHMARK_DTYPE(DATA_FORMAT, BATCH, T) \
#define BENCHMARK_RESNET50(DATA_FORMAT, BATCH, T) \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 56, 56, 64, 1, 1, 64, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 56, 56, 64, 1, 1, 256, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 56, 56, 256, 1, 1, 64, gpu); \
@ -588,7 +588,30 @@ using fp16 = Eigen::half;
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 14, 14, 256, 1, 1, 256, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 14, 14, 256, 1, 1, 1024, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 14, 14, 1024, 1, 1, 256, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 14, 14, 256, 3, 3, 256, gpu);
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 14, 14, 256, 3, 3, 256, gpu)
// NASnet-ish convolutions (Tensorflow models: slim/nets/nasnet).
#define BENCHMARK_NASNET(DATA_FORMAT, BATCH, T) \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 165, 165, 96, 1, 1, 96, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 83, 83, 84, 1, 1, 84, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 42, 42, 336, 1, 1, 336, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 42, 42, 168, 1, 1, 168, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 21, 21, 1008, 1, 1, 1008, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 21, 21, 336, 1, 1, 336, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 11, 11, 672, 1, 1, 672, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 21, 21, 2016, 1, 1, 2016, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 11, 11, 2688, 1, 1, 2688, gpu); \
BM_Conv2DFmt(T, DATA_FORMAT, BATCH, 11, 11, 4032, 1, 1, 4032, gpu)
#define BENCHMARK_DTYPE(DATA_FORMAT, BATCH, T) \
BENCHMARK_RESNET50(DATA_FORMAT, BATCH, T); \
BENCHMARK_NASNET(DATA_FORMAT, BATCH, T)
BENCHMARK_DTYPE(NHWC, 16, fp32);
BENCHMARK_DTYPE(NCHW, 16, fp32);
BENCHMARK_DTYPE(NHWC, 16, fp16);
BENCHMARK_DTYPE(NCHW, 16, fp16);
BENCHMARK_DTYPE(NHWC, 32, fp32);
BENCHMARK_DTYPE(NCHW, 32, fp32);