Support 5D tensors in binary ops of Layout Opt

This commit is contained in:
Kaixi Hou 2020-09-29 11:01:00 -07:00
parent 11823c6179
commit 94461ab056
3 changed files with 128 additions and 20 deletions

View File

@ -1062,12 +1062,41 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsAfterDstToSrcTransform(*context, *node)) {
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
if (rank != 4 && rank != 5) {
return Status::OK();
}
std::string src_format = context->src_format;
std::string dst_format = context->dst_format;
// Update the format from 4D to 5D layout if necessary.
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
(dst_format == "NHWC" || dst_format == "NCHW");
if (allow_5d) {
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
dst_format_3d);
}
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
// Change back the format from 5D to 4D layout.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return context->graph_view->GetMutationBuilder()->Apply();
}
@ -1093,16 +1122,18 @@ bool BinaryOpTransposer::IsFaninShapeSupported(
const utils::MutableNodeView& node) {
return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) ||
IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) ||
IsNDOperateWithMD(node, 1, 4));
IsNDOperateWithMD(node, 1, 4) || IsNDOperateWithMD(node, 5, 0) ||
IsNDOperateWithMD(node, 5, 1) || IsNDOperateWithMD(node, 5, 5) ||
IsNDOperateWithMD(node, 0, 5) || IsNDOperateWithMD(node, 1, 5));
}
std::vector<int> BinaryOpTransposer::Get4DDataFaninPorts(
const utils::MutableNodeView& node) {
std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
const utils::MutableNodeView& node, int rank) {
std::vector<int> values;
if (IsFaninPortRankN(node, 0, 4)) {
if (IsFaninPortRankN(node, 0, rank)) {
values.push_back(0);
}
if (IsFaninPortRankN(node, 1, 4)) {
if (IsFaninPortRankN(node, 1, rank)) {
values.push_back(1);
}
return values;
@ -1137,7 +1168,8 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
absl::string_view node_device,
bool node_in_frame,
int num_channels,
absl::string_view depended_node) {
absl::string_view depended_node,
int rank) {
NodeDef new_node;
new_node.set_name(string(node_name));
new_node.set_op(kOpConst);
@ -1147,8 +1179,9 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
new_node.mutable_attr()->insert({"dtype", attr_data_type});
AttrValue attr_tensor;
Tensor tensor(DT_INT32, TensorShape({4}));
std::vector<int> shape = {1, num_channels, 1, 1};
Tensor tensor(DT_INT32, TensorShape({rank}));
std::vector<int> shape(rank, 1);
shape[1] = num_channels;
for (int i = 0; i < static_cast<int>(shape.size()); i++) {
tensor.flat<int>()(i) = shape[i];
}
@ -1167,11 +1200,11 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
}
Status BinaryOpTransposer::MaybeReshapeVectorFanin(
TransposeContext* context, utils::MutableNodeView* node) {
TransposeContext* context, utils::MutableNodeView* node, int rank) {
int vector_index = -1;
if (IsNDOperateWithMD(*node, 4, 1)) {
if (IsNDOperateWithMD(*node, rank, 1)) {
vector_index = 1;
} else if (IsNDOperateWithMD(*node, 1, 4)) {
} else if (IsNDOperateWithMD(*node, 1, rank)) {
vector_index = 0;
}
if (vector_index != -1) {
@ -1193,7 +1226,7 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(
TF_RETURN_IF_ERROR(
AddNodeShapeConst(mutation, shape_const_node_name, node_device,
context->frames.IsInFrame(*node->node()), vector_size,
fanin_node->GetName()));
fanin_node->GetName(), rank));
const auto* t_attr = node->GetAttr(kAttrT);
if (t_attr == nullptr) {
return errors::InvalidArgument("Missing attribute ", kAttrT);
@ -1211,14 +1244,41 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(
Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBinaryOp(*node->node()));
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
std::string src_format = context->src_format;
std::string dst_format = context->dst_format;
// Update the format from 4D to 5D layout if necessary.
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
(dst_format == "NHWC" || dst_format == "NCHW");
if (allow_5d) {
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
dst_format_3d);
}
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return Status::OK();
}
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, Get4DDataFaninPorts(*node),
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetNDDataFaninPorts(*node,
rank),
node, kOpTranspose));
TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node));
TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
// Change back the format from 5D to 4D layout.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return context->graph_view->GetMutationBuilder()->Apply();
}

View File

@ -348,18 +348,20 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
private:
bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
bool IsFaninShapeSupported(const utils::MutableNodeView& node);
std::vector<int> Get4DDataFaninPorts(const utils::MutableNodeView& node);
std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
int rank);
Status AddNodeShapeConst(utils::Mutation* mutation,
absl::string_view node_name,
absl::string_view node_device, bool node_in_frame,
int num_channels, absl::string_view depended_node);
int num_channels, absl::string_view depended_node,
int rank);
Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name,
absl::string_view node_device,
absl::string_view input_name,
absl::string_view shape_const_node_name,
const DataType& data_type);
Status MaybeReshapeVectorFanin(TransposeContext* context,
utils::MutableNodeView* node);
utils::MutableNodeView* node, int rank);
};
class ConcatOpTransposer : public LayoutAgnosticOpTransposer {

View File

@ -193,7 +193,9 @@ def _get_cluster():
def _is_transpose(node):
return node.endswith('TransposeNHWCToNCHW-LayoutOptimizer') or node.endswith(
'TransposeNCHWToNHWC-LayoutOptimizer')
'TransposeNCHWToNHWC-LayoutOptimizer') or node.endswith(
'TransposeNDHWCToNCDHW-LayoutOptimizer') or node.endswith(
'TransposeNCDHWToNDHWC-LayoutOptimizer')
def _is_permute(node):
@ -1230,6 +1232,50 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testBinaryOpsFor5DTensors(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
mean = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
variance = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
y = nn.batch_normalization(conv3d,
mean=mean,
variance=variance,
scale=gamma,
offset=beta,
variance_epsilon=0.001)
output = array_ops.identity(y)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
# The binary ops mul_1 and add_1 in batch norm need to transpose one of
# the two inputs to NCDHW. The other input has already been tranposed via
# Conv3D.
expected_num_transposes = 4
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testConv3D(self):
if test.is_gpu_available(cuda_only=True):