Merge pull request #43652 from kaixih:binary_ops_layout_fix
PiperOrigin-RevId: 334626406 Change-Id: I42370495af7eff160426f549f4479e238af9005f
This commit is contained in:
commit
e075feb345
@ -1062,12 +1062,41 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
|
||||
Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
if (rank != 4 && rank != 5) {
|
||||
return Status::OK();
|
||||
}
|
||||
std::string src_format = context->src_format;
|
||||
std::string dst_format = context->dst_format;
|
||||
// Update the format from 4D to 5D layout if necessary.
|
||||
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
|
||||
(dst_format == "NHWC" || dst_format == "NCHW");
|
||||
if (allow_5d) {
|
||||
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
dst_format_3d);
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
// Change back the format from 5D to 4D layout.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
@ -1090,19 +1119,20 @@ bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
|
||||
}
|
||||
|
||||
bool BinaryOpTransposer::IsFaninShapeSupported(
|
||||
const utils::MutableNodeView& node) {
|
||||
return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) ||
|
||||
IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) ||
|
||||
IsNDOperateWithMD(node, 1, 4));
|
||||
const utils::MutableNodeView& node, int rank) {
|
||||
return (IsNDOperateWithMD(node, rank, 0) ||
|
||||
IsNDOperateWithMD(node, rank, 1) ||
|
||||
IsNDOperateWithMD(node, rank, rank) ||
|
||||
IsNDOperateWithMD(node, 0, rank) || IsNDOperateWithMD(node, 1, rank));
|
||||
}
|
||||
|
||||
std::vector<int> BinaryOpTransposer::Get4DDataFaninPorts(
|
||||
const utils::MutableNodeView& node) {
|
||||
std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
|
||||
const utils::MutableNodeView& node, int rank) {
|
||||
std::vector<int> values;
|
||||
if (IsFaninPortRankN(node, 0, 4)) {
|
||||
if (IsFaninPortRankN(node, 0, rank)) {
|
||||
values.push_back(0);
|
||||
}
|
||||
if (IsFaninPortRankN(node, 1, 4)) {
|
||||
if (IsFaninPortRankN(node, 1, rank)) {
|
||||
values.push_back(1);
|
||||
}
|
||||
return values;
|
||||
@ -1132,12 +1162,10 @@ Status BinaryOpTransposer::AddNodeReshape(
|
||||
return status;
|
||||
}
|
||||
|
||||
Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
|
||||
absl::string_view node_name,
|
||||
absl::string_view node_device,
|
||||
bool node_in_frame,
|
||||
int num_channels,
|
||||
absl::string_view depended_node) {
|
||||
Status BinaryOpTransposer::AddNodeShapeConst(
|
||||
utils::Mutation* mutation, absl::string_view node_name,
|
||||
absl::string_view node_device, bool node_in_frame, int num_channels,
|
||||
absl::string_view depended_node, int rank) {
|
||||
NodeDef new_node;
|
||||
new_node.set_name(string(node_name));
|
||||
new_node.set_op(kOpConst);
|
||||
@ -1147,8 +1175,9 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
|
||||
new_node.mutable_attr()->insert({"dtype", attr_data_type});
|
||||
|
||||
AttrValue attr_tensor;
|
||||
Tensor tensor(DT_INT32, TensorShape({4}));
|
||||
std::vector<int> shape = {1, num_channels, 1, 1};
|
||||
Tensor tensor(DT_INT32, TensorShape({rank}));
|
||||
std::vector<int> shape(rank, 1);
|
||||
shape[1] = num_channels;
|
||||
for (int i = 0; i < static_cast<int>(shape.size()); i++) {
|
||||
tensor.flat<int>()(i) = shape[i];
|
||||
}
|
||||
@ -1166,12 +1195,13 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation,
|
||||
return status;
|
||||
}
|
||||
|
||||
Status BinaryOpTransposer::MaybeReshapeVectorFanin(
|
||||
TransposeContext* context, utils::MutableNodeView* node) {
|
||||
Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
|
||||
utils::MutableNodeView* node,
|
||||
int rank) {
|
||||
int vector_index = -1;
|
||||
if (IsNDOperateWithMD(*node, 4, 1)) {
|
||||
if (IsNDOperateWithMD(*node, rank, 1)) {
|
||||
vector_index = 1;
|
||||
} else if (IsNDOperateWithMD(*node, 1, 4)) {
|
||||
} else if (IsNDOperateWithMD(*node, 1, rank)) {
|
||||
vector_index = 0;
|
||||
}
|
||||
if (vector_index != -1) {
|
||||
@ -1193,7 +1223,7 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddNodeShapeConst(mutation, shape_const_node_name, node_device,
|
||||
context->frames.IsInFrame(*node->node()), vector_size,
|
||||
fanin_node->GetName()));
|
||||
fanin_node->GetName(), rank));
|
||||
const auto* t_attr = node->GetAttr(kAttrT);
|
||||
if (t_attr == nullptr) {
|
||||
return errors::InvalidArgument("Missing attribute ", kAttrT);
|
||||
@ -1211,14 +1241,40 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(
|
||||
Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
|
||||
utils::MutableNodeView* node) {
|
||||
DCHECK(IsBinaryOp(*node->node()));
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) ||
|
||||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||
const auto& shape = output_shape_attr->list().shape(0);
|
||||
const int rank = shape.dim_size();
|
||||
std::string src_format = context->src_format;
|
||||
std::string dst_format = context->dst_format;
|
||||
// Update the format from 4D to 5D layout if necessary.
|
||||
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
|
||||
(dst_format == "NHWC" || dst_format == "NCHW");
|
||||
if (allow_5d) {
|
||||
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||
dst_format_3d);
|
||||
}
|
||||
if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, Get4DDataFaninPorts(*node),
|
||||
node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node));
|
||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
|
||||
context, GetNDDataFaninPorts(*node, rank), node, kOpTranspose));
|
||||
TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank));
|
||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||
// Change back the format from 5D to 4D layout.
|
||||
if (allow_5d) {
|
||||
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||
dst_format);
|
||||
}
|
||||
return context->graph_view->GetMutationBuilder()->Apply();
|
||||
}
|
||||
|
||||
|
@ -347,19 +347,21 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
|
||||
|
||||
private:
|
||||
bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
|
||||
bool IsFaninShapeSupported(const utils::MutableNodeView& node);
|
||||
std::vector<int> Get4DDataFaninPorts(const utils::MutableNodeView& node);
|
||||
bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
|
||||
std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
|
||||
int rank);
|
||||
Status AddNodeShapeConst(utils::Mutation* mutation,
|
||||
absl::string_view node_name,
|
||||
absl::string_view node_device, bool node_in_frame,
|
||||
int num_channels, absl::string_view depended_node);
|
||||
int num_channels, absl::string_view depended_node,
|
||||
int rank);
|
||||
Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name,
|
||||
absl::string_view node_device,
|
||||
absl::string_view input_name,
|
||||
absl::string_view shape_const_node_name,
|
||||
const DataType& data_type);
|
||||
Status MaybeReshapeVectorFanin(TransposeContext* context,
|
||||
utils::MutableNodeView* node);
|
||||
utils::MutableNodeView* node, int rank);
|
||||
};
|
||||
|
||||
class ConcatOpTransposer : public LayoutAgnosticOpTransposer {
|
||||
|
@ -193,7 +193,9 @@ def _get_cluster():
|
||||
|
||||
def _is_transpose(node):
|
||||
return node.endswith('TransposeNHWCToNCHW-LayoutOptimizer') or node.endswith(
|
||||
'TransposeNCHWToNHWC-LayoutOptimizer')
|
||||
'TransposeNCHWToNHWC-LayoutOptimizer') or node.endswith(
|
||||
'TransposeNDHWCToNCDHW-LayoutOptimizer') or node.endswith(
|
||||
'TransposeNCDHWToNDHWC-LayoutOptimizer')
|
||||
|
||||
|
||||
def _is_permute(node):
|
||||
@ -1230,6 +1232,51 @@ class LayoutOptimizerTest(test.TestCase):
|
||||
self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testBinaryOpsFor5DTensors(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
||||
w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
||||
mean = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
||||
variance = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
||||
gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
||||
beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
|
||||
conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
|
||||
y = nn.batch_normalization(
|
||||
conv3d,
|
||||
mean=mean,
|
||||
variance=variance,
|
||||
scale=gamma,
|
||||
offset=beta,
|
||||
variance_epsilon=0.001)
|
||||
output = array_ops.identity(y)
|
||||
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
output_val_ref = sess.run(output)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
metadata = config_pb2.RunMetadata()
|
||||
output_val = sess.run(output, run_metadata=metadata)
|
||||
|
||||
nodes = []
|
||||
num_transposes = 0
|
||||
for node in metadata.cost_graph.node:
|
||||
if _is_transpose(node.name):
|
||||
num_transposes += 1
|
||||
nodes.append(node.name)
|
||||
|
||||
# The binary ops mul_1 and add_1 in batch norm need to transpose one of
|
||||
# the two inputs to NCDHW. The other input has already been tranposed via
|
||||
# Conv3D.
|
||||
expected_num_transposes = 4
|
||||
self.assertEqual(expected_num_transposes, num_transposes)
|
||||
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
||||
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
|
||||
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
|
||||
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testConv3D(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
|
Loading…
Reference in New Issue
Block a user