Merge pull request #43652 from kaixih:binary_ops_layout_fix

PiperOrigin-RevId: 334626406
Change-Id: I42370495af7eff160426f549f4479e238af9005f
This commit is contained in:
TensorFlower Gardener 2020-09-30 10:26:17 -07:00
commit e075feb345
3 changed files with 137 additions and 32 deletions

View File

@ -1062,12 +1062,41 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
Status DefaultLayoutAgnosticOpTransposer::TransposeNode( Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) { TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutAgnosticOp(*node->node())); DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
!IsAfterDstToSrcTransform(*context, *node)) { const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
if (rank != 4 && rank != 5) {
return Status::OK(); return Status::OK();
} }
std::string src_format = context->src_format;
std::string dst_format = context->dst_format;
// Update the format from 4D to 5D layout if necessary.
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
(dst_format == "NHWC" || dst_format == "NCHW");
if (allow_5d) {
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
dst_format_3d);
}
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
// Change back the format from 5D to 4D layout.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return context->graph_view->GetMutationBuilder()->Apply(); return context->graph_view->GetMutationBuilder()->Apply();
} }
@ -1090,19 +1119,20 @@ bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
} }
bool BinaryOpTransposer::IsFaninShapeSupported( bool BinaryOpTransposer::IsFaninShapeSupported(
const utils::MutableNodeView& node) { const utils::MutableNodeView& node, int rank) {
return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) || return (IsNDOperateWithMD(node, rank, 0) ||
IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) || IsNDOperateWithMD(node, rank, 1) ||
IsNDOperateWithMD(node, 1, 4)); IsNDOperateWithMD(node, rank, rank) ||
IsNDOperateWithMD(node, 0, rank) || IsNDOperateWithMD(node, 1, rank));
} }
std::vector<int> BinaryOpTransposer::Get4DDataFaninPorts( std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
const utils::MutableNodeView& node) { const utils::MutableNodeView& node, int rank) {
std::vector<int> values; std::vector<int> values;
if (IsFaninPortRankN(node, 0, 4)) { if (IsFaninPortRankN(node, 0, rank)) {
values.push_back(0); values.push_back(0);
} }
if (IsFaninPortRankN(node, 1, 4)) { if (IsFaninPortRankN(node, 1, rank)) {
values.push_back(1); values.push_back(1);
} }
return values; return values;
@ -1132,12 +1162,10 @@ Status BinaryOpTransposer::AddNodeReshape(
return status; return status;
} }
Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, Status BinaryOpTransposer::AddNodeShapeConst(
absl::string_view node_name, utils::Mutation* mutation, absl::string_view node_name,
absl::string_view node_device, absl::string_view node_device, bool node_in_frame, int num_channels,
bool node_in_frame, absl::string_view depended_node, int rank) {
int num_channels,
absl::string_view depended_node) {
NodeDef new_node; NodeDef new_node;
new_node.set_name(string(node_name)); new_node.set_name(string(node_name));
new_node.set_op(kOpConst); new_node.set_op(kOpConst);
@ -1147,8 +1175,9 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
new_node.mutable_attr()->insert({"dtype", attr_data_type}); new_node.mutable_attr()->insert({"dtype", attr_data_type});
AttrValue attr_tensor; AttrValue attr_tensor;
Tensor tensor(DT_INT32, TensorShape({4})); Tensor tensor(DT_INT32, TensorShape({rank}));
std::vector<int> shape = {1, num_channels, 1, 1}; std::vector<int> shape(rank, 1);
shape[1] = num_channels;
for (int i = 0; i < static_cast<int>(shape.size()); i++) { for (int i = 0; i < static_cast<int>(shape.size()); i++) {
tensor.flat<int>()(i) = shape[i]; tensor.flat<int>()(i) = shape[i];
} }
@ -1166,12 +1195,13 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
return status; return status;
} }
Status BinaryOpTransposer::MaybeReshapeVectorFanin( Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
TransposeContext* context, utils::MutableNodeView* node) { utils::MutableNodeView* node,
int rank) {
int vector_index = -1; int vector_index = -1;
if (IsNDOperateWithMD(*node, 4, 1)) { if (IsNDOperateWithMD(*node, rank, 1)) {
vector_index = 1; vector_index = 1;
} else if (IsNDOperateWithMD(*node, 1, 4)) { } else if (IsNDOperateWithMD(*node, 1, rank)) {
vector_index = 0; vector_index = 0;
} }
if (vector_index != -1) { if (vector_index != -1) {
@ -1193,7 +1223,7 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
AddNodeShapeConst(mutation, shape_const_node_name, node_device, AddNodeShapeConst(mutation, shape_const_node_name, node_device,
context->frames.IsInFrame(*node->node()), vector_size, context->frames.IsInFrame(*node->node()), vector_size,
fanin_node->GetName())); fanin_node->GetName(), rank));
const auto* t_attr = node->GetAttr(kAttrT); const auto* t_attr = node->GetAttr(kAttrT);
if (t_attr == nullptr) { if (t_attr == nullptr) {
return errors::InvalidArgument("Missing attribute ", kAttrT); return errors::InvalidArgument("Missing attribute ", kAttrT);
@ -1211,14 +1241,40 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(
Status BinaryOpTransposer::TransposeNode(TransposeContext* context, Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) { utils::MutableNodeView* node) {
DCHECK(IsBinaryOp(*node->node())); DCHECK(IsBinaryOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) || const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
std::string src_format = context->src_format;
std::string dst_format = context->dst_format;
// Update the format from 4D to 5D layout if necessary.
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
(dst_format == "NHWC" || dst_format == "NCHW");
if (allow_5d) {
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
dst_format_3d);
}
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
!IsAfterDstToSrcTransform(*context, *node)) { !IsAfterDstToSrcTransform(*context, *node)) {
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return Status::OK(); return Status::OK();
} }
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, Get4DDataFaninPorts(*node), VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
node, kOpTranspose)); << "' with op '" << node->GetOp() << "' from data format '"
TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node)); << context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
context, GetNDDataFaninPorts(*node, rank), node, kOpTranspose));
TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
// Change back the format from 5D to 4D layout.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return context->graph_view->GetMutationBuilder()->Apply(); return context->graph_view->GetMutationBuilder()->Apply();
} }

View File

@ -347,19 +347,21 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
private: private:
bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m); bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
bool IsFaninShapeSupported(const utils::MutableNodeView& node); bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
std::vector<int> Get4DDataFaninPorts(const utils::MutableNodeView& node); std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
int rank);
Status AddNodeShapeConst(utils::Mutation* mutation, Status AddNodeShapeConst(utils::Mutation* mutation,
absl::string_view node_name, absl::string_view node_name,
absl::string_view node_device, bool node_in_frame, absl::string_view node_device, bool node_in_frame,
int num_channels, absl::string_view depended_node); int num_channels, absl::string_view depended_node,
int rank);
Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name, Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name,
absl::string_view node_device, absl::string_view node_device,
absl::string_view input_name, absl::string_view input_name,
absl::string_view shape_const_node_name, absl::string_view shape_const_node_name,
const DataType& data_type); const DataType& data_type);
Status MaybeReshapeVectorFanin(TransposeContext* context, Status MaybeReshapeVectorFanin(TransposeContext* context,
utils::MutableNodeView* node); utils::MutableNodeView* node, int rank);
}; };
class ConcatOpTransposer : public LayoutAgnosticOpTransposer { class ConcatOpTransposer : public LayoutAgnosticOpTransposer {

View File

@ -193,7 +193,9 @@ def _get_cluster():
def _is_transpose(node): def _is_transpose(node):
return node.endswith('TransposeNHWCToNCHW-LayoutOptimizer') or node.endswith( return node.endswith('TransposeNHWCToNCHW-LayoutOptimizer') or node.endswith(
'TransposeNCHWToNHWC-LayoutOptimizer') 'TransposeNCHWToNHWC-LayoutOptimizer') or node.endswith(
'TransposeNDHWCToNCDHW-LayoutOptimizer') or node.endswith(
'TransposeNCDHWToNDHWC-LayoutOptimizer')
def _is_permute(node): def _is_permute(node):
@ -1230,6 +1232,51 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes) self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3) self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testBinaryOpsFor5DTensors(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
mean = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
variance = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
y = nn.batch_normalization(
conv3d,
mean=mean,
variance=variance,
scale=gamma,
offset=beta,
variance_epsilon=0.001)
output = array_ops.identity(y)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
# The binary ops mul_1 and add_1 in batch norm need to transpose one of
# the two inputs to NCDHW. The other input has already been tranposed via
# Conv3D.
expected_num_transposes = 4
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only @test_util.deprecated_graph_mode_only
def testConv3D(self): def testConv3D(self):
if test.is_gpu_available(cuda_only=True): if test.is_gpu_available(cuda_only=True):