diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 4d0f02f4d6d..55d642612bd 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -27,6 +27,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index ba9d2eb3218..be9b9c36c71 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -66,28 +66,27 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node, int port) { const auto output = GraphView::OutputPort(node, port); - const auto fanout = graph_view.GetFanout(output); - return fanout.size() <= 1; + return graph_view.GetFanout(output).size() <= 1; } bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port) { const auto output = GraphView::OutputPort(node, port); - const auto fanout = graph_view.GetFanout(output); - return !fanout.empty(); + return !graph_view.GetFanout(output).empty(); } -bool NoControlFanin(const GraphView& graph_view, const NodeDef* node) { - const auto control_port = GraphView::InputPort(node, -1); - return graph_view.GetFanin(control_port).empty(); +bool HasControlFanin(const GraphView& graph_view, const NodeDef* node) { + const auto control_port = GraphView::InputPort(node, Graph::kControlSlot); + return !graph_view.GetFanin(control_port).empty(); } -bool NoControlFanout(const GraphView& graph_view, const NodeDef* node) { - const auto control_port = GraphView::OutputPort(node, -1); - return graph_view.GetFanout(control_port).empty(); +bool HasControlFanout(const GraphView& graph_view, const NodeDef* node) { + const auto control_port = GraphView::OutputPort(node, Graph::kControlSlot); + return !graph_view.GetFanout(control_port).empty(); } -bool NoControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node) { - return NoControlFanin(graph_view, node) && NoControlFanout(graph_view, node); +bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node) { + return HasControlFanin(graph_view, node) || + HasControlFanout(graph_view, node); } } // end namespace grappler diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index a17d17524ae..dc4ab93894c 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -369,10 +369,12 @@ bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node, // Returns true if node has at least one fanout node at given output port. bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port = 0); - -bool NoControlFanin(const GraphView& graph_view, const NodeDef* node); -bool NoControlFanout(const GraphView& graph_view, const NodeDef* node); -bool NoControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node); +// Returns true if the node has at least one input control dependency. +bool HasControlFanin(const GraphView& graph_view, const NodeDef* node); +// Returns true if the node has at least one output control dependency. +bool HasControlFanout(const GraphView& graph_view, const NodeDef* node); +// Returns true if the node has at least one input or output control dependency. +bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e2b9de7cc50..7e29cee86ac 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -700,6 +700,7 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:symbolic_shapes", "//tensorflow/core/grappler/utils:topological_sort", "@com_google_absl//absl/container:flat_hash_set", ], @@ -711,6 +712,7 @@ tf_cuda_cc_test( deps = [ ":remapper", "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index f0c81f29e68..0869e3b49bf 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/platform/logging.h" @@ -60,17 +61,30 @@ struct RemapperContext { // FusedBatchNorm that can be replaced with a cheaper set of primitives. struct FusedBatchNorm { + FusedBatchNorm() = default; + explicit FusedBatchNorm(const NodeDef* fused_batch_norm) + : fused_batch_norm(fused_batch_norm) {} + const NodeDef* fused_batch_norm = nullptr; }; // Conv2D node followed by a BiasAdd. struct Conv2DWithBiasAdd { + Conv2DWithBiasAdd() = default; + Conv2DWithBiasAdd(const NodeDef* conv2d, const NodeDef* bias_add) + : conv2d(conv2d), bias_add(bias_add) {} + const NodeDef* conv2d = nullptr; const NodeDef* bias_add = nullptr; }; // Conv2D node followed by a BiasAdd and Relu. struct Conv2DWithBiasAddAndRelu { + Conv2DWithBiasAddAndRelu() = default; + Conv2DWithBiasAddAndRelu(const NodeDef* conv2d, const NodeDef* bias_add, + const NodeDef* relu) + : conv2d(conv2d), bias_add(bias_add), relu(relu) {} + const NodeDef* conv2d = nullptr; const NodeDef* bias_add = nullptr; const NodeDef* relu = nullptr; @@ -78,6 +92,11 @@ struct Conv2DWithBiasAddAndRelu { // Conv2D node followed by a Squeeze and BiasAdd. struct Conv2DWithSqueezeAndBiasAdd { + Conv2DWithSqueezeAndBiasAdd() = default; + Conv2DWithSqueezeAndBiasAdd(const NodeDef* conv2d, const NodeDef* squeeze, + const NodeDef* bias_add) + : conv2d(conv2d), squeeze(squeeze), bias_add(bias_add) {} + const NodeDef* conv2d = nullptr; const NodeDef* squeeze = nullptr; const NodeDef* bias_add = nullptr; @@ -85,6 +104,11 @@ struct Conv2DWithSqueezeAndBiasAdd { // Conv2D node followed by a FusedBatchNorm. struct Conv2DWithBatchNorm { + Conv2DWithBatchNorm() = default; + Conv2DWithBatchNorm(const NodeDef* conv2d, const NodeDef* fused_batch_norm, + float epsilon = 0.0) + : conv2d(conv2d), fused_batch_norm(fused_batch_norm), epsilon(epsilon) {} + const NodeDef* conv2d = nullptr; const NodeDef* fused_batch_norm = nullptr; float epsilon = 0.0; @@ -92,16 +116,23 @@ struct Conv2DWithBatchNorm { // Conv2D node followed by a FusedBatchNorm and Relu. struct Conv2DWithBatchNormAndRelu { + Conv2DWithBatchNormAndRelu() = default; + Conv2DWithBatchNormAndRelu(const NodeDef* conv2d, + const NodeDef* fused_batch_norm, + const NodeDef* relu, float epsilon = 0.0) + : conv2d(conv2d), + fused_batch_norm(fused_batch_norm), + relu(relu), + epsilon(epsilon) {} + const NodeDef* conv2d = nullptr; const NodeDef* fused_batch_norm = nullptr; const NodeDef* relu = nullptr; float epsilon = 0.0; }; -bool IsFloatOrDoubleDataType(const NodeDef* node, - const string& type_attr = "T") { - DataType dtype = GetDataTypeFromAttr(*node, type_attr); - return dtype == DT_FLOAT || dtype == DT_DOUBLE; +bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) { + return ctx.nodes_to_preserve.count(node->name()) > 0; } bool HaveSameDataType(const NodeDef* lhs, const NodeDef* rhs, @@ -119,91 +150,165 @@ bool HasDataType(const NodeDef* node, const DataType& expected, return dtype == expected; } -bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) { - return ctx.nodes_to_preserve.count(node->name()) > 0; +bool IsCpuCompatibleDataType(const NodeDef* node, + const string& type_attr = "T") { + DataType dtype = GetDataTypeFromAttr(*node, type_attr); + return dtype == DT_FLOAT || dtype == DT_DOUBLE; } -bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* node, - Conv2DWithBiasAdd* matched) { +bool IsGpuCompatibleDataType(const NodeDef* node, + const string& type_attr = "T") { + DataType dtype = GetDataTypeFromAttr(*node, type_attr); + return dtype == DT_FLOAT; +} + +bool IsCpuCompatibleDataFormat(const NodeDef* conv2d) { + DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op"; + const string& data_format = conv2d->attr().at(kDataFormat).s(); + return data_format == "NHWC"; +} + +bool IsGpuCompatibleDataFormat(const NodeDef* conv2d) { + DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op"; + const string& data_format = conv2d->attr().at(kDataFormat).s(); + return data_format == "NHWC" || data_format == "NCHW"; +} + +bool IsCpuCompatibleConv2D(const NodeDef* conv2d) { + DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op"; + return NodeIsOnCpu(conv2d) && IsCpuCompatibleDataType(conv2d) && + IsCpuCompatibleDataFormat(conv2d); +} + +bool IsGpuCompatibleConv2D(const NodeDef* conv2d) { + DCHECK(IsConv2D(*conv2d)) << "Expected Conv2D op"; + return NodeIsOnGpu(conv2d) && IsGpuCompatibleDataType(conv2d) && + IsGpuCompatibleDataFormat(conv2d); +} + +// Checks if we can rewrite a pattern to the `_FusedConv2D` on CPU device. +template +bool IsCpuCompatible(const Pattern& matched) { + return IsCpuCompatibleConv2D(matched.conv2d); +} + +// Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device. +bool IsGpuCompatible(const RemapperContext& ctx, + const Conv2DWithBiasAddAndRelu& matched) { + const std::vector& input_props = + ctx.graph_properties.GetInputProperties(matched.conv2d->name()); + const TensorShapeProto& filter_shape = + input_props.size() >= 2 ? input_props[1].shape() : TensorShapeProto(); + + // FusedConv2D on GPU with 1x1 convolution is marginally faster than + // in-graph computation in micro benchmarks (see kernels/conv_ops_test.cc), + // and significantly slower in large scale benchmarks. + bool is_spatial_conv = Rank(filter_shape) == 4 && // + IsKnown(filter_shape.dim(1)) && // + IsKnown(filter_shape.dim(2)) && // + filter_shape.dim(1).size() != 1 && // + filter_shape.dim(2).size() != 1; + + return is_spatial_conv && IsGpuCompatibleConv2D(matched.conv2d); +} +bool IsGpuCompatible(const RemapperContext& ctx, + const Conv2DWithBiasAdd& matched) { + return false; +} +bool IsGpuCompatible(const RemapperContext& ctx, + const Conv2DWithSqueezeAndBiasAdd& matched) { + return false; +} + +// Returns true if the given pattern is supported on the assigned device. +template +bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) { + return IsCpuCompatible(matched) || IsGpuCompatible(ctx, matched); +} + +bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add, + Conv2DWithBiasAdd* matched, + bool check_device_compatible = true) { if (!EigenSupportsContractionOutputKernel()) return false; // Root of the pattern must be a BiasAdd. - if (!node) return false; - if (!IsBiasAdd(*node)) return false; - if (!NodeIsOnCpu(node)) return false; - if (!IsFloatOrDoubleDataType(node)) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false; + if (bias_add == nullptr || !IsBiasAdd(*bias_add) || + HasControlFaninOrFanout(ctx.graph_view, bias_add)) + return false; - // Input to the BiasAdd must be a Conv2D in NHWC format. - const auto input_port = GraphView::InputPort(node, 0); + // Input to the BiasAdd must be a Conv2D. + const auto input_port = GraphView::InputPort(bias_add, 0); const auto conv2d = ctx.graph_view.GetRegularFanin(input_port); - if (!conv2d.node) return false; - if (!IsConv2D(*conv2d.node)) return false; - if (conv2d.node->attr().at(kDataFormat).s() != "NHWC") return false; - if (!NodeIsOnCpu(conv2d.node)) return false; - if (!HaveSameDataType(node, conv2d.node)) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false; - if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false; - if (IsInPreserveSet(ctx, conv2d.node)) return false; + + if (!conv2d.node || !IsConv2D(*conv2d.node) || + !HaveSameDataType(bias_add, conv2d.node) || + HasControlFaninOrFanout(ctx.graph_view, conv2d.node) || + !HasSingleFanoutNode(ctx.graph_view, conv2d.node) || + IsInPreserveSet(ctx, conv2d.node)) + return false; + + // Check that data type and data format are supported on assigned device. + const Conv2DWithBiasAdd pattern{conv2d.node, bias_add}; + if (check_device_compatible && !IsDeviceCompatible(ctx, pattern)) { + return false; + } // We successfully found a Conv2D+BiasAdd pattern. - matched->conv2d = conv2d.node; - matched->bias_add = node; + *matched = pattern; return true; } -bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* node, +bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* relu, Conv2DWithBiasAddAndRelu* matched) { if (!EigenSupportsContractionOutputKernel()) return false; // Root of the pattern must be a Relu. - if (!node) return false; - if (!IsRelu(*node)) return false; - if (!NodeIsOnCpu(node)) return false; - if (!IsFloatOrDoubleDataType(node)) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false; + if (!relu || !IsRelu(*relu) || HasControlFaninOrFanout(ctx.graph_view, relu)) + return false; // And input to Relu must match Conv2DWithBiasAdd pattern. - const auto input_port = GraphView::InputPort(node, 0); + const auto input_port = GraphView::InputPort(relu, 0); const auto bias_add = ctx.graph_view.GetRegularFanin(input_port); Conv2DWithBiasAdd base; - if (!FindConv2DWithBias(ctx, bias_add.node, &base)) return false; - if (!HasSingleFanoutNode(ctx.graph_view, base.bias_add)) return false; - if (!HaveSameDataType(node, base.bias_add)) return false; - if (IsInPreserveSet(ctx, base.bias_add)) return false; + if (!FindConv2DWithBias(ctx, bias_add.node, &base, + /*check_device_compatible=*/false) || + !HasSingleFanoutNode(ctx.graph_view, base.bias_add) || + !HaveSameDataType(relu, base.bias_add) || + IsInPreserveSet(ctx, base.bias_add)) + return false; + + // Check that data type and data format are supported on assigned device. + const Conv2DWithBiasAddAndRelu pattern{base.conv2d, base.bias_add, relu}; + if (!IsDeviceCompatible(ctx, pattern)) return false; // We successfully found a Conv2D+BiasAdd+Relu pattern. - matched->conv2d = base.conv2d; - matched->bias_add = base.bias_add; - matched->relu = node; + *matched = pattern; return true; } bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx, - const NodeDef* node, + const NodeDef* bias_add, Conv2DWithSqueezeAndBiasAdd* matched) { if (!EigenSupportsContractionOutputKernel()) return false; // Root of the pattern must be a BiasAdd. - if (node == nullptr) return false; - if (node->op() != "BiasAdd") return false; - if (!NodeIsOnCpu(node)) return false; - if (!IsFloatOrDoubleDataType(node)) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false; + if (!bias_add || !IsBiasAdd(*bias_add) || + HasControlFaninOrFanout(ctx.graph_view, bias_add)) + return false; // Input to the BiasAdd must be a Squeeze. - const auto bias_input_port = GraphView::InputPort(node, 0); + const auto bias_input_port = GraphView::InputPort(bias_add, 0); const auto squeeze = ctx.graph_view.GetRegularFanin(bias_input_port); - if (squeeze.node == nullptr) return false; - if (squeeze.node->op() != "Squeeze") return false; - if (!NodeIsOnCpu(squeeze.node)) return false; - if (!HaveSameDataType(node, squeeze.node, "T")) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, squeeze.node)) return false; - if (!HasSingleFanoutNode(ctx.graph_view, squeeze.node)) return false; - if (IsInPreserveSet(ctx, squeeze.node)) return false; + + if (!squeeze.node || !IsSqueeze(*squeeze.node) || + !HaveSameDataType(bias_add, squeeze.node, "T") || + HasControlFaninOrFanout(ctx.graph_view, squeeze.node) || + !HasSingleFanoutNode(ctx.graph_view, squeeze.node) || + IsInPreserveSet(ctx, squeeze.node)) + return false; // Squeeze must not squeeze output channel dimension. std::vector dims; @@ -212,67 +317,72 @@ bool FindConv2DWithSqueezeAndBias(const RemapperContext& ctx, if (dim == 3) return false; } - // Input to the Squeeze must be a Conv2D in NHWC format. + // Input to the Squeeze must be a Conv2D. const auto squeeze_input_port = GraphView::InputPort(squeeze.node, 0); const auto conv2d = ctx.graph_view.GetRegularFanin(squeeze_input_port); - if (conv2d.node == nullptr) return false; - if (conv2d.node->op() != "Conv2D") return false; - if (conv2d.node->attr().at("data_format").s() != "NHWC") return false; - if (!NodeIsOnCpu(conv2d.node)) return false; - if (!HaveSameDataType(node, conv2d.node, "T")) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false; - if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false; - if (IsInPreserveSet(ctx, conv2d.node)) return false; + + if (!conv2d.node || !IsConv2D(*conv2d.node) || + !HaveSameDataType(bias_add, conv2d.node, "T") || + HasControlFaninOrFanout(ctx.graph_view, conv2d.node) || + !HasSingleFanoutNode(ctx.graph_view, conv2d.node) || + IsInPreserveSet(ctx, conv2d.node)) + return false; + + // Check that data type and data format are supported on assigned device. + const Conv2DWithSqueezeAndBiasAdd pattern{conv2d.node, squeeze.node, + bias_add}; + if (!IsDeviceCompatible(ctx, pattern)) return false; // We successfully found a Conv2D+Squeeze+BiasAdd pattern. - matched->conv2d = conv2d.node; - matched->squeeze = squeeze.node; - matched->bias_add = node; + *matched = pattern; return true; } -bool FindConv2DWithBatchNorm(const RemapperContext& ctx, const NodeDef* node, +bool FindConv2DWithBatchNorm(const RemapperContext& ctx, + const NodeDef* batch_norm, Conv2DWithBatchNorm* matched) { if (!EigenSupportsContractionOutputKernel()) return false; // Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2. - if (node == nullptr) return false; - if (!IsFusedBatchNorm(*node)) return false; - if (!NodeIsOnCpu(node)) return false; - if (!HasDataType(node, DT_FLOAT)) return false; + if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false; // V2 has a separate data type for the scale/offset/mean/variance inputs. - if (node->op() == "FusedBatchNormV2" && !HasDataType(node, DT_FLOAT, "U")) + if (batch_norm->op() == "FusedBatchNormV2" && + !HasDataType(batch_norm, DT_FLOAT, "U")) return false; // Check that batch normalization is in inference mode. - const auto& attr = node->attr(); + const auto& attr = batch_norm->attr(); if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false; // Check that only 0th output is consumed by other nodes. - if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false; - if (HasFanouts(ctx.graph_view, node, 1)) return false; // batch_mean - if (HasFanouts(ctx.graph_view, node, 2)) return false; // batch_variance - if (HasFanouts(ctx.graph_view, node, 3)) return false; // reserve_space_1 - if (HasFanouts(ctx.graph_view, node, 4)) return false; // reserve_space_2 + if (HasControlFaninOrFanout(ctx.graph_view, batch_norm) || + HasFanouts(ctx.graph_view, batch_norm, 1) || // batch_mean + HasFanouts(ctx.graph_view, batch_norm, 2) || // batch_variance + HasFanouts(ctx.graph_view, batch_norm, 3) || // reserve_space_1 + HasFanouts(ctx.graph_view, batch_norm, 4)) // reserve_space_2 + return false; - // Input to the FusedBatchNorm must be a Conv2D in NHWC format. - const auto input_port = GraphView::InputPort(node, 0); + // Input to the FusedBatchNorm must be a Conv2D. + const auto input_port = GraphView::InputPort(batch_norm, 0); const auto conv2d = ctx.graph_view.GetRegularFanin(input_port); - if (conv2d.node == nullptr) return false; - if (!IsConv2D(*conv2d.node)) return false; - if (conv2d.node->attr().at(kDataFormat).s() != "NHWC") return false; - if (!NodeIsOnCpu(conv2d.node)) return false; - if (!HaveSameDataType(node, conv2d.node)) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, conv2d.node)) return false; - if (!HasSingleFanoutNode(ctx.graph_view, conv2d.node)) return false; - if (IsInPreserveSet(ctx, conv2d.node)) return false; + + if (!conv2d.node || !IsConv2D(*conv2d.node) || // + !NodeIsOnCpu(conv2d.node) || // + !HaveSameDataType(batch_norm, conv2d.node) || // + !IsCpuCompatibleDataType(conv2d.node) || // + !IsCpuCompatibleDataFormat(conv2d.node) || // + HasControlFaninOrFanout(ctx.graph_view, conv2d.node) || // + !HasSingleFanoutNode(ctx.graph_view, conv2d.node) || // + IsInPreserveSet(ctx, conv2d.node)) + return false; // We successfully found a Conv2D+FusedBatchNorm pattern. matched->conv2d = conv2d.node; - matched->fused_batch_norm = node; - if (!GetNodeAttr(*node, "epsilon", &matched->epsilon).ok()) return false; + matched->fused_batch_norm = batch_norm; + if (!GetNodeAttr(*batch_norm, "epsilon", &matched->epsilon).ok()) + return false; return true; } @@ -283,21 +393,19 @@ bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx, if (!EigenSupportsContractionOutputKernel()) return false; // Root of the pattern must be a Relu. - if (node == nullptr) return false; - if (!IsRelu(*node)) return false; - if (!NodeIsOnCpu(node)) return false; - if (!IsFloatOrDoubleDataType(node)) return false; - if (!NoControlFaninOrFanout(ctx.graph_view, node)) return false; + if (!node || !IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node)) + return false; // And input to Relu must match Conv2DWithBatchNorm pattern. const auto input_port = GraphView::InputPort(node, 0); const auto batch_norm = ctx.graph_view.GetRegularFanin(input_port); Conv2DWithBatchNorm base; - if (!FindConv2DWithBatchNorm(ctx, batch_norm.node, &base)) return false; - if (!HasSingleFanoutNode(ctx.graph_view, base.fused_batch_norm)) return false; - if (!HaveSameDataType(node, base.fused_batch_norm)) return false; - if (IsInPreserveSet(ctx, base.fused_batch_norm)) return false; + if (!FindConv2DWithBatchNorm(ctx, batch_norm.node, &base) || + !HasSingleFanoutNode(ctx.graph_view, base.fused_batch_norm) || + !HaveSameDataType(node, base.fused_batch_norm) || + IsInPreserveSet(ctx, base.fused_batch_norm)) + return false; // We successfully found a Conv2D+FusedBatchNorm+Relu pattern. matched->conv2d = base.conv2d; @@ -355,9 +463,7 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node, return true; } -void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d, - const std::vector& fused_ops = {}, - int num_args = 1, float epsilon = 0.0) { +void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) { auto* attr = fused_conv2d->mutable_attr(); auto src_attr = conv2d->attr(); @@ -367,53 +473,65 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d, (*attr)["dilations"] = src_attr.at("dilations"); (*attr)["data_format"] = src_attr.at("data_format"); (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); +} - auto* fused_ops_attr = (*attr)["fused_ops"].mutable_list(); - for (const string& fused_op : fused_ops) { - fused_ops_attr->add_s(fused_op); - } - +void SetFusedConv2DAttributes( + NodeDef* fused_conv2d, const absl::Span fused_ops, + int num_args = 1, float epsilon = 0.0) { + auto* attr = fused_conv2d->mutable_attr(); + SetAttrValue(fused_ops, &(*attr)["fused_ops"]); SetAttrValue(num_args, &(*attr)["num_args"]); - // Required only for FusedBatchNorm. - SetAttrValue(epsilon, &(*attr)["epsilon"]); + SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm } void AddFusedConv2DNode( - const Conv2DWithBiasAdd& matched, GraphDef* optimized_graph, + const RemapperContext& ctx, const Conv2DWithBiasAdd& matched, + GraphDef* optimized_graph, absl::flat_hash_set* invalidated_nodes) { - VLOG(2) << "Fuse Conv2D with BiasAdd: bias_add=" << matched.bias_add->name() + DCHECK(IsDeviceCompatible(ctx, matched)) + << "Unsupported fused Conv2D pattern"; + + VLOG(2) << "Fuse Conv2D with BiasAdd: " + << " bias_add=" << matched.bias_add->name() << " conv2d=" << matched.conv2d->name(); NodeDef* fused_conv2d = optimized_graph->add_node(); - fused_conv2d->set_name(matched.bias_add->name()); fused_conv2d->set_op(kFusedConv2D); - fused_conv2d->set_device(matched.bias_add->device()); + fused_conv2d->set_name(matched.bias_add->name()); + fused_conv2d->set_device(matched.conv2d->device()); fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias - CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd"}); + CopyConv2DAttributes(matched.conv2d, fused_conv2d); + SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd"}); invalidated_nodes->insert(matched.bias_add); invalidated_nodes->insert(matched.conv2d); } void AddFusedConv2DNode( - const Conv2DWithBiasAddAndRelu& matched, GraphDef* optimized_graph, + const RemapperContext& ctx, const Conv2DWithBiasAddAndRelu& matched, + GraphDef* optimized_graph, absl::flat_hash_set* invalidated_nodes) { - VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: relu=" << matched.relu->name() + DCHECK(IsDeviceCompatible(ctx, matched)) + << "Unsupported fused Conv2D pattern"; + + VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: " + << " relu=" << matched.relu->name() << " bias_add=" << matched.bias_add->name() << " conv2d=" << matched.conv2d->name(); NodeDef* fused_conv2d = optimized_graph->add_node(); fused_conv2d->set_name(matched.relu->name()); fused_conv2d->set_op(kFusedConv2D); - fused_conv2d->set_device(matched.relu->device()); + fused_conv2d->set_device(matched.conv2d->device()); fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias - CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd", "Relu"}); + CopyConv2DAttributes(matched.conv2d, fused_conv2d); + SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", "Relu"}); invalidated_nodes->insert(matched.relu); invalidated_nodes->insert(matched.bias_add); @@ -421,8 +539,12 @@ void AddFusedConv2DNode( } void AddFusedConv2DNode( - const Conv2DWithSqueezeAndBiasAdd& matched, GraphDef* optimized_graph, + const RemapperContext& ctx, const Conv2DWithSqueezeAndBiasAdd& matched, + GraphDef* optimized_graph, absl::flat_hash_set* invalidated_nodes) { + DCHECK(IsDeviceCompatible(ctx, matched)) + << "Unsupported fused Conv2D pattern"; + VLOG(2) << "Fuse Conv2D with Squeeze and BiasAdd: " << " bias_add=" << matched.bias_add->name() << " squeeze=" << matched.squeeze->name() @@ -432,13 +554,14 @@ void AddFusedConv2DNode( // has single consumer (only the squeeze node). NodeDef* fused_conv2d = optimized_graph->add_node(); fused_conv2d->set_name(matched.conv2d->name()); - fused_conv2d->set_op("_FusedConv2D"); + fused_conv2d->set_op(kFusedConv2D); fused_conv2d->set_device(matched.conv2d->device()); fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias - CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"BiasAdd"}); + CopyConv2DAttributes(matched.conv2d, fused_conv2d); + SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd"}); // Replace BiasAdd node with a Squeeze. NodeDef* remapped_squeeze = optimized_graph->add_node(); @@ -461,7 +584,7 @@ void AddFusedConv2DNode( NodeDef* fused_conv2d = optimized_graph->add_node(); fused_conv2d->set_name(matched.fused_batch_norm->name()); fused_conv2d->set_op(kFusedConv2D); - fused_conv2d->set_device(matched.fused_batch_norm->device()); + fused_conv2d->set_device(matched.conv2d->device()); fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter fused_conv2d->add_input(matched.fused_batch_norm->input(1)); // 2: scale @@ -469,8 +592,9 @@ void AddFusedConv2DNode( fused_conv2d->add_input(matched.fused_batch_norm->input(3)); // 4: mean fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance - CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"FusedBatchNorm"}, - /*num_args*/ 4, /*epsilon*/ matched.epsilon); + CopyConv2DAttributes(matched.conv2d, fused_conv2d); + SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm"}, + /*num_args=*/4, /*epsilon=*/matched.epsilon); invalidated_nodes->insert(matched.fused_batch_norm); invalidated_nodes->insert(matched.conv2d); @@ -487,7 +611,7 @@ void AddFusedConv2DNode( NodeDef* fused_conv2d = optimized_graph->add_node(); fused_conv2d->set_name(matched.relu->name()); fused_conv2d->set_op(kFusedConv2D); - fused_conv2d->set_device(matched.fused_batch_norm->device()); + fused_conv2d->set_device(matched.conv2d->device()); fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input fused_conv2d->add_input(matched.conv2d->input(1)); // 1: filter fused_conv2d->add_input(matched.fused_batch_norm->input(1)); // 2: scale @@ -495,8 +619,9 @@ void AddFusedConv2DNode( fused_conv2d->add_input(matched.fused_batch_norm->input(3)); // 4: mean fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance - CopyConv2DAttributes(matched.conv2d, fused_conv2d, {"FusedBatchNorm", "Relu"}, - /*num_args*/ 4, /*epsilon*/ matched.epsilon); + CopyConv2DAttributes(matched.conv2d, fused_conv2d); + SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm", "Relu"}, + /*num_args=*/4, /*epsilon=*/matched.epsilon); invalidated_nodes->insert(matched.relu); invalidated_nodes->insert(matched.fused_batch_norm); @@ -680,13 +805,14 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, // Remap Conv2D+BiasAdd into the _FusedConv2D. if (FindConv2DWithBias(ctx, &node, &conv2d_with_bias)) { - AddFusedConv2DNode(conv2d_with_bias, optimized_graph, &invalidated_nodes); + AddFusedConv2DNode(ctx, conv2d_with_bias, optimized_graph, + &invalidated_nodes); continue; } // Remap Conv2D+BiasAdd+Relu into the _FusedConv2D. if (FindConv2DWithBiasAndRelu(ctx, &node, &conv2d_with_bias_and_relu)) { - AddFusedConv2DNode(conv2d_with_bias_and_relu, optimized_graph, + AddFusedConv2DNode(ctx, conv2d_with_bias_and_relu, optimized_graph, &invalidated_nodes); continue; } @@ -694,7 +820,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, // Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze. if (FindConv2DWithSqueezeAndBias(ctx, &node, &conv2d_with_squeeze_and_bias)) { - AddFusedConv2DNode(conv2d_with_squeeze_and_bias, optimized_graph, + AddFusedConv2DNode(ctx, conv2d_with_squeeze_and_bias, optimized_graph, &invalidated_nodes); continue; } diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index cc9a38b2d24..375c3e56c80 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" @@ -166,10 +168,10 @@ string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter) { if (!name.empty()) { if (name[0] == '^') { - return strings::StrCat("^", prefix, delimiter, name.substr(1)); + return absl::StrCat("^", prefix, delimiter, name.substr(1)); } } - return strings::StrCat(prefix, delimiter, name); + return absl::StrCat(prefix, delimiter, name); } string AddPrefixToNodeName(const string& name, const string& prefix) { @@ -193,20 +195,26 @@ bool ExecuteWithTimeout(std::function fn, const int64 timeout_in_ms, } string AsControlDependency(const NodeDef& node) { - return strings::StrCat("^", node.name()); + return absl::StrCat("^", node.name()); } string AsControlDependency(const string& node_name) { CHECK(!node_name.empty()); return (!node_name.empty() && node_name[0] == '^') ? node_name - : strings::StrCat("^", node_name); + : absl::StrCat("^", node_name); } bool NodeIsOnCpu(const NodeDef* node) { string task, device; return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && - str_util::StartsWith(device, DEVICE_CPU); + absl::StartsWith(device, DEVICE_CPU); +} + +bool NodeIsOnGpu(const NodeDef* node) { + string task, device; + return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + absl::StartsWith(device, DEVICE_GPU); } int NumOutputs(const NodeDef& node, GraphDef* graph) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 1e820977aea..9053ae4c07d 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -242,6 +242,9 @@ string AsControlDependency(const string& node); // Returns true if the node is assigned to run on CPU device. bool NodeIsOnCpu(const NodeDef* node); +// Returns true if the node is assigned to run on GPU device. +bool NodeIsOnGpu(const NodeDef* node); + // Returns the number of outputs of a node according to its OpDef. Note that // some of the outputs may be unconnected. int NumOutputs(const NodeDef& node, GraphDef* graph); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 89417f85c23..ff15541cdc9 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -143,6 +143,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 576494cad55..5ef1cf444be 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/grappler/utils/grappler_test.h" #include +#include "absl/algorithm/container.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" @@ -125,6 +127,31 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const { } } +void GrapplerTest::CompareNodes(const NodeDef& want, const NodeDef& got) const { + EXPECT_EQ(want.name(), got.name()); + EXPECT_EQ(want.op(), got.op()); + + std::vector want_inputs(want.input().begin(), want.input().end()); + std::vector got_inputs(got.input().begin(), got.input().end()); + EXPECT_EQ(want_inputs, got_inputs); + + const auto attr_name = [](const std::pair& attr) { + return attr.first; + }; + + std::vector want_attrs; + std::vector got_attrs; + absl::c_transform(want.attr(), std::back_inserter(want_attrs), attr_name); + absl::c_transform(got.attr(), std::back_inserter(got_attrs), attr_name); + absl::c_sort(want_attrs); + absl::c_sort(got_attrs); + EXPECT_EQ(want_attrs, got_attrs); + + for (const string& attr : want_attrs) { + EXPECT_TRUE(AreAttrValuesEqual(want.attr().at(attr), got.attr().at(attr))); + } +} + bool GrapplerTest::IsNodesDirectlyConnected(const NodeMap& node_map, const string& src, const string& dst, int position) { diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 0cfd740dcbe..20fd04c1c61 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -49,13 +49,25 @@ class GrapplerTest : public ::testing::Test { const std::vector>& attributes, GraphDef* graph) const; + // Checks if two graphs are equal. Both graphs must have the same set of nodes + // with the same inputs and attributes. Nodes can be in different order. + // + // NOTE: This function uses EXPECT/ASSERT macros to check node properties + // equality, and adds all failuires to the current test. void CompareGraphs(GraphDef want, GraphDef got) const; - // Check if node 'src' is directly connected to the input($position) of 'dst'. + // Checks if two nodes have the same name, op, inputs and attributes. + // + // NOTE: This function uses EXPECT/ASSERT macros to check node properties + // equality, and adds all failuires to the current test. + void CompareNodes(const NodeDef& want, const NodeDef& got) const; + + // Checks if node 'src' is directly connected to the input($position) of + // 'dst'. bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src, const string& dst, int position = 0); - // Count nodes of the given op-type in a graph. + // Counts nodes of the given op-type in a graph. int CountOpNodes(const GraphDef& graph, const string& op); // Get a random tensor with given shape. diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index 3e59219f8fc..fc93915e165 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -1497,6 +1497,26 @@ BM_FusedConv2DWithBatchNormAndRelu(32, 32, 32, 128, 3, 3, 1024, cpu, "3x3 /b 32"); #if GOOGLE_CUDA +// -------------------------------------------------------------------------- // +// 1x1 Convolution +// -------------------------------------------------------------------------- // + +BM_Conv2D(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8"); +BM_Conv2D(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16"); +BM_Conv2D(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32"); + +BM_Conv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8"); +BM_Conv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16"); +BM_Conv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32"); + +BM_FusedConv2DWithBiasAndRelu(8, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 8"); +BM_FusedConv2DWithBiasAndRelu(16, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 16"); +BM_FusedConv2DWithBiasAndRelu(32, 32, 32, 128, 1, 1, 1024, gpu, "1x1 /b 32"); + +// -------------------------------------------------------------------------- // +// 3x3 Convolution +// -------------------------------------------------------------------------- // + BM_Conv2D(8, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 8"); BM_Conv2D(16, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 16"); BM_Conv2D(32, 32, 32, 128, 3, 3, 1024, gpu, "3x3 /b 32");