[Grappler] Fuse Relu6 and Elu into Conv2D
PiperOrigin-RevId: 242907340
This commit is contained in:
parent
12b9e3ad86
commit
2ace7a7d9d
@ -211,6 +211,8 @@ bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
|
||||||
|
|
||||||
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
|
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
|
||||||
|
|
||||||
bool IsEnter(const NodeDef& node) {
|
bool IsEnter(const NodeDef& node) {
|
||||||
@ -412,6 +414,8 @@ bool IsReduction(const NodeDef& node) {
|
|||||||
|
|
||||||
bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
|
bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
|
||||||
|
|
||||||
|
bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; }
|
||||||
|
|
||||||
bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
|
bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
|
||||||
|
|
||||||
bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
|
bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
|
||||||
|
@ -66,6 +66,7 @@ bool IsDequeueOp(const NodeDef& node);
|
|||||||
bool IsDiv(const NodeDef& node);
|
bool IsDiv(const NodeDef& node);
|
||||||
bool IsDivNoNan(const NodeDef& node);
|
bool IsDivNoNan(const NodeDef& node);
|
||||||
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
|
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
|
||||||
|
bool IsElu(const NodeDef& node);
|
||||||
bool IsEluGrad(const NodeDef& node);
|
bool IsEluGrad(const NodeDef& node);
|
||||||
bool IsEnter(const NodeDef& node);
|
bool IsEnter(const NodeDef& node);
|
||||||
bool IsEqual(const NodeDef& node);
|
bool IsEqual(const NodeDef& node);
|
||||||
@ -134,6 +135,7 @@ bool IsReciprocalGrad(const NodeDef& node);
|
|||||||
bool IsRecv(const NodeDef& node);
|
bool IsRecv(const NodeDef& node);
|
||||||
bool IsReduction(const NodeDef& node);
|
bool IsReduction(const NodeDef& node);
|
||||||
bool IsRelu(const NodeDef& node);
|
bool IsRelu(const NodeDef& node);
|
||||||
|
bool IsRelu6(const NodeDef& node);
|
||||||
bool IsRelu6Grad(const NodeDef& node);
|
bool IsRelu6Grad(const NodeDef& node);
|
||||||
bool IsReluGrad(const NodeDef& node);
|
bool IsReluGrad(const NodeDef& node);
|
||||||
bool IsReshape(const NodeDef& node);
|
bool IsReshape(const NodeDef& node);
|
||||||
|
@ -78,16 +78,16 @@ struct Conv2DWithBiasAdd {
|
|||||||
const NodeDef* bias_add = nullptr;
|
const NodeDef* bias_add = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Conv2D node followed by a BiasAdd and Relu.
|
// Conv2D node followed by a BiasAdd and Activation.
|
||||||
struct Conv2DWithBiasAddAndRelu {
|
struct Conv2DWithBiasAddAndActivation {
|
||||||
Conv2DWithBiasAddAndRelu() = default;
|
Conv2DWithBiasAddAndActivation() = default;
|
||||||
Conv2DWithBiasAddAndRelu(const NodeDef* conv2d, const NodeDef* bias_add,
|
Conv2DWithBiasAddAndActivation(const NodeDef* conv2d, const NodeDef* bias_add,
|
||||||
const NodeDef* relu)
|
const NodeDef* activation)
|
||||||
: conv2d(conv2d), bias_add(bias_add), relu(relu) {}
|
: conv2d(conv2d), bias_add(bias_add), activation(activation) {}
|
||||||
|
|
||||||
const NodeDef* conv2d = nullptr;
|
const NodeDef* conv2d = nullptr;
|
||||||
const NodeDef* bias_add = nullptr;
|
const NodeDef* bias_add = nullptr;
|
||||||
const NodeDef* relu = nullptr;
|
const NodeDef* activation = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Conv2D node followed by a Squeeze and BiasAdd.
|
// Conv2D node followed by a Squeeze and BiasAdd.
|
||||||
@ -114,20 +114,21 @@ struct Conv2DWithBatchNorm {
|
|||||||
float epsilon = 0.0;
|
float epsilon = 0.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Conv2D node followed by a FusedBatchNorm and Relu.
|
// Conv2D node followed by a FusedBatchNorm and Activation.
|
||||||
struct Conv2DWithBatchNormAndRelu {
|
struct Conv2DWithBatchNormAndActivation {
|
||||||
Conv2DWithBatchNormAndRelu() = default;
|
Conv2DWithBatchNormAndActivation() = default;
|
||||||
Conv2DWithBatchNormAndRelu(const NodeDef* conv2d,
|
Conv2DWithBatchNormAndActivation(const NodeDef* conv2d,
|
||||||
const NodeDef* fused_batch_norm,
|
const NodeDef* fused_batch_norm,
|
||||||
const NodeDef* relu, float epsilon = 0.0)
|
const NodeDef* activation,
|
||||||
|
float epsilon = 0.0)
|
||||||
: conv2d(conv2d),
|
: conv2d(conv2d),
|
||||||
fused_batch_norm(fused_batch_norm),
|
fused_batch_norm(fused_batch_norm),
|
||||||
relu(relu),
|
activation(activation),
|
||||||
epsilon(epsilon) {}
|
epsilon(epsilon) {}
|
||||||
|
|
||||||
const NodeDef* conv2d = nullptr;
|
const NodeDef* conv2d = nullptr;
|
||||||
const NodeDef* fused_batch_norm = nullptr;
|
const NodeDef* fused_batch_norm = nullptr;
|
||||||
const NodeDef* relu = nullptr;
|
const NodeDef* activation = nullptr;
|
||||||
float epsilon = 0.0;
|
float epsilon = 0.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -194,7 +195,7 @@ bool IsCpuCompatible(const Pattern& matched) {
|
|||||||
|
|
||||||
// Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
|
// Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
|
||||||
bool IsGpuCompatible(const RemapperContext& ctx,
|
bool IsGpuCompatible(const RemapperContext& ctx,
|
||||||
const Conv2DWithBiasAddAndRelu& matched) {
|
const Conv2DWithBiasAddAndActivation& matched) {
|
||||||
const std::vector<OpInfo::TensorProperties>& input_props =
|
const std::vector<OpInfo::TensorProperties>& input_props =
|
||||||
ctx.graph_properties.GetInputProperties(matched.conv2d->name());
|
ctx.graph_properties.GetInputProperties(matched.conv2d->name());
|
||||||
const TensorShapeProto& filter_shape =
|
const TensorShapeProto& filter_shape =
|
||||||
@ -226,6 +227,10 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
|
|||||||
return IsCpuCompatible(matched) || IsGpuCompatible(ctx, matched);
|
return IsCpuCompatible(matched) || IsGpuCompatible(ctx, matched);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsSupportedActivation(const NodeDef& node) {
|
||||||
|
return IsRelu(node) || IsRelu6(node) || IsElu(node);
|
||||||
|
}
|
||||||
|
|
||||||
bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add,
|
bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add,
|
||||||
Conv2DWithBiasAdd* matched,
|
Conv2DWithBiasAdd* matched,
|
||||||
bool check_device_compatible = true) {
|
bool check_device_compatible = true) {
|
||||||
@ -259,31 +264,34 @@ bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* relu,
|
bool FindConv2DWithBiasAndActivation(const RemapperContext& ctx,
|
||||||
Conv2DWithBiasAddAndRelu* matched) {
|
const NodeDef* activation,
|
||||||
|
Conv2DWithBiasAddAndActivation* matched) {
|
||||||
if (!EigenSupportsContractionOutputKernel()) return false;
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
||||||
|
|
||||||
// Root of the pattern must be a Relu.
|
// Root of the pattern must be an activation node.
|
||||||
if (!relu || !IsRelu(*relu) || HasControlFaninOrFanout(ctx.graph_view, relu))
|
if (!activation || !IsSupportedActivation(*activation) ||
|
||||||
|
HasControlFaninOrFanout(ctx.graph_view, activation))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// And input to Relu must match Conv2DWithBiasAdd pattern.
|
// And input to the activation node must match Conv2DWithBiasAdd pattern.
|
||||||
const auto input_port = GraphView::InputPort(relu, 0);
|
const auto input_port = GraphView::InputPort(activation, 0);
|
||||||
const auto bias_add = ctx.graph_view.GetRegularFanin(input_port);
|
const auto bias_add = ctx.graph_view.GetRegularFanin(input_port);
|
||||||
|
|
||||||
Conv2DWithBiasAdd base;
|
Conv2DWithBiasAdd base;
|
||||||
if (!FindConv2DWithBias(ctx, bias_add.node, &base,
|
if (!FindConv2DWithBias(ctx, bias_add.node, &base,
|
||||||
/*check_device_compatible=*/false) ||
|
/*check_device_compatible=*/false) ||
|
||||||
!HasSingleFanoutNode(ctx.graph_view, base.bias_add) ||
|
!HasSingleFanoutNode(ctx.graph_view, base.bias_add) ||
|
||||||
!HaveSameDataType(relu, base.bias_add) ||
|
!HaveSameDataType(activation, base.bias_add) ||
|
||||||
IsInPreserveSet(ctx, base.bias_add))
|
IsInPreserveSet(ctx, base.bias_add))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check that data type and data format are supported on assigned device.
|
// Check that data type and data format are supported on assigned device.
|
||||||
const Conv2DWithBiasAddAndRelu pattern{base.conv2d, base.bias_add, relu};
|
const Conv2DWithBiasAddAndActivation pattern{base.conv2d, base.bias_add,
|
||||||
|
activation};
|
||||||
if (!IsDeviceCompatible(ctx, pattern)) return false;
|
if (!IsDeviceCompatible(ctx, pattern)) return false;
|
||||||
|
|
||||||
// We successfully found a Conv2D+BiasAdd+Relu pattern.
|
// We successfully found a Conv2D+BiasAdd+Activation pattern.
|
||||||
*matched = pattern;
|
*matched = pattern;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -387,16 +395,17 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx,
|
bool FindConv2DWithBatchNormAndActivation(
|
||||||
const NodeDef* node,
|
const RemapperContext& ctx, const NodeDef* node,
|
||||||
Conv2DWithBatchNormAndRelu* matched) {
|
Conv2DWithBatchNormAndActivation* matched) {
|
||||||
if (!EigenSupportsContractionOutputKernel()) return false;
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
||||||
|
|
||||||
// Root of the pattern must be a Relu.
|
// Root of the pattern must be an activation node.
|
||||||
if (!node || !IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
|
if (!node || !IsSupportedActivation(*node) ||
|
||||||
|
HasControlFaninOrFanout(ctx.graph_view, node))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// And input to Relu must match Conv2DWithBatchNorm pattern.
|
// And input to the activation node must match Conv2DWithBatchNorm pattern.
|
||||||
const auto input_port = GraphView::InputPort(node, 0);
|
const auto input_port = GraphView::InputPort(node, 0);
|
||||||
const auto batch_norm = ctx.graph_view.GetRegularFanin(input_port);
|
const auto batch_norm = ctx.graph_view.GetRegularFanin(input_port);
|
||||||
|
|
||||||
@ -407,10 +416,10 @@ bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx,
|
|||||||
IsInPreserveSet(ctx, base.fused_batch_norm))
|
IsInPreserveSet(ctx, base.fused_batch_norm))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// We successfully found a Conv2D+FusedBatchNorm+Relu pattern.
|
// We successfully found a Conv2D+FusedBatchNorm+Activation pattern.
|
||||||
matched->conv2d = base.conv2d;
|
matched->conv2d = base.conv2d;
|
||||||
matched->fused_batch_norm = base.fused_batch_norm;
|
matched->fused_batch_norm = base.fused_batch_norm;
|
||||||
matched->relu = node;
|
matched->activation = node;
|
||||||
matched->epsilon = base.epsilon;
|
matched->epsilon = base.epsilon;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -511,19 +520,19 @@ void AddFusedConv2DNode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AddFusedConv2DNode(
|
void AddFusedConv2DNode(
|
||||||
const RemapperContext& ctx, const Conv2DWithBiasAddAndRelu& matched,
|
const RemapperContext& ctx, const Conv2DWithBiasAddAndActivation& matched,
|
||||||
GraphDef* optimized_graph,
|
GraphDef* optimized_graph,
|
||||||
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
||||||
DCHECK(IsDeviceCompatible(ctx, matched))
|
DCHECK(IsDeviceCompatible(ctx, matched))
|
||||||
<< "Unsupported fused Conv2D pattern";
|
<< "Unsupported fused Conv2D pattern";
|
||||||
|
|
||||||
VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: "
|
VLOG(2) << "Fuse Conv2D with BiasAdd and " << matched.activation->op() << ":"
|
||||||
<< " relu=" << matched.relu->name()
|
<< " activation=" << matched.activation->name()
|
||||||
<< " bias_add=" << matched.bias_add->name()
|
<< " bias_add=" << matched.bias_add->name()
|
||||||
<< " conv2d=" << matched.conv2d->name();
|
<< " conv2d=" << matched.conv2d->name();
|
||||||
|
|
||||||
NodeDef* fused_conv2d = optimized_graph->add_node();
|
NodeDef* fused_conv2d = optimized_graph->add_node();
|
||||||
fused_conv2d->set_name(matched.relu->name());
|
fused_conv2d->set_name(matched.activation->name());
|
||||||
fused_conv2d->set_op(kFusedConv2D);
|
fused_conv2d->set_op(kFusedConv2D);
|
||||||
fused_conv2d->set_device(matched.conv2d->device());
|
fused_conv2d->set_device(matched.conv2d->device());
|
||||||
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
||||||
@ -531,9 +540,9 @@ void AddFusedConv2DNode(
|
|||||||
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
|
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
|
||||||
|
|
||||||
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
||||||
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", "Relu"});
|
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", matched.activation->op()});
|
||||||
|
|
||||||
invalidated_nodes->insert(matched.relu);
|
invalidated_nodes->insert(matched.activation);
|
||||||
invalidated_nodes->insert(matched.bias_add);
|
invalidated_nodes->insert(matched.bias_add);
|
||||||
invalidated_nodes->insert(matched.conv2d);
|
invalidated_nodes->insert(matched.conv2d);
|
||||||
}
|
}
|
||||||
@ -601,15 +610,15 @@ void AddFusedConv2DNode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AddFusedConv2DNode(
|
void AddFusedConv2DNode(
|
||||||
const Conv2DWithBatchNormAndRelu& matched, GraphDef* optimized_graph,
|
const Conv2DWithBatchNormAndActivation& matched, GraphDef* optimized_graph,
|
||||||
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
||||||
VLOG(2) << "Fuse Conv2D with BatchNorm and Relu: relu="
|
VLOG(2) << "Fuse Conv2D with BatchNorm and " << matched.activation->op()
|
||||||
<< matched.relu->name()
|
<< ": activation=" << matched.activation->name()
|
||||||
<< " batch_norm=" << matched.fused_batch_norm->name()
|
<< " batch_norm=" << matched.fused_batch_norm->name()
|
||||||
<< " conv2d=" << matched.conv2d->name();
|
<< " conv2d=" << matched.conv2d->name();
|
||||||
|
|
||||||
NodeDef* fused_conv2d = optimized_graph->add_node();
|
NodeDef* fused_conv2d = optimized_graph->add_node();
|
||||||
fused_conv2d->set_name(matched.relu->name());
|
fused_conv2d->set_name(matched.activation->name());
|
||||||
fused_conv2d->set_op(kFusedConv2D);
|
fused_conv2d->set_op(kFusedConv2D);
|
||||||
fused_conv2d->set_device(matched.conv2d->device());
|
fused_conv2d->set_device(matched.conv2d->device());
|
||||||
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
||||||
@ -620,10 +629,11 @@ void AddFusedConv2DNode(
|
|||||||
fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance
|
fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance
|
||||||
|
|
||||||
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
||||||
SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm", "Relu"},
|
SetFusedConv2DAttributes(fused_conv2d,
|
||||||
|
{"FusedBatchNorm", matched.activation->op()},
|
||||||
/*num_args=*/4, /*epsilon=*/matched.epsilon);
|
/*num_args=*/4, /*epsilon=*/matched.epsilon);
|
||||||
|
|
||||||
invalidated_nodes->insert(matched.relu);
|
invalidated_nodes->insert(matched.activation);
|
||||||
invalidated_nodes->insert(matched.fused_batch_norm);
|
invalidated_nodes->insert(matched.fused_batch_norm);
|
||||||
invalidated_nodes->insert(matched.conv2d);
|
invalidated_nodes->insert(matched.conv2d);
|
||||||
}
|
}
|
||||||
@ -778,9 +788,9 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|||||||
// clang-format off
|
// clang-format off
|
||||||
FusedBatchNorm fused_batch_norm;
|
FusedBatchNorm fused_batch_norm;
|
||||||
Conv2DWithBiasAdd conv2d_with_bias;
|
Conv2DWithBiasAdd conv2d_with_bias;
|
||||||
Conv2DWithBiasAddAndRelu conv2d_with_bias_and_relu;
|
Conv2DWithBiasAddAndActivation conv2d_with_bias_and_activation;
|
||||||
Conv2DWithBatchNorm conv2d_with_batch_norm;
|
Conv2DWithBatchNorm conv2d_with_batch_norm;
|
||||||
Conv2DWithBatchNormAndRelu conv2d_with_batch_norm_and_relu;
|
Conv2DWithBatchNormAndActivation conv2d_with_batch_norm_and_activation;
|
||||||
Conv2DWithSqueezeAndBiasAdd conv2d_with_squeeze_and_bias;
|
Conv2DWithSqueezeAndBiasAdd conv2d_with_squeeze_and_bias;
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
@ -795,7 +805,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|||||||
RemapperContext ctx(topo_sorted_item);
|
RemapperContext ctx(topo_sorted_item);
|
||||||
|
|
||||||
// Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd
|
// Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd
|
||||||
// and Relu nodes that were fused into a Conv2D node.
|
// and Activation nodes that were fused into a Conv2D node.
|
||||||
absl::flat_hash_set<const NodeDef*> invalidated_nodes;
|
absl::flat_hash_set<const NodeDef*> invalidated_nodes;
|
||||||
|
|
||||||
optimized_graph->mutable_node()->Reserve(topo_sorted_item.graph.node_size());
|
optimized_graph->mutable_node()->Reserve(topo_sorted_item.graph.node_size());
|
||||||
@ -810,9 +820,10 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remap Conv2D+BiasAdd+Relu into the _FusedConv2D.
|
// Remap Conv2D+BiasAdd+Activation into the _FusedConv2D.
|
||||||
if (FindConv2DWithBiasAndRelu(ctx, &node, &conv2d_with_bias_and_relu)) {
|
if (FindConv2DWithBiasAndActivation(ctx, &node,
|
||||||
AddFusedConv2DNode(ctx, conv2d_with_bias_and_relu, optimized_graph,
|
&conv2d_with_bias_and_activation)) {
|
||||||
|
AddFusedConv2DNode(ctx, conv2d_with_bias_and_activation, optimized_graph,
|
||||||
&invalidated_nodes);
|
&invalidated_nodes);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -835,10 +846,10 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remap Conv2D+FusedBatchNorm+Relu into the _FusedConv2D;
|
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
|
||||||
if (FindConv2DWithBatchNormAndRelu(ctx, &node,
|
if (FindConv2DWithBatchNormAndActivation(
|
||||||
&conv2d_with_batch_norm_and_relu)) {
|
ctx, &node, &conv2d_with_batch_norm_and_activation)) {
|
||||||
AddFusedConv2DNode(conv2d_with_batch_norm_and_relu, optimized_graph,
|
AddFusedConv2DNode(conv2d_with_batch_norm_and_activation, optimized_graph,
|
||||||
&invalidated_nodes);
|
&invalidated_nodes);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||||
|
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/grappler/devices.h"
|
#include "tensorflow/core/grappler/devices.h"
|
||||||
@ -97,7 +98,7 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
|
|||||||
EXPECT_EQ(1, tensors_expected.size());
|
EXPECT_EQ(1, tensors_expected.size());
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
EXPECT_EQ(1, tensors.size());
|
EXPECT_EQ(1, tensors.size());
|
||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,69 +165,85 @@ TEST_F(RemapperTest, FuseConv2DWithBias) {
|
|||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithBiasAndRelu) {
|
TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||||
if (!EigenSupportsContractionOutputKernel()) return;
|
if (!EigenSupportsContractionOutputKernel()) return;
|
||||||
|
|
||||||
using ::tensorflow::ops::Placeholder;
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||||
auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
|
auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
|
||||||
auto bias_shape = Placeholder::Shape({128});
|
auto bias_shape = Placeholder::Shape({128});
|
||||||
|
|
||||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||||
|
|
||||||
std::vector<int> strides = {1, 1, 1, 1};
|
std::vector<int> strides = {1, 1, 1, 1};
|
||||||
auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
|
auto conv =
|
||||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
|
||||||
auto relu = ops::Relu(s.WithOpName("relu"), bias_add);
|
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);
|
||||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
|
||||||
|
|
||||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
ops::Identity fetch = [&]() -> ops::Identity {
|
||||||
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
|
auto activate = s.WithOpName("activation");
|
||||||
auto bias_t = GenerateRandomTensor<DT_FLOAT>({128});
|
auto fetch = s.WithOpName("fetch");
|
||||||
|
|
||||||
GrapplerItem item;
|
if (activation == "Relu") {
|
||||||
item.fetch = {"fetch"};
|
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||||
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
} else if (activation == "Relu6") {
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||||
|
} else if (activation == "Elu") {
|
||||||
|
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||||
|
}
|
||||||
|
|
||||||
// Place all nodes on CPU.
|
return ops::Identity(fetch, bias);
|
||||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
}();
|
||||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
|
||||||
}
|
|
||||||
|
|
||||||
Remapper optimizer(RewriterConfig::ON);
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
||||||
GraphDef output;
|
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
|
||||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
auto bias_t = GenerateRandomTensor<DT_FLOAT>({128});
|
||||||
|
|
||||||
int found = 0;
|
GrapplerItem item;
|
||||||
for (const NodeDef& node : output.node()) {
|
item.fetch = {"fetch"};
|
||||||
if (node.name() == "relu") {
|
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
||||||
EXPECT_EQ("_FusedConv2D", node.op());
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
EXPECT_EQ("input", node.input(0));
|
|
||||||
EXPECT_EQ("filter", node.input(1));
|
|
||||||
|
|
||||||
EXPECT_EQ(1, node.attr().at("num_args").i());
|
// Place all nodes on CPU.
|
||||||
EXPECT_EQ("bias", node.input(2));
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
|
||||||
ASSERT_EQ(2, fused_ops.size());
|
|
||||||
EXPECT_EQ("BiasAdd", fused_ops[0]);
|
|
||||||
EXPECT_EQ("Relu", fused_ops[1]);
|
|
||||||
found++;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
EXPECT_EQ(1, found);
|
|
||||||
|
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
Remapper optimizer(RewriterConfig::ON);
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
GraphDef output;
|
||||||
EXPECT_EQ(1, tensors_expected.size());
|
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
EXPECT_EQ(1, tensors.size());
|
|
||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "activation") {
|
||||||
|
EXPECT_EQ("_FusedConv2D", node.op());
|
||||||
|
EXPECT_EQ("input", node.input(0));
|
||||||
|
EXPECT_EQ("filter", node.input(1));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node.attr().at("num_args").i());
|
||||||
|
EXPECT_EQ("bias", node.input(2));
|
||||||
|
|
||||||
|
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||||
|
ASSERT_EQ(2, fused_ops.size());
|
||||||
|
EXPECT_EQ("BiasAdd", fused_ops[0]);
|
||||||
|
EXPECT_EQ(activation, fused_ops[1]);
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(1, found);
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
EXPECT_EQ(1, tensors_expected.size());
|
||||||
|
EXPECT_EQ(1, tensors.size());
|
||||||
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithBatchNorm) {
|
TEST_F(RemapperTest, FuseConv2DWithBatchNorm) {
|
||||||
@ -306,83 +323,100 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNorm) {
|
|||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithBatchNormAndRelu) {
|
TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
||||||
if (!EigenSupportsContractionOutputKernel()) return;
|
if (!EigenSupportsContractionOutputKernel()) return;
|
||||||
|
|
||||||
using ops::Placeholder;
|
using ops::Placeholder;
|
||||||
|
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
|
auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
|
||||||
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128});
|
auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128});
|
||||||
auto scale_shape = ops::Placeholder::Shape({128});
|
auto scale_shape = ops::Placeholder::Shape({128});
|
||||||
|
|
||||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, scale_shape);
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, scale_shape);
|
||||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, scale_shape);
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, scale_shape);
|
||||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, scale_shape);
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, scale_shape);
|
||||||
auto variance = Placeholder(s.WithOpName("variance"), DT_FLOAT, scale_shape);
|
auto variance =
|
||||||
|
Placeholder(s.WithOpName("variance"), DT_FLOAT, scale_shape);
|
||||||
|
|
||||||
std::vector<int> strides = {1, 1, 1, 1};
|
std::vector<int> strides = {1, 1, 1, 1};
|
||||||
auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
|
auto conv =
|
||||||
ops::FusedBatchNorm::Attrs attrs;
|
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
|
||||||
attrs = attrs.IsTraining(false);
|
ops::FusedBatchNorm::Attrs attrs;
|
||||||
auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv, scale,
|
attrs = attrs.IsTraining(false);
|
||||||
offset, mean, variance, attrs);
|
auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv,
|
||||||
auto relu = ops::Relu(s.WithOpName("relu"), batch_norm.y);
|
scale, offset, mean, variance, attrs);
|
||||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
|
||||||
|
|
||||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
ops::Identity fetch = [&]() -> ops::Identity {
|
||||||
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
|
auto activate = s.WithOpName("activation");
|
||||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({128});
|
auto fetch = s.WithOpName("fetch");
|
||||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({128});
|
|
||||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({128});
|
|
||||||
auto variance_t = GenerateRandomTensor<DT_FLOAT>({128});
|
|
||||||
|
|
||||||
GrapplerItem item;
|
if (activation == "Relu") {
|
||||||
item.fetch = {"fetch"};
|
return ops::Identity(fetch, ops::Relu(activate, batch_norm.y));
|
||||||
item.feed = {{"input", input_t}, {"filter", filter_t},
|
} else if (activation == "Relu6") {
|
||||||
{"scale", scale_t}, {"offset", offset_t},
|
return ops::Identity(fetch, ops::Relu6(activate, batch_norm.y));
|
||||||
{"mean", mean_t}, {"variance", variance_t}};
|
} else if (activation == "Elu") {
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
return ops::Identity(fetch, ops::Elu(activate, batch_norm.y));
|
||||||
|
}
|
||||||
|
|
||||||
// Place all nodes on CPU.
|
return ops::Identity(fetch, batch_norm.y);
|
||||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
}();
|
||||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
|
||||||
}
|
|
||||||
|
|
||||||
Remapper optimizer(RewriterConfig::ON);
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});
|
||||||
GraphDef output;
|
auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 128});
|
||||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>({128});
|
||||||
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>({128});
|
||||||
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>({128});
|
||||||
|
auto variance_t = GenerateRandomTensor<DT_FLOAT>({128});
|
||||||
|
|
||||||
int found = 0;
|
GrapplerItem item;
|
||||||
for (const NodeDef& node : output.node()) {
|
item.fetch = {"fetch"};
|
||||||
if (node.name() == "relu") {
|
item.feed = {{"input", input_t}, {"filter", filter_t},
|
||||||
EXPECT_EQ("_FusedConv2D", node.op());
|
{"scale", scale_t}, {"offset", offset_t},
|
||||||
EXPECT_EQ("input", node.input(0));
|
{"mean", mean_t}, {"variance", variance_t}};
|
||||||
EXPECT_EQ("filter", node.input(1));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
EXPECT_EQ(4, node.attr().at("num_args").i());
|
// Place all nodes on CPU.
|
||||||
EXPECT_EQ("scale", node.input(2));
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
EXPECT_EQ("offset", node.input(3));
|
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||||
EXPECT_EQ("mean", node.input(4));
|
|
||||||
EXPECT_EQ("variance", node.input(5));
|
|
||||||
|
|
||||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
|
||||||
EXPECT_EQ(2, fused_ops.size());
|
|
||||||
EXPECT_EQ("FusedBatchNorm", fused_ops[0]);
|
|
||||||
EXPECT_EQ("Relu", fused_ops[1]);
|
|
||||||
found++;
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
EXPECT_EQ(1, found);
|
|
||||||
|
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
Remapper optimizer(RewriterConfig::ON);
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
GraphDef output;
|
||||||
EXPECT_EQ(1, tensors_expected.size());
|
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
EXPECT_EQ(1, tensors.size());
|
|
||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "activation") {
|
||||||
|
EXPECT_EQ("_FusedConv2D", node.op());
|
||||||
|
EXPECT_EQ("input", node.input(0));
|
||||||
|
EXPECT_EQ("filter", node.input(1));
|
||||||
|
|
||||||
|
EXPECT_EQ(4, node.attr().at("num_args").i());
|
||||||
|
EXPECT_EQ("scale", node.input(2));
|
||||||
|
EXPECT_EQ("offset", node.input(3));
|
||||||
|
EXPECT_EQ("mean", node.input(4));
|
||||||
|
EXPECT_EQ("variance", node.input(5));
|
||||||
|
|
||||||
|
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||||
|
EXPECT_EQ(2, fused_ops.size());
|
||||||
|
EXPECT_EQ("FusedBatchNorm", fused_ops[0]);
|
||||||
|
EXPECT_EQ(activation, fused_ops[1]);
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(1, found);
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
EXPECT_EQ(1, tensors_expected.size());
|
||||||
|
EXPECT_EQ(1, tensors.size());
|
||||||
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
|
TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
|
||||||
|
@ -88,6 +88,7 @@ GrapplerTest::GrapplerTest() {
|
|||||||
cfg->set_layout_optimizer(RewriterConfig::OFF);
|
cfg->set_layout_optimizer(RewriterConfig::OFF);
|
||||||
cfg->set_loop_optimization(RewriterConfig::OFF);
|
cfg->set_loop_optimization(RewriterConfig::OFF);
|
||||||
cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
|
cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||||
|
cfg->set_remapping(RewriterConfig::OFF);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Tensor> GrapplerTest::EvaluateNodes(
|
std::vector<Tensor> GrapplerTest::EvaluateNodes(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user