|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|
|
|
|
#include "tensorflow/core/grappler/op_types.h"
|
|
|
|
|
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
|
|
|
|
#include "tensorflow/core/grappler/utils.h"
|
|
|
|
|
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
|
|
|
|
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
|
|
|
|
#include "tensorflow/core/platform/logging.h"
|
|
|
|
|
|
|
|
|
@ -60,17 +61,30 @@ struct RemapperContext {
|
|
|
|
|
|
|
|
|
|
// FusedBatchNorm that can be replaced with a cheaper set of primitives.
|
|
|
|
|
struct FusedBatchNorm {
|
|
|
|
|
FusedBatchNorm() = default;
|
|
|
|
|
explicit FusedBatchNorm(const NodeDef* fused_batch_norm)
|
|
|
|
|
: fused_batch_norm(fused_batch_norm) {}
|
|
|
|
|
|
|
|
|
|
const NodeDef* fused_batch_norm = nullptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Conv2D node followed by a BiasAdd.
|
|
|
|
|
struct Conv2DWithBiasAdd {
|
|
|
|
|
Conv2DWithBiasAdd() = default;
|
|
|
|
|
Conv2DWithBiasAdd(const NodeDef* conv2d, const NodeDef* bias_add)
|
|
|
|
|
: conv2d(conv2d), bias_add(bias_add) {}
|
|
|
|
|
|
|
|
|
|
const NodeDef* conv2d = nullptr;
|
|
|
|
|
const NodeDef* bias_add = nullptr;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Conv2D node followed by a BiasAdd and Relu.
|
|
|
|
|
struct Conv2DWithBiasAddAndRelu {
|
|
|
|
|
Conv2DWithBiasAddAndRelu() = default;
|
|
|
|
|
Conv2DWithBiasAddAndRelu(const NodeDef* conv2d, const NodeDef* bias_add,
|
|
|
|
|
const NodeDef* relu)
|
|
|
|
|
: conv2d(conv2d), bias_add(bias_add), relu(relu) {}
|
|
|
|
|
|
|
|
|
|
const NodeDef* conv2d = nullptr;
|
|
|
|
|
const NodeDef* bias_add = nullptr;
|
|
|
|
|
const NodeDef* relu = nullptr;
|
|
|
|
@ -78,6 +92,11 @@ struct Conv2DWithBiasAddAndRelu {
|
|
|
|
|
|
|
|
|
|
// Conv2D node followed by a Squeeze and BiasAdd.
|
|
|
|
|
struct Conv2DWithSqueezeAndBiasAdd {
|
|
|
|
|
Conv2DWithSqueezeAndBiasAdd() = default;
|
|
|
|
|
Conv2DWithSqueezeAndBiasAdd(const NodeDef* conv2d, const NodeDef* squeeze,
|
|
|
|
|
const NodeDef* bias_add)
|
|
|
|
|
: conv2d(conv2d), squeeze(squeeze), bias_add(bias_add) {}
|
|
|
|
|
|
|
|
|
|
const NodeDef* conv2d = nullptr;
|
|
|
|
|
const NodeDef* squeeze = nullptr;
|
|
|
|
|
const NodeDef* bias_add = nullptr;
|
|
|
|
@ -85,6 +104,11 @@ struct Conv2DWithSqueezeAndBiasAdd {
|
|
|
|
|
|
|
|
|
|
// Conv2D node followed by a FusedBatchNorm.
|
|
|
|
|
struct Conv2DWithBatchNorm {
|
|
|
|
|
Conv2DWithBatchNorm() = default;
|
|
|
|
|
Conv2DWithBatchNorm(const NodeDef* conv2d, const NodeDef* fused_batch_norm,
|
|
|
|
|
float epsilon = 0.0)
|
|
|
|
|
: conv2d(conv2d), fused_batch_norm(fused_batch_norm), epsilon(epsilon) {}
|
|
|
|
|
|
|
|
|
|
const NodeDef* conv2d = nullptr;
|
|
|
|
|
const NodeDef* fused_batch_norm = nullptr;
|
|
|
|
|
float epsilon = 0.0;
|
|
|
|
@ -92,16 +116,23 @@ struct Conv2DWithBatchNorm {
|
|
|
|
|
|
|
|
|
|
// Conv2D node followed by a FusedBatchNorm and Relu.
|
|
|
|
|
struct Conv2DWithBatchNormAndRelu {
|
|
|
|
|
Conv2DWithBatchNormAndRelu() = default;
|
|
|
|
|
Conv2DWithBatchNormAndRelu(const NodeDef* conv2d,
|
|
|
|
|
const NodeDef* fused_batch_norm,
|
|
|
|
|
const NodeDef* relu, float epsilon = 0.0)
|
|
|
|
|
: conv2d(conv2d),
|
|
|
|
|
fused_batch_norm(fused_batch_norm),
|
|
|
|
|
relu(relu),
|
|
|
|
|
epsilon(epsilon) {}
|
|
|
|
|
|
|
|
|
|
const NodeDef* conv2d = nullptr;
|
|
|
|
|
const NodeDef* fused_batch_norm = nullptr;
|
|
|
|
|
const NodeDef* relu = nullptr;
|
|
|
|
|
float epsilon = 0.0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
bool IsFloatOrDoubleDataType(const NodeDef* node,
|
|
|
|
|
const string& type_attr = "T") {
|
|
|
|
|
DataType dtype = GetDataTypeFromAttr(*node, type_attr);
|
|
|
|
|
return dtype == DT_FLOAT || dtype == DT_DOUBLE;
|
|
|
|
|
bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
|
|
|
|
|
return ctx.nodes_to_preserve.count(node->name()) > 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HaveSameDataType(const NodeDef* lhs, const NodeDef* rhs,
|
|
|
|
@ -119,91 +150,165 @@ bool HasDataType(const NodeDef* node, const DataType& expected,
|
|
|
|
|
return dtype == expected;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
|
|
|
|
|
return ctx.nodes_to_preserve.count(node->name()) > 0;
|
|
|
|
|
bool IsCpuCompatibleDataType(const NodeDef* node,
|
|
|
|
|
const string& type_attr = "T") {
|
|
|
|
|
DataType dtype = GetDataTypeFromAttr(*node, type_attr);
|
|
|
|
|
return dtype == DT_FLOAT || dtype == DT_DOUBLE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* node,
|
|
|
|
|
Conv2DWithBiasAdd* matched) {
|
|
|
|
|
bool IsGpuCompatibleDataType(const NodeDef* node,
|
|
|
|
|
const string& type_attr = "T") {
|
|
|
|
|
DataType dtype = GetDataTypeFromAttr(*node, type_attr);
|
|
|
|
|
return dtype == DT_FLOAT;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsCpuCompatibleDataFormat(const NodeDef* conv2d) {
|
|
|
|
|
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
|
|
|
|
|
const string& data_format = conv2d->attr().at(kDataFormat).s();
|
|
|
|
|
return data_format == "NHWC";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsGpuCompatibleDataFormat(const NodeDef* conv2d) {
|
|
|
|
|
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
|
|
|
|
|
const string& data_format = conv2d->attr().at(kDataFormat).s();
|
|
|
|
|
return data_format == "NHWC" || data_format == "NCHW";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsCpuCompatibleConv2D(const NodeDef* conv2d) {
|
|
|
|
|
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
|
|
|
|
|
return NodeIsOnCpu(conv2d) && IsCpuCompatibleDataType(conv2d) &&
|
|
|
|
|
IsCpuCompatibleDataFormat(conv2d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsGpuCompatibleConv2D(const NodeDef* conv2d) {
|
|
|
|
|
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
|
|
|
|
|
return NodeIsOnGpu(conv2d) && IsGpuCompatibleDataType(conv2d) &&
|
|
|
|
|
IsGpuCompatibleDataFormat(conv2d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checks if we can rewrite a pattern to the `_FusedConv2D` on CPU device.
|
|
|
|
|
template <typename Pattern>
|
|
|
|
|
bool IsCpuCompatible(const Pattern& matched) {
|
|
|
|
|
return IsCpuCompatibleConv2D(matched.conv2d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
|
|
|
|
|
bool IsGpuCompatible(const RemapperContext& ctx,
|
|
|
|
|
const Conv2DWithBiasAddAndRelu& matched) {
|
|
|
|
|
const std::vector<OpInfo::TensorProperties>& input_props =
|
|
|
|
|
ctx.graph_properties.GetInputProperties(matched.conv2d->name());
|
|
|
|
|
const TensorShapeProto& filter_shape =
|
|
|
|
|
input_props.size() >= 2 ? input_props[1].shape() : TensorShapeProto();
|
|
|
|
|
|
|
|
|
|
// FusedConv2D on GPU with 1x1 convolution is marginally faster than
|
|
|
|
|
// in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc),
|
|
|
|
|
// and significantly slower in large scale benchmarks.
|
|
|
|
|
bool is_spatial_conv = Rank(filter_shape) == 4 && //
|
|
|
|
|
IsKnown(filter_shape.dim(1)) && //
|
|
|
|
|
IsKnown(filter_shape.dim(2)) && //
|
|
|
|
|
filter_shape.dim(1).size() != 1 && //
|
|
|
|
|
filter_shape.dim(2).size() != 1;
|
|
|
|
|
|
|
|
|
|
return is_spatial_conv && IsGpuCompatibleConv2D(matched.conv2d);
|
|
|
|
|
}
|
|
|
|
|
bool IsGpuCompatible(const RemapperContext& ctx,
|
|
|
|
|
const Conv2DWithBiasAdd& matched) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
bool IsGpuCompatible(const RemapperContext& ctx,
|
|
|
|
|
const Conv2DWithSqueezeAndBiasAdd& matched) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Returns true if the given pattern is supported on the assigned device.
|
|
|
|
|
template <typename Pattern>
|
|
|
|
|
bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
|
|
|
|
|
return IsCpuCompatible(matched) || IsGpuCompatible(ctx, matched);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add,
|
|
|
|
|
Conv2DWithBiasAdd* matched,
|
|
|
|
|
bool check_device_compatible = true) {
|
|
|
|
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
|
|
|
|
|
|
|
|
|
// Root of the pattern must be a BiasAdd.
|
|
|
|
|
if (!node) return false;
|
|
|
|
|
if (!IsBiasAdd(*node)) return false;
|
|
|
|
|
if (!NodeIsOnCpu(node)) return false;
|
|
|
|
|
if (!IsFloatOrDoubleDataType(node)) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
|
|
|
|
|
if (bias_add == nullptr || !IsBiasAdd(*bias_add) ||
|
|
|
|
|
HasControlFaninOrFanout(ctx.graph_view, bias_add))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Input to the BiasAdd must be a Conv2D in NHWC format.
|
|
|
|
|
const auto input_port = GraphView::InputPort(node, 0);
|
|
|
|
|
// Input to the BiasAdd must be a Conv2D.
|
|
|
|
|
const auto input_port = GraphView::InputPort(bias_add, 0);
|
|
|
|
|
const auto conv2d = ctx.graph_view.GetRegularFanin(input_port);
|
|
|
|
|
if (!conv2d.node) return false;
|
|
|
|
|
if (!IsConv2D(*conv2d.node)) return false;
|
|
|
|
|
if (conv2d.node->attr().at(kDataFormat).s() != "NHWC") return false;
|
|
|
|
|
if (!NodeIsOnCpu(conv2d.node)) return false;
|
|
|
|
|
if (!HaveSameDataType(node, conv2d.node)) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false;
|
|
|
|
|
if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false;
|
|
|
|
|
if (IsInPreserveSet(ctx, conv2d.node)) return false;
|
|
|
|
|
|
|
|
|
|
if (!conv2d.node || !IsConv2D(*conv2d.node) ||
|
|
|
|
|
!HaveSameDataType(bias_add, conv2d.node) ||
|
|
|
|
|
HasControlFaninOrFanout(ctx.graph_view, conv2d.node) ||
|
|
|
|
|
!HasSingleFanoutNode(ctx.graph_view, conv2d.node) ||
|
|
|
|
|
IsInPreserveSet(ctx, conv2d.node))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Check that data type and data format are supported on assigned device.
|
|
|
|
|
const Conv2DWithBiasAdd pattern{conv2d.node, bias_add};
|
|
|
|
|
if (check_device_compatible && !IsDeviceCompatible(ctx, pattern)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// We successfully found a Conv2D+BiasAdd pattern.
|
|
|
|
|
matched->conv2d = conv2d.node;
|
|
|
|
|
matched->bias_add = node;
|
|
|
|
|
*matched = pattern;
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* node,
|
|
|
|
|
bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* relu,
|
|
|
|
|
Conv2DWithBiasAddAndRelu* matched) {
|
|
|
|
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
|
|
|
|
|
|
|
|
|
// Root of the pattern must be a Relu.
|
|
|
|
|
if (!node) return false;
|
|
|
|
|
if (!IsRelu(*node)) return false;
|
|
|
|
|
if (!NodeIsOnCpu(node)) return false;
|
|
|
|
|
if (!IsFloatOrDoubleDataType(node)) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
|
|
|
|
|
if (!relu || !IsRelu(*relu) || HasControlFaninOrFanout(ctx.graph_view, relu))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// And input to Relu must match Conv2DWithBiasAdd pattern.
|
|
|
|
|
const auto input_port = GraphView::InputPort(node, 0);
|
|
|
|
|
const auto input_port = GraphView::InputPort(relu, 0);
|
|
|
|
|
const auto bias_add = ctx.graph_view.GetRegularFanin(input_port);
|
|
|
|
|
|
|
|
|
|
Conv2DWithBiasAdd base;
|
|
|
|
|
if (!FindConv2DWithBias(ctx, bias_add.node, &base)) return false;
|
|
|
|
|
if (!HasSingleFanoutNode(ctx.graph_view, base.bias_add)) return false;
|
|
|
|
|
if (!HaveSameDataType(node, base.bias_add)) return false;
|
|
|
|
|
if (IsInPreserveSet(ctx, base.bias_add)) return false;
|
|
|
|
|
if (!FindConv2DWithBias(ctx, bias_add.node, &base,
|
|
|
|
|
/*check_device_compatible=*/false) ||
|
|
|
|
|
!HasSingleFanoutNode(ctx.graph_view, base.bias_add) ||
|
|
|
|
|
!HaveSameDataType(relu, base.bias_add) ||
|
|
|
|
|
IsInPreserveSet(ctx, base.bias_add))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Check that data type and data format are supported on assigned device.
|
|
|
|
|
const Conv2DWithBiasAddAndRelu pattern{base.conv2d, base.bias_add, relu};
|
|
|
|
|
if (!IsDeviceCompatible(ctx, pattern)) return false;
|
|
|
|
|
|
|
|
|
|
// We successfully found a Conv2D+BiasAdd+Relu pattern.
|
|
|
|
|
matched->conv2d = base.conv2d;
|
|
|
|
|
matched->bias_add = base.bias_add;
|
|
|
|
|
matched->relu = node;
|
|
|
|
|
*matched = pattern;
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx,
|
|
|
|
|
const NodeDef* node,
|
|
|
|
|
const NodeDef* bias_add,
|
|
|
|
|
Conv2DWithSqueezeAndBiasAdd* matched) {
|
|
|
|
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
|
|
|
|
|
|
|
|
|
// Root of the pattern must be a BiasAdd.
|
|
|
|
|
if (node == nullptr) return false;
|
|
|
|
|
if (node->op() != "BiasAdd") return false;
|
|
|
|
|
if (!NodeIsOnCpu(node)) return false;
|
|
|
|
|
if (!IsFloatOrDoubleDataType(node)) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
|
|
|
|
|
if (!bias_add || !IsBiasAdd(*bias_add) ||
|
|
|
|
|
HasControlFaninOrFanout(ctx.graph_view, bias_add))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Input to the BiasAdd must be a Squeeze.
|
|
|
|
|
const auto bias_input_port = GraphView::InputPort(node, 0);
|
|
|
|
|
const auto bias_input_port = GraphView::InputPort(bias_add, 0);
|
|
|
|
|
const auto squeeze = ctx.graph_view.GetRegularFanin(bias_input_port);
|
|
|
|
|
if (squeeze.node == nullptr) return false;
|
|
|
|
|
if (squeeze.node->op() != "Squeeze") return false;
|
|
|
|
|
if (!NodeIsOnCpu(squeeze.node)) return false;
|
|
|
|
|
if (!HaveSameDataType(node, squeeze.node, "T")) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, squeeze.node)) return false;
|
|
|
|
|
if (!HasSingleFanoutNode(ctx.graph_view, squeeze.node)) return false;
|
|
|
|
|
if (IsInPreserveSet(ctx, squeeze.node)) return false;
|
|
|
|
|
|
|
|
|
|
if (!squeeze.node || !IsSqueeze(*squeeze.node) ||
|
|
|
|
|
!HaveSameDataType(bias_add, squeeze.node, "T") ||
|
|
|
|
|
HasControlFaninOrFanout(ctx.graph_view, squeeze.node) ||
|
|
|
|
|
!HasSingleFanoutNode(ctx.graph_view, squeeze.node) ||
|
|
|
|
|
IsInPreserveSet(ctx, squeeze.node))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Squeeze must not squeeze output channel dimension.
|
|
|
|
|
std::vector<int32> dims;
|
|
|
|
@ -212,67 +317,72 @@ bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx,
|
|
|
|
|
if (dim == 3) return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Input to the Squeeze must be a Conv2D in NHWC format.
|
|
|
|
|
// Input to the Squeeze must be a Conv2D.
|
|
|
|
|
const auto squeeze_input_port = GraphView::InputPort(squeeze.node, 0);
|
|
|
|
|
const auto conv2d = ctx.graph_view.GetRegularFanin(squeeze_input_port);
|
|
|
|
|
if (conv2d.node == nullptr) return false;
|
|
|
|
|
if (conv2d.node->op() != "Conv2D") return false;
|
|
|
|
|
if (conv2d.node->attr().at("data_format").s() != "NHWC") return false;
|
|
|
|
|
if (!NodeIsOnCpu(conv2d.node)) return false;
|
|
|
|
|
if (!HaveSameDataType(node, conv2d.node, "T")) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false;
|
|
|
|
|
if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false;
|
|
|
|
|
if (IsInPreserveSet(ctx, conv2d.node)) return false;
|
|
|
|
|
|
|
|
|
|
if (!conv2d.node || !IsConv2D(*conv2d.node) ||
|
|
|
|
|
!HaveSameDataType(bias_add, conv2d.node, "T") ||
|
|
|
|
|
HasControlFaninOrFanout(ctx.graph_view, conv2d.node) ||
|
|
|
|
|
!HasSingleFanoutNode(ctx.graph_view, conv2d.node) ||
|
|
|
|
|
IsInPreserveSet(ctx, conv2d.node))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Check that data type and data format are supported on assigned device.
|
|
|
|
|
const Conv2DWithSqueezeAndBiasAdd pattern{conv2d.node, squeeze.node,
|
|
|
|
|
bias_add};
|
|
|
|
|
if (!IsDeviceCompatible(ctx, pattern)) return false;
|
|
|
|
|
|
|
|
|
|
// We successfully found a Conv2D+Squeeze+BiasAdd pattern.
|
|
|
|
|
matched->conv2d = conv2d.node;
|
|
|
|
|
matched->squeeze = squeeze.node;
|
|
|
|
|
matched->bias_add = node;
|
|
|
|
|
*matched = pattern;
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FindConv2DWithBatchNorm(const RemapperContext& ctx, const NodeDef* node,
|
|
|
|
|
bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
|
|
|
|
|
const NodeDef* batch_norm,
|
|
|
|
|
Conv2DWithBatchNorm* matched) {
|
|
|
|
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
|
|
|
|
|
|
|
|
|
// Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2.
|
|
|
|
|
if (node == nullptr) return false;
|
|
|
|
|
if (!IsFusedBatchNorm(*node)) return false;
|
|
|
|
|
if (!NodeIsOnCpu(node)) return false;
|
|
|
|
|
if (!HasDataType(node, DT_FLOAT)) return false;
|
|
|
|
|
if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false;
|
|
|
|
|
|
|
|
|
|
// V2 has a separate data type for the scale/offset/mean/variance inputs.
|
|
|
|
|
if (node->op() == "FusedBatchNormV2" && !HasDataType(node, DT_FLOAT, "U"))
|
|
|
|
|
if (batch_norm->op() == "FusedBatchNormV2" &&
|
|
|
|
|
!HasDataType(batch_norm, DT_FLOAT, "U"))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Check that batch normalization is in inference mode.
|
|
|
|
|
const auto& attr = node->attr();
|
|
|
|
|
const auto& attr = batch_norm->attr();
|
|
|
|
|
if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false;
|
|
|
|
|
|
|
|
|
|
// Check that only 0th output is consumed by other nodes.
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
|
|
|
|
|
if (HasFanouts(ctx.graph_view, node, 1)) return false; // batch_mean
|
|
|
|
|
if (HasFanouts(ctx.graph_view, node, 2)) return false; // batch_variance
|
|
|
|
|
if (HasFanouts(ctx.graph_view, node, 3)) return false; // reserve_space_1
|
|
|
|
|
if (HasFanouts(ctx.graph_view, node, 4)) return false; // reserve_space_2
|
|
|
|
|
if (HasControlFaninOrFanout(ctx.graph_view, batch_norm) ||
|
|
|
|
|
HasFanouts(ctx.graph_view, batch_norm, 1) || // batch_mean
|
|
|
|
|
HasFanouts(ctx.graph_view, batch_norm, 2) || // batch_variance
|
|
|
|
|
HasFanouts(ctx.graph_view, batch_norm, 3) || // reserve_space_1
|
|
|
|
|
HasFanouts(ctx.graph_view, batch_norm, 4)) // reserve_space_2
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// Input to the FusedBatchNorm must be a Conv2D in NHWC format.
|
|
|
|
|
const auto input_port = GraphView::InputPort(node, 0);
|
|
|
|
|
// Input to the FusedBatchNorm must be a Conv2D.
|
|
|
|
|
const auto input_port = GraphView::InputPort(batch_norm, 0);
|
|
|
|
|
const auto conv2d = ctx.graph_view.GetRegularFanin(input_port);
|
|
|
|
|
if (conv2d.node == nullptr) return false;
|
|
|
|
|
if (!IsConv2D(*conv2d.node)) return false;
|
|
|
|
|
if (conv2d.node->attr().at(kDataFormat).s() != "NHWC") return false;
|
|
|
|
|
if (!NodeIsOnCpu(conv2d.node)) return false;
|
|
|
|
|
if (!HaveSameDataType(node, conv2d.node)) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false;
|
|
|
|
|
if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false;
|
|
|
|
|
if (IsInPreserveSet(ctx, conv2d.node)) return false;
|
|
|
|
|
|
|
|
|
|
if (!conv2d.node || !IsConv2D(*conv2d.node) || //
|
|
|
|
|
!NodeIsOnCpu(conv2d.node) || //
|
|
|
|
|
!HaveSameDataType(batch_norm, conv2d.node) || //
|
|
|
|
|
!IsCpuCompatibleDataType(conv2d.node) || //
|
|
|
|
|
!IsCpuCompatibleDataFormat(conv2d.node) || //
|
|
|
|
|
HasControlFaninOrFanout(ctx.graph_view, conv2d.node) || //
|
|
|
|
|
!HasSingleFanoutNode(ctx.graph_view, conv2d.node) || //
|
|
|
|
|
IsInPreserveSet(ctx, conv2d.node))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// We successfully found a Conv2D+FusedBatchNorm pattern.
|
|
|
|
|
matched->conv2d = conv2d.node;
|
|
|
|
|
matched->fused_batch_norm = node;
|
|
|
|
|
if (!GetNodeAttr(*node, "epsilon", &matched->epsilon).ok()) return false;
|
|
|
|
|
matched->fused_batch_norm = batch_norm;
|
|
|
|
|
if (!GetNodeAttr(*batch_norm, "epsilon", &matched->epsilon).ok())
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
@ -283,21 +393,19 @@ bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx,
|
|
|
|
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
|
|
|
|
|
|
|
|
|
// Root of the pattern must be a Relu.
|
|
|
|
|
if (node == nullptr) return false;
|
|
|
|
|
if (!IsRelu(*node)) return false;
|
|
|
|
|
if (!NodeIsOnCpu(node)) return false;
|
|
|
|
|
if (!IsFloatOrDoubleDataType(node)) return false;
|
|
|
|
|
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
|
|
|
|
|
if (!node || !IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// And input to Relu must match Conv2DWithBatchNorm pattern.
|
|
|
|
|
const auto input_port = GraphView::InputPort(node, 0);
|
|
|
|
|
const auto batch_norm = ctx.graph_view.GetRegularFanin(input_port);
|
|
|
|
|
|
|
|
|
|
Conv2DWithBatchNorm base;
|
|
|
|
|
if (!FindConv2DWithBatchNorm(ctx, batch_norm.node, &base)) return false;
|
|
|
|
|
if (!HasSingleFanoutNode(ctx.graph_view, base.fused_batch_norm)) return false;
|
|
|
|
|
if (!HaveSameDataType(node, base.fused_batch_norm)) return false;
|
|
|
|
|
if (IsInPreserveSet(ctx, base.fused_batch_norm)) return false;
|
|
|
|
|
if (!FindConv2DWithBatchNorm(ctx, batch_norm.node, &base) ||
|
|
|
|
|
!HasSingleFanoutNode(ctx.graph_view, base.fused_batch_norm) ||
|
|
|
|
|
!HaveSameDataType(node, base.fused_batch_norm) ||
|
|
|
|
|
IsInPreserveSet(ctx, base.fused_batch_norm))
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
// We successfully found a Conv2D+FusedBatchNorm+Relu pattern.
|
|
|
|
|
matched->conv2d = base.conv2d;
|
|
|
|
@ -355,9 +463,7 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d,
|
|
|
|
|
const std::vector<string>& fused_ops = {},
|
|
|
|
|
int num_args = 1, float epsilon = 0.0) {
|
|
|
|
|
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
|
|
|
|
|
auto* attr = fused_conv2d->mutable_attr();
|
|
|
|
|
auto src_attr = conv2d->attr();
|
|
|
|
|
|
|
|
|
@ -367,53 +473,65 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d,
|
|
|
|
|
(*attr)["dilations"] = src_attr.at("dilations");
|
|
|
|
|
(*attr)["data_format"] = src_attr.at("data_format");
|
|
|
|
|
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* fused_ops_attr = (*attr)["fused_ops"].mutable_list();
|
|
|
|
|
for (const string& fused_op : fused_ops) {
|
|
|
|
|
fused_ops_attr->add_s(fused_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetFusedConv2DAttributes(
|
|
|
|
|
NodeDef* fused_conv2d, const absl::Span<const absl::string_view> fused_ops,
|
|
|
|
|
int num_args = 1, float epsilon = 0.0) {
|
|
|
|
|
auto* attr = fused_conv2d->mutable_attr();
|
|
|
|
|
SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
|
|
|
|
|
SetAttrValue(num_args, &(*attr)["num_args"]);
|
|
|
|
|
// Required only for FusedBatchNorm.
|
|
|
|
|
SetAttrValue(epsilon, &(*attr)["epsilon"]);
|
|
|
|
|
SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddFusedConv2DNode(
|
|
|
|
|
const Conv2DWithBiasAdd& matched, GraphDef* optimized_graph,
|
|
|
|
|
const RemapperContext& ctx, const Conv2DWithBiasAdd& matched,
|
|
|
|
|
GraphDef* optimized_graph,
|
|
|
|
|
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
|
|
|
|
VLOG(2) << "Fuse Conv2D with BiasAdd: bias_add=" << matched.bias_add->name()
|
|
|
|
|
DCHECK(IsDeviceCompatible(ctx, matched))
|
|
|
|
|
<< "Unsupported fused Conv2D pattern";
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "Fuse Conv2D with BiasAdd: "
|
|
|
|
|
<< " bias_add=" << matched.bias_add->name()
|
|
|
|
|
<< " conv2d=" << matched.conv2d->name();
|
|
|
|
|
|
|
|
|
|
NodeDef* fused_conv2d = optimized_graph->add_node();
|
|
|
|
|
fused_conv2d->set_name(matched.bias_add->name());
|
|
|
|
|
fused_conv2d->set_op(kFusedConv2D);
|
|
|
|
|
fused_conv2d->set_device(matched.bias_add->device());
|
|
|
|
|
fused_conv2d->set_name(matched.bias_add->name());
|
|
|
|
|
fused_conv2d->set_device(matched.conv2d->device());
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
|
|
|
|
|
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
|
|
|
|
|
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd"});
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
|
|
|
|
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd"});
|
|
|
|
|
|
|
|
|
|
invalidated_nodes->insert(matched.bias_add);
|
|
|
|
|
invalidated_nodes->insert(matched.conv2d);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddFusedConv2DNode(
|
|
|
|
|
const Conv2DWithBiasAddAndRelu& matched, GraphDef* optimized_graph,
|
|
|
|
|
const RemapperContext& ctx, const Conv2DWithBiasAddAndRelu& matched,
|
|
|
|
|
GraphDef* optimized_graph,
|
|
|
|
|
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
|
|
|
|
VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: relu=" << matched.relu->name()
|
|
|
|
|
DCHECK(IsDeviceCompatible(ctx, matched))
|
|
|
|
|
<< "Unsupported fused Conv2D pattern";
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: "
|
|
|
|
|
<< " relu=" << matched.relu->name()
|
|
|
|
|
<< " bias_add=" << matched.bias_add->name()
|
|
|
|
|
<< " conv2d=" << matched.conv2d->name();
|
|
|
|
|
|
|
|
|
|
NodeDef* fused_conv2d = optimized_graph->add_node();
|
|
|
|
|
fused_conv2d->set_name(matched.relu->name());
|
|
|
|
|
fused_conv2d->set_op(kFusedConv2D);
|
|
|
|
|
fused_conv2d->set_device(matched.relu->device());
|
|
|
|
|
fused_conv2d->set_device(matched.conv2d->device());
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
|
|
|
|
|
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
|
|
|
|
|
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd", "Relu"});
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
|
|
|
|
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", "Relu"});
|
|
|
|
|
|
|
|
|
|
invalidated_nodes->insert(matched.relu);
|
|
|
|
|
invalidated_nodes->insert(matched.bias_add);
|
|
|
|
@ -421,8 +539,12 @@ void AddFusedConv2DNode(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddFusedConv2DNode(
|
|
|
|
|
const Conv2DWithSqueezeAndBiasAdd& matched, GraphDef* optimized_graph,
|
|
|
|
|
const RemapperContext& ctx, const Conv2DWithSqueezeAndBiasAdd& matched,
|
|
|
|
|
GraphDef* optimized_graph,
|
|
|
|
|
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
|
|
|
|
DCHECK(IsDeviceCompatible(ctx, matched))
|
|
|
|
|
<< "Unsupported fused Conv2D pattern";
|
|
|
|
|
|
|
|
|
|
VLOG(2) << "Fuse Conv2D with Squeeze and BiasAdd: "
|
|
|
|
|
<< " bias_add=" << matched.bias_add->name()
|
|
|
|
|
<< " squeeze=" << matched.squeeze->name()
|
|
|
|
@ -432,13 +554,14 @@ void AddFusedConv2DNode(
|
|
|
|
|
// has single consumer (only the squeeze node).
|
|
|
|
|
NodeDef* fused_conv2d = optimized_graph->add_node();
|
|
|
|
|
fused_conv2d->set_name(matched.conv2d->name());
|
|
|
|
|
fused_conv2d->set_op("_FusedConv2D");
|
|
|
|
|
fused_conv2d->set_op(kFusedConv2D);
|
|
|
|
|
fused_conv2d->set_device(matched.conv2d->device());
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
|
|
|
|
|
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
|
|
|
|
|
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd"});
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
|
|
|
|
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd"});
|
|
|
|
|
|
|
|
|
|
// Replace BiasAdd node with a Squeeze.
|
|
|
|
|
NodeDef* remapped_squeeze = optimized_graph->add_node();
|
|
|
|
@ -461,7 +584,7 @@ void AddFusedConv2DNode(
|
|
|
|
|
NodeDef* fused_conv2d = optimized_graph->add_node();
|
|
|
|
|
fused_conv2d->set_name(matched.fused_batch_norm->name());
|
|
|
|
|
fused_conv2d->set_op(kFusedConv2D);
|
|
|
|
|
fused_conv2d->set_device(matched.fused_batch_norm->device());
|
|
|
|
|
fused_conv2d->set_device(matched.conv2d->device());
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
|
|
|
|
|
fused_conv2d->add_input(matched.fused_batch_norm->input(1)); // 2: scale
|
|
|
|
@ -469,8 +592,9 @@ void AddFusedConv2DNode(
|
|
|
|
|
fused_conv2d->add_input(matched.fused_batch_norm->input(3)); // 4: mean
|
|
|
|
|
fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance
|
|
|
|
|
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"FusedBatchNorm"},
|
|
|
|
|
/*num_args*/ 4, /*epsilon*/ matched.epsilon);
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
|
|
|
|
SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm"},
|
|
|
|
|
/*num_args=*/4, /*epsilon=*/matched.epsilon);
|
|
|
|
|
|
|
|
|
|
invalidated_nodes->insert(matched.fused_batch_norm);
|
|
|
|
|
invalidated_nodes->insert(matched.conv2d);
|
|
|
|
@ -487,7 +611,7 @@ void AddFusedConv2DNode(
|
|
|
|
|
NodeDef* fused_conv2d = optimized_graph->add_node();
|
|
|
|
|
fused_conv2d->set_name(matched.relu->name());
|
|
|
|
|
fused_conv2d->set_op(kFusedConv2D);
|
|
|
|
|
fused_conv2d->set_device(matched.fused_batch_norm->device());
|
|
|
|
|
fused_conv2d->set_device(matched.conv2d->device());
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
|
|
|
|
|
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
|
|
|
|
|
fused_conv2d->add_input(matched.fused_batch_norm->input(1)); // 2: scale
|
|
|
|
@ -495,8 +619,9 @@ void AddFusedConv2DNode(
|
|
|
|
|
fused_conv2d->add_input(matched.fused_batch_norm->input(3)); // 4: mean
|
|
|
|
|
fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance
|
|
|
|
|
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"FusedBatchNorm", "Relu"},
|
|
|
|
|
/*num_args*/ 4, /*epsilon*/ matched.epsilon);
|
|
|
|
|
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
|
|
|
|
|
SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm", "Relu"},
|
|
|
|
|
/*num_args=*/4, /*epsilon=*/matched.epsilon);
|
|
|
|
|
|
|
|
|
|
invalidated_nodes->insert(matched.relu);
|
|
|
|
|
invalidated_nodes->insert(matched.fused_batch_norm);
|
|
|
|
@ -680,13 +805,14 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|
|
|
|
|
|
|
|
|
// Remap Conv2D+BiasAdd into the _FusedConv2D.
|
|
|
|
|
if (FindConv2DWithBias(ctx, &node, &conv2d_with_bias)) {
|
|
|
|
|
AddFusedConv2DNode(conv2d_with_bias, optimized_graph, &invalidated_nodes);
|
|
|
|
|
AddFusedConv2DNode(ctx, conv2d_with_bias, optimized_graph,
|
|
|
|
|
&invalidated_nodes);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Remap Conv2D+BiasAdd+Relu into the _FusedConv2D.
|
|
|
|
|
if (FindConv2DWithBiasAndRelu(ctx, &node, &conv2d_with_bias_and_relu)) {
|
|
|
|
|
AddFusedConv2DNode(conv2d_with_bias_and_relu, optimized_graph,
|
|
|
|
|
AddFusedConv2DNode(ctx, conv2d_with_bias_and_relu, optimized_graph,
|
|
|
|
|
&invalidated_nodes);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -694,7 +820,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|
|
|
|
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
|
|
|
|
if (FindConv2DWithSqueezeAndBias(ctx, &node,
|
|
|
|
|
&conv2d_with_squeeze_and_bias)) {
|
|
|
|
|
AddFusedConv2DNode(conv2d_with_squeeze_and_bias, optimized_graph,
|
|
|
|
|
AddFusedConv2DNode(ctx, conv2d_with_squeeze_and_bias, optimized_graph,
|
|
|
|
|
&invalidated_nodes);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|