Also check dst_format

This commit is contained in:
Kaixi Hou 2020-09-23 12:04:10 -07:00
parent d602375436
commit 9b947dd637

View File

@ -1379,7 +1379,8 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
std::string src_format = context->src_format;
std::string dst_format = context->dst_format;
// Update the format from 4D to 5D layout if necessary.
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
(dst_format == "NHWC" || dst_format == "NCHW");
if (allow_5d) {
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";