From 94461ab0560f639fb67bf32a3c546d2c6f46041c Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Tue, 29 Sep 2020 11:01:00 -0700 Subject: [PATCH] Support 5D tensors in binary ops of Layout Opt --- .../generic_layout_optimizer_transposer.cc | 92 +++++++++++++++---- .../generic_layout_optimizer_transposer.h | 8 +- .../python/grappler/layout_optimizer_test.py | 48 +++++++++- 3 files changed, 128 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index 74b86e4e9ba..eae86e7c18c 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -1062,12 +1062,41 @@ std::vector LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts( Status DefaultLayoutAgnosticOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutAgnosticOp(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || - !IsAfterDstToSrcTransform(*context, *node)) { + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + if (rank != 4 && rank != 5) { return Status::OK(); } + std::string src_format = context->src_format; + std::string dst_format = context->dst_format; + // Update the format from 4D to 5D layout if necessary. + bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") && + (dst_format == "NHWC" || dst_format == "NCHW"); + if (allow_5d) { + std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; + std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; + context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, + dst_format_3d); + } + if (!ShouldProcess(*context, *node) || + !IsAfterDstToSrcTransform(*context, *node)) { + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } + return Status::OK(); + } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + // Change back the format from 5D to 4D layout. + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1093,16 +1122,18 @@ bool BinaryOpTransposer::IsFaninShapeSupported( const utils::MutableNodeView& node) { return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) || IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) || - IsNDOperateWithMD(node, 1, 4)); + IsNDOperateWithMD(node, 1, 4) || IsNDOperateWithMD(node, 5, 0) || + IsNDOperateWithMD(node, 5, 1) || IsNDOperateWithMD(node, 5, 5) || + IsNDOperateWithMD(node, 0, 5) || IsNDOperateWithMD(node, 1, 5)); } -std::vector BinaryOpTransposer::Get4DDataFaninPorts( - const utils::MutableNodeView& node) { +std::vector BinaryOpTransposer::GetNDDataFaninPorts( + const utils::MutableNodeView& node, int rank) { std::vector values; - if (IsFaninPortRankN(node, 0, 4)) { + if (IsFaninPortRankN(node, 0, rank)) { values.push_back(0); } - if (IsFaninPortRankN(node, 1, 4)) { + if (IsFaninPortRankN(node, 1, rank)) { values.push_back(1); } return values; @@ -1137,7 +1168,8 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, absl::string_view node_device, bool node_in_frame, int num_channels, - absl::string_view depended_node) { + absl::string_view depended_node, + int rank) { NodeDef new_node; new_node.set_name(string(node_name)); new_node.set_op(kOpConst); @@ -1147,8 +1179,9 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, new_node.mutable_attr()->insert({"dtype", attr_data_type}); AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({4})); - std::vector shape = {1, num_channels, 1, 1}; + Tensor tensor(DT_INT32, TensorShape({rank})); + std::vector shape(rank, 1); + shape[1] = num_channels; for (int i = 0; i < static_cast(shape.size()); i++) { tensor.flat()(i) = shape[i]; } @@ -1167,11 +1200,11 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, } Status BinaryOpTransposer::MaybeReshapeVectorFanin( - TransposeContext* context, utils::MutableNodeView* node) { + TransposeContext* context, utils::MutableNodeView* node, int rank) { int vector_index = -1; - if (IsNDOperateWithMD(*node, 4, 1)) { + if (IsNDOperateWithMD(*node, rank, 1)) { vector_index = 1; - } else if (IsNDOperateWithMD(*node, 1, 4)) { + } else if (IsNDOperateWithMD(*node, 1, rank)) { vector_index = 0; } if (vector_index != -1) { @@ -1193,7 +1226,7 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin( TF_RETURN_IF_ERROR( AddNodeShapeConst(mutation, shape_const_node_name, node_device, context->frames.IsInFrame(*node->node()), vector_size, - fanin_node->GetName())); + fanin_node->GetName(), rank)); const auto* t_attr = node->GetAttr(kAttrT); if (t_attr == nullptr) { return errors::InvalidArgument("Missing attribute ", kAttrT); @@ -1211,14 +1244,41 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin( Status BinaryOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBinaryOp(*node->node())); + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + std::string src_format = context->src_format; + std::string dst_format = context->dst_format; + // Update the format from 4D to 5D layout if necessary. + bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") && + (dst_format == "NHWC" || dst_format == "NCHW"); + if (allow_5d) { + std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; + std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; + context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, + dst_format_3d); + } if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) || !IsAfterDstToSrcTransform(*context, *node)) { + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, Get4DDataFaninPorts(*node), + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetNDDataFaninPorts(*node, + rank), node, kOpTranspose)); - TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node)); + TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + // Change back the format from 5D to 4D layout. + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } return context->graph_view->GetMutationBuilder()->Apply(); } diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index 61720df791b..8db9ff0e70f 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -348,18 +348,20 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer { private: bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m); bool IsFaninShapeSupported(const utils::MutableNodeView& node); - std::vector Get4DDataFaninPorts(const utils::MutableNodeView& node); + std::vector GetNDDataFaninPorts(const utils::MutableNodeView& node, + int rank); Status AddNodeShapeConst(utils::Mutation* mutation, absl::string_view node_name, absl::string_view node_device, bool node_in_frame, - int num_channels, absl::string_view depended_node); + int num_channels, absl::string_view depended_node, + int rank); Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name, absl::string_view node_device, absl::string_view input_name, absl::string_view shape_const_node_name, const DataType& data_type); Status MaybeReshapeVectorFanin(TransposeContext* context, - utils::MutableNodeView* node); + utils::MutableNodeView* node, int rank); }; class ConcatOpTransposer : public LayoutAgnosticOpTransposer { diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index a69ed72db87..35430ea8664 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -193,7 +193,9 @@ def _get_cluster(): def _is_transpose(node): return node.endswith('TransposeNHWCToNCHW-LayoutOptimizer') or node.endswith( - 'TransposeNCHWToNHWC-LayoutOptimizer') + 'TransposeNCHWToNHWC-LayoutOptimizer') or node.endswith( + 'TransposeNDHWCToNCDHW-LayoutOptimizer') or node.endswith( + 'TransposeNCDHWToNDHWC-LayoutOptimizer') def _is_permute(node): @@ -1230,6 +1232,50 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only + def testBinaryOpsFor5DTensors(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0) + mean = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + variance = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME') + y = nn.batch_normalization(conv3d, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=0.001) + output = array_ops.identity(y) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # The binary ops mul_1 and add_1 in batch norm need to transpose one of + # the two inputs to NCDHW. The other input has already been tranposed via + # Conv3D. + expected_num_transposes = 4 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes) + self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes) + self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only def testConv3D(self): if test.is_gpu_available(cuda_only=True):