[Grappler] Fuse FusedBatchNorm + <SideInput> + <Activation>.

Resnet50 in NHWC: ~500 img/sec -> ~800 img/sec*

(*) with disabled layout optimizer.

PiperOrigin-RevId: 252108889
This commit is contained in:
Eugene Zhulenev 2019-06-07 13:33:32 -07:00 committed by TensorFlower Gardener
parent 17a2326611
commit 92144e5bd6
8 changed files with 813 additions and 134 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
// XLA implementation of BatchNorm operations.
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -28,7 +29,11 @@ namespace {
class FusedBatchNormOp : public XlaOpKernel {
public:
explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
explicit FusedBatchNormOp(OpKernelConstruction* ctx)
: FusedBatchNormOp(ctx, false) {}
FusedBatchNormOp(OpKernelConstruction* ctx, bool is_batch_norm_ex)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
string data_format_str;
@ -36,6 +41,26 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
if (is_batch_norm_ex) {
int num_side_inputs;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_side_inputs", &num_side_inputs));
OP_REQUIRES(ctx, num_side_inputs >= 0 && num_side_inputs <= 1,
errors::InvalidArgument(
"FusedBatchNormEx supports at most 1 side input."));
add_side_input_ = (num_side_inputs == 1);
string activation_mode;
OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode));
OP_REQUIRES(ctx,
activation_mode == "Identity" || activation_mode == "Relu",
errors::InvalidArgument(
"Unsupported FusedBatchNormEx activation mode: ",
activation_mode));
apply_relu_ = (activation_mode == "Relu");
} else {
add_side_input_ = false;
apply_relu_ = false;
}
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
}
@ -66,9 +91,18 @@ class FusedBatchNormOp : public XlaOpKernel {
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
// In training mode, outputs the normalized value as well as the
// calculated mean and variance.
ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0),
input_type));
// calculated mean and variance. Optionally we add side input and apply
// relu activation.
xla::XlaOp converted =
xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type);
if (add_side_input_ && apply_relu_) {
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
} else if (apply_relu_) {
ctx->SetOutput(0, xla::Relu(converted));
} else {
ctx->SetOutput(0, converted);
}
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
xla::XlaOp variance = xla::GetTupleElement(output, 2);
// Apply Bessel's correction.
@ -103,7 +137,16 @@ class FusedBatchNormOp : public XlaOpKernel {
xla::XlaOp output = xla::BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index);
ctx->SetOutput(0, xla::ConvertElementType(output, input_type));
xla::XlaOp converted = xla::ConvertElementType(output, input_type);
if (add_side_input_ && apply_relu_) {
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
} else if (apply_relu_) {
ctx->SetOutput(0, xla::Relu(converted));
} else {
ctx->SetOutput(0, converted);
}
// Directly send input to output as mean and variance in inference mode.
ctx->SetOutput(1, ctx->Input(3));
ctx->SetOutput(2, ctx->Input(4));
@ -116,6 +159,8 @@ class FusedBatchNormOp : public XlaOpKernel {
float epsilon_;
TensorFormat data_format_;
bool is_training_;
bool add_side_input_;
bool apply_relu_;
bool is_on_gpu_;
};
@ -131,17 +176,26 @@ class FusedBatchNormOpV3 : public FusedBatchNormOp {
}
ctx->SetConstantOutput(5, Tensor());
}
};
private:
float epsilon_;
TensorFormat data_format_;
bool is_training_;
bool is_on_gpu_;
class FusedBatchNormOpEx : public FusedBatchNormOp {
public:
explicit FusedBatchNormOpEx(OpKernelConstruction* ctx)
: FusedBatchNormOp(ctx, /*is_batch_norm_ex=*/true) {}
void Compile(XlaOpKernelContext* ctx) override {
FusedBatchNormOp::CompileImpl(ctx);
if (!ctx->status().ok()) {
return;
}
ctx->SetConstantOutput(5, Tensor());
}
};
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
REGISTER_XLA_OP(Name("_FusedBatchNormEx"), FusedBatchNormOpEx);
class FusedBatchNormGradOp : public XlaOpKernel {
public:

View File

@ -781,6 +781,7 @@ tf_kernel_library(
":constant_folding",
":graph_optimizer",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",

View File

@ -26,6 +26,11 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/env_var.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
namespace grappler {
@ -40,12 +45,17 @@ namespace grappler {
// MatMul + ... -> _FusedMatMul:
// (1) MatMul + BiasAdd + <Activation>
//
// FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
// (1) FusedBatchNorm + <Activation>
// (2) FusedBatchNorm + SideInput + <Activation>
//
// Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
// patterns are "ContractionWith...".
namespace {
constexpr char kFusedConv2D[] = "_FusedConv2D";
constexpr char kFusedMatMul[] = "_FusedMatMul";
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
constexpr char kDataFormat[] = "data_format";
constexpr char kIsTraining[] = "is_training";
@ -81,6 +91,17 @@ struct FusedBatchNorm {
const NodeDef* fused_batch_norm = nullptr;
};
// FusedBatchNorm[$is_training] with fused side input and/or activation.
struct FusedBatchNormEx {
FusedBatchNormEx() = default;
const NodeDef* fused_batch_norm = nullptr;
const NodeDef* side_input = nullptr;
const NodeDef* activation = nullptr;
// Add node that will be invalidated by fusing side input and fused batch norm
const NodeDef* invalidated = nullptr;
};
// Contraction node followed by a BiasAdd.
struct ContractionWithBiasAdd {
ContractionWithBiasAdd() = default;
@ -445,12 +466,11 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
ContractionWithBatchNorm* matched) {
if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2.
// Root of the pattern must be a FusedBatchNorm.
if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false;
// V2 has a separate data type for the scale/offset/mean/variance inputs.
if ((batch_norm->op() == "FusedBatchNormV2" ||
batch_norm->op() == "FusedBatchNormV3") &&
// FusedBatchNormV2 and V3 have an extra type parameter.
if (batch_norm->op() != "FusedBatchNorm" &&
!HasDataType(batch_norm, DT_FLOAT, "U"))
return false;
@ -602,23 +622,15 @@ bool FindContractionWithBiasAndAddActivation(
}
#endif
// Check that given node meets some basic FusedBatchNorm optimization
// preconditions. We use this check to lazily infer graph properties which is
// rather expensive.
bool IsFusedBatchNormCandidate(const NodeDef& node) {
if (!IsFusedBatchNorm(node)) return false;
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
// Check that the node is in inference mode.
const auto& attr = node.attr();
if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false;
return true;
}
bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
FusedBatchNorm* matched) {
if (!IsFusedBatchNormCandidate(*node)) return false;
if (!IsFusedBatchNorm(*node)) return false;
if (GetDataTypeFromAttr(*node, "T") != DT_FLOAT) return false;
// Check that the node is in inference mode.
bool is_training = true;
if (!GetNodeAttr(*node, kIsTraining, &is_training).ok()) return false;
if (is_training) return false;
const auto& props = ctx.graph_properties.GetInputProperties(node->name());
@ -649,6 +661,139 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
return true;
}
// NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
// `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
bool BatchnormSpatialPersistentEnabled() {
#if CUDNN_VERSION >= 7402
static bool is_enabled = [] {
bool is_enabled = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
"TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
/*default_val=*/false, &is_enabled));
return is_enabled;
}();
return is_enabled;
#else
return false;
#endif
}
bool FindFusedBatchNormEx(const RemapperContext& ctx, const NodeDef* node,
FusedBatchNormEx* matched) {
// Root of the pattern must be a Relu.
// TODO(ezhulenev): Forward control dependencies.
if (!IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
return false;
const NodeDef* relu = node;
// Returns true iff the node is a compatible FusedBatchNorm node.
const auto valid_batch_norm = [&](const NodeDef* fused_batch_norm) -> bool {
if (fused_batch_norm == nullptr || !IsFusedBatchNorm(*fused_batch_norm))
return false;
AttrSlice attr(*fused_batch_norm);
// We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with
// contraction (MatMul or Conv2D node).
if (!NodeIsOnGpu(fused_batch_norm)) return false;
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm, "T");
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
// Get the FusedBatchNorm training mode.
bool is_training;
if (!GetNodeAttr(attr, kIsTraining, &is_training).ok()) return false;
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
if (!is_training) return false;
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
// inputs and activation, and it has its own limitations. In inference mode
// we have a custom CUDA kernel that doesn't not have these constraints.
if (is_training) {
// cuDNN only supports NHWC data layout.
string data_format;
if (!GetNodeAttr(attr, kDataFormat, &data_format).ok()) return false;
if (data_format != "NHWC") return false;
// Data type must be DT_HALF.
if (t_dtype != DT_HALF) return false;
// Channel dimension must be a multiple of 4.
const auto& props =
ctx.graph_properties.GetInputProperties(fused_batch_norm->name());
const bool valid_channel_dim = !props.empty() &&
props[0].shape().dim_size() == 4 &&
props[0].shape().dim(3).size() % 4 == 0;
if (!valid_channel_dim) return false;
// cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
if (!BatchnormSpatialPersistentEnabled()) return false;
}
// FusedBatchNormV2 and V3 have an extra type parameter.
if ((fused_batch_norm->op() != "FusedBatchNorm") &&
!HasDataType(fused_batch_norm, DT_FLOAT, "U"))
return false;
// Check that only one node consumes the output of a FusedBatchNorm.
if (HasControlFaninOrFanout(ctx.graph_view, fused_batch_norm) ||
!HasSingleFanoutNode(ctx.graph_view, fused_batch_norm) ||
IsInPreserveSet(ctx, fused_batch_norm))
return false;
return true;
};
const auto relu_input_port = GraphView::InputPort(relu, 0);
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
if (!relu_fanin.node) return false;
// Input to a Relu can be a FusedBatchNorm.
if (valid_batch_norm(relu_fanin.node)) {
matched->activation = relu;
matched->side_input = nullptr;
matched->fused_batch_norm = relu_fanin.node;
matched->invalidated = nullptr;
return true;
}
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
if (IsAdd(*relu_fanin.node)) {
const NodeDef* add = relu_fanin.node;
// Check that only Relu node consumes the output of an Add node.
if (HasControlFaninOrFanout(ctx.graph_view, add) ||
!HasSingleFanoutNode(ctx.graph_view, add) || IsInPreserveSet(ctx, add))
return false;
const auto add_input_port_0 = GraphView::InputPort(add, 0);
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
const auto add_input_port_1 = GraphView::InputPort(add, 1);
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
if (valid_batch_norm(add_fanin_0.node)) {
matched->activation = relu;
matched->side_input = add_fanin_1.node;
matched->fused_batch_norm = add_fanin_0.node;
matched->invalidated = add;
return true;
}
if (valid_batch_norm(add_fanin_1.node)) {
matched->activation = relu;
matched->side_input = add_fanin_0.node;
matched->fused_batch_norm = add_fanin_1.node;
matched->invalidated = add;
return true;
}
}
return false;
}
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
DCHECK(IsConv2D(*conv2d)) << "Input node must be a Conv2D";
@ -664,6 +809,27 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
}
void CopyFusedBatchNormAttributes(const NodeDef* fused_batch_norm,
NodeDef* fused_batch_norm_ex) {
DCHECK(IsFusedBatchNorm(*fused_batch_norm))
<< "Input node must be a FusedBatchNorm";
auto* attr = fused_batch_norm_ex->mutable_attr();
auto src_attr = fused_batch_norm->attr();
(*attr)["T"] = src_attr.at("T");
(*attr)["is_training"] = src_attr.at("is_training");
(*attr)["data_format"] = src_attr.at("data_format");
(*attr)["epsilon"] = src_attr.at("epsilon");
// FusedBatchNormV2 and V3 have an extra type parameter.
if (fused_batch_norm->op() != "FusedBatchNorm") {
(*attr)["U"] = src_attr.at("U");
} else {
(*attr)["U"] = src_attr.at("T");
}
}
void CopyMatMulAttributes(const NodeDef* matmul, NodeDef* fused_matmul) {
DCHECK(IsMatMul(*matmul)) << "Input node must be a MatMul";
@ -902,6 +1068,55 @@ void AddFusedContractionNode(
}
#endif
void AddFusedBatchNormExNode(
const FusedBatchNormEx& matched, GraphDef* optimized_graph,
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
VLOG(2) << "Fuse " << matched.activation->op() << " with FusedBatchNorm:"
<< " side_input="
<< (matched.side_input ? matched.side_input->name() : "<none>")
<< " activation=" << matched.activation->name()
<< " fused_batch_norm=" << matched.fused_batch_norm->name();
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
NodeDef* fused_op = optimized_graph->add_node();
fused_op->set_op(kFusedBatchNormEx);
fused_op->set_name(matched.fused_batch_norm->name());
fused_op->set_device(matched.fused_batch_norm->device());
fused_op->add_input(matched.fused_batch_norm->input(0)); // 0: input
fused_op->add_input(matched.fused_batch_norm->input(1)); // 1: scale
fused_op->add_input(matched.fused_batch_norm->input(2)); // 2: offset
fused_op->add_input(matched.fused_batch_norm->input(3)); // 3: estimated_mean
fused_op->add_input(matched.fused_batch_norm->input(4)); // 4: estimated_var
CopyFusedBatchNormAttributes(matched.fused_batch_norm, fused_op);
auto* attrs = fused_op->mutable_attr();
SetAttrValue(matched.activation->op(), &(*attrs)["activation_mode"]);
if (matched.side_input != nullptr) {
SetAttrValue(1, &(*attrs)["num_side_inputs"]);
fused_op->add_input(matched.side_input->name()); // 5: side_input
} else {
SetAttrValue(0, &(*attrs)["num_side_inputs"]);
}
// Turn activation node into Identity node.
NodeDef* identity_op = optimized_graph->add_node();
identity_op->set_op("Identity");
identity_op->set_name(matched.activation->name());
identity_op->set_device(matched.fused_batch_norm->device());
identity_op->add_input(matched.fused_batch_norm->name());
(*identity_op->mutable_attr())["T"] = attrs->at("T");
// Invalidate all nodes bypassed by this rewrite.
invalidated_nodes->insert(matched.activation);
invalidated_nodes->insert(matched.fused_batch_norm);
if (matched.side_input != nullptr) {
invalidated_nodes->insert(matched.invalidated);
}
}
void AddBatchNormNodes(const FusedBatchNorm& matched,
GraphDef* optimized_graph) {
const NodeDef& fused_node = *matched.fused_batch_norm;
@ -1044,6 +1259,55 @@ void AddBatchNormNodes(const FusedBatchNorm& matched,
*r->add_input() = a->name();
*r->add_input() = c->name();
}
// Check if a node is a candidate to one of the patterns that require inferred
// shapes:
// (1) Splitting FusedBatchNorm into primitives.
// (2) Fusing side input and/or activation into FusedBatchNorm.
bool RequiresInferredShapes(const RemapperContext& ctx, const NodeDef& node) {
// Candidate for a FusedBatchNorm splitting.
const auto is_batch_norm_candidate = [&]() -> bool {
if (!IsFusedBatchNorm(node)) return false;
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
bool is_training = true;
if (!GetNodeAttr(node, kIsTraining, &is_training).ok()) return false;
if (is_training) return false;
return true;
};
// Candidate for a FusedBatchNorm fusion.
const auto is_batch_norm_fusion_candidate = [&]() -> bool {
if (!IsRelu(node)) return false;
const auto relu_input_port = GraphView::InputPort(&node, 0);
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
if (!relu_fanin.node) return false;
if (IsFusedBatchNorm(*relu_fanin.node)) {
// FusedBatchNorm + Relu.
return true;
} else if (IsAdd(*relu_fanin.node)) {
// FusedBatchNorm + Add + Relu.
const NodeDef* add = relu_fanin.node;
const auto add_input_port_0 = GraphView::InputPort(add, 0);
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
if (IsFusedBatchNorm(*add_fanin_0.node)) return true;
const auto add_input_port_1 = GraphView::InputPort(add, 1);
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
if (IsFusedBatchNorm(*add_fanin_1.node)) return true;
}
return false;
};
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate();
}
} // namespace
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
@ -1051,6 +1315,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// Supported graph patterns.
// clang-format off
FusedBatchNorm fused_batch_norm;
FusedBatchNormEx fused_batch_norm_ex;
ContractionWithBiasAdd contract_with_bias;
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
#ifndef INTEL_MKL
@ -1078,9 +1343,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// and Activation nodes that were fused into a Conv2D node.
absl::flat_hash_set<const NodeDef*> invalidated_nodes;
// _FusedMatMul and _FusedConv2D kernels do not have registered gradient
// function, so we must not perform rewrite if the graph will be
// differentiated later.
// _Fused{...} kernels do not have registered gradient function, so we must
// not perform rewrite if the graph will be differentiated later.
bool allow_non_differentiable_rewrites =
item.optimization_options().allow_non_differentiable_rewrites;
@ -1161,16 +1425,25 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
#endif // !INTEL_MKL
// Infer properties lazily in case they are not needed.
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) {
if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, node)) {
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
// TODO(rmlarsen): Get rid of tensor value copies.
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
/*assume_valid_feeds=*/false,
assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_input_tensor_values=*/true,
/*include_output_tensor_values=*/false));
ctx.inferred_graph_properties = true;
}
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
if (allow_non_differentiable_rewrites &&
FindFusedBatchNormEx(ctx, &node, &fused_batch_norm_ex)) {
AddFusedBatchNormExNode(fused_batch_norm_ex, optimized_graph,
&invalidated_nodes);
continue;
}
// During inference, most of the inputs to FusedBatchNorm are constant, and
// we can therefore replace the op with a much cheaper set of primitives.
if (FindFusedBatchNorm(ctx, &node, &fused_batch_norm)) {

View File

@ -26,7 +26,7 @@ namespace grappler {
// nodes to decrease the amount of operations needed to perform a computation.
class Remapper : public GraphOptimizer {
public:
explicit Remapper(RewriterConfig::Toggle opt_level) {}
explicit Remapper(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {}
~Remapper() override {}
@ -37,6 +37,9 @@ class Remapper : public GraphOptimizer {
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override;
private:
RewriterConfig::Toggle opt_level_;
};
} // end namespace grappler

View File

@ -22,11 +22,20 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
namespace grappler {
class RemapperTest : public GrapplerTest {
protected:
void SetUp() override {
// This is a requirement for fusing FusedBatchNorm + SideInput + Activation.
setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
}
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
// contractions with non-default contraction output kernels.
bool EigenSupportsContractionOutputKernel() {
@ -102,6 +111,179 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
}
}
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
using ::tensorflow::ops::Placeholder;
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
"CUDNN_VERSION >= 7402.";
#else
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t},
{"scale", scale_t},
{"offset", offset_t},
{"mean", mean_t},
{"var", var_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("fused_batch_norm", node.input(0));
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ("_FusedBatchNormEx", node.op());
EXPECT_EQ("input_cast", node.input(0));
EXPECT_EQ("scale", node.input(1));
EXPECT_EQ("offset", node.input(2));
EXPECT_EQ("mean", node.input(3));
EXPECT_EQ("var", node.input(4));
auto attr = node.attr();
EXPECT_EQ(0, attr["num_side_inputs"].i());
EXPECT_EQ("Relu", attr["activation_mode"].s());
found++;
}
}
EXPECT_EQ(2, found);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
}
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
using ::tensorflow::ops::Placeholder;
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
"CUDNN_VERSION >= 7402.";
#else
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
auto side_input =
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
auto side_input_cast =
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
auto relu = ops::Relu(s.WithOpName("relu"), add);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"scale", scale_t},
{"offset", offset_t}, {"mean", mean_t},
{"var", var_t}, {"side_input", side_input_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("fused_batch_norm", node.input(0));
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ("_FusedBatchNormEx", node.op());
EXPECT_EQ("input_cast", node.input(0));
EXPECT_EQ("scale", node.input(1));
EXPECT_EQ("offset", node.input(2));
EXPECT_EQ("mean", node.input(3));
EXPECT_EQ("var", node.input(4));
EXPECT_EQ("side_input_cast", node.input(5));
auto attr = node.attr();
EXPECT_EQ(1, attr["num_side_inputs"].i());
EXPECT_EQ("Relu", attr["activation_mode"].s());
found++;
}
}
EXPECT_EQ(2, found);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
}
TEST_F(RemapperTest, FuseConv2DWithBias) {
if (!EigenSupportsContractionOutputKernel()) return;

View File

@ -1814,6 +1814,7 @@ tf_cuda_cc_test(
":ops_util",
":relu_op",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
@ -1825,6 +1826,7 @@ tf_cuda_cc_test(
"//tensorflow/core:testlib",
"//tensorflow/stream_executor/cuda:cudnn_plugin",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)

View File

@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/
#include "absl/algorithm/container.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h"
@ -29,6 +31,10 @@ limitations under the License.
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/public/session.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
template <typename T, typename U>
@ -39,25 +45,46 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
}
protected:
struct FusedBatchNormOutputs {
Tensor y;
Tensor batch_mean;
Tensor batch_variance;
Tensor reserve_space_1;
Tensor reserve_space_2;
Tensor reserve_space_3;
};
struct FusedBatchNormGradOutputs {
Tensor y_backprop;
Tensor x_backprop;
Tensor scale_backprop;
Tensor offset_backprop;
Tensor reserve_space_4;
Tensor reserve_space_5;
};
using GraphRunner = std::function<void(
const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data, const Tensor& side_input_data, Tensor* out)>;
const Tensor& y_backprop, const Tensor& input_data,
const Tensor& scale_data, const Tensor& offset_data,
const Tensor& mean_data, const Tensor& var_data,
const Tensor& side_input_data, FusedBatchNormOutputs* forward,
FusedBatchNormGradOutputs* backward)>;
// Runs a Tensorflow graph defined by the root scope, and fetches the result
// of 'fetch' node into the output Tensor. Optional `fetch_node` parameter
// allows to define a fetch node directly using a NodeDef for the ops that are
// of 'fetch' node into the outputs. Optional `add_nodes` parameter
// allows to define nodes directly using a NodeDef for the ops that are
// not supported by the C++ Api.
// TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
// Add a base class for all FusedABC kernels and remove code duplication.
void RunAndFetch(const tensorflow::Scope& root, const string& fetch,
Tensor* output, bool allow_gpu_device,
const NodeDef* fetch_node = nullptr) {
void RunAndFetch(const tensorflow::Scope& root,
const std::vector<string>& fetch,
std::vector<Tensor>* outputs, bool allow_gpu_device,
const std::vector<const NodeDef*> add_nodes = {}) {
tensorflow::GraphDef graph;
TF_ASSERT_OK(root.ToGraphDef(&graph));
if (fetch_node) {
*graph.add_node() = *fetch_node;
for (const NodeDef* add_node : add_nodes) {
*graph.add_node() = *add_node;
}
// We really want to make sure that graph executed exactly as we passed it
@ -101,64 +128,22 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
}
TF_ASSERT_OK(session->Create(graph));
std::vector<Tensor> unfused_tensors;
TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
*output = unfused_tensors[0];
TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
}
void RunFusedBatchNorm(const Tensor& input_data, const Tensor& scale_data,
void RunFusedBatchNorm(const Tensor& y_backprop_data,
const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data, const Tensor& side_input_data,
const TensorFormat data_format, bool is_training,
bool has_side_input, const string& activation_mode,
Tensor* output, float epsilon = 0.1f) {
FusedBatchNormOutputs* forward,
FusedBatchNormGradOutputs* backward,
float epsilon = 0.1f) {
Scope root = tensorflow::Scope::NewRootScope();
ops::FusedBatchNormV2 fbn = ops::FusedBatchNormV2(
root.WithOpName("fused_batch_norm"),
ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)),
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)),
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)),
ops::Const(root.WithOpName("var"), Input::Initializer(var_data)),
ops::FusedBatchNormV2::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat(ToString(data_format)));
Output with_side_input;
if (has_side_input) {
with_side_input =
ops::Add(root.WithOpName("with_side_input"), fbn.y,
ops::Const(root.WithOpName("side_input"),
Input::Initializer(side_input_data)));
} else {
with_side_input =
ops::Identity(root.WithOpName("with_side_input"), fbn.y);
}
if (activation_mode == "Relu") {
ops::Relu(root.WithOpName("with_activation"), with_side_input);
} else {
ops::Identity(root.WithOpName("with_activation"), with_side_input);
}
RunAndFetch(root, "with_activation", output, /*allow_gpu_device=*/true);
}
void RunFusedBatchNormEx(const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data,
const Tensor& side_input_data,
const TensorFormat data_format, bool is_training,
bool has_side_input, const string& activation_mode,
Tensor* output, float epsilon = 0.1f) {
Scope root = tensorflow::Scope::NewRootScope();
DataType t_dtype = DataTypeToEnum<T>::v();
DataType u_dtype = DataTypeToEnum<U>::v();
Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
Input::Initializer(y_backprop_data));
Output input =
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
Output scale =
@ -172,6 +157,101 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
Output side_input = ops::Const(root.WithOpName("side_input"),
Input::Initializer(side_input_data));
ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat(ToString(data_format)));
Output with_side_input;
if (has_side_input) {
with_side_input =
ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
} else {
with_side_input =
ops::Identity(root.WithOpName("with_side_input"), fwd.y);
}
Output activation;
if (activation_mode == "Relu") {
activation =
ops::Relu(root.WithOpName("with_activation"), with_side_input);
} else {
activation =
ops::Identity(root.WithOpName("with_activation"), with_side_input);
}
Output activation_grad;
if (activation_mode == "Relu") {
activation_grad = ops::internal::ReluGrad(
root.WithOpName("activation_grad"), y_backprop, activation);
} else {
activation_grad =
ops::Identity(root.WithOpName("activation_grad"), y_backprop);
}
ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
ops::FusedBatchNormGradV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat(ToString(data_format)));
std::vector<Tensor> out_tensors;
RunAndFetch(
root,
{"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
"fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
"activation_grad:0", "fused_batch_norm_grad:0",
"fused_batch_norm_grad:1", "fused_batch_norm_grad:2"},
&out_tensors, /*allow_gpu_device=*/true);
forward->y = out_tensors[0];
forward->batch_mean = out_tensors[1];
forward->batch_variance = out_tensors[2];
forward->reserve_space_1 = out_tensors[3];
forward->reserve_space_2 = out_tensors[4];
forward->reserve_space_3 = out_tensors[5];
backward->y_backprop = out_tensors[6];
backward->x_backprop = out_tensors[7];
backward->scale_backprop = out_tensors[8];
backward->offset_backprop = out_tensors[9];
}
void RunFusedBatchNormEx(const Tensor& y_backprop_data,
const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data,
const Tensor& side_input_data,
const TensorFormat data_format, bool is_training,
bool has_side_input, const string& activation_mode,
FusedBatchNormOutputs* forward,
FusedBatchNormGradOutputs* backward,
float epsilon = 0.1f) {
Scope root = tensorflow::Scope::NewRootScope();
DataType t_dtype = DataTypeToEnum<T>::v();
DataType u_dtype = DataTypeToEnum<U>::v();
Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
Input::Initializer(y_backprop_data));
Output input =
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
Output scale =
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
Output offset =
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
Output mean =
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
Output var =
ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
Output side_input = ops::Const(root.WithOpName("side_input"),
Input::Initializer(side_input_data));
Output empty =
ops::Const(root.WithOpName("empty"),
Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
int num_side_inputs = 0;
std::vector<NodeDefBuilder::NodeOut> side_inputs;
@ -197,8 +277,58 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
.Attr("is_training", is_training)
.Finalize(&fused_batch_norm_ex));
RunAndFetch(root, fused_batch_norm_ex.name(), output,
/*allow_gpu_device=*/true, &fused_batch_norm_ex);
NodeDef activation_grad;
if (activation_mode == "Relu") {
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
.Input({y_backprop.name(), 0, t_dtype})
.Input({fused_batch_norm_ex.name(), 0, t_dtype})
.Attr("T", t_dtype)
.Finalize(&activation_grad));
} else {
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
.Input({y_backprop.name(), 0, t_dtype})
.Attr("T", t_dtype)
.Finalize(&activation_grad));
}
NodeDef fused_batch_norm_grad;
TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
.Input({activation_grad.name(), 0, t_dtype})
.Input({input.name(), 0, t_dtype})
.Input({scale.name(), 0, u_dtype})
.Input({fused_batch_norm_ex.name(), 3, u_dtype})
.Input({fused_batch_norm_ex.name(), 4, u_dtype})
.Input({fused_batch_norm_ex.name(), 5, u_dtype})
.Attr("T", t_dtype)
.Attr("U", u_dtype)
.Attr("data_format", ToString(data_format))
.Attr("epsilon", epsilon)
.Attr("is_training", is_training)
.Finalize(&fused_batch_norm_grad));
std::vector<Tensor> out_tensors;
RunAndFetch(
root,
{"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
"fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
"fused_batch_norm_ex:4", "fused_batch_norm_ex:5", "activation_grad:0",
"fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
"fused_batch_norm_grad:2"},
&out_tensors,
/*allow_gpu_device=*/true,
{&fused_batch_norm_ex, &activation_grad, &fused_batch_norm_grad});
forward->y = out_tensors[0];
forward->batch_mean = out_tensors[1];
forward->batch_variance = out_tensors[2];
forward->reserve_space_1 = out_tensors[3];
forward->reserve_space_2 = out_tensors[4];
forward->reserve_space_3 = out_tensors[5];
backward->y_backprop = out_tensors[6];
backward->x_backprop = out_tensors[7];
backward->scale_backprop = out_tensors[8];
backward->offset_backprop = out_tensors[9];
}
void VerifyTensorsNear(int batch, int height, int width, int channels,
@ -215,6 +345,7 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
Tensor input(t_dtype, input_shape);
input.flat<T>().setRandom();
input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
Tensor scale(u_dtype, {channels});
scale.flat<U>().setRandom();
@ -228,29 +359,60 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
Tensor var(u_dtype, {channels});
var.flat<U>().setRandom();
Tensor empty(u_dtype, {0});
Tensor fused_batch_norm;
Tensor fused_batch_norm_ex;
Tensor side_input(t_dtype, input_shape);
side_input.flat<T>().setRandom();
side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
run_default(input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, side_input, &fused_batch_norm);
Tensor y_backprop(t_dtype, input_shape);
y_backprop.flat<T>().setRandom();
y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
// Write some garbage to the `fused_batch_norm_ex` first to make sure
// that fused kernel actually writes correct results to memory.
run_default(side_input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, input, &fused_batch_norm_ex);
Tensor empty(u_dtype, {0});
run_fused(input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, side_input, &fused_batch_norm_ex);
FusedBatchNormOutputs fbn_forward;
FusedBatchNormOutputs fbn_ex_forward;
ASSERT_EQ(fused_batch_norm.dtype(), fused_batch_norm_ex.dtype());
ASSERT_EQ(fused_batch_norm.shape(), fused_batch_norm_ex.shape());
FusedBatchNormGradOutputs fbn_backward;
FusedBatchNormGradOutputs fbn_ex_backward;
test::ExpectClose(fused_batch_norm, fused_batch_norm_ex, 1e-2);
run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, side_input, &fbn_forward,
&fbn_backward);
// Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
// make sure that fused kernel actually writes correct results to memory.
run_default(y_backprop, side_input, scale, offset,
is_training ? empty : mean, is_training ? empty : var, input,
&fbn_ex_forward, &fbn_ex_backward);
run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, side_input, &fbn_ex_forward,
&fbn_ex_backward);
std::vector<std::pair<Tensor, Tensor>> tensor_pairs = {
{fbn_forward.y, fbn_ex_forward.y},
{fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
{fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
{fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
{fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
// NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
// because BatchNormEx with fused side input has different data in it,
// but we make sure that final gradients are the same.
{fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
{fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
{fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
{fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
};
for (auto& pair : tensor_pairs) {
const Tensor& fbn = pair.first;
const Tensor& fbn_ex = pair.second;
ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
ASSERT_EQ(fbn.shape(), fbn_ex.shape());
test::ExpectClose(fbn, fbn_ex, 1e-2);
}
}
// Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
@ -260,25 +422,27 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
bool has_side_input,
const string& activation_mode) {
const GraphRunner run_default =
[&](const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data, const Tensor& side_input_data,
Tensor* out) {
this->RunFusedBatchNorm(input_data, scale_data, offset_data,
mean_data, var_data, side_input_data,
data_format, is_training, has_side_input,
activation_mode, out);
[&](const Tensor& y_backprop, const Tensor& input_data,
const Tensor& scale_data, const Tensor& offset_data,
const Tensor& mean_data, const Tensor& var_data,
const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
FusedBatchNormGradOutputs* bwd) {
this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
offset_data, mean_data, var_data,
side_input_data, data_format, is_training,
has_side_input, activation_mode, fwd, bwd);
};
const GraphRunner run_inference =
[&](const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data, const Tensor& side_input_data,
Tensor* out) {
this->RunFusedBatchNormEx(input_data, scale_data, offset_data,
mean_data, var_data, side_input_data,
data_format, is_training, has_side_input,
activation_mode, out);
[&](const Tensor& y_backprop, const Tensor& input_data,
const Tensor& scale_data, const Tensor& offset_data,
const Tensor& mean_data, const Tensor& var_data,
const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
FusedBatchNormGradOutputs* bwd) {
this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
offset_data, mean_data, var_data,
side_input_data, data_format, is_training,
has_side_input, activation_mode, fwd, bwd);
};
VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
@ -297,17 +461,17 @@ constexpr bool kWithSideInput = true; // side_input == true
TYPED_TEST_SUITE_P(FusedBatchNormExOpTest);
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingInNHWCTest) {
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
kNoSideInput, "Identity");
}
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithReluInNHWCTest) {
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
kNoSideInput, "Relu");
}
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithSideInputAndReluInNHWCTest) {
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
kWithSideInput, "Relu");
}

View File

@ -692,8 +692,8 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
<< " y_backprop shape: " << y_backprop.shape().DebugString()
<< " x shape: " << x.shape().DebugString()
<< " scale shape: " << scale.shape().DebugString()
<< " tensor format: " << tensor_format
<< " compute format: " << compute_format;
<< " tensor format: " << ToString(tensor_format)
<< " compute format: " << ToString(compute_format);
// Inputs
Tensor y_backprop_maybe_transformed = y_backprop;