[Grappler] Fuse FusedBatchNorm + <SideInput> + <Activation>.

Resnet50 in NHWC: ~500 img/sec -> ~800 img/sec*

(*) with disabled layout optimizer.

PiperOrigin-RevId: 252108889
This commit is contained in:
Eugene Zhulenev 2019-06-07 13:33:32 -07:00 committed by TensorFlower Gardener
parent 17a2326611
commit 92144e5bd6
8 changed files with 813 additions and 134 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// XLA implementation of BatchNorm operations. // XLA implementation of BatchNorm operations.
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -28,7 +29,11 @@ namespace {
class FusedBatchNormOp : public XlaOpKernel { class FusedBatchNormOp : public XlaOpKernel {
public: public:
explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { explicit FusedBatchNormOp(OpKernelConstruction* ctx)
: FusedBatchNormOp(ctx, false) {}
FusedBatchNormOp(OpKernelConstruction* ctx, bool is_batch_norm_ex)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
string data_format_str; string data_format_str;
@ -36,6 +41,26 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES( OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_), ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str)); errors::InvalidArgument("Invalid data format: ", data_format_str));
if (is_batch_norm_ex) {
int num_side_inputs;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_side_inputs", &num_side_inputs));
OP_REQUIRES(ctx, num_side_inputs >= 0 && num_side_inputs <= 1,
errors::InvalidArgument(
"FusedBatchNormEx supports at most 1 side input."));
add_side_input_ = (num_side_inputs == 1);
string activation_mode;
OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode));
OP_REQUIRES(ctx,
activation_mode == "Identity" || activation_mode == "Relu",
errors::InvalidArgument(
"Unsupported FusedBatchNormEx activation mode: ",
activation_mode));
apply_relu_ = (activation_mode == "Relu");
} else {
add_side_input_ = false;
apply_relu_ = false;
}
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
} }
@ -66,9 +91,18 @@ class FusedBatchNormOp : public XlaOpKernel {
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
// In training mode, outputs the normalized value as well as the // In training mode, outputs the normalized value as well as the
// calculated mean and variance. // calculated mean and variance. Optionally we add side input and apply
ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0), // relu activation.
input_type)); xla::XlaOp converted =
xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type);
if (add_side_input_ && apply_relu_) {
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
} else if (apply_relu_) {
ctx->SetOutput(0, xla::Relu(converted));
} else {
ctx->SetOutput(0, converted);
}
ctx->SetOutput(1, xla::GetTupleElement(output, 1)); ctx->SetOutput(1, xla::GetTupleElement(output, 1));
xla::XlaOp variance = xla::GetTupleElement(output, 2); xla::XlaOp variance = xla::GetTupleElement(output, 2);
// Apply Bessel's correction. // Apply Bessel's correction.
@ -103,7 +137,16 @@ class FusedBatchNormOp : public XlaOpKernel {
xla::XlaOp output = xla::BatchNormInference( xla::XlaOp output = xla::BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index); epsilon_, feature_index);
ctx->SetOutput(0, xla::ConvertElementType(output, input_type));
xla::XlaOp converted = xla::ConvertElementType(output, input_type);
if (add_side_input_ && apply_relu_) {
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
} else if (apply_relu_) {
ctx->SetOutput(0, xla::Relu(converted));
} else {
ctx->SetOutput(0, converted);
}
// Directly send input to output as mean and variance in inference mode. // Directly send input to output as mean and variance in inference mode.
ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(1, ctx->Input(3));
ctx->SetOutput(2, ctx->Input(4)); ctx->SetOutput(2, ctx->Input(4));
@ -116,6 +159,8 @@ class FusedBatchNormOp : public XlaOpKernel {
float epsilon_; float epsilon_;
TensorFormat data_format_; TensorFormat data_format_;
bool is_training_; bool is_training_;
bool add_side_input_;
bool apply_relu_;
bool is_on_gpu_; bool is_on_gpu_;
}; };
@ -131,17 +176,26 @@ class FusedBatchNormOpV3 : public FusedBatchNormOp {
} }
ctx->SetConstantOutput(5, Tensor()); ctx->SetConstantOutput(5, Tensor());
} }
};
private: class FusedBatchNormOpEx : public FusedBatchNormOp {
float epsilon_; public:
TensorFormat data_format_; explicit FusedBatchNormOpEx(OpKernelConstruction* ctx)
bool is_training_; : FusedBatchNormOp(ctx, /*is_batch_norm_ex=*/true) {}
bool is_on_gpu_;
void Compile(XlaOpKernelContext* ctx) override {
FusedBatchNormOp::CompileImpl(ctx);
if (!ctx->status().ok()) {
return;
}
ctx->SetConstantOutput(5, Tensor());
}
}; };
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3); REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
REGISTER_XLA_OP(Name("_FusedBatchNormEx"), FusedBatchNormOpEx);
class FusedBatchNormGradOp : public XlaOpKernel { class FusedBatchNormGradOp : public XlaOpKernel {
public: public:

View File

@ -781,6 +781,7 @@ tf_kernel_library(
":constant_folding", ":constant_folding",
":graph_optimizer", ":graph_optimizer",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",

View File

@ -26,6 +26,11 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/env_var.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
@ -40,12 +45,17 @@ namespace grappler {
// MatMul + ... -> _FusedMatMul: // MatMul + ... -> _FusedMatMul:
// (1) MatMul + BiasAdd + <Activation> // (1) MatMul + BiasAdd + <Activation>
// //
// FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
// (1) FusedBatchNorm + <Activation>
// (2) FusedBatchNorm + SideInput + <Activation>
//
// Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
// patterns are "ContractionWith...". // patterns are "ContractionWith...".
namespace { namespace {
constexpr char kFusedConv2D[] = "_FusedConv2D"; constexpr char kFusedConv2D[] = "_FusedConv2D";
constexpr char kFusedMatMul[] = "_FusedMatMul"; constexpr char kFusedMatMul[] = "_FusedMatMul";
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
constexpr char kDataFormat[] = "data_format"; constexpr char kDataFormat[] = "data_format";
constexpr char kIsTraining[] = "is_training"; constexpr char kIsTraining[] = "is_training";
@ -81,6 +91,17 @@ struct FusedBatchNorm {
const NodeDef* fused_batch_norm = nullptr; const NodeDef* fused_batch_norm = nullptr;
}; };
// FusedBatchNorm[$is_training] with fused side input and/or activation.
struct FusedBatchNormEx {
FusedBatchNormEx() = default;
const NodeDef* fused_batch_norm = nullptr;
const NodeDef* side_input = nullptr;
const NodeDef* activation = nullptr;
// Add node that will be invalidated by fusing side input and fused batch norm
const NodeDef* invalidated = nullptr;
};
// Contraction node followed by a BiasAdd. // Contraction node followed by a BiasAdd.
struct ContractionWithBiasAdd { struct ContractionWithBiasAdd {
ContractionWithBiasAdd() = default; ContractionWithBiasAdd() = default;
@ -445,12 +466,11 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
ContractionWithBatchNorm* matched) { ContractionWithBatchNorm* matched) {
if (!EigenSupportsContractionOutputKernel()) return false; if (!EigenSupportsContractionOutputKernel()) return false;
// Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2. // Root of the pattern must be a FusedBatchNorm.
if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false; if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false;
// V2 has a separate data type for the scale/offset/mean/variance inputs. // FusedBatchNormV2 and V3 have an extra type parameter.
if ((batch_norm->op() == "FusedBatchNormV2" || if (batch_norm->op() != "FusedBatchNorm" &&
batch_norm->op() == "FusedBatchNormV3") &&
!HasDataType(batch_norm, DT_FLOAT, "U")) !HasDataType(batch_norm, DT_FLOAT, "U"))
return false; return false;
@ -602,23 +622,15 @@ bool FindContractionWithBiasAndAddActivation(
} }
#endif #endif
// Check that given node meets some basic FusedBatchNorm optimization
// preconditions. We use this check to lazily infer graph properties which is
// rather expensive.
bool IsFusedBatchNormCandidate(const NodeDef& node) {
if (!IsFusedBatchNorm(node)) return false;
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
// Check that the node is in inference mode.
const auto& attr = node.attr();
if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false;
return true;
}
bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node, bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
FusedBatchNorm* matched) { FusedBatchNorm* matched) {
if (!IsFusedBatchNormCandidate(*node)) return false; if (!IsFusedBatchNorm(*node)) return false;
if (GetDataTypeFromAttr(*node, "T") != DT_FLOAT) return false;
// Check that the node is in inference mode.
bool is_training = true;
if (!GetNodeAttr(*node, kIsTraining, &is_training).ok()) return false;
if (is_training) return false;
const auto& props = ctx.graph_properties.GetInputProperties(node->name()); const auto& props = ctx.graph_properties.GetInputProperties(node->name());
@ -649,6 +661,139 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
return true; return true;
} }
// NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
// `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
bool BatchnormSpatialPersistentEnabled() {
#if CUDNN_VERSION >= 7402
static bool is_enabled = [] {
bool is_enabled = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
"TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
/*default_val=*/false, &is_enabled));
return is_enabled;
}();
return is_enabled;
#else
return false;
#endif
}
bool FindFusedBatchNormEx(const RemapperContext& ctx, const NodeDef* node,
FusedBatchNormEx* matched) {
// Root of the pattern must be a Relu.
// TODO(ezhulenev): Forward control dependencies.
if (!IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
return false;
const NodeDef* relu = node;
// Returns true iff the node is a compatible FusedBatchNorm node.
const auto valid_batch_norm = [&](const NodeDef* fused_batch_norm) -> bool {
if (fused_batch_norm == nullptr || !IsFusedBatchNorm(*fused_batch_norm))
return false;
AttrSlice attr(*fused_batch_norm);
// We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with
// contraction (MatMul or Conv2D node).
if (!NodeIsOnGpu(fused_batch_norm)) return false;
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm, "T");
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
// Get the FusedBatchNorm training mode.
bool is_training;
if (!GetNodeAttr(attr, kIsTraining, &is_training).ok()) return false;
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
if (!is_training) return false;
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
// inputs and activation, and it has its own limitations. In inference mode
// we have a custom CUDA kernel that doesn't not have these constraints.
if (is_training) {
// cuDNN only supports NHWC data layout.
string data_format;
if (!GetNodeAttr(attr, kDataFormat, &data_format).ok()) return false;
if (data_format != "NHWC") return false;
// Data type must be DT_HALF.
if (t_dtype != DT_HALF) return false;
// Channel dimension must be a multiple of 4.
const auto& props =
ctx.graph_properties.GetInputProperties(fused_batch_norm->name());
const bool valid_channel_dim = !props.empty() &&
props[0].shape().dim_size() == 4 &&
props[0].shape().dim(3).size() % 4 == 0;
if (!valid_channel_dim) return false;
// cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
if (!BatchnormSpatialPersistentEnabled()) return false;
}
// FusedBatchNormV2 and V3 have an extra type parameter.
if ((fused_batch_norm->op() != "FusedBatchNorm") &&
!HasDataType(fused_batch_norm, DT_FLOAT, "U"))
return false;
// Check that only one node consumes the output of a FusedBatchNorm.
if (HasControlFaninOrFanout(ctx.graph_view, fused_batch_norm) ||
!HasSingleFanoutNode(ctx.graph_view, fused_batch_norm) ||
IsInPreserveSet(ctx, fused_batch_norm))
return false;
return true;
};
const auto relu_input_port = GraphView::InputPort(relu, 0);
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
if (!relu_fanin.node) return false;
// Input to a Relu can be a FusedBatchNorm.
if (valid_batch_norm(relu_fanin.node)) {
matched->activation = relu;
matched->side_input = nullptr;
matched->fused_batch_norm = relu_fanin.node;
matched->invalidated = nullptr;
return true;
}
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
if (IsAdd(*relu_fanin.node)) {
const NodeDef* add = relu_fanin.node;
// Check that only Relu node consumes the output of an Add node.
if (HasControlFaninOrFanout(ctx.graph_view, add) ||
!HasSingleFanoutNode(ctx.graph_view, add) || IsInPreserveSet(ctx, add))
return false;
const auto add_input_port_0 = GraphView::InputPort(add, 0);
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
const auto add_input_port_1 = GraphView::InputPort(add, 1);
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
if (valid_batch_norm(add_fanin_0.node)) {
matched->activation = relu;
matched->side_input = add_fanin_1.node;
matched->fused_batch_norm = add_fanin_0.node;
matched->invalidated = add;
return true;
}
if (valid_batch_norm(add_fanin_1.node)) {
matched->activation = relu;
matched->side_input = add_fanin_0.node;
matched->fused_batch_norm = add_fanin_1.node;
matched->invalidated = add;
return true;
}
}
return false;
}
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) { void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
DCHECK(IsConv2D(*conv2d)) << "Input node must be a Conv2D"; DCHECK(IsConv2D(*conv2d)) << "Input node must be a Conv2D";
@ -664,6 +809,27 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
} }
void CopyFusedBatchNormAttributes(const NodeDef* fused_batch_norm,
NodeDef* fused_batch_norm_ex) {
DCHECK(IsFusedBatchNorm(*fused_batch_norm))
<< "Input node must be a FusedBatchNorm";
auto* attr = fused_batch_norm_ex->mutable_attr();
auto src_attr = fused_batch_norm->attr();
(*attr)["T"] = src_attr.at("T");
(*attr)["is_training"] = src_attr.at("is_training");
(*attr)["data_format"] = src_attr.at("data_format");
(*attr)["epsilon"] = src_attr.at("epsilon");
// FusedBatchNormV2 and V3 have an extra type parameter.
if (fused_batch_norm->op() != "FusedBatchNorm") {
(*attr)["U"] = src_attr.at("U");
} else {
(*attr)["U"] = src_attr.at("T");
}
}
void CopyMatMulAttributes(const NodeDef* matmul, NodeDef* fused_matmul) { void CopyMatMulAttributes(const NodeDef* matmul, NodeDef* fused_matmul) {
DCHECK(IsMatMul(*matmul)) << "Input node must be a MatMul"; DCHECK(IsMatMul(*matmul)) << "Input node must be a MatMul";
@ -902,6 +1068,55 @@ void AddFusedContractionNode(
} }
#endif #endif
void AddFusedBatchNormExNode(
const FusedBatchNormEx& matched, GraphDef* optimized_graph,
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
VLOG(2) << "Fuse " << matched.activation->op() << " with FusedBatchNorm:"
<< " side_input="
<< (matched.side_input ? matched.side_input->name() : "<none>")
<< " activation=" << matched.activation->name()
<< " fused_batch_norm=" << matched.fused_batch_norm->name();
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
NodeDef* fused_op = optimized_graph->add_node();
fused_op->set_op(kFusedBatchNormEx);
fused_op->set_name(matched.fused_batch_norm->name());
fused_op->set_device(matched.fused_batch_norm->device());
fused_op->add_input(matched.fused_batch_norm->input(0)); // 0: input
fused_op->add_input(matched.fused_batch_norm->input(1)); // 1: scale
fused_op->add_input(matched.fused_batch_norm->input(2)); // 2: offset
fused_op->add_input(matched.fused_batch_norm->input(3)); // 3: estimated_mean
fused_op->add_input(matched.fused_batch_norm->input(4)); // 4: estimated_var
CopyFusedBatchNormAttributes(matched.fused_batch_norm, fused_op);
auto* attrs = fused_op->mutable_attr();
SetAttrValue(matched.activation->op(), &(*attrs)["activation_mode"]);
if (matched.side_input != nullptr) {
SetAttrValue(1, &(*attrs)["num_side_inputs"]);
fused_op->add_input(matched.side_input->name()); // 5: side_input
} else {
SetAttrValue(0, &(*attrs)["num_side_inputs"]);
}
// Turn activation node into Identity node.
NodeDef* identity_op = optimized_graph->add_node();
identity_op->set_op("Identity");
identity_op->set_name(matched.activation->name());
identity_op->set_device(matched.fused_batch_norm->device());
identity_op->add_input(matched.fused_batch_norm->name());
(*identity_op->mutable_attr())["T"] = attrs->at("T");
// Invalidate all nodes bypassed by this rewrite.
invalidated_nodes->insert(matched.activation);
invalidated_nodes->insert(matched.fused_batch_norm);
if (matched.side_input != nullptr) {
invalidated_nodes->insert(matched.invalidated);
}
}
void AddBatchNormNodes(const FusedBatchNorm& matched, void AddBatchNormNodes(const FusedBatchNorm& matched,
GraphDef* optimized_graph) { GraphDef* optimized_graph) {
const NodeDef& fused_node = *matched.fused_batch_norm; const NodeDef& fused_node = *matched.fused_batch_norm;
@ -1044,6 +1259,55 @@ void AddBatchNormNodes(const FusedBatchNorm& matched,
*r->add_input() = a->name(); *r->add_input() = a->name();
*r->add_input() = c->name(); *r->add_input() = c->name();
} }
// Check if a node is a candidate to one of the patterns that require inferred
// shapes:
// (1) Splitting FusedBatchNorm into primitives.
// (2) Fusing side input and/or activation into FusedBatchNorm.
bool RequiresInferredShapes(const RemapperContext& ctx, const NodeDef& node) {
// Candidate for a FusedBatchNorm splitting.
const auto is_batch_norm_candidate = [&]() -> bool {
if (!IsFusedBatchNorm(node)) return false;
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
bool is_training = true;
if (!GetNodeAttr(node, kIsTraining, &is_training).ok()) return false;
if (is_training) return false;
return true;
};
// Candidate for a FusedBatchNorm fusion.
const auto is_batch_norm_fusion_candidate = [&]() -> bool {
if (!IsRelu(node)) return false;
const auto relu_input_port = GraphView::InputPort(&node, 0);
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
if (!relu_fanin.node) return false;
if (IsFusedBatchNorm(*relu_fanin.node)) {
// FusedBatchNorm + Relu.
return true;
} else if (IsAdd(*relu_fanin.node)) {
// FusedBatchNorm + Add + Relu.
const NodeDef* add = relu_fanin.node;
const auto add_input_port_0 = GraphView::InputPort(add, 0);
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
if (IsFusedBatchNorm(*add_fanin_0.node)) return true;
const auto add_input_port_1 = GraphView::InputPort(add, 1);
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
if (IsFusedBatchNorm(*add_fanin_1.node)) return true;
}
return false;
};
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate();
}
} // namespace } // namespace
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
@ -1051,6 +1315,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// Supported graph patterns. // Supported graph patterns.
// clang-format off // clang-format off
FusedBatchNorm fused_batch_norm; FusedBatchNorm fused_batch_norm;
FusedBatchNormEx fused_batch_norm_ex;
ContractionWithBiasAdd contract_with_bias; ContractionWithBiasAdd contract_with_bias;
ContractionWithBiasAddAndActivation contract_with_bias_and_activation; ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
#ifndef INTEL_MKL #ifndef INTEL_MKL
@ -1078,9 +1343,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// and Activation nodes that were fused into a Conv2D node. // and Activation nodes that were fused into a Conv2D node.
absl::flat_hash_set<const NodeDef*> invalidated_nodes; absl::flat_hash_set<const NodeDef*> invalidated_nodes;
// _FusedMatMul and _FusedConv2D kernels do not have registered gradient // _Fused{...} kernels do not have registered gradient function, so we must
// function, so we must not perform rewrite if the graph will be // not perform rewrite if the graph will be differentiated later.
// differentiated later.
bool allow_non_differentiable_rewrites = bool allow_non_differentiable_rewrites =
item.optimization_options().allow_non_differentiable_rewrites; item.optimization_options().allow_non_differentiable_rewrites;
@ -1161,16 +1425,25 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
#endif // !INTEL_MKL #endif // !INTEL_MKL
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) { if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, node)) {
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
// TODO(rmlarsen): Get rid of tensor value copies. // TODO(rmlarsen): Get rid of tensor value copies.
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically( TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
/*assume_valid_feeds=*/false, assume_valid_feeds,
/*aggressive_shape_inference=*/false, /*aggressive_shape_inference=*/false,
/*include_input_tensor_values=*/true, /*include_input_tensor_values=*/true,
/*include_output_tensor_values=*/false)); /*include_output_tensor_values=*/false));
ctx.inferred_graph_properties = true; ctx.inferred_graph_properties = true;
} }
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
if (allow_non_differentiable_rewrites &&
FindFusedBatchNormEx(ctx, &node, &fused_batch_norm_ex)) {
AddFusedBatchNormExNode(fused_batch_norm_ex, optimized_graph,
&invalidated_nodes);
continue;
}
// During inference, most of the inputs to FusedBatchNorm are constant, and // During inference, most of the inputs to FusedBatchNorm are constant, and
// we can therefore replace the op with a much cheaper set of primitives. // we can therefore replace the op with a much cheaper set of primitives.
if (FindFusedBatchNorm(ctx, &node, &fused_batch_norm)) { if (FindFusedBatchNorm(ctx, &node, &fused_batch_norm)) {

View File

@ -26,7 +26,7 @@ namespace grappler {
// nodes to decrease the amount of operations needed to perform a computation. // nodes to decrease the amount of operations needed to perform a computation.
class Remapper : public GraphOptimizer { class Remapper : public GraphOptimizer {
public: public:
explicit Remapper(RewriterConfig::Toggle opt_level) {} explicit Remapper(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {}
~Remapper() override {} ~Remapper() override {}
@ -37,6 +37,9 @@ class Remapper : public GraphOptimizer {
void Feedback(Cluster* cluster, const GrapplerItem& item, void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override; const GraphDef& optimized_graph, double result) override;
private:
RewriterConfig::Toggle opt_level_;
}; };
} // end namespace grappler } // end namespace grappler

View File

@ -22,11 +22,20 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
class RemapperTest : public GrapplerTest { class RemapperTest : public GrapplerTest {
protected: protected:
void SetUp() override {
// This is a requirement for fusing FusedBatchNorm + SideInput + Activation.
setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
}
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for // TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
// contractions with non-default contraction output kernels. // contractions with non-default contraction output kernels.
bool EigenSupportsContractionOutputKernel() { bool EigenSupportsContractionOutputKernel() {
@ -102,6 +111,179 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
} }
} }
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
using ::tensorflow::ops::Placeholder;
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
"CUDNN_VERSION >= 7402.";
#else
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t},
{"scale", scale_t},
{"offset", offset_t},
{"mean", mean_t},
{"var", var_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("fused_batch_norm", node.input(0));
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ("_FusedBatchNormEx", node.op());
EXPECT_EQ("input_cast", node.input(0));
EXPECT_EQ("scale", node.input(1));
EXPECT_EQ("offset", node.input(2));
EXPECT_EQ("mean", node.input(3));
EXPECT_EQ("var", node.input(4));
auto attr = node.attr();
EXPECT_EQ(0, attr["num_side_inputs"].i());
EXPECT_EQ("Relu", attr["activation_mode"].s());
found++;
}
}
EXPECT_EQ(2, found);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
}
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
using ::tensorflow::ops::Placeholder;
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
"CUDNN_VERSION >= 7402.";
#else
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
auto side_input =
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
auto side_input_cast =
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
auto relu = ops::Relu(s.WithOpName("relu"), add);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"scale", scale_t},
{"offset", offset_t}, {"mean", mean_t},
{"var", var_t}, {"side_input", side_input_t}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("fused_batch_norm", node.input(0));
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ("_FusedBatchNormEx", node.op());
EXPECT_EQ("input_cast", node.input(0));
EXPECT_EQ("scale", node.input(1));
EXPECT_EQ("offset", node.input(2));
EXPECT_EQ("mean", node.input(3));
EXPECT_EQ("var", node.input(4));
EXPECT_EQ("side_input_cast", node.input(5));
auto attr = node.attr();
EXPECT_EQ(1, attr["num_side_inputs"].i());
EXPECT_EQ("Relu", attr["activation_mode"].s());
found++;
}
}
EXPECT_EQ(2, found);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
}
TEST_F(RemapperTest, FuseConv2DWithBias) { TEST_F(RemapperTest, FuseConv2DWithBias) {
if (!EigenSupportsContractionOutputKernel()) return; if (!EigenSupportsContractionOutputKernel()) return;

View File

@ -1814,6 +1814,7 @@ tf_cuda_cc_test(
":ops_util", ":ops_util",
":relu_op", ":relu_op",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session", "//tensorflow/core:direct_session",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -1825,6 +1826,7 @@ tf_cuda_cc_test(
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"//tensorflow/stream_executor/cuda:cudnn_plugin", "//tensorflow/stream_executor/cuda:cudnn_plugin",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -14,8 +14,10 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_builder.h"
@ -29,6 +31,10 @@ limitations under the License.
#include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cudnn/cudnn.h"
#endif // GOOGLE_CUDA
namespace tensorflow { namespace tensorflow {
template <typename T, typename U> template <typename T, typename U>
@ -39,25 +45,46 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
} }
protected: protected:
struct FusedBatchNormOutputs {
Tensor y;
Tensor batch_mean;
Tensor batch_variance;
Tensor reserve_space_1;
Tensor reserve_space_2;
Tensor reserve_space_3;
};
struct FusedBatchNormGradOutputs {
Tensor y_backprop;
Tensor x_backprop;
Tensor scale_backprop;
Tensor offset_backprop;
Tensor reserve_space_4;
Tensor reserve_space_5;
};
using GraphRunner = std::function<void( using GraphRunner = std::function<void(
const Tensor& input_data, const Tensor& scale_data, const Tensor& y_backprop, const Tensor& input_data,
const Tensor& offset_data, const Tensor& mean_data, const Tensor& scale_data, const Tensor& offset_data,
const Tensor& var_data, const Tensor& side_input_data, Tensor* out)>; const Tensor& mean_data, const Tensor& var_data,
const Tensor& side_input_data, FusedBatchNormOutputs* forward,
FusedBatchNormGradOutputs* backward)>;
// Runs a Tensorflow graph defined by the root scope, and fetches the result // Runs a Tensorflow graph defined by the root scope, and fetches the result
// of 'fetch' node into the output Tensor. Optional `fetch_node` parameter // of 'fetch' node into the outputs. Optional `add_nodes` parameter
// allows to define a fetch node directly using a NodeDef for the ops that are // allows to define nodes directly using a NodeDef for the ops that are
// not supported by the C++ Api. // not supported by the C++ Api.
// TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests. // TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
// Add a base class for all FusedABC kernels and remove code duplication. // Add a base class for all FusedABC kernels and remove code duplication.
void RunAndFetch(const tensorflow::Scope& root, const string& fetch, void RunAndFetch(const tensorflow::Scope& root,
Tensor* output, bool allow_gpu_device, const std::vector<string>& fetch,
const NodeDef* fetch_node = nullptr) { std::vector<Tensor>* outputs, bool allow_gpu_device,
const std::vector<const NodeDef*> add_nodes = {}) {
tensorflow::GraphDef graph; tensorflow::GraphDef graph;
TF_ASSERT_OK(root.ToGraphDef(&graph)); TF_ASSERT_OK(root.ToGraphDef(&graph));
if (fetch_node) { for (const NodeDef* add_node : add_nodes) {
*graph.add_node() = *fetch_node; *graph.add_node() = *add_node;
} }
// We really want to make sure that graph executed exactly as we passed it // We really want to make sure that graph executed exactly as we passed it
@ -101,64 +128,22 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
} }
TF_ASSERT_OK(session->Create(graph)); TF_ASSERT_OK(session->Create(graph));
TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
std::vector<Tensor> unfused_tensors;
TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
*output = unfused_tensors[0];
} }
void RunFusedBatchNorm(const Tensor& input_data, const Tensor& scale_data, void RunFusedBatchNorm(const Tensor& y_backprop_data,
const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data, const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data, const Tensor& side_input_data, const Tensor& var_data, const Tensor& side_input_data,
const TensorFormat data_format, bool is_training, const TensorFormat data_format, bool is_training,
bool has_side_input, const string& activation_mode, bool has_side_input, const string& activation_mode,
Tensor* output, float epsilon = 0.1f) { FusedBatchNormOutputs* forward,
FusedBatchNormGradOutputs* backward,
float epsilon = 0.1f) {
Scope root = tensorflow::Scope::NewRootScope(); Scope root = tensorflow::Scope::NewRootScope();
ops::FusedBatchNormV2 fbn = ops::FusedBatchNormV2( Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
root.WithOpName("fused_batch_norm"), Input::Initializer(y_backprop_data));
ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)),
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)),
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)),
ops::Const(root.WithOpName("var"), Input::Initializer(var_data)),
ops::FusedBatchNormV2::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat(ToString(data_format)));
Output with_side_input;
if (has_side_input) {
with_side_input =
ops::Add(root.WithOpName("with_side_input"), fbn.y,
ops::Const(root.WithOpName("side_input"),
Input::Initializer(side_input_data)));
} else {
with_side_input =
ops::Identity(root.WithOpName("with_side_input"), fbn.y);
}
if (activation_mode == "Relu") {
ops::Relu(root.WithOpName("with_activation"), with_side_input);
} else {
ops::Identity(root.WithOpName("with_activation"), with_side_input);
}
RunAndFetch(root, "with_activation", output, /*allow_gpu_device=*/true);
}
void RunFusedBatchNormEx(const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data,
const Tensor& side_input_data,
const TensorFormat data_format, bool is_training,
bool has_side_input, const string& activation_mode,
Tensor* output, float epsilon = 0.1f) {
Scope root = tensorflow::Scope::NewRootScope();
DataType t_dtype = DataTypeToEnum<T>::v();
DataType u_dtype = DataTypeToEnum<U>::v();
Output input = Output input =
ops::Const(root.WithOpName("input"), Input::Initializer(input_data)); ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
Output scale = Output scale =
@ -172,6 +157,101 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
Output side_input = ops::Const(root.WithOpName("side_input"), Output side_input = ops::Const(root.WithOpName("side_input"),
Input::Initializer(side_input_data)); Input::Initializer(side_input_data));
ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat(ToString(data_format)));
Output with_side_input;
if (has_side_input) {
with_side_input =
ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
} else {
with_side_input =
ops::Identity(root.WithOpName("with_side_input"), fwd.y);
}
Output activation;
if (activation_mode == "Relu") {
activation =
ops::Relu(root.WithOpName("with_activation"), with_side_input);
} else {
activation =
ops::Identity(root.WithOpName("with_activation"), with_side_input);
}
Output activation_grad;
if (activation_mode == "Relu") {
activation_grad = ops::internal::ReluGrad(
root.WithOpName("activation_grad"), y_backprop, activation);
} else {
activation_grad =
ops::Identity(root.WithOpName("activation_grad"), y_backprop);
}
ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
ops::FusedBatchNormGradV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat(ToString(data_format)));
std::vector<Tensor> out_tensors;
RunAndFetch(
root,
{"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
"fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
"activation_grad:0", "fused_batch_norm_grad:0",
"fused_batch_norm_grad:1", "fused_batch_norm_grad:2"},
&out_tensors, /*allow_gpu_device=*/true);
forward->y = out_tensors[0];
forward->batch_mean = out_tensors[1];
forward->batch_variance = out_tensors[2];
forward->reserve_space_1 = out_tensors[3];
forward->reserve_space_2 = out_tensors[4];
forward->reserve_space_3 = out_tensors[5];
backward->y_backprop = out_tensors[6];
backward->x_backprop = out_tensors[7];
backward->scale_backprop = out_tensors[8];
backward->offset_backprop = out_tensors[9];
}
void RunFusedBatchNormEx(const Tensor& y_backprop_data,
const Tensor& input_data, const Tensor& scale_data,
const Tensor& offset_data, const Tensor& mean_data,
const Tensor& var_data,
const Tensor& side_input_data,
const TensorFormat data_format, bool is_training,
bool has_side_input, const string& activation_mode,
FusedBatchNormOutputs* forward,
FusedBatchNormGradOutputs* backward,
float epsilon = 0.1f) {
Scope root = tensorflow::Scope::NewRootScope();
DataType t_dtype = DataTypeToEnum<T>::v();
DataType u_dtype = DataTypeToEnum<U>::v();
Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
Input::Initializer(y_backprop_data));
Output input =
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
Output scale =
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
Output offset =
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
Output mean =
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
Output var =
ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
Output side_input = ops::Const(root.WithOpName("side_input"),
Input::Initializer(side_input_data));
Output empty =
ops::Const(root.WithOpName("empty"),
Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
int num_side_inputs = 0; int num_side_inputs = 0;
std::vector<NodeDefBuilder::NodeOut> side_inputs; std::vector<NodeDefBuilder::NodeOut> side_inputs;
@ -197,8 +277,58 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
.Attr("is_training", is_training) .Attr("is_training", is_training)
.Finalize(&fused_batch_norm_ex)); .Finalize(&fused_batch_norm_ex));
RunAndFetch(root, fused_batch_norm_ex.name(), output, NodeDef activation_grad;
/*allow_gpu_device=*/true, &fused_batch_norm_ex); if (activation_mode == "Relu") {
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
.Input({y_backprop.name(), 0, t_dtype})
.Input({fused_batch_norm_ex.name(), 0, t_dtype})
.Attr("T", t_dtype)
.Finalize(&activation_grad));
} else {
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
.Input({y_backprop.name(), 0, t_dtype})
.Attr("T", t_dtype)
.Finalize(&activation_grad));
}
NodeDef fused_batch_norm_grad;
TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
.Input({activation_grad.name(), 0, t_dtype})
.Input({input.name(), 0, t_dtype})
.Input({scale.name(), 0, u_dtype})
.Input({fused_batch_norm_ex.name(), 3, u_dtype})
.Input({fused_batch_norm_ex.name(), 4, u_dtype})
.Input({fused_batch_norm_ex.name(), 5, u_dtype})
.Attr("T", t_dtype)
.Attr("U", u_dtype)
.Attr("data_format", ToString(data_format))
.Attr("epsilon", epsilon)
.Attr("is_training", is_training)
.Finalize(&fused_batch_norm_grad));
std::vector<Tensor> out_tensors;
RunAndFetch(
root,
{"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
"fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
"fused_batch_norm_ex:4", "fused_batch_norm_ex:5", "activation_grad:0",
"fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
"fused_batch_norm_grad:2"},
&out_tensors,
/*allow_gpu_device=*/true,
{&fused_batch_norm_ex, &activation_grad, &fused_batch_norm_grad});
forward->y = out_tensors[0];
forward->batch_mean = out_tensors[1];
forward->batch_variance = out_tensors[2];
forward->reserve_space_1 = out_tensors[3];
forward->reserve_space_2 = out_tensors[4];
forward->reserve_space_3 = out_tensors[5];
backward->y_backprop = out_tensors[6];
backward->x_backprop = out_tensors[7];
backward->scale_backprop = out_tensors[8];
backward->offset_backprop = out_tensors[9];
} }
void VerifyTensorsNear(int batch, int height, int width, int channels, void VerifyTensorsNear(int batch, int height, int width, int channels,
@ -215,6 +345,7 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
Tensor input(t_dtype, input_shape); Tensor input(t_dtype, input_shape);
input.flat<T>().setRandom(); input.flat<T>().setRandom();
input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
Tensor scale(u_dtype, {channels}); Tensor scale(u_dtype, {channels});
scale.flat<U>().setRandom(); scale.flat<U>().setRandom();
@ -228,29 +359,60 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
Tensor var(u_dtype, {channels}); Tensor var(u_dtype, {channels});
var.flat<U>().setRandom(); var.flat<U>().setRandom();
Tensor empty(u_dtype, {0});
Tensor fused_batch_norm;
Tensor fused_batch_norm_ex;
Tensor side_input(t_dtype, input_shape); Tensor side_input(t_dtype, input_shape);
side_input.flat<T>().setRandom(); side_input.flat<T>().setRandom();
side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
run_default(input, scale, offset, is_training ? empty : mean, Tensor y_backprop(t_dtype, input_shape);
is_training ? empty : var, side_input, &fused_batch_norm); y_backprop.flat<T>().setRandom();
y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
// Write some garbage to the `fused_batch_norm_ex` first to make sure Tensor empty(u_dtype, {0});
// that fused kernel actually writes correct results to memory.
run_default(side_input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, input, &fused_batch_norm_ex);
run_fused(input, scale, offset, is_training ? empty : mean, FusedBatchNormOutputs fbn_forward;
is_training ? empty : var, side_input, &fused_batch_norm_ex); FusedBatchNormOutputs fbn_ex_forward;
ASSERT_EQ(fused_batch_norm.dtype(), fused_batch_norm_ex.dtype()); FusedBatchNormGradOutputs fbn_backward;
ASSERT_EQ(fused_batch_norm.shape(), fused_batch_norm_ex.shape()); FusedBatchNormGradOutputs fbn_ex_backward;
test::ExpectClose(fused_batch_norm, fused_batch_norm_ex, 1e-2); run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, side_input, &fbn_forward,
&fbn_backward);
// Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
// make sure that fused kernel actually writes correct results to memory.
run_default(y_backprop, side_input, scale, offset,
is_training ? empty : mean, is_training ? empty : var, input,
&fbn_ex_forward, &fbn_ex_backward);
run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
is_training ? empty : var, side_input, &fbn_ex_forward,
&fbn_ex_backward);
std::vector<std::pair<Tensor, Tensor>> tensor_pairs = {
{fbn_forward.y, fbn_ex_forward.y},
{fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
{fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
{fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
{fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
// NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
// because BatchNormEx with fused side input has different data in it,
// but we make sure that final gradients are the same.
{fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
{fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
{fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
{fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
};
for (auto& pair : tensor_pairs) {
const Tensor& fbn = pair.first;
const Tensor& fbn_ex = pair.second;
ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
ASSERT_EQ(fbn.shape(), fbn_ex.shape());
test::ExpectClose(fbn, fbn_ex, 1e-2);
}
} }
// Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is // Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
@ -260,25 +422,27 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
bool has_side_input, bool has_side_input,
const string& activation_mode) { const string& activation_mode) {
const GraphRunner run_default = const GraphRunner run_default =
[&](const Tensor& input_data, const Tensor& scale_data, [&](const Tensor& y_backprop, const Tensor& input_data,
const Tensor& offset_data, const Tensor& mean_data, const Tensor& scale_data, const Tensor& offset_data,
const Tensor& var_data, const Tensor& side_input_data, const Tensor& mean_data, const Tensor& var_data,
Tensor* out) { const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
this->RunFusedBatchNorm(input_data, scale_data, offset_data, FusedBatchNormGradOutputs* bwd) {
mean_data, var_data, side_input_data, this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
data_format, is_training, has_side_input, offset_data, mean_data, var_data,
activation_mode, out); side_input_data, data_format, is_training,
has_side_input, activation_mode, fwd, bwd);
}; };
const GraphRunner run_inference = const GraphRunner run_inference =
[&](const Tensor& input_data, const Tensor& scale_data, [&](const Tensor& y_backprop, const Tensor& input_data,
const Tensor& offset_data, const Tensor& mean_data, const Tensor& scale_data, const Tensor& offset_data,
const Tensor& var_data, const Tensor& side_input_data, const Tensor& mean_data, const Tensor& var_data,
Tensor* out) { const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
this->RunFusedBatchNormEx(input_data, scale_data, offset_data, FusedBatchNormGradOutputs* bwd) {
mean_data, var_data, side_input_data, this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
data_format, is_training, has_side_input, offset_data, mean_data, var_data,
activation_mode, out); side_input_data, data_format, is_training,
has_side_input, activation_mode, fwd, bwd);
}; };
VerifyTensorsNear(batch, height, width, channels, data_format, is_training, VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
@ -297,17 +461,17 @@ constexpr bool kWithSideInput = true; // side_input == true
TYPED_TEST_SUITE_P(FusedBatchNormExOpTest); TYPED_TEST_SUITE_P(FusedBatchNormExOpTest);
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingInNHWCTest) { TYPED_TEST_P(FusedBatchNormExOpTest, TrainingInNHWCTest) {
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining, this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
kNoSideInput, "Identity"); kNoSideInput, "Identity");
} }
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithReluInNHWCTest) { TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithReluInNHWCTest) {
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining, this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
kNoSideInput, "Relu"); kNoSideInput, "Relu");
} }
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithSideInputAndReluInNHWCTest) { TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithSideInputAndReluInNHWCTest) {
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining, this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
kWithSideInput, "Relu"); kWithSideInput, "Relu");
} }

View File

@ -692,8 +692,8 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
<< " y_backprop shape: " << y_backprop.shape().DebugString() << " y_backprop shape: " << y_backprop.shape().DebugString()
<< " x shape: " << x.shape().DebugString() << " x shape: " << x.shape().DebugString()
<< " scale shape: " << scale.shape().DebugString() << " scale shape: " << scale.shape().DebugString()
<< " tensor format: " << tensor_format << " tensor format: " << ToString(tensor_format)
<< " compute format: " << compute_format; << " compute format: " << ToString(compute_format);
// Inputs // Inputs
Tensor y_backprop_maybe_transformed = y_backprop; Tensor y_backprop_maybe_transformed = y_backprop;