[Grappler] Fuse FusedBatchNorm + <SideInput> + <Activation>.
Resnet50 in NHWC: ~500 img/sec -> ~800 img/sec* (*) with disabled layout optimizer. PiperOrigin-RevId: 252108889
This commit is contained in:
parent
17a2326611
commit
92144e5bd6
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// XLA implementation of BatchNorm operations.
|
// XLA implementation of BatchNorm operations.
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
@ -28,7 +29,11 @@ namespace {
|
|||||||
|
|
||||||
class FusedBatchNormOp : public XlaOpKernel {
|
class FusedBatchNormOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
explicit FusedBatchNormOp(OpKernelConstruction* ctx)
|
||||||
|
: FusedBatchNormOp(ctx, false) {}
|
||||||
|
|
||||||
|
FusedBatchNormOp(OpKernelConstruction* ctx, bool is_batch_norm_ex)
|
||||||
|
: XlaOpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
|
||||||
string data_format_str;
|
string data_format_str;
|
||||||
@ -36,6 +41,26 @@ class FusedBatchNormOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, FormatFromString(data_format_str, &data_format_),
|
ctx, FormatFromString(data_format_str, &data_format_),
|
||||||
errors::InvalidArgument("Invalid data format: ", data_format_str));
|
errors::InvalidArgument("Invalid data format: ", data_format_str));
|
||||||
|
|
||||||
|
if (is_batch_norm_ex) {
|
||||||
|
int num_side_inputs;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_side_inputs", &num_side_inputs));
|
||||||
|
OP_REQUIRES(ctx, num_side_inputs >= 0 && num_side_inputs <= 1,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"FusedBatchNormEx supports at most 1 side input."));
|
||||||
|
add_side_input_ = (num_side_inputs == 1);
|
||||||
|
string activation_mode;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode));
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
activation_mode == "Identity" || activation_mode == "Relu",
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Unsupported FusedBatchNormEx activation mode: ",
|
||||||
|
activation_mode));
|
||||||
|
apply_relu_ = (activation_mode == "Relu");
|
||||||
|
} else {
|
||||||
|
add_side_input_ = false;
|
||||||
|
apply_relu_ = false;
|
||||||
|
}
|
||||||
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
|
is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,9 +91,18 @@ class FusedBatchNormOp : public XlaOpKernel {
|
|||||||
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
|
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
|
||||||
|
|
||||||
// In training mode, outputs the normalized value as well as the
|
// In training mode, outputs the normalized value as well as the
|
||||||
// calculated mean and variance.
|
// calculated mean and variance. Optionally we add side input and apply
|
||||||
ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0),
|
// relu activation.
|
||||||
input_type));
|
xla::XlaOp converted =
|
||||||
|
xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type);
|
||||||
|
if (add_side_input_ && apply_relu_) {
|
||||||
|
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
|
||||||
|
} else if (apply_relu_) {
|
||||||
|
ctx->SetOutput(0, xla::Relu(converted));
|
||||||
|
} else {
|
||||||
|
ctx->SetOutput(0, converted);
|
||||||
|
}
|
||||||
|
|
||||||
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
|
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
|
||||||
xla::XlaOp variance = xla::GetTupleElement(output, 2);
|
xla::XlaOp variance = xla::GetTupleElement(output, 2);
|
||||||
// Apply Bessel's correction.
|
// Apply Bessel's correction.
|
||||||
@ -103,7 +137,16 @@ class FusedBatchNormOp : public XlaOpKernel {
|
|||||||
xla::XlaOp output = xla::BatchNormInference(
|
xla::XlaOp output = xla::BatchNormInference(
|
||||||
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
|
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
|
||||||
epsilon_, feature_index);
|
epsilon_, feature_index);
|
||||||
ctx->SetOutput(0, xla::ConvertElementType(output, input_type));
|
|
||||||
|
xla::XlaOp converted = xla::ConvertElementType(output, input_type);
|
||||||
|
if (add_side_input_ && apply_relu_) {
|
||||||
|
ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
|
||||||
|
} else if (apply_relu_) {
|
||||||
|
ctx->SetOutput(0, xla::Relu(converted));
|
||||||
|
} else {
|
||||||
|
ctx->SetOutput(0, converted);
|
||||||
|
}
|
||||||
|
|
||||||
// Directly send input to output as mean and variance in inference mode.
|
// Directly send input to output as mean and variance in inference mode.
|
||||||
ctx->SetOutput(1, ctx->Input(3));
|
ctx->SetOutput(1, ctx->Input(3));
|
||||||
ctx->SetOutput(2, ctx->Input(4));
|
ctx->SetOutput(2, ctx->Input(4));
|
||||||
@ -116,6 +159,8 @@ class FusedBatchNormOp : public XlaOpKernel {
|
|||||||
float epsilon_;
|
float epsilon_;
|
||||||
TensorFormat data_format_;
|
TensorFormat data_format_;
|
||||||
bool is_training_;
|
bool is_training_;
|
||||||
|
bool add_side_input_;
|
||||||
|
bool apply_relu_;
|
||||||
bool is_on_gpu_;
|
bool is_on_gpu_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -131,17 +176,26 @@ class FusedBatchNormOpV3 : public FusedBatchNormOp {
|
|||||||
}
|
}
|
||||||
ctx->SetConstantOutput(5, Tensor());
|
ctx->SetConstantOutput(5, Tensor());
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
class FusedBatchNormOpEx : public FusedBatchNormOp {
|
||||||
float epsilon_;
|
public:
|
||||||
TensorFormat data_format_;
|
explicit FusedBatchNormOpEx(OpKernelConstruction* ctx)
|
||||||
bool is_training_;
|
: FusedBatchNormOp(ctx, /*is_batch_norm_ex=*/true) {}
|
||||||
bool is_on_gpu_;
|
|
||||||
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
FusedBatchNormOp::CompileImpl(ctx);
|
||||||
|
if (!ctx->status().ok()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ctx->SetConstantOutput(5, Tensor());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
|
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
|
||||||
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
|
REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
|
||||||
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
|
REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
|
||||||
|
REGISTER_XLA_OP(Name("_FusedBatchNormEx"), FusedBatchNormOpEx);
|
||||||
|
|
||||||
class FusedBatchNormGradOp : public XlaOpKernel {
|
class FusedBatchNormGradOp : public XlaOpKernel {
|
||||||
public:
|
public:
|
||||||
|
@ -781,6 +781,7 @@ tf_kernel_library(
|
|||||||
":constant_folding",
|
":constant_folding",
|
||||||
":graph_optimizer",
|
":graph_optimizer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
"//tensorflow/core/grappler:graph_view",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
@ -26,6 +26,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "third_party/gpus/cudnn/cudnn.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
@ -40,12 +45,17 @@ namespace grappler {
|
|||||||
// MatMul + ... -> _FusedMatMul:
|
// MatMul + ... -> _FusedMatMul:
|
||||||
// (1) MatMul + BiasAdd + <Activation>
|
// (1) MatMul + BiasAdd + <Activation>
|
||||||
//
|
//
|
||||||
|
// FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training]
|
||||||
|
// (1) FusedBatchNorm + <Activation>
|
||||||
|
// (2) FusedBatchNorm + SideInput + <Activation>
|
||||||
|
//
|
||||||
// Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
|
// Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
|
||||||
// patterns are "ContractionWith...".
|
// patterns are "ContractionWith...".
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kFusedConv2D[] = "_FusedConv2D";
|
constexpr char kFusedConv2D[] = "_FusedConv2D";
|
||||||
constexpr char kFusedMatMul[] = "_FusedMatMul";
|
constexpr char kFusedMatMul[] = "_FusedMatMul";
|
||||||
|
constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx";
|
||||||
|
|
||||||
constexpr char kDataFormat[] = "data_format";
|
constexpr char kDataFormat[] = "data_format";
|
||||||
constexpr char kIsTraining[] = "is_training";
|
constexpr char kIsTraining[] = "is_training";
|
||||||
@ -81,6 +91,17 @@ struct FusedBatchNorm {
|
|||||||
const NodeDef* fused_batch_norm = nullptr;
|
const NodeDef* fused_batch_norm = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// FusedBatchNorm[$is_training] with fused side input and/or activation.
|
||||||
|
struct FusedBatchNormEx {
|
||||||
|
FusedBatchNormEx() = default;
|
||||||
|
|
||||||
|
const NodeDef* fused_batch_norm = nullptr;
|
||||||
|
const NodeDef* side_input = nullptr;
|
||||||
|
const NodeDef* activation = nullptr;
|
||||||
|
// Add node that will be invalidated by fusing side input and fused batch norm
|
||||||
|
const NodeDef* invalidated = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
// Contraction node followed by a BiasAdd.
|
// Contraction node followed by a BiasAdd.
|
||||||
struct ContractionWithBiasAdd {
|
struct ContractionWithBiasAdd {
|
||||||
ContractionWithBiasAdd() = default;
|
ContractionWithBiasAdd() = default;
|
||||||
@ -445,12 +466,11 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx,
|
|||||||
ContractionWithBatchNorm* matched) {
|
ContractionWithBatchNorm* matched) {
|
||||||
if (!EigenSupportsContractionOutputKernel()) return false;
|
if (!EigenSupportsContractionOutputKernel()) return false;
|
||||||
|
|
||||||
// Root of the pattern must be a FusedBatchNorm or a FusedBatchNormV2.
|
// Root of the pattern must be a FusedBatchNorm.
|
||||||
if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false;
|
if (!batch_norm || !IsFusedBatchNorm(*batch_norm)) return false;
|
||||||
|
|
||||||
// V2 has a separate data type for the scale/offset/mean/variance inputs.
|
// FusedBatchNormV2 and V3 have an extra type parameter.
|
||||||
if ((batch_norm->op() == "FusedBatchNormV2" ||
|
if (batch_norm->op() != "FusedBatchNorm" &&
|
||||||
batch_norm->op() == "FusedBatchNormV3") &&
|
|
||||||
!HasDataType(batch_norm, DT_FLOAT, "U"))
|
!HasDataType(batch_norm, DT_FLOAT, "U"))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -602,23 +622,15 @@ bool FindContractionWithBiasAndAddActivation(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Check that given node meets some basic FusedBatchNorm optimization
|
|
||||||
// preconditions. We use this check to lazily infer graph properties which is
|
|
||||||
// rather expensive.
|
|
||||||
bool IsFusedBatchNormCandidate(const NodeDef& node) {
|
|
||||||
if (!IsFusedBatchNorm(node)) return false;
|
|
||||||
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
|
|
||||||
|
|
||||||
// Check that the node is in inference mode.
|
|
||||||
const auto& attr = node.attr();
|
|
||||||
if (attr.count(kIsTraining) > 0 && attr.at(kIsTraining).b()) return false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
|
bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
|
||||||
FusedBatchNorm* matched) {
|
FusedBatchNorm* matched) {
|
||||||
if (!IsFusedBatchNormCandidate(*node)) return false;
|
if (!IsFusedBatchNorm(*node)) return false;
|
||||||
|
if (GetDataTypeFromAttr(*node, "T") != DT_FLOAT) return false;
|
||||||
|
|
||||||
|
// Check that the node is in inference mode.
|
||||||
|
bool is_training = true;
|
||||||
|
if (!GetNodeAttr(*node, kIsTraining, &is_training).ok()) return false;
|
||||||
|
if (is_training) return false;
|
||||||
|
|
||||||
const auto& props = ctx.graph_properties.GetInputProperties(node->name());
|
const auto& props = ctx.graph_properties.GetInputProperties(node->name());
|
||||||
|
|
||||||
@ -649,6 +661,139 @@ bool FindFusedBatchNorm(const RemapperContext& ctx, const NodeDef* node,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
|
||||||
|
// `tensorflow/stream_executor/cuda/cuda_dnn.cc` for details.
|
||||||
|
bool BatchnormSpatialPersistentEnabled() {
|
||||||
|
#if CUDNN_VERSION >= 7402
|
||||||
|
static bool is_enabled = [] {
|
||||||
|
bool is_enabled = false;
|
||||||
|
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
|
||||||
|
"TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
|
||||||
|
/*default_val=*/false, &is_enabled));
|
||||||
|
return is_enabled;
|
||||||
|
}();
|
||||||
|
return is_enabled;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FindFusedBatchNormEx(const RemapperContext& ctx, const NodeDef* node,
|
||||||
|
FusedBatchNormEx* matched) {
|
||||||
|
// Root of the pattern must be a Relu.
|
||||||
|
// TODO(ezhulenev): Forward control dependencies.
|
||||||
|
if (!IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const NodeDef* relu = node;
|
||||||
|
|
||||||
|
// Returns true iff the node is a compatible FusedBatchNorm node.
|
||||||
|
const auto valid_batch_norm = [&](const NodeDef* fused_batch_norm) -> bool {
|
||||||
|
if (fused_batch_norm == nullptr || !IsFusedBatchNorm(*fused_batch_norm))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
AttrSlice attr(*fused_batch_norm);
|
||||||
|
|
||||||
|
// We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with
|
||||||
|
// contraction (MatMul or Conv2D node).
|
||||||
|
if (!NodeIsOnGpu(fused_batch_norm)) return false;
|
||||||
|
|
||||||
|
DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm, "T");
|
||||||
|
if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false;
|
||||||
|
|
||||||
|
// Get the FusedBatchNorm training mode.
|
||||||
|
bool is_training;
|
||||||
|
if (!GetNodeAttr(attr, kIsTraining, &is_training).ok()) return false;
|
||||||
|
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
|
||||||
|
if (!is_training) return false;
|
||||||
|
|
||||||
|
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
||||||
|
// inputs and activation, and it has its own limitations. In inference mode
|
||||||
|
// we have a custom CUDA kernel that doesn't not have these constraints.
|
||||||
|
if (is_training) {
|
||||||
|
// cuDNN only supports NHWC data layout.
|
||||||
|
string data_format;
|
||||||
|
if (!GetNodeAttr(attr, kDataFormat, &data_format).ok()) return false;
|
||||||
|
if (data_format != "NHWC") return false;
|
||||||
|
|
||||||
|
// Data type must be DT_HALF.
|
||||||
|
if (t_dtype != DT_HALF) return false;
|
||||||
|
|
||||||
|
// Channel dimension must be a multiple of 4.
|
||||||
|
const auto& props =
|
||||||
|
ctx.graph_properties.GetInputProperties(fused_batch_norm->name());
|
||||||
|
|
||||||
|
const bool valid_channel_dim = !props.empty() &&
|
||||||
|
props[0].shape().dim_size() == 4 &&
|
||||||
|
props[0].shape().dim(3).size() % 4 == 0;
|
||||||
|
if (!valid_channel_dim) return false;
|
||||||
|
|
||||||
|
// cuDNN must support CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode.
|
||||||
|
if (!BatchnormSpatialPersistentEnabled()) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// FusedBatchNormV2 and V3 have an extra type parameter.
|
||||||
|
if ((fused_batch_norm->op() != "FusedBatchNorm") &&
|
||||||
|
!HasDataType(fused_batch_norm, DT_FLOAT, "U"))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Check that only one node consumes the output of a FusedBatchNorm.
|
||||||
|
if (HasControlFaninOrFanout(ctx.graph_view, fused_batch_norm) ||
|
||||||
|
!HasSingleFanoutNode(ctx.graph_view, fused_batch_norm) ||
|
||||||
|
IsInPreserveSet(ctx, fused_batch_norm))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto relu_input_port = GraphView::InputPort(relu, 0);
|
||||||
|
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
|
||||||
|
if (!relu_fanin.node) return false;
|
||||||
|
|
||||||
|
// Input to a Relu can be a FusedBatchNorm.
|
||||||
|
if (valid_batch_norm(relu_fanin.node)) {
|
||||||
|
matched->activation = relu;
|
||||||
|
matched->side_input = nullptr;
|
||||||
|
matched->fused_batch_norm = relu_fanin.node;
|
||||||
|
matched->invalidated = nullptr;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs
|
||||||
|
if (IsAdd(*relu_fanin.node)) {
|
||||||
|
const NodeDef* add = relu_fanin.node;
|
||||||
|
|
||||||
|
// Check that only Relu node consumes the output of an Add node.
|
||||||
|
if (HasControlFaninOrFanout(ctx.graph_view, add) ||
|
||||||
|
!HasSingleFanoutNode(ctx.graph_view, add) || IsInPreserveSet(ctx, add))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const auto add_input_port_0 = GraphView::InputPort(add, 0);
|
||||||
|
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
|
||||||
|
|
||||||
|
const auto add_input_port_1 = GraphView::InputPort(add, 1);
|
||||||
|
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
|
||||||
|
|
||||||
|
if (valid_batch_norm(add_fanin_0.node)) {
|
||||||
|
matched->activation = relu;
|
||||||
|
matched->side_input = add_fanin_1.node;
|
||||||
|
matched->fused_batch_norm = add_fanin_0.node;
|
||||||
|
matched->invalidated = add;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (valid_batch_norm(add_fanin_1.node)) {
|
||||||
|
matched->activation = relu;
|
||||||
|
matched->side_input = add_fanin_0.node;
|
||||||
|
matched->fused_batch_norm = add_fanin_1.node;
|
||||||
|
matched->invalidated = add;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
|
void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
|
||||||
DCHECK(IsConv2D(*conv2d)) << "Input node must be a Conv2D";
|
DCHECK(IsConv2D(*conv2d)) << "Input node must be a Conv2D";
|
||||||
|
|
||||||
@ -664,6 +809,27 @@ void CopyConv2DAttributes(const NodeDef* conv2d, NodeDef* fused_conv2d) {
|
|||||||
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CopyFusedBatchNormAttributes(const NodeDef* fused_batch_norm,
|
||||||
|
NodeDef* fused_batch_norm_ex) {
|
||||||
|
DCHECK(IsFusedBatchNorm(*fused_batch_norm))
|
||||||
|
<< "Input node must be a FusedBatchNorm";
|
||||||
|
|
||||||
|
auto* attr = fused_batch_norm_ex->mutable_attr();
|
||||||
|
auto src_attr = fused_batch_norm->attr();
|
||||||
|
|
||||||
|
(*attr)["T"] = src_attr.at("T");
|
||||||
|
(*attr)["is_training"] = src_attr.at("is_training");
|
||||||
|
(*attr)["data_format"] = src_attr.at("data_format");
|
||||||
|
(*attr)["epsilon"] = src_attr.at("epsilon");
|
||||||
|
|
||||||
|
// FusedBatchNormV2 and V3 have an extra type parameter.
|
||||||
|
if (fused_batch_norm->op() != "FusedBatchNorm") {
|
||||||
|
(*attr)["U"] = src_attr.at("U");
|
||||||
|
} else {
|
||||||
|
(*attr)["U"] = src_attr.at("T");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CopyMatMulAttributes(const NodeDef* matmul, NodeDef* fused_matmul) {
|
void CopyMatMulAttributes(const NodeDef* matmul, NodeDef* fused_matmul) {
|
||||||
DCHECK(IsMatMul(*matmul)) << "Input node must be a MatMul";
|
DCHECK(IsMatMul(*matmul)) << "Input node must be a MatMul";
|
||||||
|
|
||||||
@ -902,6 +1068,55 @@ void AddFusedContractionNode(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void AddFusedBatchNormExNode(
|
||||||
|
const FusedBatchNormEx& matched, GraphDef* optimized_graph,
|
||||||
|
absl::flat_hash_set<const NodeDef*>* invalidated_nodes) {
|
||||||
|
VLOG(2) << "Fuse " << matched.activation->op() << " with FusedBatchNorm:"
|
||||||
|
<< " side_input="
|
||||||
|
<< (matched.side_input ? matched.side_input->name() : "<none>")
|
||||||
|
<< " activation=" << matched.activation->name()
|
||||||
|
<< " fused_batch_norm=" << matched.fused_batch_norm->name();
|
||||||
|
|
||||||
|
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
|
||||||
|
NodeDef* fused_op = optimized_graph->add_node();
|
||||||
|
fused_op->set_op(kFusedBatchNormEx);
|
||||||
|
fused_op->set_name(matched.fused_batch_norm->name());
|
||||||
|
fused_op->set_device(matched.fused_batch_norm->device());
|
||||||
|
|
||||||
|
fused_op->add_input(matched.fused_batch_norm->input(0)); // 0: input
|
||||||
|
fused_op->add_input(matched.fused_batch_norm->input(1)); // 1: scale
|
||||||
|
fused_op->add_input(matched.fused_batch_norm->input(2)); // 2: offset
|
||||||
|
fused_op->add_input(matched.fused_batch_norm->input(3)); // 3: estimated_mean
|
||||||
|
fused_op->add_input(matched.fused_batch_norm->input(4)); // 4: estimated_var
|
||||||
|
|
||||||
|
CopyFusedBatchNormAttributes(matched.fused_batch_norm, fused_op);
|
||||||
|
|
||||||
|
auto* attrs = fused_op->mutable_attr();
|
||||||
|
SetAttrValue(matched.activation->op(), &(*attrs)["activation_mode"]);
|
||||||
|
|
||||||
|
if (matched.side_input != nullptr) {
|
||||||
|
SetAttrValue(1, &(*attrs)["num_side_inputs"]);
|
||||||
|
fused_op->add_input(matched.side_input->name()); // 5: side_input
|
||||||
|
} else {
|
||||||
|
SetAttrValue(0, &(*attrs)["num_side_inputs"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turn activation node into Identity node.
|
||||||
|
NodeDef* identity_op = optimized_graph->add_node();
|
||||||
|
identity_op->set_op("Identity");
|
||||||
|
identity_op->set_name(matched.activation->name());
|
||||||
|
identity_op->set_device(matched.fused_batch_norm->device());
|
||||||
|
identity_op->add_input(matched.fused_batch_norm->name());
|
||||||
|
(*identity_op->mutable_attr())["T"] = attrs->at("T");
|
||||||
|
|
||||||
|
// Invalidate all nodes bypassed by this rewrite.
|
||||||
|
invalidated_nodes->insert(matched.activation);
|
||||||
|
invalidated_nodes->insert(matched.fused_batch_norm);
|
||||||
|
if (matched.side_input != nullptr) {
|
||||||
|
invalidated_nodes->insert(matched.invalidated);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void AddBatchNormNodes(const FusedBatchNorm& matched,
|
void AddBatchNormNodes(const FusedBatchNorm& matched,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
const NodeDef& fused_node = *matched.fused_batch_norm;
|
const NodeDef& fused_node = *matched.fused_batch_norm;
|
||||||
@ -1044,6 +1259,55 @@ void AddBatchNormNodes(const FusedBatchNorm& matched,
|
|||||||
*r->add_input() = a->name();
|
*r->add_input() = a->name();
|
||||||
*r->add_input() = c->name();
|
*r->add_input() = c->name();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if a node is a candidate to one of the patterns that require inferred
|
||||||
|
// shapes:
|
||||||
|
// (1) Splitting FusedBatchNorm into primitives.
|
||||||
|
// (2) Fusing side input and/or activation into FusedBatchNorm.
|
||||||
|
bool RequiresInferredShapes(const RemapperContext& ctx, const NodeDef& node) {
|
||||||
|
// Candidate for a FusedBatchNorm splitting.
|
||||||
|
const auto is_batch_norm_candidate = [&]() -> bool {
|
||||||
|
if (!IsFusedBatchNorm(node)) return false;
|
||||||
|
if (GetDataTypeFromAttr(node, "T") != DT_FLOAT) return false;
|
||||||
|
|
||||||
|
bool is_training = true;
|
||||||
|
if (!GetNodeAttr(node, kIsTraining, &is_training).ok()) return false;
|
||||||
|
if (is_training) return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Candidate for a FusedBatchNorm fusion.
|
||||||
|
const auto is_batch_norm_fusion_candidate = [&]() -> bool {
|
||||||
|
if (!IsRelu(node)) return false;
|
||||||
|
|
||||||
|
const auto relu_input_port = GraphView::InputPort(&node, 0);
|
||||||
|
const auto relu_fanin = ctx.graph_view.GetRegularFanin(relu_input_port);
|
||||||
|
if (!relu_fanin.node) return false;
|
||||||
|
|
||||||
|
if (IsFusedBatchNorm(*relu_fanin.node)) {
|
||||||
|
// FusedBatchNorm + Relu.
|
||||||
|
return true;
|
||||||
|
|
||||||
|
} else if (IsAdd(*relu_fanin.node)) {
|
||||||
|
// FusedBatchNorm + Add + Relu.
|
||||||
|
const NodeDef* add = relu_fanin.node;
|
||||||
|
|
||||||
|
const auto add_input_port_0 = GraphView::InputPort(add, 0);
|
||||||
|
const auto add_fanin_0 = ctx.graph_view.GetRegularFanin(add_input_port_0);
|
||||||
|
if (IsFusedBatchNorm(*add_fanin_0.node)) return true;
|
||||||
|
|
||||||
|
const auto add_input_port_1 = GraphView::InputPort(add, 1);
|
||||||
|
const auto add_fanin_1 = ctx.graph_view.GetRegularFanin(add_input_port_1);
|
||||||
|
if (IsFusedBatchNorm(*add_fanin_1.node)) return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
return is_batch_norm_candidate() || is_batch_norm_fusion_candidate();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
||||||
@ -1051,6 +1315,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|||||||
// Supported graph patterns.
|
// Supported graph patterns.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
FusedBatchNorm fused_batch_norm;
|
FusedBatchNorm fused_batch_norm;
|
||||||
|
FusedBatchNormEx fused_batch_norm_ex;
|
||||||
ContractionWithBiasAdd contract_with_bias;
|
ContractionWithBiasAdd contract_with_bias;
|
||||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||||
#ifndef INTEL_MKL
|
#ifndef INTEL_MKL
|
||||||
@ -1078,9 +1343,8 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|||||||
// and Activation nodes that were fused into a Conv2D node.
|
// and Activation nodes that were fused into a Conv2D node.
|
||||||
absl::flat_hash_set<const NodeDef*> invalidated_nodes;
|
absl::flat_hash_set<const NodeDef*> invalidated_nodes;
|
||||||
|
|
||||||
// _FusedMatMul and _FusedConv2D kernels do not have registered gradient
|
// _Fused{...} kernels do not have registered gradient function, so we must
|
||||||
// function, so we must not perform rewrite if the graph will be
|
// not perform rewrite if the graph will be differentiated later.
|
||||||
// differentiated later.
|
|
||||||
bool allow_non_differentiable_rewrites =
|
bool allow_non_differentiable_rewrites =
|
||||||
item.optimization_options().allow_non_differentiable_rewrites;
|
item.optimization_options().allow_non_differentiable_rewrites;
|
||||||
|
|
||||||
@ -1161,16 +1425,25 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
|||||||
#endif // !INTEL_MKL
|
#endif // !INTEL_MKL
|
||||||
|
|
||||||
// Infer properties lazily in case they are not needed.
|
// Infer properties lazily in case they are not needed.
|
||||||
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) {
|
if (!ctx.inferred_graph_properties && RequiresInferredShapes(ctx, node)) {
|
||||||
|
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||||
// TODO(rmlarsen): Get rid of tensor value copies.
|
// TODO(rmlarsen): Get rid of tensor value copies.
|
||||||
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
|
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
|
||||||
/*assume_valid_feeds=*/false,
|
assume_valid_feeds,
|
||||||
/*aggressive_shape_inference=*/false,
|
/*aggressive_shape_inference=*/false,
|
||||||
/*include_input_tensor_values=*/true,
|
/*include_input_tensor_values=*/true,
|
||||||
/*include_output_tensor_values=*/false));
|
/*include_output_tensor_values=*/false));
|
||||||
ctx.inferred_graph_properties = true;
|
ctx.inferred_graph_properties = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
|
||||||
|
if (allow_non_differentiable_rewrites &&
|
||||||
|
FindFusedBatchNormEx(ctx, &node, &fused_batch_norm_ex)) {
|
||||||
|
AddFusedBatchNormExNode(fused_batch_norm_ex, optimized_graph,
|
||||||
|
&invalidated_nodes);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
||||||
// we can therefore replace the op with a much cheaper set of primitives.
|
// we can therefore replace the op with a much cheaper set of primitives.
|
||||||
if (FindFusedBatchNorm(ctx, &node, &fused_batch_norm)) {
|
if (FindFusedBatchNorm(ctx, &node, &fused_batch_norm)) {
|
||||||
|
@ -26,7 +26,7 @@ namespace grappler {
|
|||||||
// nodes to decrease the amount of operations needed to perform a computation.
|
// nodes to decrease the amount of operations needed to perform a computation.
|
||||||
class Remapper : public GraphOptimizer {
|
class Remapper : public GraphOptimizer {
|
||||||
public:
|
public:
|
||||||
explicit Remapper(RewriterConfig::Toggle opt_level) {}
|
explicit Remapper(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {}
|
||||||
|
|
||||||
~Remapper() override {}
|
~Remapper() override {}
|
||||||
|
|
||||||
@ -37,6 +37,9 @@ class Remapper : public GraphOptimizer {
|
|||||||
|
|
||||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
const GraphDef& optimized_graph, double result) override;
|
const GraphDef& optimized_graph, double result) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
RewriterConfig::Toggle opt_level_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
|
@ -22,11 +22,20 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "third_party/gpus/cudnn/cudnn.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
class RemapperTest : public GrapplerTest {
|
class RemapperTest : public GrapplerTest {
|
||||||
protected:
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
// This is a requirement for fusing FusedBatchNorm + SideInput + Activation.
|
||||||
|
setenv("TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", "1", 1 /* replace */);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
// TODO(b/119765980): Upgrade upstream Eigen to set `m_can_use_xsmm=false` for
|
||||||
// contractions with non-default contraction output kernels.
|
// contractions with non-default contraction output kernels.
|
||||||
bool EigenSupportsContractionOutputKernel() {
|
bool EigenSupportsContractionOutputKernel() {
|
||||||
@ -102,6 +111,179 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
||||||
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
|
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
|
||||||
|
"CUDNN_VERSION >= 7402.";
|
||||||
|
#else
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
||||||
|
auto channels_shape = ops::Placeholder::Shape({24});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||||
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||||
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
||||||
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
||||||
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
||||||
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
||||||
|
|
||||||
|
float epsilon = 0.1f;
|
||||||
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
||||||
|
"NHWC"));
|
||||||
|
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||||
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||||
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||||
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||||
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||||
|
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_t},
|
||||||
|
{"scale", scale_t},
|
||||||
|
{"offset", offset_t},
|
||||||
|
{"mean", mean_t},
|
||||||
|
{"var", var_t}};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on GPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||||
|
GraphDef output;
|
||||||
|
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "relu") {
|
||||||
|
EXPECT_EQ("Identity", node.op());
|
||||||
|
EXPECT_EQ("fused_batch_norm", node.input(0));
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "fused_batch_norm") {
|
||||||
|
EXPECT_EQ("_FusedBatchNormEx", node.op());
|
||||||
|
EXPECT_EQ("input_cast", node.input(0));
|
||||||
|
EXPECT_EQ("scale", node.input(1));
|
||||||
|
EXPECT_EQ("offset", node.input(2));
|
||||||
|
EXPECT_EQ("mean", node.input(3));
|
||||||
|
EXPECT_EQ("var", node.input(4));
|
||||||
|
|
||||||
|
auto attr = node.attr();
|
||||||
|
EXPECT_EQ(0, attr["num_side_inputs"].i());
|
||||||
|
EXPECT_EQ("Relu", attr["activation_mode"].s());
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(2, found);
|
||||||
|
|
||||||
|
if (GetNumAvailableGPUs() > 0) {
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
EXPECT_EQ(1, tensors_expected.size());
|
||||||
|
EXPECT_EQ(1, tensors.size());
|
||||||
|
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
|
||||||
|
}
|
||||||
|
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
||||||
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
|
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
|
||||||
|
"CUDNN_VERSION >= 7402.";
|
||||||
|
#else
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
||||||
|
auto channels_shape = ops::Placeholder::Shape({24});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||||
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||||
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
||||||
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
||||||
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
||||||
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
||||||
|
auto side_input =
|
||||||
|
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
|
||||||
|
auto side_input_cast =
|
||||||
|
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
||||||
|
|
||||||
|
float epsilon = 0.1f;
|
||||||
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
||||||
|
"NHWC"));
|
||||||
|
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||||
|
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||||
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||||
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||||
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||||
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||||
|
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||||
|
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_t}, {"scale", scale_t},
|
||||||
|
{"offset", offset_t}, {"mean", mean_t},
|
||||||
|
{"var", var_t}, {"side_input", side_input_t}};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on GPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||||
|
GraphDef output;
|
||||||
|
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "relu") {
|
||||||
|
EXPECT_EQ("Identity", node.op());
|
||||||
|
EXPECT_EQ("fused_batch_norm", node.input(0));
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "fused_batch_norm") {
|
||||||
|
EXPECT_EQ("_FusedBatchNormEx", node.op());
|
||||||
|
EXPECT_EQ("input_cast", node.input(0));
|
||||||
|
EXPECT_EQ("scale", node.input(1));
|
||||||
|
EXPECT_EQ("offset", node.input(2));
|
||||||
|
EXPECT_EQ("mean", node.input(3));
|
||||||
|
EXPECT_EQ("var", node.input(4));
|
||||||
|
EXPECT_EQ("side_input_cast", node.input(5));
|
||||||
|
|
||||||
|
auto attr = node.attr();
|
||||||
|
EXPECT_EQ(1, attr["num_side_inputs"].i());
|
||||||
|
EXPECT_EQ("Relu", attr["activation_mode"].s());
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(2, found);
|
||||||
|
|
||||||
|
if (GetNumAvailableGPUs() > 0) {
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
EXPECT_EQ(1, tensors_expected.size());
|
||||||
|
EXPECT_EQ(1, tensors.size());
|
||||||
|
test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, /*rtol=*/1e-2);
|
||||||
|
}
|
||||||
|
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
||||||
if (!EigenSupportsContractionOutputKernel()) return;
|
if (!EigenSupportsContractionOutputKernel()) return;
|
||||||
|
|
||||||
|
@ -1814,6 +1814,7 @@ tf_cuda_cc_test(
|
|||||||
":ops_util",
|
":ops_util",
|
||||||
":relu_op",
|
":relu_op",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:cc_ops_internal",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:direct_session",
|
"//tensorflow/core:direct_session",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -1825,6 +1826,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/stream_executor/cuda:cudnn_plugin",
|
"//tensorflow/stream_executor/cuda:cudnn_plugin",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,8 +14,10 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "tensorflow/cc/ops/const_op.h"
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
#include "tensorflow/cc/ops/nn_ops.h"
|
#include "tensorflow/cc/ops/nn_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||||
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
@ -29,6 +31,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||||
#include "tensorflow/core/public/session.h"
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "third_party/gpus/cudnn/cudnn.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
@ -39,25 +45,46 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
struct FusedBatchNormOutputs {
|
||||||
|
Tensor y;
|
||||||
|
Tensor batch_mean;
|
||||||
|
Tensor batch_variance;
|
||||||
|
Tensor reserve_space_1;
|
||||||
|
Tensor reserve_space_2;
|
||||||
|
Tensor reserve_space_3;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FusedBatchNormGradOutputs {
|
||||||
|
Tensor y_backprop;
|
||||||
|
Tensor x_backprop;
|
||||||
|
Tensor scale_backprop;
|
||||||
|
Tensor offset_backprop;
|
||||||
|
Tensor reserve_space_4;
|
||||||
|
Tensor reserve_space_5;
|
||||||
|
};
|
||||||
|
|
||||||
using GraphRunner = std::function<void(
|
using GraphRunner = std::function<void(
|
||||||
const Tensor& input_data, const Tensor& scale_data,
|
const Tensor& y_backprop, const Tensor& input_data,
|
||||||
const Tensor& offset_data, const Tensor& mean_data,
|
const Tensor& scale_data, const Tensor& offset_data,
|
||||||
const Tensor& var_data, const Tensor& side_input_data, Tensor* out)>;
|
const Tensor& mean_data, const Tensor& var_data,
|
||||||
|
const Tensor& side_input_data, FusedBatchNormOutputs* forward,
|
||||||
|
FusedBatchNormGradOutputs* backward)>;
|
||||||
|
|
||||||
// Runs a Tensorflow graph defined by the root scope, and fetches the result
|
// Runs a Tensorflow graph defined by the root scope, and fetches the result
|
||||||
// of 'fetch' node into the output Tensor. Optional `fetch_node` parameter
|
// of 'fetch' node into the outputs. Optional `add_nodes` parameter
|
||||||
// allows to define a fetch node directly using a NodeDef for the ops that are
|
// allows to define nodes directly using a NodeDef for the ops that are
|
||||||
// not supported by the C++ Api.
|
// not supported by the C++ Api.
|
||||||
// TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
|
// TODO(ezhulenev): RunAndFetch defined in FusedConv2D and FusedMatMul tests.
|
||||||
// Add a base class for all FusedABC kernels and remove code duplication.
|
// Add a base class for all FusedABC kernels and remove code duplication.
|
||||||
void RunAndFetch(const tensorflow::Scope& root, const string& fetch,
|
void RunAndFetch(const tensorflow::Scope& root,
|
||||||
Tensor* output, bool allow_gpu_device,
|
const std::vector<string>& fetch,
|
||||||
const NodeDef* fetch_node = nullptr) {
|
std::vector<Tensor>* outputs, bool allow_gpu_device,
|
||||||
|
const std::vector<const NodeDef*> add_nodes = {}) {
|
||||||
tensorflow::GraphDef graph;
|
tensorflow::GraphDef graph;
|
||||||
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
||||||
|
|
||||||
if (fetch_node) {
|
for (const NodeDef* add_node : add_nodes) {
|
||||||
*graph.add_node() = *fetch_node;
|
*graph.add_node() = *add_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We really want to make sure that graph executed exactly as we passed it
|
// We really want to make sure that graph executed exactly as we passed it
|
||||||
@ -101,64 +128,22 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSERT_OK(session->Create(graph));
|
TF_ASSERT_OK(session->Create(graph));
|
||||||
|
TF_ASSERT_OK(session->Run({}, fetch, {}, outputs));
|
||||||
std::vector<Tensor> unfused_tensors;
|
|
||||||
TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
|
|
||||||
|
|
||||||
*output = unfused_tensors[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RunFusedBatchNorm(const Tensor& input_data, const Tensor& scale_data,
|
void RunFusedBatchNorm(const Tensor& y_backprop_data,
|
||||||
|
const Tensor& input_data, const Tensor& scale_data,
|
||||||
const Tensor& offset_data, const Tensor& mean_data,
|
const Tensor& offset_data, const Tensor& mean_data,
|
||||||
const Tensor& var_data, const Tensor& side_input_data,
|
const Tensor& var_data, const Tensor& side_input_data,
|
||||||
const TensorFormat data_format, bool is_training,
|
const TensorFormat data_format, bool is_training,
|
||||||
bool has_side_input, const string& activation_mode,
|
bool has_side_input, const string& activation_mode,
|
||||||
Tensor* output, float epsilon = 0.1f) {
|
FusedBatchNormOutputs* forward,
|
||||||
|
FusedBatchNormGradOutputs* backward,
|
||||||
|
float epsilon = 0.1f) {
|
||||||
Scope root = tensorflow::Scope::NewRootScope();
|
Scope root = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
ops::FusedBatchNormV2 fbn = ops::FusedBatchNormV2(
|
Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
|
||||||
root.WithOpName("fused_batch_norm"),
|
Input::Initializer(y_backprop_data));
|
||||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
|
|
||||||
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data)),
|
|
||||||
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data)),
|
|
||||||
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data)),
|
|
||||||
ops::Const(root.WithOpName("var"), Input::Initializer(var_data)),
|
|
||||||
ops::FusedBatchNormV2::IsTraining(is_training)
|
|
||||||
.Epsilon(epsilon)
|
|
||||||
.DataFormat(ToString(data_format)));
|
|
||||||
|
|
||||||
Output with_side_input;
|
|
||||||
if (has_side_input) {
|
|
||||||
with_side_input =
|
|
||||||
ops::Add(root.WithOpName("with_side_input"), fbn.y,
|
|
||||||
ops::Const(root.WithOpName("side_input"),
|
|
||||||
Input::Initializer(side_input_data)));
|
|
||||||
} else {
|
|
||||||
with_side_input =
|
|
||||||
ops::Identity(root.WithOpName("with_side_input"), fbn.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (activation_mode == "Relu") {
|
|
||||||
ops::Relu(root.WithOpName("with_activation"), with_side_input);
|
|
||||||
} else {
|
|
||||||
ops::Identity(root.WithOpName("with_activation"), with_side_input);
|
|
||||||
}
|
|
||||||
|
|
||||||
RunAndFetch(root, "with_activation", output, /*allow_gpu_device=*/true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void RunFusedBatchNormEx(const Tensor& input_data, const Tensor& scale_data,
|
|
||||||
const Tensor& offset_data, const Tensor& mean_data,
|
|
||||||
const Tensor& var_data,
|
|
||||||
const Tensor& side_input_data,
|
|
||||||
const TensorFormat data_format, bool is_training,
|
|
||||||
bool has_side_input, const string& activation_mode,
|
|
||||||
Tensor* output, float epsilon = 0.1f) {
|
|
||||||
Scope root = tensorflow::Scope::NewRootScope();
|
|
||||||
|
|
||||||
DataType t_dtype = DataTypeToEnum<T>::v();
|
|
||||||
DataType u_dtype = DataTypeToEnum<U>::v();
|
|
||||||
|
|
||||||
Output input =
|
Output input =
|
||||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||||
Output scale =
|
Output scale =
|
||||||
@ -172,6 +157,101 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
|||||||
Output side_input = ops::Const(root.WithOpName("side_input"),
|
Output side_input = ops::Const(root.WithOpName("side_input"),
|
||||||
Input::Initializer(side_input_data));
|
Input::Initializer(side_input_data));
|
||||||
|
|
||||||
|
ops::FusedBatchNormV3 fwd = ops::FusedBatchNormV3(
|
||||||
|
root.WithOpName("fused_batch_norm"), input, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||||
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat(ToString(data_format)));
|
||||||
|
|
||||||
|
Output with_side_input;
|
||||||
|
if (has_side_input) {
|
||||||
|
with_side_input =
|
||||||
|
ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input);
|
||||||
|
} else {
|
||||||
|
with_side_input =
|
||||||
|
ops::Identity(root.WithOpName("with_side_input"), fwd.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
Output activation;
|
||||||
|
if (activation_mode == "Relu") {
|
||||||
|
activation =
|
||||||
|
ops::Relu(root.WithOpName("with_activation"), with_side_input);
|
||||||
|
} else {
|
||||||
|
activation =
|
||||||
|
ops::Identity(root.WithOpName("with_activation"), with_side_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
Output activation_grad;
|
||||||
|
if (activation_mode == "Relu") {
|
||||||
|
activation_grad = ops::internal::ReluGrad(
|
||||||
|
root.WithOpName("activation_grad"), y_backprop, activation);
|
||||||
|
} else {
|
||||||
|
activation_grad =
|
||||||
|
ops::Identity(root.WithOpName("activation_grad"), y_backprop);
|
||||||
|
}
|
||||||
|
|
||||||
|
ops::FusedBatchNormGradV3 bwd = ops::FusedBatchNormGradV3(
|
||||||
|
root.WithOpName("fused_batch_norm_grad"), activation_grad, input, scale,
|
||||||
|
fwd.reserve_space_1, fwd.reserve_space_2, fwd.reserve_space_3,
|
||||||
|
ops::FusedBatchNormGradV3::IsTraining(is_training)
|
||||||
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat(ToString(data_format)));
|
||||||
|
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
RunAndFetch(
|
||||||
|
root,
|
||||||
|
{"with_activation:0", "fused_batch_norm:1", "fused_batch_norm:2",
|
||||||
|
"fused_batch_norm:3", "fused_batch_norm:4", "fused_batch_norm:5",
|
||||||
|
"activation_grad:0", "fused_batch_norm_grad:0",
|
||||||
|
"fused_batch_norm_grad:1", "fused_batch_norm_grad:2"},
|
||||||
|
&out_tensors, /*allow_gpu_device=*/true);
|
||||||
|
|
||||||
|
forward->y = out_tensors[0];
|
||||||
|
forward->batch_mean = out_tensors[1];
|
||||||
|
forward->batch_variance = out_tensors[2];
|
||||||
|
forward->reserve_space_1 = out_tensors[3];
|
||||||
|
forward->reserve_space_2 = out_tensors[4];
|
||||||
|
forward->reserve_space_3 = out_tensors[5];
|
||||||
|
|
||||||
|
backward->y_backprop = out_tensors[6];
|
||||||
|
backward->x_backprop = out_tensors[7];
|
||||||
|
backward->scale_backprop = out_tensors[8];
|
||||||
|
backward->offset_backprop = out_tensors[9];
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunFusedBatchNormEx(const Tensor& y_backprop_data,
|
||||||
|
const Tensor& input_data, const Tensor& scale_data,
|
||||||
|
const Tensor& offset_data, const Tensor& mean_data,
|
||||||
|
const Tensor& var_data,
|
||||||
|
const Tensor& side_input_data,
|
||||||
|
const TensorFormat data_format, bool is_training,
|
||||||
|
bool has_side_input, const string& activation_mode,
|
||||||
|
FusedBatchNormOutputs* forward,
|
||||||
|
FusedBatchNormGradOutputs* backward,
|
||||||
|
float epsilon = 0.1f) {
|
||||||
|
Scope root = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
|
DataType t_dtype = DataTypeToEnum<T>::v();
|
||||||
|
DataType u_dtype = DataTypeToEnum<U>::v();
|
||||||
|
|
||||||
|
Output y_backprop = ops::Const(root.WithOpName("y_backprop"),
|
||||||
|
Input::Initializer(y_backprop_data));
|
||||||
|
Output input =
|
||||||
|
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||||
|
Output scale =
|
||||||
|
ops::Const(root.WithOpName("scale"), Input::Initializer(scale_data));
|
||||||
|
Output offset =
|
||||||
|
ops::Const(root.WithOpName("offset"), Input::Initializer(offset_data));
|
||||||
|
Output mean =
|
||||||
|
ops::Const(root.WithOpName("mean"), Input::Initializer(mean_data));
|
||||||
|
Output var =
|
||||||
|
ops::Const(root.WithOpName("var"), Input::Initializer(var_data));
|
||||||
|
Output side_input = ops::Const(root.WithOpName("side_input"),
|
||||||
|
Input::Initializer(side_input_data));
|
||||||
|
Output empty =
|
||||||
|
ops::Const(root.WithOpName("empty"),
|
||||||
|
Input::Initializer(Tensor(DataTypeToEnum<U>::value, {0})));
|
||||||
|
|
||||||
int num_side_inputs = 0;
|
int num_side_inputs = 0;
|
||||||
std::vector<NodeDefBuilder::NodeOut> side_inputs;
|
std::vector<NodeDefBuilder::NodeOut> side_inputs;
|
||||||
|
|
||||||
@ -197,8 +277,58 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
|||||||
.Attr("is_training", is_training)
|
.Attr("is_training", is_training)
|
||||||
.Finalize(&fused_batch_norm_ex));
|
.Finalize(&fused_batch_norm_ex));
|
||||||
|
|
||||||
RunAndFetch(root, fused_batch_norm_ex.name(), output,
|
NodeDef activation_grad;
|
||||||
/*allow_gpu_device=*/true, &fused_batch_norm_ex);
|
if (activation_mode == "Relu") {
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "ReluGrad")
|
||||||
|
.Input({y_backprop.name(), 0, t_dtype})
|
||||||
|
.Input({fused_batch_norm_ex.name(), 0, t_dtype})
|
||||||
|
.Attr("T", t_dtype)
|
||||||
|
.Finalize(&activation_grad));
|
||||||
|
} else {
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("activation_grad", "Identity")
|
||||||
|
.Input({y_backprop.name(), 0, t_dtype})
|
||||||
|
.Attr("T", t_dtype)
|
||||||
|
.Finalize(&activation_grad));
|
||||||
|
}
|
||||||
|
|
||||||
|
NodeDef fused_batch_norm_grad;
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("fused_batch_norm_grad", "FusedBatchNormGradV3")
|
||||||
|
.Input({activation_grad.name(), 0, t_dtype})
|
||||||
|
.Input({input.name(), 0, t_dtype})
|
||||||
|
.Input({scale.name(), 0, u_dtype})
|
||||||
|
.Input({fused_batch_norm_ex.name(), 3, u_dtype})
|
||||||
|
.Input({fused_batch_norm_ex.name(), 4, u_dtype})
|
||||||
|
.Input({fused_batch_norm_ex.name(), 5, u_dtype})
|
||||||
|
.Attr("T", t_dtype)
|
||||||
|
.Attr("U", u_dtype)
|
||||||
|
.Attr("data_format", ToString(data_format))
|
||||||
|
.Attr("epsilon", epsilon)
|
||||||
|
.Attr("is_training", is_training)
|
||||||
|
.Finalize(&fused_batch_norm_grad));
|
||||||
|
|
||||||
|
std::vector<Tensor> out_tensors;
|
||||||
|
RunAndFetch(
|
||||||
|
root,
|
||||||
|
{"fused_batch_norm_ex:0", "fused_batch_norm_ex:1",
|
||||||
|
"fused_batch_norm_ex:2", "fused_batch_norm_ex:3",
|
||||||
|
"fused_batch_norm_ex:4", "fused_batch_norm_ex:5", "activation_grad:0",
|
||||||
|
"fused_batch_norm_grad:0", "fused_batch_norm_grad:1",
|
||||||
|
"fused_batch_norm_grad:2"},
|
||||||
|
&out_tensors,
|
||||||
|
/*allow_gpu_device=*/true,
|
||||||
|
{&fused_batch_norm_ex, &activation_grad, &fused_batch_norm_grad});
|
||||||
|
|
||||||
|
forward->y = out_tensors[0];
|
||||||
|
forward->batch_mean = out_tensors[1];
|
||||||
|
forward->batch_variance = out_tensors[2];
|
||||||
|
forward->reserve_space_1 = out_tensors[3];
|
||||||
|
forward->reserve_space_2 = out_tensors[4];
|
||||||
|
forward->reserve_space_3 = out_tensors[5];
|
||||||
|
|
||||||
|
backward->y_backprop = out_tensors[6];
|
||||||
|
backward->x_backprop = out_tensors[7];
|
||||||
|
backward->scale_backprop = out_tensors[8];
|
||||||
|
backward->offset_backprop = out_tensors[9];
|
||||||
}
|
}
|
||||||
|
|
||||||
void VerifyTensorsNear(int batch, int height, int width, int channels,
|
void VerifyTensorsNear(int batch, int height, int width, int channels,
|
||||||
@ -215,6 +345,7 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
|||||||
|
|
||||||
Tensor input(t_dtype, input_shape);
|
Tensor input(t_dtype, input_shape);
|
||||||
input.flat<T>().setRandom();
|
input.flat<T>().setRandom();
|
||||||
|
input.flat<T>() -= input.flat<T>().constant(static_cast<T>(0.5));
|
||||||
|
|
||||||
Tensor scale(u_dtype, {channels});
|
Tensor scale(u_dtype, {channels});
|
||||||
scale.flat<U>().setRandom();
|
scale.flat<U>().setRandom();
|
||||||
@ -228,29 +359,60 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
|||||||
Tensor var(u_dtype, {channels});
|
Tensor var(u_dtype, {channels});
|
||||||
var.flat<U>().setRandom();
|
var.flat<U>().setRandom();
|
||||||
|
|
||||||
Tensor empty(u_dtype, {0});
|
|
||||||
|
|
||||||
Tensor fused_batch_norm;
|
|
||||||
Tensor fused_batch_norm_ex;
|
|
||||||
|
|
||||||
Tensor side_input(t_dtype, input_shape);
|
Tensor side_input(t_dtype, input_shape);
|
||||||
side_input.flat<T>().setRandom();
|
side_input.flat<T>().setRandom();
|
||||||
|
side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0));
|
||||||
|
|
||||||
run_default(input, scale, offset, is_training ? empty : mean,
|
Tensor y_backprop(t_dtype, input_shape);
|
||||||
is_training ? empty : var, side_input, &fused_batch_norm);
|
y_backprop.flat<T>().setRandom();
|
||||||
|
y_backprop.flat<T>() -= y_backprop.flat<T>().constant(static_cast<T>(0.5));
|
||||||
|
|
||||||
// Write some garbage to the `fused_batch_norm_ex` first to make sure
|
Tensor empty(u_dtype, {0});
|
||||||
// that fused kernel actually writes correct results to memory.
|
|
||||||
run_default(side_input, scale, offset, is_training ? empty : mean,
|
|
||||||
is_training ? empty : var, input, &fused_batch_norm_ex);
|
|
||||||
|
|
||||||
run_fused(input, scale, offset, is_training ? empty : mean,
|
FusedBatchNormOutputs fbn_forward;
|
||||||
is_training ? empty : var, side_input, &fused_batch_norm_ex);
|
FusedBatchNormOutputs fbn_ex_forward;
|
||||||
|
|
||||||
ASSERT_EQ(fused_batch_norm.dtype(), fused_batch_norm_ex.dtype());
|
FusedBatchNormGradOutputs fbn_backward;
|
||||||
ASSERT_EQ(fused_batch_norm.shape(), fused_batch_norm_ex.shape());
|
FusedBatchNormGradOutputs fbn_ex_backward;
|
||||||
|
|
||||||
test::ExpectClose(fused_batch_norm, fused_batch_norm_ex, 1e-2);
|
run_default(y_backprop, input, scale, offset, is_training ? empty : mean,
|
||||||
|
is_training ? empty : var, side_input, &fbn_forward,
|
||||||
|
&fbn_backward);
|
||||||
|
|
||||||
|
// Write some garbage to the `fbn_ex_forward` and `fbn_ex_backward` first to
|
||||||
|
// make sure that fused kernel actually writes correct results to memory.
|
||||||
|
run_default(y_backprop, side_input, scale, offset,
|
||||||
|
is_training ? empty : mean, is_training ? empty : var, input,
|
||||||
|
&fbn_ex_forward, &fbn_ex_backward);
|
||||||
|
|
||||||
|
run_fused(y_backprop, input, scale, offset, is_training ? empty : mean,
|
||||||
|
is_training ? empty : var, side_input, &fbn_ex_forward,
|
||||||
|
&fbn_ex_backward);
|
||||||
|
|
||||||
|
std::vector<std::pair<Tensor, Tensor>> tensor_pairs = {
|
||||||
|
{fbn_forward.y, fbn_ex_forward.y},
|
||||||
|
{fbn_forward.batch_mean, fbn_ex_forward.batch_mean},
|
||||||
|
{fbn_forward.batch_variance, fbn_ex_forward.batch_variance},
|
||||||
|
{fbn_forward.reserve_space_1, fbn_ex_forward.reserve_space_1},
|
||||||
|
{fbn_forward.reserve_space_2, fbn_ex_forward.reserve_space_2},
|
||||||
|
// NOTE(ezhulenev): We deliberately do not check `reserved_space_3`
|
||||||
|
// because BatchNormEx with fused side input has different data in it,
|
||||||
|
// but we make sure that final gradients are the same.
|
||||||
|
{fbn_backward.y_backprop, fbn_ex_backward.y_backprop},
|
||||||
|
{fbn_backward.x_backprop, fbn_ex_backward.x_backprop},
|
||||||
|
{fbn_backward.scale_backprop, fbn_ex_backward.scale_backprop},
|
||||||
|
{fbn_backward.offset_backprop, fbn_ex_backward.offset_backprop},
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& pair : tensor_pairs) {
|
||||||
|
const Tensor& fbn = pair.first;
|
||||||
|
const Tensor& fbn_ex = pair.second;
|
||||||
|
|
||||||
|
ASSERT_EQ(fbn.dtype(), fbn_ex.dtype());
|
||||||
|
ASSERT_EQ(fbn.shape(), fbn_ex.shape());
|
||||||
|
|
||||||
|
test::ExpectClose(fbn, fbn_ex, 1e-2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
|
// Verifies that computing FusedBatchNormOp+{SideInput}+{Activation} is
|
||||||
@ -260,25 +422,27 @@ class FusedBatchNormExOpTestBase : public OpsTestBase {
|
|||||||
bool has_side_input,
|
bool has_side_input,
|
||||||
const string& activation_mode) {
|
const string& activation_mode) {
|
||||||
const GraphRunner run_default =
|
const GraphRunner run_default =
|
||||||
[&](const Tensor& input_data, const Tensor& scale_data,
|
[&](const Tensor& y_backprop, const Tensor& input_data,
|
||||||
const Tensor& offset_data, const Tensor& mean_data,
|
const Tensor& scale_data, const Tensor& offset_data,
|
||||||
const Tensor& var_data, const Tensor& side_input_data,
|
const Tensor& mean_data, const Tensor& var_data,
|
||||||
Tensor* out) {
|
const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
|
||||||
this->RunFusedBatchNorm(input_data, scale_data, offset_data,
|
FusedBatchNormGradOutputs* bwd) {
|
||||||
mean_data, var_data, side_input_data,
|
this->RunFusedBatchNorm(y_backprop, input_data, scale_data,
|
||||||
data_format, is_training, has_side_input,
|
offset_data, mean_data, var_data,
|
||||||
activation_mode, out);
|
side_input_data, data_format, is_training,
|
||||||
|
has_side_input, activation_mode, fwd, bwd);
|
||||||
};
|
};
|
||||||
|
|
||||||
const GraphRunner run_inference =
|
const GraphRunner run_inference =
|
||||||
[&](const Tensor& input_data, const Tensor& scale_data,
|
[&](const Tensor& y_backprop, const Tensor& input_data,
|
||||||
const Tensor& offset_data, const Tensor& mean_data,
|
const Tensor& scale_data, const Tensor& offset_data,
|
||||||
const Tensor& var_data, const Tensor& side_input_data,
|
const Tensor& mean_data, const Tensor& var_data,
|
||||||
Tensor* out) {
|
const Tensor& side_input_data, FusedBatchNormOutputs* fwd,
|
||||||
this->RunFusedBatchNormEx(input_data, scale_data, offset_data,
|
FusedBatchNormGradOutputs* bwd) {
|
||||||
mean_data, var_data, side_input_data,
|
this->RunFusedBatchNormEx(y_backprop, input_data, scale_data,
|
||||||
data_format, is_training, has_side_input,
|
offset_data, mean_data, var_data,
|
||||||
activation_mode, out);
|
side_input_data, data_format, is_training,
|
||||||
|
has_side_input, activation_mode, fwd, bwd);
|
||||||
};
|
};
|
||||||
|
|
||||||
VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
|
VerifyTensorsNear(batch, height, width, channels, data_format, is_training,
|
||||||
@ -297,17 +461,17 @@ constexpr bool kWithSideInput = true; // side_input == true
|
|||||||
TYPED_TEST_SUITE_P(FusedBatchNormExOpTest);
|
TYPED_TEST_SUITE_P(FusedBatchNormExOpTest);
|
||||||
|
|
||||||
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingInNHWCTest) {
|
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingInNHWCTest) {
|
||||||
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
|
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
|
||||||
kNoSideInput, "Identity");
|
kNoSideInput, "Identity");
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithReluInNHWCTest) {
|
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithReluInNHWCTest) {
|
||||||
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
|
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
|
||||||
kNoSideInput, "Relu");
|
kNoSideInput, "Relu");
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithSideInputAndReluInNHWCTest) {
|
TYPED_TEST_P(FusedBatchNormExOpTest, TrainingWithSideInputAndReluInNHWCTest) {
|
||||||
this->VerifyFusedBatchNormEx(2, 2, 2, 4, FORMAT_NHWC, kInTraining,
|
this->VerifyFusedBatchNormEx(4, 28, 28, 256, FORMAT_NHWC, kInTraining,
|
||||||
kWithSideInput, "Relu");
|
kWithSideInput, "Relu");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -692,8 +692,8 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
|
|||||||
<< " y_backprop shape: " << y_backprop.shape().DebugString()
|
<< " y_backprop shape: " << y_backprop.shape().DebugString()
|
||||||
<< " x shape: " << x.shape().DebugString()
|
<< " x shape: " << x.shape().DebugString()
|
||||||
<< " scale shape: " << scale.shape().DebugString()
|
<< " scale shape: " << scale.shape().DebugString()
|
||||||
<< " tensor format: " << tensor_format
|
<< " tensor format: " << ToString(tensor_format)
|
||||||
<< " compute format: " << compute_format;
|
<< " compute format: " << ToString(compute_format);
|
||||||
|
|
||||||
// Inputs
|
// Inputs
|
||||||
Tensor y_backprop_maybe_transformed = y_backprop;
|
Tensor y_backprop_maybe_transformed = y_backprop;
|
||||||
|
Loading…
Reference in New Issue
Block a user