[Grappler] Fuse FusedBatchNorm + <SideInput> + <Activation>.
Resnet50 in NHWC: ~500 img/sec -> ~800 img/sec* (*) with disabled layout optimizer. PiperOrigin-RevId: 252108889
This commit is contained in:
parent
17a2326611
commit
92144e5bd6
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// XLA implementation of BatchNorm operations.
|
||||
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -28,7 +29,11 @@ namespace {
|
||||
|
||||
class FusedBatchNormOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
explicit FusedBatchNormOp(OpKernelConstruction* ctx)
|
||||
: FusedBatchNormOp(ctx, false) {}
|
||||
|
||||
FusedBatchNormOp(OpKernelConstruction* ctx, bool is_batch_norm_ex)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
|
||||
string data_format_str;
|
||||
@ -36,6 +41,26 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
OP_REQUIRES(
|
||||
ctx, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format: ", data_format_str));
|
||||
|
||||
if (is_batch_norm_ex) {
|
||||
int num_side_inputs;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_side_inputs", &num_side_inputs));
|
||||
OP_REQUIRES(ctx, num_side_inputs >= 0 && num_side_inputs <= 1,
|
||||
errors::InvalidArgument(
|
||||
"FusedBatchNormEx supports at most 1 side input."));
|
||||
add_side_input_ = (num_side_inputs == 1);
|
||||
string activation_mode;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode));
|
||||
OP_REQUIRES(ctx,
|
||||
activation_mode == "Identity" || activation_mode == "Relu",
|
||||
errors::InvalidArgument(
|
||||
"Unsupported FusedBatchNormEx activation mode: ",
|
||||
activation_mode));
|
||||
apply_relu_ = (activation_mode == "Relu");
|
||||
} else {
|
||||
add_side_input_ = false;
|
||||
apply_relu_ = false;
|
||||
}
|
||||
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
|
||||
}
|
||||
|
||||
@ -66,9 +91,18 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
|
||||
|
||||
// In training mode, outputs the normalized value as well as the
|
||||
// calculated mean and variance.
|
||||
ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0),
|
||||
input_type));
|
||||
// calculated mean and variance. Optionally we add side input and apply
|
||||
// relu activation.
|
||||
xla::XlaOp converted =
|
||||
xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type);
|
||||
if (add_side_input_ && apply_relu_) {
|
||||
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
|
||||
} else if (apply_relu_) {
|
||||
ctx->SetOutput(0, xla::Relu(converted));
|
||||
} else {
|
||||
ctx->SetOutput(0, converted);
|
||||
}
|
||||
|
||||
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
|
||||
xla::XlaOp variance = xla::GetTupleElement(output, 2);
|
||||
// Apply Bessel's correction.
|
||||
@ -103,7 +137,16 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
xla::XlaOp output = xla::BatchNormInference(
|
||||
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
|
||||
epsilon_, feature_index);
|
||||
ctx->SetOutput(0, xla::ConvertElementType(output, input_type));
|
||||
|
||||
xla::XlaOp converted = xla::ConvertElementType(output, input_type);
|
||||
if (add_side_input_ && apply_relu_) {
|
||||
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
|
||||
} else if (apply_relu_) {
|
||||
ctx->SetOutput(0, xla::Relu(converted));
|
||||
} else {
|
||||
ctx->SetOutput(0, converted);
|
||||
}
|
||||
|
||||
// Directly send input to output as mean and variance in inference mode.
|
||||
ctx->SetOutput(1, ctx->Input(3));
|
||||
ctx->SetOutput(2, ctx->Input(4));
|
||||
@ -116,6 +159,8 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
float epsilon_;
|
||||
TensorFormat data_format_;
|
||||
bool is_training_;
|
||||
bool add_side_input_;
|
||||
bool apply_relu_;
|
||||
bool is_on_gpu_;
|
||||
};
|
||||
|
||||
@ -131,17 +176,26 @@ class FusedBatchNormOpV3 : public FusedBatchNormOp {
|
||||
}
|
||||
ctx->SetConstantOutput(5, Tensor());
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
float epsilon_;
|
||||
TensorFormat data_format_;
|
||||
bool is_training_;
|
||||
bool is_on_gpu_;
|
||||
class FusedBatchNormOpEx : public FusedBatchNormOp {
|
||||
public:
|
||||
explicit FusedBatchNormOpEx(OpKernelConstruction* ctx)
|
||||
: FusedBatchNormOp(ctx, /*is_batch_norm_ex=*/true) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
FusedBatchNormOp::CompileImpl(ctx);
|
||||
if (!ctx->status().ok()) {
|
||||
return;
|
||||
}
|
||||
ctx->SetConstantOutput(5, Tensor());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
|
||||
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
|
||||
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
|
||||
REGISTER_XLA_OP(Name("_FusedBatchNormEx"), FusedBatchNormOpEx);
|
||||
|
||||
class FusedBatchNormGradOp : public XlaOpKernel {
|
||||
public:
|
||||
|
@ -781,6 +781,7 @@ tf_kernel_library(
|
||||
":constant_folding",
|
||||
":graph_optimizer",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
|
@ -26,6 +26,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -40,12 +45,17 @@ namespace grappler {
|
||||
// MatMul + ... -> _FusedMatMul:
|
||||
// (1) MatMul + BiasAdd + <Activation>
|
||||
//
|
||||
// FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
|
||||
// (1) FusedBatchNorm + <Activation>
|
||||
// (2) FusedBatchNorm + SideInput + <Activation>
|
||||
//
|
||||
// Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
|
||||
// patterns are "ContractionWith...".
|
||||
namespace {
|
||||
|
||||
constexpr char kFusedConv2D[] = "_FusedConv2D";
|
||||
constexpr char kFusedMatMul[] = "_FusedMatMul";
|
||||
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
|
||||
|
||||
constexpr char kDataFormat[] = "data_format";
|
||||
constexpr char kIsTraining[] = "is_training";
|
||||
@ -81,6 +91,17 @@ struct FusedBatchNorm {
|
||||
const NodeDef* fused_batch_norm = nullptr;
|
||||
};
|
||||
|
||||
// FusedBatchNorm[$is_training] with fused side input and/or activation.
|
||||
struct FusedBatchNormEx {
|
||||
FusedBatchNormEx() = default;
|
||||
|
||||
const NodeDef* fused_batch_norm = nullptr;
|
||||
const NodeDef* side_input = nullptr;
|
||||
const NodeDef* activation = nullptr;
|
||||
// Add node that will be invalidated by fusing side input and fused batch norm
|
||||
const NodeDef* invalidated = nullptr;
|
||||
};
|
||||
|
||||
// Contraction node followed by a BiasAdd.
|
||||
struct ContractionWithBiasAdd {
|
||||
ContractionWithBiasAdd() = default;
|
||||
@ -445,12 +466,11 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
|
||||
ContractionWithBatchNorm* matched) {
|
||||
if (!EigenSupportsContractionOutputKernel()) return false;
|
||||
|
||||
// Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2.
|
||||
// Root of the pattern must be a FusedBatchNorm.
|
||||
if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false;
|
||||
|
||||
// V2 has a separate data type for the scale/offset/mean/variance inputs.
|
||||
if ((batch_norm->op() == "FusedBatchNormV2" ||
|
||||
batch_norm->op() == "FusedBatchNormV3") &&
|
||||
// FusedBatchNormV2 and V3 have an extra type parameter.
|
||||
if (batch_norm->op() != "FusedBatchNorm" &&
|
||||
!HasDataType(batch_norm, DT_FLOAT, "U"))
|
||||
return false;
|
||||
|
||||
@ -602,23 +622,15 @@ bool FindContractionWithBiasAndAddActivation(
|
||||
}
|
||||
#endif
|
||||
|
||||
// Check that given node meets some basic FusedBatchNorm optimization
|
||||
// preconditions. We use this check to lazily infer graph properties which is
|
||||
// rather expensive.
|
||||
bool IsFusedBatchNormCandidate(const NodeDef& node) {
|
||||
if (!IsFusedBatchNorm(node)) return false;
|
||||
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
|
||||
|
||||
// Check that the node is in inference mode.
|
||||
const auto& attr = node.attr();
|
||||
if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
|
||||
FusedBatchNorm* matched) {
|
||||
if (!IsFusedBatchNormCandidate(*node)) return false;
|
||||
if (!IsFusedBatchNorm(*node)) return false;
|
||||
if (GetDataTypeFromAttr(*node, "T") != DT_FLOAT) return false;
|
||||
|
||||
// Check that the node is in inference mode.
|
||||
bool is_training = true;
|
||||
if (!GetNodeAttr(*node, kIsTraining, &is_training).ok()) return false;
|
||||
if (is_training) return false;
|
||||
|
||||
const auto& props = ctx.graph_properties.GetInputProperties(node->name());
|
||||
|
||||
@ -649,6 +661,139 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
|
||||
return true;
|
||||
}
|
||||
|
||||
// NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
|
||||
// `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
|
||||
bool BatchnormSpatialPersistentEnabled() {
|
||||
#if CUDNN_VERSION >= 7402
|
||||
static bool is_enabled = [] {
|
||||
bool is_enabled = false;
|
||||
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
|
||||
"TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
|
||||
/*default_val=*/false, &is_enabled));
|
||||
return is_enabled;
|
||||
}();
|
||||
return is_enabled;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool FindFusedBatchNormEx(const RemapperContext& ctx, const NodeDef* node,
|
||||
FusedBatchNormEx* matched) {
|
||||
// Root of the pattern must be a Relu.
|
||||
// TODO(ezhulenev): Forward control dependencies.
|
||||
if (!IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
|
||||
return false;
|
||||
|
||||
const NodeDef* relu = node;
|
||||
|
||||
// Returns true iff the node is a compatible FusedBatchNorm node.
|
||||
const auto valid_batch_norm = [&](const NodeDef* fused_batch_norm) -> bool {
|
||||
if (fused_batch_norm == nullptr || !IsFusedBatchNorm(*fused_batch_norm))
|
||||
return false;
|
||||
|
||||
AttrSlice attr(*fused_batch_norm);
|
||||
|
||||
// We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with
|
||||
// contraction (MatMul or Conv2D node).
|
||||
if (!NodeIsOnGpu(fused_batch_norm)) return false;
|
||||
|
||||
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm, "T");
|
||||
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
|
||||
|
||||
// Get the FusedBatchNorm training mode.
|
||||
bool is_training;
|
||||
if (!GetNodeAttr(attr, kIsTraining, &is_training).ok()) return false;
|
||||
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
|
||||
if (!is_training) return false;
|
||||
|
||||
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
||||
// inputs and activation, and it has its own limitations. In inference mode
|
||||
// we have a custom CUDA kernel that doesn't not have these constraints.
|
||||
if (is_training) {
|
||||
// cuDNN only supports NHWC data layout.
|
||||
string data_format;
|
||||
if (!GetNodeAttr(attr, kDataFormat, &data_format).ok()) return false;
|
||||
if (data_format != "NHWC") return false;
|
||||
|
||||
// Data type must be DT_HALF.
|
||||
if (t_dtype != DT_HALF) return false;
|
||||
|
||||
// Channel dimension must be a multiple of 4.
|
||||
const auto& props =
|
||||
ctx.graph_properties.GetInputProperties(fused_batch_norm->name());
|
||||
|
||||
const bool valid_channel_dim = !props.empty() &&
|
||||
props[0].shape().dim_size() == 4 &&
|
||||
props[0].shape().dim(3).size() % 4 == 0;
|
||||
if (!valid_channel_dim) return false;
|
||||
|
||||
// cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
|
||||
if (!BatchnormSpatialPersistentEnabled()) return false;
|
||||
}
|
||||
|
||||
// FusedBatchNormV2 and V3 have an extra type parameter.
|
||||
if ((fused_batch_norm->op() != "FusedBatchNorm") &&
|
||||
!HasDataType(fused_batch_norm, DT_FLOAT, "U"))
|
||||
return false;
|
||||
|
||||
// Check that only one node consumes the output of a FusedBatchNorm.
|
||||
if (HasControlFaninOrFanout(ctx.graph_view, fused_batch_norm) ||
|
||||
!HasSingleFanoutNode(ctx.graph_view, fused_batch_norm) ||
|
||||
IsInPreserveSet(ctx, fused_batch_norm))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
const auto relu_input_port = GraphView::InputPort(relu, 0);
|
||||
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
|
||||
if (!relu_fanin.node) return false;
|
||||
|
||||
// Input to a Relu can be a FusedBatchNorm.
|
||||
if (valid_batch_norm(relu_fanin.node)) {
|
||||
matched->activation = relu;
|
||||
matched->side_input = nullptr;
|
||||
matched->fused_batch_norm = relu_fanin.node;
|
||||
matched->invalidated = nullptr;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
|
||||
if (IsAdd(*relu_fanin.node)) {
|
||||
const NodeDef* add = relu_fanin.node;
|
||||
|
||||
// Check that only Relu node consumes the output of an Add node.
|
||||
if (HasControlFaninOrFanout(ctx.graph_view, add) ||
|
||||
!HasSingleFanoutNode(ctx.graph_view, add) || IsInPreserveSet(ctx, add))
|
||||
return false;
|
||||
|
||||
const auto add_input_port_0 = GraphView::InputPort(add, 0);
|
||||
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
|
||||
|
||||
const auto add_input_port_1 = GraphView::InputPort(add, 1);
|
||||
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
|
||||
|
||||
if (valid_batch_norm(add_fanin_0.node)) {
|
||||
matched->activation = relu;
|
||||
matched->side_input = add_fanin_1.node;
|
||||
matched->fused_batch_norm = add_fanin_0.node;
|
||||
matched->invalidated = add;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (valid_batch_norm(add_fanin_1.node)) {
|
||||
matched->activation = relu;
|
||||
matched->side_input = add_fanin_0.node;
|
||||
matched->fused_batch_norm = add_fanin_1.node;
|
||||
matched->invalidated = add;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
|
||||
DCHECK(IsConv2D(*conv2d)) << "Input node must be a Conv2D";
|
||||
|
||||
@ -664,6 +809,27 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
|
||||
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
||||
}
|
||||
|
||||
void CopyFusedBatchNormAttributes(const NodeDef* fused_batch_norm,
|
||||
NodeDef* fused_batch_norm_ex) {
|
||||
DCHECK(IsFusedBatchNorm(*fused_batch_norm))
|
||||
<< "Input node must be a FusedBatchNorm";
|
||||
|
||||
auto* attr = fused_batch_norm_ex->mutable_attr();
|
||||
auto src_attr = fused_batch_norm->attr();
|
||||
|
||||
(*attr)["T"] = src_attr.at("T");
|
||||
(*attr)["is_training"] = src_attr.at("is_training");
|
||||
(*attr)["data_format"] = src_attr.at("data_format");
|
||||
(*attr)["epsilon"] = src_attr.at("epsilon");
|
||||
|
||||
// FusedBatchNormV2 and V3 have an extra type parameter.
|
||||
if (fused_batch_norm->op() != "FusedBatchNorm") {
|
||||
(*attr)["U"] = src_attr.at("U");
|
||||
} else {
|
||||
(*attr)["U"] = src_attr.at("T");
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMatMulAttributes(const NodeDef* matmul, NodeDef* fused_matmul) {
|
||||
DCHECK(IsMatMul(*matmul)) << "Input node must be a MatMul";
|
||||
|
||||
@ -902,6 +1068,55 @@ void AddFusedContractionNode(
|
||||
}
|
||||
#endif
|
||||
|
||||
void AddFusedBatchNormExNode(
|
||||
const FusedBatchNormEx& matched, GraphDef* optimized_graph,
|
||||
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
||||
VLOG(2) << "Fuse " << matched.activation->op() << " with FusedBatchNorm:"
|
||||
<< " side_input="
|
||||
<< (matched.side_input ? matched.side_input->name() : "<none>")
|
||||
<< " activation=" << matched.activation->name()
|
||||
<< " fused_batch_norm=" << matched.fused_batch_norm->name();
|
||||
|
||||
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
|
||||
NodeDef* fused_op = optimized_graph->add_node();
|
||||
fused_op->set_op(kFusedBatchNormEx);
|
||||
fused_op->set_name(matched.fused_batch_norm->name());
|
||||
fused_op->set_device(matched.fused_batch_norm->device());
|
||||
|
||||
fused_op->add_input(matched.fused_batch_norm->input(0)); // 0: input
|
||||
fused_op->add_input(matched.fused_batch_norm->input(1)); // 1: scale
|
||||
fused_op->add_input(matched.fused_batch_norm->input(2)); // 2: offset
|
||||
fused_op->add_input(matched.fused_batch_norm->input(3)); // 3: estimated_mean
|
||||
fused_op->add_input(matched.fused_batch_norm->input(4)); // 4: estimated_var
|
||||
|
||||
CopyFusedBatchNormAttributes(matched.fused_batch_norm, fused_op);
|
||||
|
||||
auto* attrs = fused_op->mutable_attr();
|
||||
SetAttrValue(matched.activation->op(), &(*attrs)["activation_mode"]);
|
||||
|
||||
if (matched.side_input != nullptr) {
|
||||
SetAttrValue(1, &(*attrs)["num_side_inputs"]);
|
||||
fused_op->add_input(matched.side_input->name()); // 5: side_input
|
||||
} else {
|
||||
SetAttrValue(0, &(*attrs)["num_side_inputs"]);
|
||||
}
|
||||
|
||||
// Turn activation node into Identity node.
|
||||
NodeDef* identity_op = optimized_graph->add_node();
|
||||
identity_op->set_op("Identity");
|
||||
identity_op->set_name(matched.activation->name());
|
||||
identity_op->set_device(matched.fused_batch_norm->device());
|
||||
identity_op->add_input(matched.fused_batch_norm->name());
|
||||
(*identity_op->mutable_attr())["T"] = attrs->at("T");
|
||||
|
||||
// Invalidate all nodes bypassed by this rewrite.
|
||||
invalidated_nodes->insert(matched.activation);
|
||||
invalidated_nodes->insert(matched.fused_batch_norm);
|
||||
if (matched.side_input != nullptr) {
|
||||
invalidated_nodes->insert(matched.invalidated);
|
||||
}
|
||||
}
|
||||
|
||||
void AddBatchNormNodes(const FusedBatchNorm& matched,
|
||||
GraphDef* optimized_graph) {
|
||||
const NodeDef& fused_node = *matched.fused_batch_norm;
|
||||
@ -1044,6 +1259,55 @@ void AddBatchNormNodes(const FusedBatchNorm& matched,
|
||||
*r->add_input() = a->name();
|
||||
*r->add_input() = c->name();
|
||||
}
|
||||
|
||||
// Check if a node is a candidate to one of the patterns that require inferred
|
||||
// shapes:
|
||||
// (1) Splitting FusedBatchNorm into primitives.
|
||||
// (2) Fusing side input and/or activation into FusedBatchNorm.
|
||||
bool RequiresInferredShapes(const RemapperContext& ctx, const NodeDef& node) {
|
||||
// Candidate for a FusedBatchNorm splitting.
|
||||
const auto is_batch_norm_candidate = [&]() -> bool {
|
||||
if (!IsFusedBatchNorm(node)) return false;
|
||||
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
|
||||
|
||||
bool is_training = true;
|
||||
if (!GetNodeAttr(node, kIsTraining, &is_training).ok()) return false;
|
||||
if (is_training) return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
// Candidate for a FusedBatchNorm fusion.
|
||||
const auto is_batch_norm_fusion_candidate = [&]() -> bool {
|
||||
if (!IsRelu(node)) return false;
|
||||
|
||||
const auto relu_input_port = GraphView::InputPort(&node, 0);
|
||||
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
|
||||
if (!relu_fanin.node) return false;
|
||||
|
||||
if (IsFusedBatchNorm(*relu_fanin.node)) {
|
||||
// FusedBatchNorm + Relu.
|
||||
return true;
|
||||
|
||||
} else if (IsAdd(*relu_fanin.node)) {
|
||||
// FusedBatchNorm + Add + Relu.
|
||||
const NodeDef* add = relu_fanin.node;
|
||||
|
||||
const auto add_input_port_0 = GraphView::InputPort(add, 0);
|
||||
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
|
||||
if (IsFusedBatchNorm(*add_fanin_0.node)) return true;
|
||||
|
||||
const auto add_input_port_1 = GraphView::InputPort(add, 1);
|
||||
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
|
||||
if (IsFusedBatchNorm(*add_fanin_1.node)) return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
||||
@ -1051,6 +1315,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
||||
// Supported graph patterns.
|
||||
// clang-format off
|
||||
FusedBatchNorm fused_batch_norm;
|
||||
FusedBatchNormEx fused_batch_norm_ex;
|
||||
ContractionWithBiasAdd contract_with_bias;
|
||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||
#ifndef INTEL_MKL
|
||||
@ -1078,9 +1343,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
||||
// and Activation nodes that were fused into a Conv2D node.
|
||||
absl::flat_hash_set<const NodeDef*> invalidated_nodes;
|
||||
|
||||
// _FusedMatMul and _FusedConv2D kernels do not have registered gradient
|
||||
// function, so we must not perform rewrite if the graph will be
|
||||
// differentiated later.
|
||||
// _Fused{...} kernels do not have registered gradient function, so we must
|
||||
// not perform rewrite if the graph will be differentiated later.
|
||||
bool allow_non_differentiable_rewrites =
|
||||
item.optimization_options().allow_non_differentiable_rewrites;
|
||||
|
||||
@ -1161,16 +1425,25 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
||||
#endif // !INTEL_MKL
|
||||
|
||||
// Infer properties lazily in case they are not needed.
|
||||
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) {
|
||||
if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, node)) {
|
||||
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
// TODO(rmlarsen): Get rid of tensor value copies.
|
||||
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
assume_valid_feeds,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_input_tensor_values=*/true,
|
||||
/*include_output_tensor_values=*/false));
|
||||
ctx.inferred_graph_properties = true;
|
||||
}
|
||||
|
||||
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindFusedBatchNormEx(ctx, &node, &fused_batch_norm_ex)) {
|
||||
AddFusedBatchNormExNode(fused_batch_norm_ex, optimized_graph,
|
||||
&invalidated_nodes);
|
||||
continue;
|
||||
}
|
||||
|
||||
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
||||
// we can therefore replace the op with a much cheaper set of primitives.
|
||||
if (FindFusedBatchNorm(ctx, &node, &fused_batch_norm)) {
|
||||
|
@ -26,7 +26,7 @@ namespace grappler {
|
||||
// nodes to decrease the amount of operations needed to perform a computation.
|
||||
class Remapper : public GraphOptimizer {
|
||||
public:
|
||||
explicit Remapper(RewriterConfig::Toggle opt_level) {}
|
||||
explicit Remapper(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {}
|
||||
|
||||
~Remapper() override {}
|
||||
|
||||
@ -37,6 +37,9 @@ class Remapper : public GraphOptimizer {
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) override;
|
||||
|
||||
private:
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -22,11 +22,20 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class RemapperTest : public GrapplerTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
// This is a requirement for fusing FusedBatchNorm + SideInput + Activation.
|
||||
setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
|
||||
}
|
||||
|
||||
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
||||
// contractions with non-default contraction output kernels.
|
||||
bool EigenSupportsContractionOutputKernel() {
|
||||
@ -102,6 +111,179 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
|
||||
"CUDNN_VERSION >= 7402.";
|
||||
#else
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
||||
auto channels_shape = ops::Placeholder::Shape({24});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
||||
"NHWC"));
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t},
|
||||
{"scale", scale_t},
|
||||
{"offset", offset_t},
|
||||
{"mean", mean_t},
|
||||
{"var", var_t}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on GPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ("fused_batch_norm", node.input(0));
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ("_FusedBatchNormEx", node.op());
|
||||
EXPECT_EQ("input_cast", node.input(0));
|
||||
EXPECT_EQ("scale", node.input(1));
|
||||
EXPECT_EQ("offset", node.input(2));
|
||||
EXPECT_EQ("mean", node.input(3));
|
||||
EXPECT_EQ("var", node.input(4));
|
||||
|
||||
auto attr = node.attr();
|
||||
EXPECT_EQ(0, attr["num_side_inputs"].i());
|
||||
EXPECT_EQ("Relu", attr["activation_mode"].s());
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(2, found);
|
||||
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
EXPECT_EQ(1, tensors_expected.size());
|
||||
EXPECT_EQ(1, tensors.size());
|
||||
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
|
||||
}
|
||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
|
||||
"CUDNN_VERSION >= 7402.";
|
||||
#else
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
||||
auto channels_shape = ops::Placeholder::Shape({24});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
||||
auto side_input =
|
||||
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
|
||||
auto side_input_cast =
|
||||
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
||||
"NHWC"));
|
||||
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"scale", scale_t},
|
||||
{"offset", offset_t}, {"mean", mean_t},
|
||||
{"var", var_t}, {"side_input", side_input_t}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on GPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ("fused_batch_norm", node.input(0));
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ("_FusedBatchNormEx", node.op());
|
||||
EXPECT_EQ("input_cast", node.input(0));
|
||||
EXPECT_EQ("scale", node.input(1));
|
||||
EXPECT_EQ("offset", node.input(2));
|
||||
EXPECT_EQ("mean", node.input(3));
|
||||
EXPECT_EQ("var", node.input(4));
|
||||
EXPECT_EQ("side_input_cast", node.input(5));
|
||||
|
||||
auto attr = node.attr();
|
||||
EXPECT_EQ(1, attr["num_side_inputs"].i());
|
||||
EXPECT_EQ("Relu", attr["activation_mode"].s());
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(2, found);
|
||||
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
EXPECT_EQ(1, tensors_expected.size());
|
||||
EXPECT_EQ(1, tensors.size());
|
||||
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
|
||||
}
|
||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
||||
if (!EigenSupportsContractionOutputKernel()) return;
|
||||
|
||||
|
@ -1814,6 +1814,7 @@ tf_cuda_cc_test(
|
||||
":ops_util",
|
||||
":relu_op",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:framework",
|
||||
@ -1825,6 +1826,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/stream_executor/cuda:cudnn_plugin",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,8 +14,10 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
@ -29,6 +31,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T, typename U>
|
||||
@ -39,25 +45,46 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
||||
}
|
||||
|
||||
protected:
|
||||
struct FusedBatchNormOutputs {
|
||||
Tensor y;
|
||||
Tensor batch_mean;
|
||||
Tensor batch_variance;
|
||||
Tensor reserve_space_1;
|
||||
Tensor reserve_space_2;
|
||||
Tensor reserve_space_3;
|
||||
};
|
||||
|
||||
struct FusedBatchNormGradOutputs {
|
||||
Tensor y_backprop;
|
||||
Tensor x_backprop;
|
||||
Tensor scale_backprop;
|
||||
Tensor offset_backprop;
|
||||
Tensor reserve_space_4;
|
||||
Tensor reserve_space_5;
|
||||
};
|
||||
|
||||
using GraphRunner = std::function<void(
|
||||
const Tensor& input_data, const Tensor& scale_data,
|
||||
const Tensor& offset_data, const Tensor& mean_data,
|
||||
const Tensor& var_data, const Tensor& side_input_data, Tensor* out)>;
|
||||
const Tensor& y_backprop, const Tensor& input_data,
|
||||
const Tensor& scale_data, const Tensor& offset_data,
|
||||
const Tensor& mean_data, const Tensor& var_data,
|
||||
const Tensor& side_input_data, FusedBatchNormOutputs* forward,
|
||||
FusedBatchNormGradOutputs* backward)>;
|
||||
|
||||
// Runs a Tensorflow graph defined by the root scope, and fetches the result
|
||||
// of 'fetch' node into the output Tensor. Optional `fetch_node` parameter
|
||||
// allows to define a fetch node directly using a NodeDef for the ops that are
|
||||
// of 'fetch' node into the outputs. Optional `add_nodes` parameter
|
||||
// allows to define nodes directly using a NodeDef for the ops that are
|
||||
// not supported by the C++ Api.
|
||||
// TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
|
||||
// Add a base class for all FusedABC kernels and remove code duplication.
|
||||
void RunAndFetch(const tensorflow::Scope& root, const string& fetch,
|
||||
Tensor* output, bool allow_gpu_device,
|
||||
const NodeDef* fetch_node = nullptr) {
|
||||
void RunAndFetch(const tensorflow::Scope& root,
|
||||
const std::vector<string>& fetch,
|
||||
std::vector<Tensor>* outputs, bool allow_gpu_device,
|
||||
const std::vector<const NodeDef*> add_nodes = {}) {
|
||||
tensorflow::GraphDef graph;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
||||
|
||||
if (fetch_node) {
|
||||
*graph.add_node() = *fetch_node;
|
||||
for (const NodeDef* add_node : add_nodes) {
|
||||
*graph.add_node() = *add_node;
|
||||
}
|
||||
|
||||
// We really want to make sure that graph executed exactly as we passed it
|
||||
@ -101,64 +128,22 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
|
||||
std::vector<Tensor> unfused_tensors;
|
||||
TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
|
||||
|
||||
*output = unfused_tensors[0];
|
||||
TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
|
||||
}
|
||||
|
||||
void RunFusedBatchNorm(const Tensor& input_data, const Tensor& scale_data,
|
||||
void RunFusedBatchNorm(const Tensor& y_backprop_data,
|
||||
const Tensor& input_data, const Tensor& scale_data,
|
||||
const Tensor& offset_data, const Tensor& mean_data,
|
||||
const Tensor& var_data, const Tensor& side_input_data,
|
||||
const TensorFormat data_format, bool is_training,
|
||||
bool has_side_input, const string& activation_mode,
|
||||
Tensor* output, float epsilon = 0.1f) {
|
||||
FusedBatchNormOutputs* forward,
|
||||
FusedBatchNormGradOutputs* backward,
|
||||
float epsilon = 0.1f) {
|
||||
Scope root = tensorflow::Scope::NewRootScope();
|
||||
|
||||
ops::FusedBatchNormV2 fbn = ops::FusedBatchNormV2(
|
||||
root.WithOpName("fused_batch_norm"),
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
|
||||
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)),
|
||||
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)),
|
||||
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)),
|
||||
ops::Const(root.WithOpName("var"), Input::Initializer(var_data)),
|
||||
ops::FusedBatchNormV2::IsTraining(is_training)
|
||||
.Epsilon(epsilon)
|
||||
.DataFormat(ToString(data_format)));
|
||||
|
||||
Output with_side_input;
|
||||
if (has_side_input) {
|
||||
with_side_input =
|
||||
ops::Add(root.WithOpName("with_side_input"), fbn.y,
|
||||
ops::Const(root.WithOpName("side_input"),
|
||||
Input::Initializer(side_input_data)));
|
||||
} else {
|
||||
with_side_input =
|
||||
ops::Identity(root.WithOpName("with_side_input"), fbn.y);
|
||||
}
|
||||
|
||||
if (activation_mode == "Relu") {
|
||||
ops::Relu(root.WithOpName("with_activation"), with_side_input);
|
||||
} else {
|
||||
ops::Identity(root.WithOpName("with_activation"), with_side_input);
|
||||
}
|
||||
|
||||
RunAndFetch(root, "with_activation", output, /*allow_gpu_device=*/true);
|
||||
}
|
||||
|
||||
void RunFusedBatchNormEx(const Tensor& input_data, const Tensor& scale_data,
|
||||
const Tensor& offset_data, const Tensor& mean_data,
|
||||
const Tensor& var_data,
|
||||
const Tensor& side_input_data,
|
||||
const TensorFormat data_format, bool is_training,
|
||||
bool has_side_input, const string& activation_mode,
|
||||
Tensor* output, float epsilon = 0.1f) {
|
||||
Scope root = tensorflow::Scope::NewRootScope();
|
||||
|
||||
DataType t_dtype = DataTypeToEnum<T>::v();
|
||||
DataType u_dtype = DataTypeToEnum<U>::v();
|
||||
|
||||
Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
|
||||
Input::Initializer(y_backprop_data));
|
||||
Output input =
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||
Output scale =
|
||||
@ -172,6 +157,101 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
||||
Output side_input = ops::Const(root.WithOpName("side_input"),
|
||||
Input::Initializer(side_input_data));
|
||||
|
||||
ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
|
||||
root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||
.Epsilon(epsilon)
|
||||
.DataFormat(ToString(data_format)));
|
||||
|
||||
Output with_side_input;
|
||||
if (has_side_input) {
|
||||
with_side_input =
|
||||
ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
|
||||
} else {
|
||||
with_side_input =
|
||||
ops::Identity(root.WithOpName("with_side_input"), fwd.y);
|
||||
}
|
||||
|
||||
Output activation;
|
||||
if (activation_mode == "Relu") {
|
||||
activation =
|
||||
ops::Relu(root.WithOpName("with_activation"), with_side_input);
|
||||
} else {
|
||||
activation =
|
||||
ops::Identity(root.WithOpName("with_activation"), with_side_input);
|
||||
}
|
||||
|
||||
Output activation_grad;
|
||||
if (activation_mode == "Relu") {
|
||||
activation_grad = ops::internal::ReluGrad(
|
||||
root.WithOpName("activation_grad"), y_backprop, activation);
|
||||
} else {
|
||||
activation_grad =
|
||||
ops::Identity(root.WithOpName("activation_grad"), y_backprop);
|
||||
}
|
||||
|
||||
ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
|
||||
root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
|
||||
fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
|
||||
ops::FusedBatchNormGradV3::IsTraining(is_training)
|
||||
.Epsilon(epsilon)
|
||||
.DataFormat(ToString(data_format)));
|
||||
|
||||
std::vector<Tensor> out_tensors;
|
||||
RunAndFetch(
|
||||
root,
|
||||
{"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
|
||||
"fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
|
||||
"activation_grad:0", "fused_batch_norm_grad:0",
|
||||
"fused_batch_norm_grad:1", "fused_batch_norm_grad:2"},
|
||||
&out_tensors, /*allow_gpu_device=*/true);
|
||||
|
||||
forward->y = out_tensors[0];
|
||||
forward->batch_mean = out_tensors[1];
|
||||
forward->batch_variance = out_tensors[2];
|
||||
forward->reserve_space_1 = out_tensors[3];
|
||||
forward->reserve_space_2 = out_tensors[4];
|
||||
forward->reserve_space_3 = out_tensors[5];
|
||||
|
||||
backward->y_backprop = out_tensors[6];
|
||||
backward->x_backprop = out_tensors[7];
|
||||
backward->scale_backprop = out_tensors[8];
|
||||
backward->offset_backprop = out_tensors[9];
|
||||
}
|
||||
|
||||
void RunFusedBatchNormEx(const Tensor& y_backprop_data,
|
||||
const Tensor& input_data, const Tensor& scale_data,
|
||||
const Tensor& offset_data, const Tensor& mean_data,
|
||||
const Tensor& var_data,
|
||||
const Tensor& side_input_data,
|
||||
const TensorFormat data_format, bool is_training,
|
||||
bool has_side_input, const string& activation_mode,
|
||||
FusedBatchNormOutputs* forward,
|
||||
FusedBatchNormGradOutputs* backward,
|
||||
float epsilon = 0.1f) {
|
||||
Scope root = tensorflow::Scope::NewRootScope();
|
||||
|
||||
DataType t_dtype = DataTypeToEnum<T>::v();
|
||||
DataType u_dtype = DataTypeToEnum<U>::v();
|
||||
|
||||
Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
|
||||
Input::Initializer(y_backprop_data));
|
||||
Output input =
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||
Output scale =
|
||||
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
|
||||
Output offset =
|
||||
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
|
||||
Output mean =
|
||||
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
|
||||
Output var =
|
||||
ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
|
||||
Output side_input = ops::Const(root.WithOpName("side_input"),
|
||||
Input::Initializer(side_input_data));
|
||||
Output empty =
|
||||
ops::Const(root.WithOpName("empty"),
|
||||
Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
|
||||
|
||||
int num_side_inputs = 0;
|
||||
std::vector<NodeDefBuilder::NodeOut> side_inputs;
|
||||
|
||||
@ -197,8 +277,58 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
||||
.Attr("is_training", is_training)
|
||||
.Finalize(&fused_batch_norm_ex));
|
||||
|
||||
RunAndFetch(root, fused_batch_norm_ex.name(), output,
|
||||
/*allow_gpu_device=*/true, &fused_batch_norm_ex);
|
||||
NodeDef activation_grad;
|
||||
if (activation_mode == "Relu") {
|
||||
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
|
||||
.Input({y_backprop.name(), 0, t_dtype})
|
||||
.Input({fused_batch_norm_ex.name(), 0, t_dtype})
|
||||
.Attr("T", t_dtype)
|
||||
.Finalize(&activation_grad));
|
||||
} else {
|
||||
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
|
||||
.Input({y_backprop.name(), 0, t_dtype})
|
||||
.Attr("T", t_dtype)
|
||||
.Finalize(&activation_grad));
|
||||
}
|
||||
|
||||
NodeDef fused_batch_norm_grad;
|
||||
TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
|
||||
.Input({activation_grad.name(), 0, t_dtype})
|
||||
.Input({input.name(), 0, t_dtype})
|
||||
.Input({scale.name(), 0, u_dtype})
|
||||
.Input({fused_batch_norm_ex.name(), 3, u_dtype})
|
||||
.Input({fused_batch_norm_ex.name(), 4, u_dtype})
|
||||
.Input({fused_batch_norm_ex.name(), 5, u_dtype})
|
||||
.Attr("T", t_dtype)
|
||||
.Attr("U", u_dtype)
|
||||
.Attr("data_format", ToString(data_format))
|
||||
.Attr("epsilon", epsilon)
|
||||
.Attr("is_training", is_training)
|
||||
.Finalize(&fused_batch_norm_grad));
|
||||
|
||||
std::vector<Tensor> out_tensors;
|
||||
RunAndFetch(
|
||||
root,
|
||||
{"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
|
||||
"fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
|
||||
"fused_batch_norm_ex:4", "fused_batch_norm_ex:5", "activation_grad:0",
|
||||
"fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
|
||||
"fused_batch_norm_grad:2"},
|
||||
&out_tensors,
|
||||
/*allow_gpu_device=*/true,
|
||||
{&fused_batch_norm_ex, &activation_grad, &fused_batch_norm_grad});
|
||||
|
||||
forward->y = out_tensors[0];
|
||||
forward->batch_mean = out_tensors[1];
|
||||
forward->batch_variance = out_tensors[2];
|
||||
forward->reserve_space_1 = out_tensors[3];
|
||||
forward->reserve_space_2 = out_tensors[4];
|
||||
forward->reserve_space_3 = out_tensors[5];
|
||||
|
||||
backward->y_backprop = out_tensors[6];
|
||||
backward->x_backprop = out_tensors[7];
|
||||
backward->scale_backprop = out_tensors[8];
|
||||
backward->offset_backprop = out_tensors[9];
|
||||
}
|
||||
|
||||
void VerifyTensorsNear(int batch, int height, int width, int channels,
|
||||
@ -215,6 +345,7 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
||||
|
||||
Tensor input(t_dtype, input_shape);
|
||||
input.flat<T>().setRandom();
|
||||
input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
|
||||
|
||||
Tensor scale(u_dtype, {channels});
|
||||
scale.flat<U>().setRandom();
|
||||
@ -228,29 +359,60 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
||||
Tensor var(u_dtype, {channels});
|
||||
var.flat<U>().setRandom();
|
||||
|
||||
Tensor empty(u_dtype, {0});
|
||||
|
||||
Tensor fused_batch_norm;
|
||||
Tensor fused_batch_norm_ex;
|
||||
|
||||
Tensor side_input(t_dtype, input_shape);
|
||||
side_input.flat<T>().setRandom();
|
||||
side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
|
||||
|
||||
run_default(input, scale, offset, is_training ? empty : mean,
|
||||
is_training ? empty : var, side_input, &fused_batch_norm);
|
||||
Tensor y_backprop(t_dtype, input_shape);
|
||||
y_backprop.flat<T>().setRandom();
|
||||
y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
|
||||
|
||||
// Write some garbage to the `fused_batch_norm_ex` first to make sure
|
||||
// that fused kernel actually writes correct results to memory.
|
||||
run_default(side_input, scale, offset, is_training ? empty : mean,
|
||||
is_training ? empty : var, input, &fused_batch_norm_ex);
|
||||
Tensor empty(u_dtype, {0});
|
||||
|
||||
run_fused(input, scale, offset, is_training ? empty : mean,
|
||||
is_training ? empty : var, side_input, &fused_batch_norm_ex);
|
||||
FusedBatchNormOutputs fbn_forward;
|
||||
FusedBatchNormOutputs fbn_ex_forward;
|
||||
|
||||
ASSERT_EQ(fused_batch_norm.dtype(), fused_batch_norm_ex.dtype());
|
||||
ASSERT_EQ(fused_batch_norm.shape(), fused_batch_norm_ex.shape());
|
||||
FusedBatchNormGradOutputs fbn_backward;
|
||||
FusedBatchNormGradOutputs fbn_ex_backward;
|
||||
|
||||
test::ExpectClose(fused_batch_norm, fused_batch_norm_ex, 1e-2);
|
||||
run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
|
||||
is_training ? empty : var, side_input, &fbn_forward,
|
||||
&fbn_backward);
|
||||
|
||||
// Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
|
||||
// make sure that fused kernel actually writes correct results to memory.
|
||||
run_default(y_backprop, side_input, scale, offset,
|
||||
is_training ? empty : mean, is_training ? empty : var, input,
|
||||
&fbn_ex_forward, &fbn_ex_backward);
|
||||
|
||||
run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
|
||||
is_training ? empty : var, side_input, &fbn_ex_forward,
|
||||
&fbn_ex_backward);
|
||||
|
||||
std::vector<std::pair<Tensor, Tensor>> tensor_pairs = {
|
||||
{fbn_forward.y, fbn_ex_forward.y},
|
||||
{fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
|
||||
{fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
|
||||
{fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
|
||||
{fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
|
||||
// NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
|
||||
// because BatchNormEx with fused side input has different data in it,
|
||||
// but we make sure that final gradients are the same.
|
||||
{fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
|
||||
{fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
|
||||
{fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
|
||||
{fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
|
||||
};
|
||||
|
||||
for (auto& pair : tensor_pairs) {
|
||||
const Tensor& fbn = pair.first;
|
||||
const Tensor& fbn_ex = pair.second;
|
||||
|
||||
ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
|
||||
ASSERT_EQ(fbn.shape(), fbn_ex.shape());
|
||||
|
||||
test::ExpectClose(fbn, fbn_ex, 1e-2);
|
||||
}
|
||||
}
|
||||
|
||||
// Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
|
||||
@ -260,25 +422,27 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
||||
bool has_side_input,
|
||||
const string& activation_mode) {
|
||||
const GraphRunner run_default =
|
||||
[&](const Tensor& input_data, const Tensor& scale_data,
|
||||
const Tensor& offset_data, const Tensor& mean_data,
|
||||
const Tensor& var_data, const Tensor& side_input_data,
|
||||
Tensor* out) {
|
||||
this->RunFusedBatchNorm(input_data, scale_data, offset_data,
|
||||
mean_data, var_data, side_input_data,
|
||||
data_format, is_training, has_side_input,
|
||||
activation_mode, out);
|
||||
[&](const Tensor& y_backprop, const Tensor& input_data,
|
||||
const Tensor& scale_data, const Tensor& offset_data,
|
||||
const Tensor& mean_data, const Tensor& var_data,
|
||||
const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
|
||||
FusedBatchNormGradOutputs* bwd) {
|
||||
this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
|
||||
offset_data, mean_data, var_data,
|
||||
side_input_data, data_format, is_training,
|
||||
has_side_input, activation_mode, fwd, bwd);
|
||||
};
|
||||
|
||||
const GraphRunner run_inference =
|
||||
[&](const Tensor& input_data, const Tensor& scale_data,
|
||||
const Tensor& offset_data, const Tensor& mean_data,
|
||||
const Tensor& var_data, const Tensor& side_input_data,
|
||||
Tensor* out) {
|
||||
this->RunFusedBatchNormEx(input_data, scale_data, offset_data,
|
||||
mean_data, var_data, side_input_data,
|
||||
data_format, is_training, has_side_input,
|
||||
activation_mode, out);
|
||||
[&](const Tensor& y_backprop, const Tensor& input_data,
|
||||
const Tensor& scale_data, const Tensor& offset_data,
|
||||
const Tensor& mean_data, const Tensor& var_data,
|
||||
const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
|
||||
FusedBatchNormGradOutputs* bwd) {
|
||||
this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
|
||||
offset_data, mean_data, var_data,
|
||||
side_input_data, data_format, is_training,
|
||||
has_side_input, activation_mode, fwd, bwd);
|
||||
};
|
||||
|
||||
VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
|
||||
@ -297,17 +461,17 @@ constexpr bool kWithSideInput = true; // side_input == true
|
||||
TYPED_TEST_SUITE_P(FusedBatchNormExOpTest);
|
||||
|
||||
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingInNHWCTest) {
|
||||
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
|
||||
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
|
||||
kNoSideInput, "Identity");
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithReluInNHWCTest) {
|
||||
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
|
||||
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
|
||||
kNoSideInput, "Relu");
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithSideInputAndReluInNHWCTest) {
|
||||
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
|
||||
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
|
||||
kWithSideInput, "Relu");
|
||||
}
|
||||
|
||||
|
@ -692,8 +692,8 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
|
||||
<< " y_backprop shape: " << y_backprop.shape().DebugString()
|
||||
<< " x shape: " << x.shape().DebugString()
|
||||
<< " scale shape: " << scale.shape().DebugString()
|
||||
<< " tensor format: " << tensor_format
|
||||
<< " compute format: " << compute_format;
|
||||
<< " tensor format: " << ToString(tensor_format)
|
||||
<< " compute format: " << ToString(compute_format);
|
||||
|
||||
// Inputs
|
||||
Tensor y_backprop_maybe_transformed = y_backprop;
|
||||
|
Loading…
Reference in New Issue
Block a user