Merge pull request from kaixih:layout_opt_other_ops_pr_v2

PiperOrigin-RevId: 344072074
Change-Id: I13a22cdd5595bf8c6699c92a4c8d706f6527a519
This commit is contained in:
TensorFlower Gardener 2020-11-24 09:20:22 -08:00
commit 24ae9d458b
7 changed files with 456 additions and 120 deletions

View File

@ -91,15 +91,15 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_));
OP_REQUIRES(
ctx, src_format_.size() == 4,
errors::InvalidArgument("Data format should have 4 characters"));
ctx, src_format_.size() == 4 || src_format_.size() == 5,
errors::InvalidArgument("Data format should have 4 or 5 characters"));
TensorFormat data_format;
OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_));
OP_REQUIRES(
ctx, dst_format_.size() == 4,
errors::InvalidArgument("Data format should have 4 characters"));
ctx, dst_format_.size() == 4 || dst_format_.size() == 5,
errors::InvalidArgument("Data format should have 4 or 5 characters"));
OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format),
errors::InvalidArgument("Invalid data format"));
}
@ -113,9 +113,10 @@ class DataFormatVecPermuteOp : public XlaOpKernel {
input_tensor_shape.DebugString()));
const int dim0 = input_tensor_shape.dim_size(0);
OP_REQUIRES(
ctx, dim0 == 2 || dim0 == 4,
ctx, dim0 == 2 || dim0 == 4 || dim0 == 5,
errors::InvalidArgument(
"First dimension of input must be of size 4, but got shape ",
"First dimension of input must be of size 2, 4 or 5, but got "
"shape ",
input_tensor_shape.DebugString()));
if (input_rank == 2) {
OP_REQUIRES(

View File

@ -113,6 +113,8 @@ bool IsBiasAdd(const NodeDef& node) {
return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
}
bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }

View File

@ -45,6 +45,7 @@ bool IsAtan2(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsBetainc(const NodeDef& node);
bool IsBiasAdd(const NodeDef& node);
bool IsBiasAddV2(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node);
bool IsBitcast(const NodeDef& node);
bool IsBroadcastTo(const NodeDef& node);

View File

@ -767,16 +767,51 @@ Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
return context->graph_view->GetMutationBuilder()->Apply();
}
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBiasAddGrad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4)) {
Status BiasAddTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
// This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd
// supports different data format.
DCHECK(IsBiasAddV2(*node->node()));
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
// BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
// second or the last dim. Therefore, we use the original 4D data format in
// the context to update the node. For the input/output tensor, the
// corresponding 4D or 5D data format is needed.
TF_RETURN_IF_ERROR(UpdateNode(context, node));
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
return context->graph_view->GetMutationBuilder()->Apply();
}
Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsBiasAddGrad(*node->node()));
const int rank = GetFaninPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
// BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
// second or the last dim. Therefore, we use the original 4D data format in
// the context to update the node. For the input tensor, the corresponding 4D
// or 5D data format is needed.
TF_RETURN_IF_ERROR(UpdateNode(context, node));
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
// No need to update output shape, as it is always of shape 1-D with size the
// feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
@ -839,7 +874,12 @@ Status Conv2DBackpropInputTransposer::TransposeNode(
Status Conv3DTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsConv3D(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -854,7 +894,12 @@ Status Conv3DTransposer::TransposeNode(TransposeContext* context,
Status Conv3DBackpropFilterTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsConv3DBackpropFilterV2(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -872,7 +917,12 @@ Status Conv3DBackpropFilterTransposer::TransposeNode(
Status Conv3DBackpropInputTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsConv3DBackpropInputV2(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) {
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -1081,8 +1131,9 @@ bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
return false;
}
std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
const TransposeContext& context, const utils::MutableNodeView& node) const {
std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
const TransposeContext& context, const utils::MutableNodeView& node,
int rank) const {
std::vector<int> ports;
const int num_regular_fanins = node.NumRegularFanins();
ports.reserve(num_regular_fanins);
@ -1090,7 +1141,7 @@ std::vector<int> LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts(
const auto& regular_fanin = node.GetRegularFanin(i);
auto* regular_fanin_node = regular_fanin.node_view();
int regular_fanin_port = regular_fanin.index();
if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, 4) &&
if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
IsLayoutOptimizerAddedDstToSrcTranspose(context,
@ -1124,10 +1175,18 @@ Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
Status AddNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsAddN(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
@ -1284,7 +1343,12 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsConcat(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
@ -1297,6 +1361,9 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
axis_node = n_attr->i();
}
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
@ -1320,14 +1387,33 @@ Status FillOpTransposer::TransposeNode(TransposeContext* context,
Status IdentityNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsIdentityN(*node->node()));
const auto ports = GetVariadic4DFaninPorts(*context, *node);
if (!ShouldProcess(*context, *node) || ports.empty()) {
const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
// Temporarily upgrade the context to obtain the number of 5D fanin ports.
std::vector<int> ports_5d;
{
ScopedDataFormatUpgrader data_format_upgrader(context, 5);
ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
}
if (!ShouldProcess(*context, *node)) {
return Status::OK();
}
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
if (!ports_4d.empty()) {
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
}
if (!ports_5d.empty()) {
ScopedDataFormatUpgrader data_format_upgrader(context, 5);
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
}
return context->graph_view->GetMutationBuilder()->Apply();
}
@ -1528,10 +1614,18 @@ Status SelectTransposer::TransposeNode(TransposeContext* context,
Status ShapeTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsShape(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
const int rank = GetFaninPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
@ -1541,10 +1635,20 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context,
Status ShapeNTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsShapeN(*node->node()));
const auto ports = GetVariadic4DFaninPorts(*context, *node);
// ShapeN requires all input tensors to have the same dimensions. Therefore,
// we simply use the 0th fanin port.
const int rank = GetFaninPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
if (!ShouldProcess(*context, *node) || ports.empty()) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
TF_RETURN_IF_ERROR(
@ -1555,11 +1659,19 @@ Status ShapeNTransposer::TransposeNode(TransposeContext* context,
Status SliceTransposer::TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) {
DCHECK(IsSlice(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
!IsFaninPortsDimsNIfConst(*node, {1, 2}, {4}) ||
const int rank = GetFanoutPortRank(*node, 0);
if (rank != 4 && rank != 5) {
return Status::OK();
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
<< "' with op '" << node->GetOp() << "' from data format '"
<< context->src_format << "' to '" << context->dst_format << "'";
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
@ -1848,18 +1960,17 @@ string GetDeviceName(const VirtualPlacer* virtual_placer, const NodeDef& node) {
bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
static absl::flat_hash_set<string>* default_layout_sensitive_ops =
new absl::flat_hash_set<std::string>(
{"AvgPool", "BiasAdd", "Conv2D", "DepthwiseConv2dNative",
"DepthToSpace", "FusedBatchNorm", "FusedBatchNormV2",
"FusedBatchNormV3", "FusedConv2DBiasActivation", "MaxPool",
"SpaceToDepth"});
{"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
"FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
return default_layout_sensitive_ops->find(node.op()) !=
default_layout_sensitive_ops->end();
}
bool IsLayoutSensitiveOp(const NodeDef& node) {
return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
IsBiasAddGrad(node) || IsConv2DBackpropFilter(node) ||
IsConv2DBackpropInput(node) ||
IsBiasAddV2(node) || IsBiasAddGrad(node) ||
IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
IsDepthwiseConv2dNativeBackpropFilter(node) ||
IsDepthwiseConv2dNativeBackpropInput(node) ||
IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||

View File

@ -210,6 +210,14 @@ class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
utils::MutableNodeView* node) override;
};
class BiasAddTransposer : public LayoutSensitiveOpTransposer {
public:
explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
Status TransposeNode(TransposeContext* context,
utils::MutableNodeView* node) override;
};
class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
public:
explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
@ -319,9 +327,9 @@ class LayoutAgnosticOpTransposer : public Transposer {
bool IsAfterDstToSrcTransform(const TransposeContext& context,
const utils::MutableNodeView& node) const;
std::vector<int> GetVariadic4DFaninPorts(
const TransposeContext& context,
const utils::MutableNodeView& node) const;
std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
const utils::MutableNodeView& node,
int rank) const;
};
class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {

View File

@ -30,6 +30,9 @@ std::shared_ptr<Transposer> TransposerFactory::GetTransposer(
if (IsAvgPoolGrad(node)) {
return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad");
}
if (IsBiasAddV2(node)) {
return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
}
if (IsBiasAddGrad(node)) {
return GetOrCreateIfNotFound<BiasAddGradTransposer>("BiasAddGrad");
}

View File

@ -782,6 +782,45 @@ class LayoutOptimizerTest(test.TestCase):
self.assertIn('concat-2-LayoutOptimizer', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testConcatWithControlDependencyFor5DTensor(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0)
strides = [1, 1, 1, 1, 1]
y = gen_nn_ops.conv3d(x, w, strides, 'SAME')
axis = constant_op.constant(4)
var = variables.Variable(3)
assign = state_ops.assign(var, 6)
with ops.control_dependencies([assign]):
concat = array_ops.concat([y, y], axis)
output = array_ops.identity(concat)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = self.evaluate(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
# Four transposes were initially added in the Expand phase of
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ncdhw_to_ndhwc('concat-0-0', nodes)
self._assert_map_ndhwc_to_ncdhw('concat-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testFill(self):
if test.is_gpu_available(cuda_only=True):
@ -1397,107 +1436,167 @@ class LayoutOptimizerTest(test.TestCase):
@test_util.deprecated_graph_mode_only
def testConv3D(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
filters = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0)
strides_val = [1, 1, 1, 1, 1]
x_3d = array_ops.reshape(conv, [-1, 4, 14, 14, 1])
conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'VALID')
output = array_ops.identity(conv3d)
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0)
strides = [1, 1, 1, 1, 1]
y = gen_nn_ops.conv3d(x, w, strides, 'SAME')
output = array_ops.identity(y)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ncdhw_to_ndhwc('Conv3D-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ncdhw_to_ndhwc('Conv3D-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testConv3DBackpropInput(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
x_3d = array_ops.reshape(conv, [-1, 4, 14, 14, 1])
filters = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0)
strides_val = [1, 1, 1, 1, 1]
shape = array_ops.shape(x_3d)
conv3d_grad = gen_nn_ops.conv3d_backprop_input_v2(shape, filters, x_3d,
strides_val, 'SAME')
output = array_ops.identity(conv3d_grad)
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
dy = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0)
strides = [1, 1, 1, 1, 1]
x_shape = array_ops.shape(dy)
dx = gen_nn_ops.conv3d_backprop_input_v2(x_shape, w, dy, strides, 'SAME')
output = array_ops.identity(dx)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_vec_ndhwc_to_ncdhw('Conv3DBackpropInputV2-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropInputV2-2', nodes)
self._assert_trans_ncdhw_to_ndhwc('Conv3DBackpropInputV2-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_vec_ndhwc_to_ncdhw('Conv3DBackpropInputV2-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropInputV2-2', nodes)
self._assert_trans_ncdhw_to_ndhwc('Conv3DBackpropInputV2-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testConv3DBackpropFilter(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
conv = _two_layer_model(x)
x_3d = array_ops.reshape(conv, [-1, 4, 14, 14, 1])
filters = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0)
strides_val = [1, 1, 1, 1, 1]
shape = constant_op.constant([2, 2, 2, 1, 1], shape=[5])
conv3d_grad = gen_nn_ops.conv3d_backprop_filter_v2(
x_3d, shape, x_3d, strides_val, 'SAME')
output = array_ops.identity(conv3d_grad)
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
dy = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
strides = [1, 1, 1, 1, 1]
w_shape = constant_op.constant([2, 2, 2, 1, 1], shape=[5])
dw = gen_nn_ops.conv3d_backprop_filter_v2(x, w_shape, dy, strides, 'SAME')
output = array_ops.identity(dw)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropFilterV2-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testBiasAddFor5DTensor(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0)
b = random_ops.truncated_normal([2], seed=0)
strides = [1, 1, 1, 1, 1]
y = gen_nn_ops.conv3d(x, w, strides, 'SAME')
y = gen_nn_ops.bias_add(y, b, 'NHWC')
output = array_ops.identity(y)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ncdhw_to_ndhwc('BiasAdd-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testBiasAddGradFor5DTensor(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
dy = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 1, 1], seed=0)
strides = [1, 1, 1, 1, 1]
dy_shape = array_ops.shape(dy)
dx = gen_nn_ops.conv3d_backprop_input_v2(dy_shape, w, dy, strides, 'SAME')
db = gen_nn_ops.bias_add_grad(dx, 'NHWC')
output = array_ops.identity(db)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
# The output of Conv3DBackpropInputV2 won't be converted back to NDHWC
# because of the BiasAddGrad.
expected_num_transposes = 1
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_vec_ndhwc_to_ncdhw('Conv3DBackpropInputV2-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('Conv3DBackpropInputV2-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testSliceWithNonConstAxis(self):
@ -1536,6 +1635,44 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_vec_nhwc_to_nchw('Slice-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testSliceWithNonConstAxisFor5DTensor(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([2, 2, 14, 14, 1], seed=0)
w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0)
strides = [1, 1, 1, 1, 1]
y = gen_nn_ops.conv3d(x, w, strides, 'SAME')
size = array_ops.placeholder(dtype='int32')
s = array_ops.slice(y, [0, 0, 0, 0, 0], size)
output = array_ops.identity(s)
size_val = [1, 1, 2, 2, 1]
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output, feed_dict={size: size_val})
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(
output, run_metadata=metadata, feed_dict={size: size_val})
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
# Four transposes were initially added in the Expand phase of
# LayoutOptimizer; two of them are cancelled out in the Collapse phase.
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ncdhw_to_ndhwc('Slice-0-0', nodes)
self._assert_vec_ndhwc_to_ncdhw('Slice-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testStridedSliceWithNonConstAxis(self):
if test.is_gpu_available(cuda_only=True):
@ -1722,6 +1859,79 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_vec_nchw_to_nhwc('ShapeN-0-0', nodes)
self.assertAllEqual(output_val_ref, output_val)
@test_util.deprecated_graph_mode_only
def testShapeNFor5DTensor(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
h = array_ops.placeholder(dtype='float32')
x = array_ops.reshape(h, [-1, 2, 14, 14, 1])
w = random_ops.truncated_normal([2, 2, 2, 1, 2], seed=0)
strides = [1, 1, 1, 1, 1]
y = gen_nn_ops.conv3d(x, w, strides, 'SAME')
shapen = array_ops.shape_n([y, y])
output = math_ops.add(shapen[0], shapen[1])
x_val = [1.7] * 784
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output, feed_dict={h: x_val})
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata, feed_dict={h: x_val})
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 1
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_vec_ncdhw_to_ndhwc('ShapeN-0-0', nodes)
self._assert_vec_ncdhw_to_ndhwc('ShapeN-1-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testIdentityNFor4DAnd5DTensors(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
h = array_ops.placeholder(dtype='float32')
x = array_ops.reshape(h, [-1, 2, 14, 14, 1])
w = random_ops.truncated_normal([2, 2, 2, 1, 4], seed=0)
strides = [1, 1, 1, 1, 1]
y = gen_nn_ops.conv3d(x, w, strides, 'SAME')
x1 = array_ops.reshape(h, [-1, 784])
y1 = _two_layer_model(x1)
outputs = array_ops.identity_n([y1, y])
new_x0 = array_ops.reshape(outputs[0], [-1, 2, 14, 14, 1])
new_x1 = array_ops.reshape(outputs[1], [-1, 2, 14, 14, 1])
output = math_ops.add(new_x0, new_x1)
x_val = [1.7] * 784
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output, feed_dict={h: x_val})
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata, feed_dict={h: x_val})
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 4
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self._assert_trans_ncdhw_to_ndhwc('IdentityN-1-0', nodes)
self._assert_trans_nchw_to_nhwc('IdentityN-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testShapeNFollowedByNotConvertibleNodeReshape(self):
if test.is_gpu_available(cuda_only=True):