[Grappler] Remapper for GPU Conv2d+BiasAdd+Activation

PiperOrigin-RevId: 228214620
This commit is contained in:
Eugene Zhulenev 2019-01-07 12:31:43 -08:00 committed by TensorFlower Gardener
parent 9dfc1ddf24
commit bd2610f638
11 changed files with 353 additions and 152 deletions

View File

@ -27,6 +27,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -66,28 +66,27 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node,
int port) {
const auto output = GraphView::OutputPort(node, port);
const auto fanout = graph_view.GetFanout(output);
return fanout.size() <= 1;
return graph_view.GetFanout(output).size() <= 1;
}
bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port) {
const auto output = GraphView::OutputPort(node, port);
const auto fanout = graph_view.GetFanout(output);
return !fanout.empty();
return !graph_view.GetFanout(output).empty();
}
bool NoControlFanin(const GraphView& graph_view, const NodeDef* node) {
const auto control_port = GraphView::InputPort(node, -1);
return graph_view.GetFanin(control_port).empty();
bool HasControlFanin(const GraphView& graph_view, const NodeDef* node) {
const auto control_port = GraphView::InputPort(node, Graph::kControlSlot);
return !graph_view.GetFanin(control_port).empty();
}
bool NoControlFanout(const GraphView& graph_view, const NodeDef* node) {
const auto control_port = GraphView::OutputPort(node, -1);
return graph_view.GetFanout(control_port).empty();
bool HasControlFanout(const GraphView& graph_view, const NodeDef* node) {
const auto control_port = GraphView::OutputPort(node, Graph::kControlSlot);
return !graph_view.GetFanout(control_port).empty();
}
bool NoControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node) {
return NoControlFanin(graph_view, node) && NoControlFanout(graph_view, node);
bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node) {
return HasControlFanin(graph_view, node) ||
HasControlFanout(graph_view, node);
}
} // end namespace grappler

View File

@ -369,10 +369,12 @@ bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node,
// Returns true if node has at least one fanout node at given output port.
bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port = 0);
bool NoControlFanin(const GraphView& graph_view, const NodeDef* node);
bool NoControlFanout(const GraphView& graph_view, const NodeDef* node);
bool NoControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node);
// Returns true if the node has at least one input control dependency.
bool HasControlFanin(const GraphView& graph_view, const NodeDef* node);
// Returns true if the node has at least one output control dependency.
bool HasControlFanout(const GraphView& graph_view, const NodeDef* node);
// Returns true if the node has at least one input or output control dependency.
bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -700,6 +700,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
"@com_google_absl//absl/container:flat_hash_set",
],
@ -711,6 +712,7 @@ tf_cuda_cc_test(
deps = [
":remapper",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/platform/logging.h"
@ -60,17 +61,30 @@ struct RemapperContext {
// FusedBatchNorm that can be replaced with a cheaper set of primitives.
struct FusedBatchNorm {
FusedBatchNorm() = default;
explicit FusedBatchNorm(const NodeDef* fused_batch_norm)
: fused_batch_norm(fused_batch_norm) {}
const NodeDef* fused_batch_norm = nullptr;
};
// Conv2D node followed by a BiasAdd.
struct Conv2DWithBiasAdd {
Conv2DWithBiasAdd() = default;
Conv2DWithBiasAdd(const NodeDef* conv2d, const NodeDef* bias_add)
: conv2d(conv2d), bias_add(bias_add) {}
const NodeDef* conv2d = nullptr;
const NodeDef* bias_add = nullptr;
};
// Conv2D node followed by a BiasAdd and Relu.
struct Conv2DWithBiasAddAndRelu {
Conv2DWithBiasAddAndRelu() = default;
Conv2DWithBiasAddAndRelu(const NodeDef* conv2d, const NodeDef* bias_add,
const NodeDef* relu)
: conv2d(conv2d), bias_add(bias_add), relu(relu) {}
const NodeDef* conv2d = nullptr;
const NodeDef* bias_add = nullptr;
const NodeDef* relu = nullptr;
@ -78,6 +92,11 @@ struct Conv2DWithBiasAddAndRelu {
// Conv2D node followed by a Squeeze and BiasAdd.
struct Conv2DWithSqueezeAndBiasAdd {
Conv2DWithSqueezeAndBiasAdd() = default;
Conv2DWithSqueezeAndBiasAdd(const NodeDef* conv2d, const NodeDef* squeeze,
const NodeDef* bias_add)
: conv2d(conv2d), squeeze(squeeze), bias_add(bias_add) {}
const NodeDef* conv2d = nullptr;
const NodeDef* squeeze = nullptr;
const NodeDef* bias_add = nullptr;
@ -85,6 +104,11 @@ struct Conv2DWithSqueezeAndBiasAdd {
// Conv2D node followed by a FusedBatchNorm.
struct Conv2DWithBatchNorm {
Conv2DWithBatchNorm() = default;
Conv2DWithBatchNorm(const NodeDef* conv2d, const NodeDef* fused_batch_norm,
float epsilon = 0.0)
: conv2d(conv2d), fused_batch_norm(fused_batch_norm), epsilon(epsilon) {}
const NodeDef* conv2d = nullptr;
const NodeDef* fused_batch_norm = nullptr;
float epsilon = 0.0;
@ -92,16 +116,23 @@ struct Conv2DWithBatchNorm {
// Conv2D node followed by a FusedBatchNorm and Relu.
struct Conv2DWithBatchNormAndRelu {
Conv2DWithBatchNormAndRelu() = default;
Conv2DWithBatchNormAndRelu(const NodeDef* conv2d,
const NodeDef* fused_batch_norm,
const NodeDef* relu, float epsilon = 0.0)
: conv2d(conv2d),
fused_batch_norm(fused_batch_norm),
relu(relu),
epsilon(epsilon) {}
const NodeDef* conv2d = nullptr;
const NodeDef* fused_batch_norm = nullptr;
const NodeDef* relu = nullptr;
float epsilon = 0.0;
};
bool IsFloatOrDoubleDataType(const NodeDef* node,
const string& type_attr = "T") {
DataType dtype = GetDataTypeFromAttr(*node, type_attr);
return dtype == DT_FLOAT || dtype == DT_DOUBLE;
bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
return ctx.nodes_to_preserve.count(node->name()) > 0;
}
bool HaveSameDataType(const NodeDef* lhs, const NodeDef* rhs,
@ -119,91 +150,165 @@ bool HasDataType(const NodeDef* node, const DataType& expected,
return dtype == expected;
}
bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
return ctx.nodes_to_preserve.count(node->name()) > 0;
bool IsCpuCompatibleDataType(const NodeDef* node,
const string& type_attr = "T") {
DataType dtype = GetDataTypeFromAttr(*node, type_attr);
return dtype == DT_FLOAT || dtype == DT_DOUBLE;
}
bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* node,
Conv2DWithBiasAdd* matched) {
bool IsGpuCompatibleDataType(const NodeDef* node,
const string& type_attr = "T") {
DataType dtype = GetDataTypeFromAttr(*node, type_attr);
return dtype == DT_FLOAT;
}
bool IsCpuCompatibleDataFormat(const NodeDef* conv2d) {
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
const string& data_format = conv2d->attr().at(kDataFormat).s();
return data_format == "NHWC";
}
bool IsGpuCompatibleDataFormat(const NodeDef* conv2d) {
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
const string& data_format = conv2d->attr().at(kDataFormat).s();
return data_format == "NHWC" || data_format == "NCHW";
}
bool IsCpuCompatibleConv2D(const NodeDef* conv2d) {
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
return NodeIsOnCpu(conv2d) && IsCpuCompatibleDataType(conv2d) &&
IsCpuCompatibleDataFormat(conv2d);
}
bool IsGpuCompatibleConv2D(const NodeDef* conv2d) {
DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op";
return NodeIsOnGpu(conv2d) && IsGpuCompatibleDataType(conv2d) &&
IsGpuCompatibleDataFormat(conv2d);
}
// Checks if we can rewrite a pattern to the `_FusedConv2D` on CPU device.
template <typename Pattern>
bool IsCpuCompatible(const Pattern& matched) {
return IsCpuCompatibleConv2D(matched.conv2d);
}
// Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device.
bool IsGpuCompatible(const RemapperContext& ctx,
const Conv2DWithBiasAddAndRelu& matched) {
const std::vector<OpInfo::TensorProperties>& input_props =
ctx.graph_properties.GetInputProperties(matched.conv2d->name());
const TensorShapeProto& filter_shape =
input_props.size() >= 2 ? input_props[1].shape() : TensorShapeProto();
// FusedConv2D on GPU with 1x1 convolution is marginally faster than
// in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc),
// and significantly slower in large scale benchmarks.
bool is_spatial_conv = Rank(filter_shape) == 4 && //
IsKnown(filter_shape.dim(1)) && //
IsKnown(filter_shape.dim(2)) && //
filter_shape.dim(1).size() != 1 && //
filter_shape.dim(2).size() != 1;
return is_spatial_conv && IsGpuCompatibleConv2D(matched.conv2d);
}
bool IsGpuCompatible(const RemapperContext& ctx,
const Conv2DWithBiasAdd& matched) {
return false;
}
bool IsGpuCompatible(const RemapperContext& ctx,
const Conv2DWithSqueezeAndBiasAdd& matched) {
return false;
}
// Returns true if the given pattern is supported on the assigned device.
template <typename Pattern>
bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
return IsCpuCompatible(matched) || IsGpuCompatible(ctx, matched);
}
bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add,
Conv2DWithBiasAdd* matched,
bool check_device_compatible = true) {
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a BiasAdd.
if (!node) return false;
if (!IsBiasAdd(*node)) return false;
if (!NodeIsOnCpu(node)) return false;
if (!IsFloatOrDoubleDataType(node)) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
if (bias_add == nullptr || !IsBiasAdd(*bias_add) ||
HasControlFaninOrFanout(ctx.graph_view, bias_add))
return false;
// Input to the BiasAdd must be a Conv2D in NHWC format.
const auto input_port = GraphView::InputPort(node, 0);
// Input to the BiasAdd must be a Conv2D.
const auto input_port = GraphView::InputPort(bias_add, 0);
const auto conv2d = ctx.graph_view.GetRegularFanin(input_port);
if (!conv2d.node) return false;
if (!IsConv2D(*conv2d.node)) return false;
if (conv2d.node->attr().at(kDataFormat).s() != "NHWC") return false;
if (!NodeIsOnCpu(conv2d.node)) return false;
if (!HaveSameDataType(node, conv2d.node)) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false;
if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false;
if (IsInPreserveSet(ctx, conv2d.node)) return false;
if (!conv2d.node || !IsConv2D(*conv2d.node) ||
!HaveSameDataType(bias_add, conv2d.node) ||
HasControlFaninOrFanout(ctx.graph_view, conv2d.node) ||
!HasSingleFanoutNode(ctx.graph_view, conv2d.node) ||
IsInPreserveSet(ctx, conv2d.node))
return false;
// Check that data type and data format are supported on assigned device.
const Conv2DWithBiasAdd pattern{conv2d.node, bias_add};
if (check_device_compatible && !IsDeviceCompatible(ctx, pattern)) {
return false;
}
// We successfully found a Conv2D+BiasAdd pattern.
matched->conv2d = conv2d.node;
matched->bias_add = node;
*matched = pattern;
return true;
}
bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* node,
bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* relu,
Conv2DWithBiasAddAndRelu* matched) {
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a Relu.
if (!node) return false;
if (!IsRelu(*node)) return false;
if (!NodeIsOnCpu(node)) return false;
if (!IsFloatOrDoubleDataType(node)) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
if (!relu || !IsRelu(*relu) || HasControlFaninOrFanout(ctx.graph_view, relu))
return false;
// And input to Relu must match Conv2DWithBiasAdd pattern.
const auto input_port = GraphView::InputPort(node, 0);
const auto input_port = GraphView::InputPort(relu, 0);
const auto bias_add = ctx.graph_view.GetRegularFanin(input_port);
Conv2DWithBiasAdd base;
if (!FindConv2DWithBias(ctx, bias_add.node, &base)) return false;
if (!HasSingleFanoutNode(ctx.graph_view, base.bias_add)) return false;
if (!HaveSameDataType(node, base.bias_add)) return false;
if (IsInPreserveSet(ctx, base.bias_add)) return false;
if (!FindConv2DWithBias(ctx, bias_add.node, &base,
/*check_device_compatible=*/false) ||
!HasSingleFanoutNode(ctx.graph_view, base.bias_add) ||
!HaveSameDataType(relu, base.bias_add) ||
IsInPreserveSet(ctx, base.bias_add))
return false;
// Check that data type and data format are supported on assigned device.
const Conv2DWithBiasAddAndRelu pattern{base.conv2d, base.bias_add, relu};
if (!IsDeviceCompatible(ctx, pattern)) return false;
// We successfully found a Conv2D+BiasAdd+Relu pattern.
matched->conv2d = base.conv2d;
matched->bias_add = base.bias_add;
matched->relu = node;
*matched = pattern;
return true;
}
bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx,
const NodeDef* node,
const NodeDef* bias_add,
Conv2DWithSqueezeAndBiasAdd* matched) {
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a BiasAdd.
if (node == nullptr) return false;
if (node->op() != "BiasAdd") return false;
if (!NodeIsOnCpu(node)) return false;
if (!IsFloatOrDoubleDataType(node)) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
if (!bias_add || !IsBiasAdd(*bias_add) ||
HasControlFaninOrFanout(ctx.graph_view, bias_add))
return false;
// Input to the BiasAdd must be a Squeeze.
const auto bias_input_port = GraphView::InputPort(node, 0);
const auto bias_input_port = GraphView::InputPort(bias_add, 0);
const auto squeeze = ctx.graph_view.GetRegularFanin(bias_input_port);
if (squeeze.node == nullptr) return false;
if (squeeze.node->op() != "Squeeze") return false;
if (!NodeIsOnCpu(squeeze.node)) return false;
if (!HaveSameDataType(node, squeeze.node, "T")) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, squeeze.node)) return false;
if (!HasSingleFanoutNode(ctx.graph_view, squeeze.node)) return false;
if (IsInPreserveSet(ctx, squeeze.node)) return false;
if (!squeeze.node || !IsSqueeze(*squeeze.node) ||
!HaveSameDataType(bias_add, squeeze.node, "T") ||
HasControlFaninOrFanout(ctx.graph_view, squeeze.node) ||
!HasSingleFanoutNode(ctx.graph_view, squeeze.node) ||
IsInPreserveSet(ctx, squeeze.node))
return false;
// Squeeze must not squeeze output channel dimension.
std::vector<int32> dims;
@ -212,67 +317,72 @@ bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx,
if (dim == 3) return false;
}
// Input to the Squeeze must be a Conv2D in NHWC format.
// Input to the Squeeze must be a Conv2D.
const auto squeeze_input_port = GraphView::InputPort(squeeze.node, 0);
const auto conv2d = ctx.graph_view.GetRegularFanin(squeeze_input_port);
if (conv2d.node == nullptr) return false;
if (conv2d.node->op() != "Conv2D") return false;
if (conv2d.node->attr().at("data_format").s() != "NHWC") return false;
if (!NodeIsOnCpu(conv2d.node)) return false;
if (!HaveSameDataType(node, conv2d.node, "T")) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false;
if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false;
if (IsInPreserveSet(ctx, conv2d.node)) return false;
if (!conv2d.node || !IsConv2D(*conv2d.node) ||
!HaveSameDataType(bias_add, conv2d.node, "T") ||
HasControlFaninOrFanout(ctx.graph_view, conv2d.node) ||
!HasSingleFanoutNode(ctx.graph_view, conv2d.node) ||
IsInPreserveSet(ctx, conv2d.node))
return false;
// Check that data type and data format are supported on assigned device.
const Conv2DWithSqueezeAndBiasAdd pattern{conv2d.node, squeeze.node,
bias_add};
if (!IsDeviceCompatible(ctx, pattern)) return false;
// We successfully found a Conv2D+Squeeze+BiasAdd pattern.
matched->conv2d = conv2d.node;
matched->squeeze = squeeze.node;
matched->bias_add = node;
*matched = pattern;
return true;
}
bool FindConv2DWithBatchNorm(const RemapperContext& ctx, const NodeDef* node,
bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
const NodeDef* batch_norm,
Conv2DWithBatchNorm* matched) {
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2.
if (node == nullptr) return false;
if (!IsFusedBatchNorm(*node)) return false;
if (!NodeIsOnCpu(node)) return false;
if (!HasDataType(node, DT_FLOAT)) return false;
if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false;
// V2 has a separate data type for the scale/offset/mean/variance inputs.
if (node->op() == "FusedBatchNormV2" && !HasDataType(node, DT_FLOAT, "U"))
if (batch_norm->op() == "FusedBatchNormV2" &&
!HasDataType(batch_norm, DT_FLOAT, "U"))
return false;
// Check that batch normalization is in inference mode.
const auto& attr = node->attr();
const auto& attr = batch_norm->attr();
if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false;
// Check that only 0th output is consumed by other nodes.
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
if (HasFanouts(ctx.graph_view, node, 1)) return false; // batch_mean
if (HasFanouts(ctx.graph_view, node, 2)) return false; // batch_variance
if (HasFanouts(ctx.graph_view, node, 3)) return false; // reserve_space_1
if (HasFanouts(ctx.graph_view, node, 4)) return false; // reserve_space_2
if (HasControlFaninOrFanout(ctx.graph_view, batch_norm) ||
HasFanouts(ctx.graph_view, batch_norm, 1) || // batch_mean
HasFanouts(ctx.graph_view, batch_norm, 2) || // batch_variance
HasFanouts(ctx.graph_view, batch_norm, 3) || // reserve_space_1
HasFanouts(ctx.graph_view, batch_norm, 4)) // reserve_space_2
return false;
// Input to the FusedBatchNorm must be a Conv2D in NHWC format.
const auto input_port = GraphView::InputPort(node, 0);
// Input to the FusedBatchNorm must be a Conv2D.
const auto input_port = GraphView::InputPort(batch_norm, 0);
const auto conv2d = ctx.graph_view.GetRegularFanin(input_port);
if (conv2d.node == nullptr) return false;
if (!IsConv2D(*conv2d.node)) return false;
if (conv2d.node->attr().at(kDataFormat).s() != "NHWC") return false;
if (!NodeIsOnCpu(conv2d.node)) return false;
if (!HaveSameDataType(node, conv2d.node)) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false;
if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false;
if (IsInPreserveSet(ctx, conv2d.node)) return false;
if (!conv2d.node || !IsConv2D(*conv2d.node) || //
!NodeIsOnCpu(conv2d.node) || //
!HaveSameDataType(batch_norm, conv2d.node) || //
!IsCpuCompatibleDataType(conv2d.node) || //
!IsCpuCompatibleDataFormat(conv2d.node) || //
HasControlFaninOrFanout(ctx.graph_view, conv2d.node) || //
!HasSingleFanoutNode(ctx.graph_view, conv2d.node) || //
IsInPreserveSet(ctx, conv2d.node))
return false;
// We successfully found a Conv2D+FusedBatchNorm pattern.
matched->conv2d = conv2d.node;
matched->fused_batch_norm = node;
if (!GetNodeAttr(*node, "epsilon", &matched->epsilon).ok()) return false;
matched->fused_batch_norm = batch_norm;
if (!GetNodeAttr(*batch_norm, "epsilon", &matched->epsilon).ok())
return false;
return true;
}
@ -283,21 +393,19 @@ bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx,
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a Relu.
if (node == nullptr) return false;
if (!IsRelu(*node)) return false;
if (!NodeIsOnCpu(node)) return false;
if (!IsFloatOrDoubleDataType(node)) return false;
if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false;
if (!node || !IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
return false;
// And input to Relu must match Conv2DWithBatchNorm pattern.
const auto input_port = GraphView::InputPort(node, 0);
const auto batch_norm = ctx.graph_view.GetRegularFanin(input_port);
Conv2DWithBatchNorm base;
if (!FindConv2DWithBatchNorm(ctx, batch_norm.node, &base)) return false;
if (!HasSingleFanoutNode(ctx.graph_view, base.fused_batch_norm)) return false;
if (!HaveSameDataType(node, base.fused_batch_norm)) return false;
if (IsInPreserveSet(ctx, base.fused_batch_norm)) return false;
if (!FindConv2DWithBatchNorm(ctx, batch_norm.node, &base) ||
!HasSingleFanoutNode(ctx.graph_view, base.fused_batch_norm) ||
!HaveSameDataType(node, base.fused_batch_norm) ||
IsInPreserveSet(ctx, base.fused_batch_norm))
return false;
// We successfully found a Conv2D+FusedBatchNorm+Relu pattern.
matched->conv2d = base.conv2d;
@ -355,9 +463,7 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
return true;
}
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d,
const std::vector<string>& fused_ops = {},
int num_args = 1, float epsilon = 0.0) {
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
auto* attr = fused_conv2d->mutable_attr();
auto src_attr = conv2d->attr();
@ -367,53 +473,65 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d,
(*attr)["dilations"] = src_attr.at("dilations");
(*attr)["data_format"] = src_attr.at("data_format");
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
}
auto* fused_ops_attr = (*attr)["fused_ops"].mutable_list();
for (const string& fused_op : fused_ops) {
fused_ops_attr->add_s(fused_op);
}
void SetFusedConv2DAttributes(
NodeDef* fused_conv2d, const absl::Span<const absl::string_view> fused_ops,
int num_args = 1, float epsilon = 0.0) {
auto* attr = fused_conv2d->mutable_attr();
SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
SetAttrValue(num_args, &(*attr)["num_args"]);
// Required only for FusedBatchNorm.
SetAttrValue(epsilon, &(*attr)["epsilon"]);
SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm
}
void AddFusedConv2DNode(
const Conv2DWithBiasAdd& matched, GraphDef* optimized_graph,
const RemapperContext& ctx, const Conv2DWithBiasAdd& matched,
GraphDef* optimized_graph,
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
VLOG(2) << "Fuse Conv2D with BiasAdd: bias_add=" << matched.bias_add->name()
DCHECK(IsDeviceCompatible(ctx, matched))
<< "Unsupported fused Conv2D pattern";
VLOG(2) << "Fuse Conv2D with BiasAdd: "
<< " bias_add=" << matched.bias_add->name()
<< " conv2d=" << matched.conv2d->name();
NodeDef* fused_conv2d = optimized_graph->add_node();
fused_conv2d->set_name(matched.bias_add->name());
fused_conv2d->set_op(kFusedConv2D);
fused_conv2d->set_device(matched.bias_add->device());
fused_conv2d->set_name(matched.bias_add->name());
fused_conv2d->set_device(matched.conv2d->device());
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd"});
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd"});
invalidated_nodes->insert(matched.bias_add);
invalidated_nodes->insert(matched.conv2d);
}
void AddFusedConv2DNode(
const Conv2DWithBiasAddAndRelu& matched, GraphDef* optimized_graph,
const RemapperContext& ctx, const Conv2DWithBiasAddAndRelu& matched,
GraphDef* optimized_graph,
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: relu=" << matched.relu->name()
DCHECK(IsDeviceCompatible(ctx, matched))
<< "Unsupported fused Conv2D pattern";
VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: "
<< " relu=" << matched.relu->name()
<< " bias_add=" << matched.bias_add->name()
<< " conv2d=" << matched.conv2d->name();
NodeDef* fused_conv2d = optimized_graph->add_node();
fused_conv2d->set_name(matched.relu->name());
fused_conv2d->set_op(kFusedConv2D);
fused_conv2d->set_device(matched.relu->device());
fused_conv2d->set_device(matched.conv2d->device());
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd", "Relu"});
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", "Relu"});
invalidated_nodes->insert(matched.relu);
invalidated_nodes->insert(matched.bias_add);
@ -421,8 +539,12 @@ void AddFusedConv2DNode(
}
void AddFusedConv2DNode(
const Conv2DWithSqueezeAndBiasAdd& matched, GraphDef* optimized_graph,
const RemapperContext& ctx, const Conv2DWithSqueezeAndBiasAdd& matched,
GraphDef* optimized_graph,
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
DCHECK(IsDeviceCompatible(ctx, matched))
<< "Unsupported fused Conv2D pattern";
VLOG(2) << "Fuse Conv2D with Squeeze and BiasAdd: "
<< " bias_add=" << matched.bias_add->name()
<< " squeeze=" << matched.squeeze->name()
@ -432,13 +554,14 @@ void AddFusedConv2DNode(
// has single consumer (only the squeeze node).
NodeDef* fused_conv2d = optimized_graph->add_node();
fused_conv2d->set_name(matched.conv2d->name());
fused_conv2d->set_op("_FusedConv2D");
fused_conv2d->set_op(kFusedConv2D);
fused_conv2d->set_device(matched.conv2d->device());
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd"});
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd"});
// Replace BiasAdd node with a Squeeze.
NodeDef* remapped_squeeze = optimized_graph->add_node();
@ -461,7 +584,7 @@ void AddFusedConv2DNode(
NodeDef* fused_conv2d = optimized_graph->add_node();
fused_conv2d->set_name(matched.fused_batch_norm->name());
fused_conv2d->set_op(kFusedConv2D);
fused_conv2d->set_device(matched.fused_batch_norm->device());
fused_conv2d->set_device(matched.conv2d->device());
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
fused_conv2d->add_input(matched.fused_batch_norm->input(1)); // 2: scale
@ -469,8 +592,9 @@ void AddFusedConv2DNode(
fused_conv2d->add_input(matched.fused_batch_norm->input(3)); // 4: mean
fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"FusedBatchNorm"},
/*num_args*/ 4, /*epsilon*/ matched.epsilon);
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm"},
/*num_args=*/4, /*epsilon=*/matched.epsilon);
invalidated_nodes->insert(matched.fused_batch_norm);
invalidated_nodes->insert(matched.conv2d);
@ -487,7 +611,7 @@ void AddFusedConv2DNode(
NodeDef* fused_conv2d = optimized_graph->add_node();
fused_conv2d->set_name(matched.relu->name());
fused_conv2d->set_op(kFusedConv2D);
fused_conv2d->set_device(matched.fused_batch_norm->device());
fused_conv2d->set_device(matched.conv2d->device());
fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input
fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter
fused_conv2d->add_input(matched.fused_batch_norm->input(1)); // 2: scale
@ -495,8 +619,9 @@ void AddFusedConv2DNode(
fused_conv2d->add_input(matched.fused_batch_norm->input(3)); // 4: mean
fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance
CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"FusedBatchNorm", "Relu"},
/*num_args*/ 4, /*epsilon*/ matched.epsilon);
CopyConv2DAttributes(matched.conv2d, fused_conv2d);
SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm", "Relu"},
/*num_args=*/4, /*epsilon=*/matched.epsilon);
invalidated_nodes->insert(matched.relu);
invalidated_nodes->insert(matched.fused_batch_norm);
@ -680,13 +805,14 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// Remap Conv2D+BiasAdd into the _FusedConv2D.
if (FindConv2DWithBias(ctx, &node, &conv2d_with_bias)) {
AddFusedConv2DNode(conv2d_with_bias, optimized_graph, &invalidated_nodes);
AddFusedConv2DNode(ctx, conv2d_with_bias, optimized_graph,
&invalidated_nodes);
continue;
}
// Remap Conv2D+BiasAdd+Relu into the _FusedConv2D.
if (FindConv2DWithBiasAndRelu(ctx, &node, &conv2d_with_bias_and_relu)) {
AddFusedConv2DNode(conv2d_with_bias_and_relu, optimized_graph,
AddFusedConv2DNode(ctx, conv2d_with_bias_and_relu, optimized_graph,
&invalidated_nodes);
continue;
}
@ -694,7 +820,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
if (FindConv2DWithSqueezeAndBias(ctx, &node,
&conv2d_with_squeeze_and_bias)) {
AddFusedConv2DNode(conv2d_with_squeeze_and_bias, optimized_graph,
AddFusedConv2DNode(ctx, conv2d_with_squeeze_and_bias, optimized_graph,
&invalidated_nodes);
continue;
}

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <queue>
#include <vector>
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -166,10 +168,10 @@ string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
if (name[0] == '^') {
return strings::StrCat("^", prefix, delimiter, name.substr(1));
return absl::StrCat("^", prefix, delimiter, name.substr(1));
}
}
return strings::StrCat(prefix, delimiter, name);
return absl::StrCat(prefix, delimiter, name);
}
string AddPrefixToNodeName(const string& name, const string& prefix) {
@ -193,20 +195,26 @@ bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
}
string AsControlDependency(const NodeDef& node) {
return strings::StrCat("^", node.name());
return absl::StrCat("^", node.name());
}
string AsControlDependency(const string& node_name) {
CHECK(!node_name.empty());
return (!node_name.empty() && node_name[0] == '^')
? node_name
: strings::StrCat("^", node_name);
: absl::StrCat("^", node_name);
}
bool NodeIsOnCpu(const NodeDef* node) {
string task, device;
return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
str_util::StartsWith(device, DEVICE_CPU);
absl::StartsWith(device, DEVICE_CPU);
}
bool NodeIsOnGpu(const NodeDef* node) {
string task, device;
return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
absl::StartsWith(device, DEVICE_GPU);
}
int NumOutputs(const NodeDef& node, GraphDef* graph) {

View File

@ -242,6 +242,9 @@ string AsControlDependency(const string& node);
// Returns true if the node is assigned to run on CPU device.
bool NodeIsOnCpu(const NodeDef* node);
// Returns true if the node is assigned to run on GPU device.
bool NodeIsOnGpu(const NodeDef* node);
// Returns the number of outputs of a node according to its OpDef. Note that
// some of the outputs may be unconnected.
int NumOutputs(const NodeDef& node, GraphDef* graph);

View File

@ -143,6 +143,7 @@ cc_library(
"//tensorflow/core:test",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"@com_google_absl//absl/algorithm:container",
],
)

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include <memory>
#include "absl/algorithm/container.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
@ -125,6 +127,31 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
}
}
void GrapplerTest::CompareNodes(const NodeDef& want, const NodeDef& got) const {
EXPECT_EQ(want.name(), got.name());
EXPECT_EQ(want.op(), got.op());
std::vector<string> want_inputs(want.input().begin(), want.input().end());
std::vector<string> got_inputs(got.input().begin(), got.input().end());
EXPECT_EQ(want_inputs, got_inputs);
const auto attr_name = [](const std::pair<const string, AttrValue>& attr) {
return attr.first;
};
std::vector<string> want_attrs;
std::vector<string> got_attrs;
absl::c_transform(want.attr(), std::back_inserter(want_attrs), attr_name);
absl::c_transform(got.attr(), std::back_inserter(got_attrs), attr_name);
absl::c_sort(want_attrs);
absl::c_sort(got_attrs);
EXPECT_EQ(want_attrs, got_attrs);
for (const string& attr : want_attrs) {
EXPECT_TRUE(AreAttrValuesEqual(want.attr().at(attr), got.attr().at(attr)));
}
}
bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map,
const string& src,
const string& dst, int position) {

View File

@ -49,13 +49,25 @@ class GrapplerTest : public ::testing::Test {
const std::vector<std::pair<string, AttrValue>>& attributes,
GraphDef* graph) const;
// Checks if two graphs are equal. Both graphs must have the same set of nodes
// with the same inputs and attributes. Nodes can be in different order.
//
// NOTE: This function uses EXPECT/ASSERT macros to check node properties
// equality, and adds all failuires to the current test.
void CompareGraphs(GraphDef want, GraphDef got) const;
// Check if node 'src' is directly connected to the input($position) of 'dst'.
// Checks if two nodes have the same name, op, inputs and attributes.
//
// NOTE: This function uses EXPECT/ASSERT macros to check node properties
// equality, and adds all failuires to the current test.
void CompareNodes(const NodeDef& want, const NodeDef& got) const;
// Checks if node 'src' is directly connected to the input($position) of
// 'dst'.
bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src,
const string& dst, int position = 0);
// Count nodes of the given op-type in a graph.
// Counts nodes of the given op-type in a graph.
int CountOpNodes(const GraphDef& graph, const string& op);
// Get a random tensor with given shape.

View File

@ -1497,6 +1497,26 @@ BM_FusedConv2DWithBatchNormAndRelu(32, 32, 32, 128, 3, 3, 1024, cpu,
"3x3 /b 32");
#if GOOGLE_CUDA
// -------------------------------------------------------------------------- //
// 1x1 Convolution
// -------------------------------------------------------------------------- //
BM_Conv2D(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8");
BM_Conv2D(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16");
BM_Conv2D(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32");
BM_Conv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8");
BM_Conv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16");
BM_Conv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32");
BM_FusedConv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8");
BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16");
BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32");
// -------------------------------------------------------------------------- //
// 3x3 Convolution
// -------------------------------------------------------------------------- //
BM_Conv2D(8, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 8");
BM_Conv2D(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16");
BM_Conv2D(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");