From 92144e5bd632d6c7f20b905792451da6dd668e0e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 7 Jun 2019 13:33:32 -0700 Subject: [PATCH] [Grappler] Fuse FusedBatchNorm + + . Resnet50 in NHWC: ~500 img/sec -> ~800 img/sec* (*) with disabled layout optimizer. PiperOrigin-RevId: 252108889 --- .../compiler/tf2xla/kernels/batch_norm_op.cc | 74 +++- tensorflow/core/grappler/optimizers/BUILD | 1 + .../core/grappler/optimizers/remapper.cc | 321 ++++++++++++++-- .../core/grappler/optimizers/remapper.h | 5 +- .../core/grappler/optimizers/remapper_test.cc | 182 +++++++++ tensorflow/core/kernels/BUILD | 2 + .../kernels/fused_batch_norm_ex_op_test.cc | 358 +++++++++++++----- .../core/kernels/fused_batch_norm_op.cc | 4 +- 8 files changed, 813 insertions(+), 134 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 013a5734863..8ce9a089fcc 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ // XLA implementation of BatchNorm operations. +#include "tensorflow/compiler/tf2xla/kernels/relu_op.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -28,7 +29,11 @@ namespace { class FusedBatchNormOp : public XlaOpKernel { public: - explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + explicit FusedBatchNormOp(OpKernelConstruction* ctx) + : FusedBatchNormOp(ctx, false) {} + + FusedBatchNormOp(OpKernelConstruction* ctx, bool is_batch_norm_ex) + : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); string data_format_str; @@ -36,6 +41,26 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); + + if (is_batch_norm_ex) { + int num_side_inputs; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_side_inputs", &num_side_inputs)); + OP_REQUIRES(ctx, num_side_inputs >= 0 && num_side_inputs <= 1, + errors::InvalidArgument( + "FusedBatchNormEx supports at most 1 side input.")); + add_side_input_ = (num_side_inputs == 1); + string activation_mode; + OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode)); + OP_REQUIRES(ctx, + activation_mode == "Identity" || activation_mode == "Relu", + errors::InvalidArgument( + "Unsupported FusedBatchNormEx activation mode: ", + activation_mode)); + apply_relu_ = (activation_mode == "Relu"); + } else { + add_side_input_ = false; + apply_relu_ = false; + } is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; } @@ -66,9 +91,18 @@ class FusedBatchNormOp : public XlaOpKernel { input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the - // calculated mean and variance. - ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0), - input_type)); + // calculated mean and variance. Optionally we add side input and apply + // relu activation. + xla::XlaOp converted = + xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type); + if (add_side_input_ && apply_relu_) { + ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted))); + } else if (apply_relu_) { + ctx->SetOutput(0, xla::Relu(converted)); + } else { + ctx->SetOutput(0, converted); + } + ctx->SetOutput(1, xla::GetTupleElement(output, 1)); xla::XlaOp variance = xla::GetTupleElement(output, 2); // Apply Bessel's correction. @@ -103,7 +137,16 @@ class FusedBatchNormOp : public XlaOpKernel { xla::XlaOp output = xla::BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); - ctx->SetOutput(0, xla::ConvertElementType(output, input_type)); + + xla::XlaOp converted = xla::ConvertElementType(output, input_type); + if (add_side_input_ && apply_relu_) { + ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted))); + } else if (apply_relu_) { + ctx->SetOutput(0, xla::Relu(converted)); + } else { + ctx->SetOutput(0, converted); + } + // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -116,6 +159,8 @@ class FusedBatchNormOp : public XlaOpKernel { float epsilon_; TensorFormat data_format_; bool is_training_; + bool add_side_input_; + bool apply_relu_; bool is_on_gpu_; }; @@ -131,17 +176,26 @@ class FusedBatchNormOpV3 : public FusedBatchNormOp { } ctx->SetConstantOutput(5, Tensor()); } +}; - private: - float epsilon_; - TensorFormat data_format_; - bool is_training_; - bool is_on_gpu_; +class FusedBatchNormOpEx : public FusedBatchNormOp { + public: + explicit FusedBatchNormOpEx(OpKernelConstruction* ctx) + : FusedBatchNormOp(ctx, /*is_batch_norm_ex=*/true) {} + + void Compile(XlaOpKernelContext* ctx) override { + FusedBatchNormOp::CompileImpl(ctx); + if (!ctx->status().ok()) { + return; + } + ctx->SetConstantOutput(5, Tensor()); + } }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3); +REGISTER_XLA_OP(Name("_FusedBatchNormEx"), FusedBatchNormOpEx); class FusedBatchNormGradOp : public XlaOpKernel { public: diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 7b411ce9fc0..bf46ec010fe 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -781,6 +781,7 @@ tf_kernel_library( ":constant_folding", ":graph_optimizer", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 781e4cab1d6..f15866e08ac 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -26,6 +26,11 @@ limitations under the License. #include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/env_var.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cudnn/cudnn.h" +#endif // GOOGLE_CUDA namespace tensorflow { namespace grappler { @@ -40,12 +45,17 @@ namespace grappler { // MatMul + ... -> _FusedMatMul: // (1) MatMul + BiasAdd + // +// FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training] +// (1) FusedBatchNorm + +// (2) FusedBatchNorm + SideInput + +// // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the // patterns are "ContractionWith...". namespace { constexpr char kFusedConv2D[] = "_FusedConv2D"; constexpr char kFusedMatMul[] = "_FusedMatMul"; +constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx"; constexpr char kDataFormat[] = "data_format"; constexpr char kIsTraining[] = "is_training"; @@ -81,6 +91,17 @@ struct FusedBatchNorm { const NodeDef* fused_batch_norm = nullptr; }; +// FusedBatchNorm[$is_training] with fused side input and/or activation. +struct FusedBatchNormEx { + FusedBatchNormEx() = default; + + const NodeDef* fused_batch_norm = nullptr; + const NodeDef* side_input = nullptr; + const NodeDef* activation = nullptr; + // Add node that will be invalidated by fusing side input and fused batch norm + const NodeDef* invalidated = nullptr; +}; + // Contraction node followed by a BiasAdd. struct ContractionWithBiasAdd { ContractionWithBiasAdd() = default; @@ -445,12 +466,11 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx, ContractionWithBatchNorm* matched) { if (!EigenSupportsContractionOutputKernel()) return false; - // Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2. + // Root of the pattern must be a FusedBatchNorm. if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false; - // V2 has a separate data type for the scale/offset/mean/variance inputs. - if ((batch_norm->op() == "FusedBatchNormV2" || - batch_norm->op() == "FusedBatchNormV3") && + // FusedBatchNormV2 and V3 have an extra type parameter. + if (batch_norm->op() != "FusedBatchNorm" && !HasDataType(batch_norm, DT_FLOAT, "U")) return false; @@ -602,23 +622,15 @@ bool FindContractionWithBiasAndAddActivation( } #endif -// Check that given node meets some basic FusedBatchNorm optimization -// preconditions. We use this check to lazily infer graph properties which is -// rather expensive. -bool IsFusedBatchNormCandidate(const NodeDef& node) { - if (!IsFusedBatchNorm(node)) return false; - if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false; - - // Check that the node is in inference mode. - const auto& attr = node.attr(); - if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false; - - return true; -} - bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node, FusedBatchNorm* matched) { - if (!IsFusedBatchNormCandidate(*node)) return false; + if (!IsFusedBatchNorm(*node)) return false; + if (GetDataTypeFromAttr(*node, "T") != DT_FLOAT) return false; + + // Check that the node is in inference mode. + bool is_training = true; + if (!GetNodeAttr(*node, kIsTraining, &is_training).ok()) return false; + if (is_training) return false; const auto& props = ctx.graph_properties.GetInputProperties(node->name()); @@ -649,6 +661,139 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node, return true; } +// NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the +// `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details. +bool BatchnormSpatialPersistentEnabled() { +#if CUDNN_VERSION >= 7402 + static bool is_enabled = [] { + bool is_enabled = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar( + "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", + /*default_val=*/false, &is_enabled)); + return is_enabled; + }(); + return is_enabled; +#else + return false; +#endif +} + +bool FindFusedBatchNormEx(const RemapperContext& ctx, const NodeDef* node, + FusedBatchNormEx* matched) { + // Root of the pattern must be a Relu. + // TODO(ezhulenev): Forward control dependencies. + if (!IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node)) + return false; + + const NodeDef* relu = node; + + // Returns true iff the node is a compatible FusedBatchNorm node. + const auto valid_batch_norm = [&](const NodeDef* fused_batch_norm) -> bool { + if (fused_batch_norm == nullptr || !IsFusedBatchNorm(*fused_batch_norm)) + return false; + + AttrSlice attr(*fused_batch_norm); + + // We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with + // contraction (MatMul or Conv2D node). + if (!NodeIsOnGpu(fused_batch_norm)) return false; + + DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm, "T"); + if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false; + + // Get the FusedBatchNorm training mode. + bool is_training; + if (!GetNodeAttr(attr, kIsTraining, &is_training).ok()) return false; + // TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel. + if (!is_training) return false; + + // In training mode we rely on cuDNN for computing FusedBatchNorm with side + // inputs and activation, and it has its own limitations. In inference mode + // we have a custom CUDA kernel that doesn't not have these constraints. + if (is_training) { + // cuDNN only supports NHWC data layout. + string data_format; + if (!GetNodeAttr(attr, kDataFormat, &data_format).ok()) return false; + if (data_format != "NHWC") return false; + + // Data type must be DT_HALF. + if (t_dtype != DT_HALF) return false; + + // Channel dimension must be a multiple of 4. + const auto& props = + ctx.graph_properties.GetInputProperties(fused_batch_norm->name()); + + const bool valid_channel_dim = !props.empty() && + props[0].shape().dim_size() == 4 && + props[0].shape().dim(3).size() % 4 == 0; + if (!valid_channel_dim) return false; + + // cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode. + if (!BatchnormSpatialPersistentEnabled()) return false; + } + + // FusedBatchNormV2 and V3 have an extra type parameter. + if ((fused_batch_norm->op() != "FusedBatchNorm") && + !HasDataType(fused_batch_norm, DT_FLOAT, "U")) + return false; + + // Check that only one node consumes the output of a FusedBatchNorm. + if (HasControlFaninOrFanout(ctx.graph_view, fused_batch_norm) || + !HasSingleFanoutNode(ctx.graph_view, fused_batch_norm) || + IsInPreserveSet(ctx, fused_batch_norm)) + return false; + + return true; + }; + + const auto relu_input_port = GraphView::InputPort(relu, 0); + const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port); + if (!relu_fanin.node) return false; + + // Input to a Relu can be a FusedBatchNorm. + if (valid_batch_norm(relu_fanin.node)) { + matched->activation = relu; + matched->side_input = nullptr; + matched->fused_batch_norm = relu_fanin.node; + matched->invalidated = nullptr; + return true; + } + + // Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs + if (IsAdd(*relu_fanin.node)) { + const NodeDef* add = relu_fanin.node; + + // Check that only Relu node consumes the output of an Add node. + if (HasControlFaninOrFanout(ctx.graph_view, add) || + !HasSingleFanoutNode(ctx.graph_view, add) || IsInPreserveSet(ctx, add)) + return false; + + const auto add_input_port_0 = GraphView::InputPort(add, 0); + const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0); + + const auto add_input_port_1 = GraphView::InputPort(add, 1); + const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1); + + if (valid_batch_norm(add_fanin_0.node)) { + matched->activation = relu; + matched->side_input = add_fanin_1.node; + matched->fused_batch_norm = add_fanin_0.node; + matched->invalidated = add; + return true; + } + + if (valid_batch_norm(add_fanin_1.node)) { + matched->activation = relu; + matched->side_input = add_fanin_0.node; + matched->fused_batch_norm = add_fanin_1.node; + matched->invalidated = add; + return true; + } + } + + return false; +} + void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) { DCHECK(IsConv2D(*conv2d)) << "Input node must be a Conv2D"; @@ -664,6 +809,27 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) { (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); } +void CopyFusedBatchNormAttributes(const NodeDef* fused_batch_norm, + NodeDef* fused_batch_norm_ex) { + DCHECK(IsFusedBatchNorm(*fused_batch_norm)) + << "Input node must be a FusedBatchNorm"; + + auto* attr = fused_batch_norm_ex->mutable_attr(); + auto src_attr = fused_batch_norm->attr(); + + (*attr)["T"] = src_attr.at("T"); + (*attr)["is_training"] = src_attr.at("is_training"); + (*attr)["data_format"] = src_attr.at("data_format"); + (*attr)["epsilon"] = src_attr.at("epsilon"); + + // FusedBatchNormV2 and V3 have an extra type parameter. + if (fused_batch_norm->op() != "FusedBatchNorm") { + (*attr)["U"] = src_attr.at("U"); + } else { + (*attr)["U"] = src_attr.at("T"); + } +} + void CopyMatMulAttributes(const NodeDef* matmul, NodeDef* fused_matmul) { DCHECK(IsMatMul(*matmul)) << "Input node must be a MatMul"; @@ -902,6 +1068,55 @@ void AddFusedContractionNode( } #endif +void AddFusedBatchNormExNode( + const FusedBatchNormEx& matched, GraphDef* optimized_graph, + absl::flat_hash_set* invalidated_nodes) { + VLOG(2) << "Fuse " << matched.activation->op() << " with FusedBatchNorm:" + << " side_input=" + << (matched.side_input ? matched.side_input->name() : "") + << " activation=" << matched.activation->name() + << " fused_batch_norm=" << matched.fused_batch_norm->name(); + + // Replace FusedBatchNorm with _FusedBatchNormEx + + . + NodeDef* fused_op = optimized_graph->add_node(); + fused_op->set_op(kFusedBatchNormEx); + fused_op->set_name(matched.fused_batch_norm->name()); + fused_op->set_device(matched.fused_batch_norm->device()); + + fused_op->add_input(matched.fused_batch_norm->input(0)); // 0: input + fused_op->add_input(matched.fused_batch_norm->input(1)); // 1: scale + fused_op->add_input(matched.fused_batch_norm->input(2)); // 2: offset + fused_op->add_input(matched.fused_batch_norm->input(3)); // 3: estimated_mean + fused_op->add_input(matched.fused_batch_norm->input(4)); // 4: estimated_var + + CopyFusedBatchNormAttributes(matched.fused_batch_norm, fused_op); + + auto* attrs = fused_op->mutable_attr(); + SetAttrValue(matched.activation->op(), &(*attrs)["activation_mode"]); + + if (matched.side_input != nullptr) { + SetAttrValue(1, &(*attrs)["num_side_inputs"]); + fused_op->add_input(matched.side_input->name()); // 5: side_input + } else { + SetAttrValue(0, &(*attrs)["num_side_inputs"]); + } + + // Turn activation node into Identity node. + NodeDef* identity_op = optimized_graph->add_node(); + identity_op->set_op("Identity"); + identity_op->set_name(matched.activation->name()); + identity_op->set_device(matched.fused_batch_norm->device()); + identity_op->add_input(matched.fused_batch_norm->name()); + (*identity_op->mutable_attr())["T"] = attrs->at("T"); + + // Invalidate all nodes bypassed by this rewrite. + invalidated_nodes->insert(matched.activation); + invalidated_nodes->insert(matched.fused_batch_norm); + if (matched.side_input != nullptr) { + invalidated_nodes->insert(matched.invalidated); + } +} + void AddBatchNormNodes(const FusedBatchNorm& matched, GraphDef* optimized_graph) { const NodeDef& fused_node = *matched.fused_batch_norm; @@ -1044,6 +1259,55 @@ void AddBatchNormNodes(const FusedBatchNorm& matched, *r->add_input() = a->name(); *r->add_input() = c->name(); } + +// Check if a node is a candidate to one of the patterns that require inferred +// shapes: +// (1) Splitting FusedBatchNorm into primitives. +// (2) Fusing side input and/or activation into FusedBatchNorm. +bool RequiresInferredShapes(const RemapperContext& ctx, const NodeDef& node) { + // Candidate for a FusedBatchNorm splitting. + const auto is_batch_norm_candidate = [&]() -> bool { + if (!IsFusedBatchNorm(node)) return false; + if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false; + + bool is_training = true; + if (!GetNodeAttr(node, kIsTraining, &is_training).ok()) return false; + if (is_training) return false; + + return true; + }; + + // Candidate for a FusedBatchNorm fusion. + const auto is_batch_norm_fusion_candidate = [&]() -> bool { + if (!IsRelu(node)) return false; + + const auto relu_input_port = GraphView::InputPort(&node, 0); + const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port); + if (!relu_fanin.node) return false; + + if (IsFusedBatchNorm(*relu_fanin.node)) { + // FusedBatchNorm + Relu. + return true; + + } else if (IsAdd(*relu_fanin.node)) { + // FusedBatchNorm + Add + Relu. + const NodeDef* add = relu_fanin.node; + + const auto add_input_port_0 = GraphView::InputPort(add, 0); + const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0); + if (IsFusedBatchNorm(*add_fanin_0.node)) return true; + + const auto add_input_port_1 = GraphView::InputPort(add, 1); + const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1); + if (IsFusedBatchNorm(*add_fanin_1.node)) return true; + } + + return false; + }; + + return is_batch_norm_candidate() || is_batch_norm_fusion_candidate(); +} + } // namespace Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, @@ -1051,6 +1315,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, // Supported graph patterns. // clang-format off FusedBatchNorm fused_batch_norm; + FusedBatchNormEx fused_batch_norm_ex; ContractionWithBiasAdd contract_with_bias; ContractionWithBiasAddAndActivation contract_with_bias_and_activation; #ifndef INTEL_MKL @@ -1078,9 +1343,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, // and Activation nodes that were fused into a Conv2D node. absl::flat_hash_set invalidated_nodes; - // _FusedMatMul and _FusedConv2D kernels do not have registered gradient - // function, so we must not perform rewrite if the graph will be - // differentiated later. + // _Fused{...} kernels do not have registered gradient function, so we must + // not perform rewrite if the graph will be differentiated later. bool allow_non_differentiable_rewrites = item.optimization_options().allow_non_differentiable_rewrites; @@ -1161,16 +1425,25 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, #endif // !INTEL_MKL // Infer properties lazily in case they are not needed. - if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) { + if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, node)) { + const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; // TODO(rmlarsen): Get rid of tensor value copies. TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically( - /*assume_valid_feeds=*/false, + assume_valid_feeds, /*aggressive_shape_inference=*/false, /*include_input_tensor_values=*/true, /*include_output_tensor_values=*/false)); ctx.inferred_graph_properties = true; } + // Remap FusedBatchNorm++ into the _FusedBatchNormEx. + if (allow_non_differentiable_rewrites && + FindFusedBatchNormEx(ctx, &node, &fused_batch_norm_ex)) { + AddFusedBatchNormExNode(fused_batch_norm_ex, optimized_graph, + &invalidated_nodes); + continue; + } + // During inference, most of the inputs to FusedBatchNorm are constant, and // we can therefore replace the op with a much cheaper set of primitives. if (FindFusedBatchNorm(ctx, &node, &fused_batch_norm)) { diff --git a/tensorflow/core/grappler/optimizers/remapper.h b/tensorflow/core/grappler/optimizers/remapper.h index 804338f4d21..c18413e4e72 100644 --- a/tensorflow/core/grappler/optimizers/remapper.h +++ b/tensorflow/core/grappler/optimizers/remapper.h @@ -26,7 +26,7 @@ namespace grappler { // nodes to decrease the amount of operations needed to perform a computation. class Remapper : public GraphOptimizer { public: - explicit Remapper(RewriterConfig::Toggle opt_level) {} + explicit Remapper(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {} ~Remapper() override {} @@ -37,6 +37,9 @@ class Remapper : public GraphOptimizer { void Feedback(Cluster* cluster, const GrapplerItem& item, const GraphDef& optimized_graph, double result) override; + + private: + RewriterConfig::Toggle opt_level_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 25914231ab1..00e9a69d507 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -22,11 +22,20 @@ limitations under the License. #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/platform/test.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cudnn/cudnn.h" +#endif // GOOGLE_CUDA + namespace tensorflow { namespace grappler { class RemapperTest : public GrapplerTest { protected: + void SetUp() override { + // This is a requirement for fusing FusedBatchNorm + SideInput + Activation. + setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */); + } + // TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for // contractions with non-default contraction output kernels. bool EigenSupportsContractionOutputKernel() { @@ -102,6 +111,179 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) { } } +TEST_F(RemapperTest, FuseBatchNormWithRelu) { + using ::tensorflow::ops::Placeholder; + +#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) + LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires " + "CUDNN_VERSION >= 7402."; +#else + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24}); + auto channels_shape = ops::Placeholder::Shape({24}); + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF); + auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape); + auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape); + auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape); + auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape); + + float epsilon = 0.1f; + auto fbn = ops::FusedBatchNormV3( + s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var, + ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat( + "NHWC")); + auto relu = ops::Relu(s.WithOpName("relu"), fbn.y); + auto fetch = ops::Identity(s.WithOpName("fetch"), relu); + + auto input_t = GenerateRandomTensor({2, 8, 8, 24}); + auto scale_t = GenerateRandomTensor({24}); + auto offset_t = GenerateRandomTensor({24}); + auto mean_t = GenerateRandomTensor({0}); // empty for training + auto var_t = GenerateRandomTensor({0}); // empty for training + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, + {"scale", scale_t}, + {"offset", offset_t}, + {"mean", mean_t}, + {"var", var_t}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on GPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:GPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "relu") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("fused_batch_norm", node.input(0)); + found++; + } + if (node.name() == "fused_batch_norm") { + EXPECT_EQ("_FusedBatchNormEx", node.op()); + EXPECT_EQ("input_cast", node.input(0)); + EXPECT_EQ("scale", node.input(1)); + EXPECT_EQ("offset", node.input(2)); + EXPECT_EQ("mean", node.input(3)); + EXPECT_EQ("var", node.input(4)); + + auto attr = node.attr(); + EXPECT_EQ(0, attr["num_side_inputs"].i()); + EXPECT_EQ("Relu", attr["activation_mode"].s()); + found++; + } + } + EXPECT_EQ(2, found); + + if (GetNumAvailableGPUs() > 0) { + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2); + } +#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) +} + +TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) { + using ::tensorflow::ops::Placeholder; + +#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) + LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires " + "CUDNN_VERSION >= 7402."; +#else + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24}); + auto channels_shape = ops::Placeholder::Shape({24}); + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF); + auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape); + auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape); + auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape); + auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape); + auto side_input = + Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape); + auto side_input_cast = + ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF); + + float epsilon = 0.1f; + auto fbn = ops::FusedBatchNormV3( + s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var, + ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat( + "NHWC")); + auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast); + auto relu = ops::Relu(s.WithOpName("relu"), add); + auto fetch = ops::Identity(s.WithOpName("fetch"), relu); + + auto input_t = GenerateRandomTensor({2, 8, 8, 24}); + auto scale_t = GenerateRandomTensor({24}); + auto offset_t = GenerateRandomTensor({24}); + auto mean_t = GenerateRandomTensor({0}); // empty for training + auto var_t = GenerateRandomTensor({0}); // empty for training + auto side_input_t = GenerateRandomTensor({2, 8, 8, 24}); + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, {"scale", scale_t}, + {"offset", offset_t}, {"mean", mean_t}, + {"var", var_t}, {"side_input", side_input_t}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on GPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:GPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "relu") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("fused_batch_norm", node.input(0)); + found++; + } + if (node.name() == "fused_batch_norm") { + EXPECT_EQ("_FusedBatchNormEx", node.op()); + EXPECT_EQ("input_cast", node.input(0)); + EXPECT_EQ("scale", node.input(1)); + EXPECT_EQ("offset", node.input(2)); + EXPECT_EQ("mean", node.input(3)); + EXPECT_EQ("var", node.input(4)); + EXPECT_EQ("side_input_cast", node.input(5)); + + auto attr = node.attr(); + EXPECT_EQ(1, attr["num_side_inputs"].i()); + EXPECT_EQ("Relu", attr["activation_mode"].s()); + found++; + } + } + EXPECT_EQ(2, found); + + if (GetNumAvailableGPUs() > 0) { + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2); + } +#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) +} + TEST_F(RemapperTest, FuseConv2DWithBias) { if (!EigenSupportsContractionOutputKernel()) return; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0abddd5cc4e..a42e6bde7f8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1814,6 +1814,7 @@ tf_cuda_cc_test( ":ops_util", ":relu_op", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:direct_session", "//tensorflow/core:framework", @@ -1825,6 +1826,7 @@ tf_cuda_cc_test( "//tensorflow/core:testlib", "//tensorflow/stream_executor/cuda:cudnn_plugin", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/kernels/fused_batch_norm_ex_op_test.cc b/tensorflow/core/kernels/fused_batch_norm_ex_op_test.cc index d4fee88d4b6..9584ab7df7e 100644 --- a/tensorflow/core/kernels/fused_batch_norm_ex_op_test.cc +++ b/tensorflow/core/kernels/fused_batch_norm_ex_op_test.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "absl/algorithm/container.h" +#include "absl/strings/match.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -29,6 +31,10 @@ limitations under the License. #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/public/session.h" +#if GOOGLE_CUDA +#include "third_party/gpus/cudnn/cudnn.h" +#endif // GOOGLE_CUDA + namespace tensorflow { template @@ -39,25 +45,46 @@ class FusedBatchNormExOpTestBase : public OpsTestBase { } protected: + struct FusedBatchNormOutputs { + Tensor y; + Tensor batch_mean; + Tensor batch_variance; + Tensor reserve_space_1; + Tensor reserve_space_2; + Tensor reserve_space_3; + }; + + struct FusedBatchNormGradOutputs { + Tensor y_backprop; + Tensor x_backprop; + Tensor scale_backprop; + Tensor offset_backprop; + Tensor reserve_space_4; + Tensor reserve_space_5; + }; + using GraphRunner = std::function; + const Tensor& y_backprop, const Tensor& input_data, + const Tensor& scale_data, const Tensor& offset_data, + const Tensor& mean_data, const Tensor& var_data, + const Tensor& side_input_data, FusedBatchNormOutputs* forward, + FusedBatchNormGradOutputs* backward)>; // Runs a Tensorflow graph defined by the root scope, and fetches the result - // of 'fetch' node into the output Tensor. Optional `fetch_node` parameter - // allows to define a fetch node directly using a NodeDef for the ops that are + // of 'fetch' node into the outputs. Optional `add_nodes` parameter + // allows to define nodes directly using a NodeDef for the ops that are // not supported by the C++ Api. // TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests. // Add a base class for all FusedABC kernels and remove code duplication. - void RunAndFetch(const tensorflow::Scope& root, const string& fetch, - Tensor* output, bool allow_gpu_device, - const NodeDef* fetch_node = nullptr) { + void RunAndFetch(const tensorflow::Scope& root, + const std::vector& fetch, + std::vector* outputs, bool allow_gpu_device, + const std::vector add_nodes = {}) { tensorflow::GraphDef graph; TF_ASSERT_OK(root.ToGraphDef(&graph)); - if (fetch_node) { - *graph.add_node() = *fetch_node; + for (const NodeDef* add_node : add_nodes) { + *graph.add_node() = *add_node; } // We really want to make sure that graph executed exactly as we passed it @@ -101,64 +128,22 @@ class FusedBatchNormExOpTestBase : public OpsTestBase { } TF_ASSERT_OK(session->Create(graph)); - - std::vector unfused_tensors; - TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors)); - - *output = unfused_tensors[0]; + TF_ASSERT_OK(session->Run({}, fetch, {}, outputs)); } - void RunFusedBatchNorm(const Tensor& input_data, const Tensor& scale_data, + void RunFusedBatchNorm(const Tensor& y_backprop_data, + const Tensor& input_data, const Tensor& scale_data, const Tensor& offset_data, const Tensor& mean_data, const Tensor& var_data, const Tensor& side_input_data, const TensorFormat data_format, bool is_training, bool has_side_input, const string& activation_mode, - Tensor* output, float epsilon = 0.1f) { + FusedBatchNormOutputs* forward, + FusedBatchNormGradOutputs* backward, + float epsilon = 0.1f) { Scope root = tensorflow::Scope::NewRootScope(); - ops::FusedBatchNormV2 fbn = ops::FusedBatchNormV2( - root.WithOpName("fused_batch_norm"), - ops::Const(root.WithOpName("input"), Input::Initializer(input_data)), - ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)), - ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)), - ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)), - ops::Const(root.WithOpName("var"), Input::Initializer(var_data)), - ops::FusedBatchNormV2::IsTraining(is_training) - .Epsilon(epsilon) - .DataFormat(ToString(data_format))); - - Output with_side_input; - if (has_side_input) { - with_side_input = - ops::Add(root.WithOpName("with_side_input"), fbn.y, - ops::Const(root.WithOpName("side_input"), - Input::Initializer(side_input_data))); - } else { - with_side_input = - ops::Identity(root.WithOpName("with_side_input"), fbn.y); - } - - if (activation_mode == "Relu") { - ops::Relu(root.WithOpName("with_activation"), with_side_input); - } else { - ops::Identity(root.WithOpName("with_activation"), with_side_input); - } - - RunAndFetch(root, "with_activation", output, /*allow_gpu_device=*/true); - } - - void RunFusedBatchNormEx(const Tensor& input_data, const Tensor& scale_data, - const Tensor& offset_data, const Tensor& mean_data, - const Tensor& var_data, - const Tensor& side_input_data, - const TensorFormat data_format, bool is_training, - bool has_side_input, const string& activation_mode, - Tensor* output, float epsilon = 0.1f) { - Scope root = tensorflow::Scope::NewRootScope(); - - DataType t_dtype = DataTypeToEnum::v(); - DataType u_dtype = DataTypeToEnum::v(); - + Output y_backprop = ops::Const(root.WithOpName("y_backprop"), + Input::Initializer(y_backprop_data)); Output input = ops::Const(root.WithOpName("input"), Input::Initializer(input_data)); Output scale = @@ -172,6 +157,101 @@ class FusedBatchNormExOpTestBase : public OpsTestBase { Output side_input = ops::Const(root.WithOpName("side_input"), Input::Initializer(side_input_data)); + ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3( + root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var, + ops::FusedBatchNormV3::IsTraining(is_training) + .Epsilon(epsilon) + .DataFormat(ToString(data_format))); + + Output with_side_input; + if (has_side_input) { + with_side_input = + ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input); + } else { + with_side_input = + ops::Identity(root.WithOpName("with_side_input"), fwd.y); + } + + Output activation; + if (activation_mode == "Relu") { + activation = + ops::Relu(root.WithOpName("with_activation"), with_side_input); + } else { + activation = + ops::Identity(root.WithOpName("with_activation"), with_side_input); + } + + Output activation_grad; + if (activation_mode == "Relu") { + activation_grad = ops::internal::ReluGrad( + root.WithOpName("activation_grad"), y_backprop, activation); + } else { + activation_grad = + ops::Identity(root.WithOpName("activation_grad"), y_backprop); + } + + ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3( + root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale, + fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3, + ops::FusedBatchNormGradV3::IsTraining(is_training) + .Epsilon(epsilon) + .DataFormat(ToString(data_format))); + + std::vector out_tensors; + RunAndFetch( + root, + {"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2", + "fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5", + "activation_grad:0", "fused_batch_norm_grad:0", + "fused_batch_norm_grad:1", "fused_batch_norm_grad:2"}, + &out_tensors, /*allow_gpu_device=*/true); + + forward->y = out_tensors[0]; + forward->batch_mean = out_tensors[1]; + forward->batch_variance = out_tensors[2]; + forward->reserve_space_1 = out_tensors[3]; + forward->reserve_space_2 = out_tensors[4]; + forward->reserve_space_3 = out_tensors[5]; + + backward->y_backprop = out_tensors[6]; + backward->x_backprop = out_tensors[7]; + backward->scale_backprop = out_tensors[8]; + backward->offset_backprop = out_tensors[9]; + } + + void RunFusedBatchNormEx(const Tensor& y_backprop_data, + const Tensor& input_data, const Tensor& scale_data, + const Tensor& offset_data, const Tensor& mean_data, + const Tensor& var_data, + const Tensor& side_input_data, + const TensorFormat data_format, bool is_training, + bool has_side_input, const string& activation_mode, + FusedBatchNormOutputs* forward, + FusedBatchNormGradOutputs* backward, + float epsilon = 0.1f) { + Scope root = tensorflow::Scope::NewRootScope(); + + DataType t_dtype = DataTypeToEnum::v(); + DataType u_dtype = DataTypeToEnum::v(); + + Output y_backprop = ops::Const(root.WithOpName("y_backprop"), + Input::Initializer(y_backprop_data)); + Output input = + ops::Const(root.WithOpName("input"), Input::Initializer(input_data)); + Output scale = + ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)); + Output offset = + ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)); + Output mean = + ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)); + Output var = + ops::Const(root.WithOpName("var"), Input::Initializer(var_data)); + Output side_input = ops::Const(root.WithOpName("side_input"), + Input::Initializer(side_input_data)); + Output empty = + ops::Const(root.WithOpName("empty"), + Input::Initializer(Tensor(DataTypeToEnum::value, {0}))); + int num_side_inputs = 0; std::vector side_inputs; @@ -197,8 +277,58 @@ class FusedBatchNormExOpTestBase : public OpsTestBase { .Attr("is_training", is_training) .Finalize(&fused_batch_norm_ex)); - RunAndFetch(root, fused_batch_norm_ex.name(), output, - /*allow_gpu_device=*/true, &fused_batch_norm_ex); + NodeDef activation_grad; + if (activation_mode == "Relu") { + TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad") + .Input({y_backprop.name(), 0, t_dtype}) + .Input({fused_batch_norm_ex.name(), 0, t_dtype}) + .Attr("T", t_dtype) + .Finalize(&activation_grad)); + } else { + TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity") + .Input({y_backprop.name(), 0, t_dtype}) + .Attr("T", t_dtype) + .Finalize(&activation_grad)); + } + + NodeDef fused_batch_norm_grad; + TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3") + .Input({activation_grad.name(), 0, t_dtype}) + .Input({input.name(), 0, t_dtype}) + .Input({scale.name(), 0, u_dtype}) + .Input({fused_batch_norm_ex.name(), 3, u_dtype}) + .Input({fused_batch_norm_ex.name(), 4, u_dtype}) + .Input({fused_batch_norm_ex.name(), 5, u_dtype}) + .Attr("T", t_dtype) + .Attr("U", u_dtype) + .Attr("data_format", ToString(data_format)) + .Attr("epsilon", epsilon) + .Attr("is_training", is_training) + .Finalize(&fused_batch_norm_grad)); + + std::vector out_tensors; + RunAndFetch( + root, + {"fused_batch_norm_ex:0", "fused_batch_norm_ex:1", + "fused_batch_norm_ex:2", "fused_batch_norm_ex:3", + "fused_batch_norm_ex:4", "fused_batch_norm_ex:5", "activation_grad:0", + "fused_batch_norm_grad:0", "fused_batch_norm_grad:1", + "fused_batch_norm_grad:2"}, + &out_tensors, + /*allow_gpu_device=*/true, + {&fused_batch_norm_ex, &activation_grad, &fused_batch_norm_grad}); + + forward->y = out_tensors[0]; + forward->batch_mean = out_tensors[1]; + forward->batch_variance = out_tensors[2]; + forward->reserve_space_1 = out_tensors[3]; + forward->reserve_space_2 = out_tensors[4]; + forward->reserve_space_3 = out_tensors[5]; + + backward->y_backprop = out_tensors[6]; + backward->x_backprop = out_tensors[7]; + backward->scale_backprop = out_tensors[8]; + backward->offset_backprop = out_tensors[9]; } void VerifyTensorsNear(int batch, int height, int width, int channels, @@ -215,6 +345,7 @@ class FusedBatchNormExOpTestBase : public OpsTestBase { Tensor input(t_dtype, input_shape); input.flat().setRandom(); + input.flat() -= input.flat().constant(static_cast(0.5)); Tensor scale(u_dtype, {channels}); scale.flat().setRandom(); @@ -228,29 +359,60 @@ class FusedBatchNormExOpTestBase : public OpsTestBase { Tensor var(u_dtype, {channels}); var.flat().setRandom(); - Tensor empty(u_dtype, {0}); - - Tensor fused_batch_norm; - Tensor fused_batch_norm_ex; - Tensor side_input(t_dtype, input_shape); side_input.flat().setRandom(); + side_input.flat() += side_input.flat().constant(static_cast(5.0)); - run_default(input, scale, offset, is_training ? empty : mean, - is_training ? empty : var, side_input, &fused_batch_norm); + Tensor y_backprop(t_dtype, input_shape); + y_backprop.flat().setRandom(); + y_backprop.flat() -= y_backprop.flat().constant(static_cast(0.5)); - // Write some garbage to the `fused_batch_norm_ex` first to make sure - // that fused kernel actually writes correct results to memory. - run_default(side_input, scale, offset, is_training ? empty : mean, - is_training ? empty : var, input, &fused_batch_norm_ex); + Tensor empty(u_dtype, {0}); - run_fused(input, scale, offset, is_training ? empty : mean, - is_training ? empty : var, side_input, &fused_batch_norm_ex); + FusedBatchNormOutputs fbn_forward; + FusedBatchNormOutputs fbn_ex_forward; - ASSERT_EQ(fused_batch_norm.dtype(), fused_batch_norm_ex.dtype()); - ASSERT_EQ(fused_batch_norm.shape(), fused_batch_norm_ex.shape()); + FusedBatchNormGradOutputs fbn_backward; + FusedBatchNormGradOutputs fbn_ex_backward; - test::ExpectClose(fused_batch_norm, fused_batch_norm_ex, 1e-2); + run_default(y_backprop, input, scale, offset, is_training ? empty : mean, + is_training ? empty : var, side_input, &fbn_forward, + &fbn_backward); + + // Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to + // make sure that fused kernel actually writes correct results to memory. + run_default(y_backprop, side_input, scale, offset, + is_training ? empty : mean, is_training ? empty : var, input, + &fbn_ex_forward, &fbn_ex_backward); + + run_fused(y_backprop, input, scale, offset, is_training ? empty : mean, + is_training ? empty : var, side_input, &fbn_ex_forward, + &fbn_ex_backward); + + std::vector> tensor_pairs = { + {fbn_forward.y, fbn_ex_forward.y}, + {fbn_forward.batch_mean, fbn_ex_forward.batch_mean}, + {fbn_forward.batch_variance, fbn_ex_forward.batch_variance}, + {fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1}, + {fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2}, + // NOTE(ezhulenev): We deliberately do not check `reserved_space_3` + // because BatchNormEx with fused side input has different data in it, + // but we make sure that final gradients are the same. + {fbn_backward.y_backprop, fbn_ex_backward.y_backprop}, + {fbn_backward.x_backprop, fbn_ex_backward.x_backprop}, + {fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop}, + {fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop}, + }; + + for (auto& pair : tensor_pairs) { + const Tensor& fbn = pair.first; + const Tensor& fbn_ex = pair.second; + + ASSERT_EQ(fbn.dtype(), fbn_ex.dtype()); + ASSERT_EQ(fbn.shape(), fbn_ex.shape()); + + test::ExpectClose(fbn, fbn_ex, 1e-2); + } } // Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is @@ -260,25 +422,27 @@ class FusedBatchNormExOpTestBase : public OpsTestBase { bool has_side_input, const string& activation_mode) { const GraphRunner run_default = - [&](const Tensor& input_data, const Tensor& scale_data, - const Tensor& offset_data, const Tensor& mean_data, - const Tensor& var_data, const Tensor& side_input_data, - Tensor* out) { - this->RunFusedBatchNorm(input_data, scale_data, offset_data, - mean_data, var_data, side_input_data, - data_format, is_training, has_side_input, - activation_mode, out); + [&](const Tensor& y_backprop, const Tensor& input_data, + const Tensor& scale_data, const Tensor& offset_data, + const Tensor& mean_data, const Tensor& var_data, + const Tensor& side_input_data, FusedBatchNormOutputs* fwd, + FusedBatchNormGradOutputs* bwd) { + this->RunFusedBatchNorm(y_backprop, input_data, scale_data, + offset_data, mean_data, var_data, + side_input_data, data_format, is_training, + has_side_input, activation_mode, fwd, bwd); }; const GraphRunner run_inference = - [&](const Tensor& input_data, const Tensor& scale_data, - const Tensor& offset_data, const Tensor& mean_data, - const Tensor& var_data, const Tensor& side_input_data, - Tensor* out) { - this->RunFusedBatchNormEx(input_data, scale_data, offset_data, - mean_data, var_data, side_input_data, - data_format, is_training, has_side_input, - activation_mode, out); + [&](const Tensor& y_backprop, const Tensor& input_data, + const Tensor& scale_data, const Tensor& offset_data, + const Tensor& mean_data, const Tensor& var_data, + const Tensor& side_input_data, FusedBatchNormOutputs* fwd, + FusedBatchNormGradOutputs* bwd) { + this->RunFusedBatchNormEx(y_backprop, input_data, scale_data, + offset_data, mean_data, var_data, + side_input_data, data_format, is_training, + has_side_input, activation_mode, fwd, bwd); }; VerifyTensorsNear(batch, height, width, channels, data_format, is_training, @@ -297,17 +461,17 @@ constexpr bool kWithSideInput = true; // side_input == true TYPED_TEST_SUITE_P(FusedBatchNormExOpTest); TYPED_TEST_P(FusedBatchNormExOpTest, TrainingInNHWCTest) { - this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining, + this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining, kNoSideInput, "Identity"); } TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithReluInNHWCTest) { - this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining, + this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining, kNoSideInput, "Relu"); } TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithSideInputAndReluInNHWCTest) { - this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining, + this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining, kWithSideInput, "Relu"); } diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 47334739194..eef1a455797 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -692,8 +692,8 @@ struct FusedBatchNormGrad { << " y_backprop shape: " << y_backprop.shape().DebugString() << " x shape: " << x.shape().DebugString() << " scale shape: " << scale.shape().DebugString() - << " tensor format: " << tensor_format - << " compute format: " << compute_format; + << " tensor format: " << ToString(tensor_format) + << " compute format: " << ToString(compute_format); // Inputs Tensor y_backprop_maybe_transformed = y_backprop;