Roll forward of https://github.com/tensorflow/tensorflow/pull/42970 with fix for 5D tensors and unsupported axis that cannot use fused batch norm.
PiperOrigin-RevId: 335071264 Change-Id: Id9ab44bdba870336f03522e1ca76d88b1f305a10
This commit is contained in:
parent
eed3ab97e5
commit
27d26a8d86
@ -1121,8 +1121,17 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||||
|
string data_format_str;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||||
|
TensorFormat data_format;
|
||||||
|
if (!FormatFromString(data_format_str, &data_format)) {
|
||||||
|
return errors::InvalidArgument("Invalid data format string: ",
|
||||||
|
data_format_str);
|
||||||
|
}
|
||||||
|
const int rank =
|
||||||
|
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||||
ShapeHandle x;
|
ShapeHandle x;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
|
||||||
|
|
||||||
bool is_training;
|
bool is_training;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||||
@ -1131,14 +1140,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
|||||||
exponential_avg_factor = 1.0f; // default value
|
exponential_avg_factor = 1.0f; // default value
|
||||||
}
|
}
|
||||||
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
||||||
string data_format_str;
|
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||||
TensorFormat data_format;
|
|
||||||
if (!FormatFromString(data_format_str, &data_format)) {
|
|
||||||
return errors::InvalidArgument("Invalid data format string: ",
|
|
||||||
data_format_str);
|
|
||||||
}
|
|
||||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
|
||||||
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
|
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
|
||||||
|
|
||||||
// covers scale, offset, and if is_training is false, mean, variance
|
// covers scale, offset, and if is_training is false, mean, variance
|
||||||
@ -1191,13 +1194,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||||
ShapeHandle y_backprop;
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
|
|
||||||
ShapeHandle x;
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
|
|
||||||
|
|
||||||
bool is_training;
|
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
|
||||||
string data_format_str;
|
string data_format_str;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||||
TensorFormat data_format;
|
TensorFormat data_format;
|
||||||
@ -1205,7 +1201,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
|||||||
return errors::InvalidArgument("Invalid data format string: ",
|
return errors::InvalidArgument("Invalid data format string: ",
|
||||||
data_format_str);
|
data_format_str);
|
||||||
}
|
}
|
||||||
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
|
const int rank =
|
||||||
|
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
|
||||||
|
ShapeHandle y_backprop;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
|
||||||
|
ShapeHandle x;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
|
||||||
|
|
||||||
|
bool is_training;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||||
|
|
||||||
|
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
|
||||||
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
|
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
|
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
|
||||||
|
@ -670,7 +670,25 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
|
|||||||
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
||||||
TransposeContext* context, utils::MutableNodeView* node) {
|
TransposeContext* context, utils::MutableNodeView* node) {
|
||||||
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
|
||||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
|
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||||
|
const auto& shape = output_shape_attr->list().shape(0);
|
||||||
|
const int rank = shape.dim_size();
|
||||||
|
std::string src_format = context->src_format;
|
||||||
|
std::string dst_format = context->dst_format;
|
||||||
|
// Update the format from 4D to 5D layout if necessary.
|
||||||
|
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||||
|
if (allow_5d) {
|
||||||
|
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||||
|
dst_format_3d);
|
||||||
|
}
|
||||||
|
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
|
||||||
|
// Change back to the original layout due to early exit.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
@ -679,6 +697,11 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
|
|||||||
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
TF_RETURN_IF_ERROR(UpdateNode(context, node));
|
||||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
|
// Change back the format from 5D to 4D layout.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return context->graph_view->GetMutationBuilder()->Apply();
|
return context->graph_view->GetMutationBuilder()->Apply();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -881,8 +904,26 @@ bool FusedBatchNormGradTransposer::IsTraining(
|
|||||||
Status FusedBatchNormGradTransposer::TransposeNode(
|
Status FusedBatchNormGradTransposer::TransposeNode(
|
||||||
TransposeContext* context, utils::MutableNodeView* node) {
|
TransposeContext* context, utils::MutableNodeView* node) {
|
||||||
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
DCHECK(IsFusedBatchNormGrad(*node->node()));
|
||||||
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
|
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
|
||||||
|
const auto& shape = output_shape_attr->list().shape(0);
|
||||||
|
const int rank = shape.dim_size();
|
||||||
|
std::string src_format = context->src_format;
|
||||||
|
std::string dst_format = context->dst_format;
|
||||||
|
// Update the format from 4D to 5D layout if necessary.
|
||||||
|
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
|
||||||
|
if (allow_5d) {
|
||||||
|
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||||
|
dst_format_3d);
|
||||||
|
}
|
||||||
|
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
|
||||||
!IsTraining(*node)) {
|
!IsTraining(*node)) {
|
||||||
|
// Change back to the original layout due to early exit.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
@ -892,6 +933,11 @@ Status FusedBatchNormGradTransposer::TransposeNode(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
|
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
|
||||||
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
|
// Change back the format from 5D to 4D layout.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return context->graph_view->GetMutationBuilder()->Apply();
|
return context->graph_view->GetMutationBuilder()->Apply();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1438,16 +1438,18 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||||
Status status;
|
Status status;
|
||||||
|
|
||||||
if (fused_node.attr().at(kDataFormat).s() == "NCHW") {
|
string x_format = fused_node.attr().at(kDataFormat).s();
|
||||||
|
if (x_format == "NCHW" or x_format == "NCDHW") {
|
||||||
// Need to reshape the last 4 inputs
|
// Need to reshape the last 4 inputs
|
||||||
NodeDef new_shape;
|
NodeDef new_shape;
|
||||||
const string new_shape_name =
|
const string new_shape_name =
|
||||||
AddPrefixToNodeName("NCHWShape", fused_node.name());
|
AddPrefixToNodeName(x_format + "Shape", fused_node.name());
|
||||||
new_shape.set_name(new_shape_name);
|
new_shape.set_name(new_shape_name);
|
||||||
new_shape.set_op("Const");
|
new_shape.set_op("Const");
|
||||||
new_shape.set_device(fused_node.device());
|
new_shape.set_device(fused_node.device());
|
||||||
*new_shape.add_input() = AsControlDependency(scale);
|
*new_shape.add_input() = AsControlDependency(scale);
|
||||||
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
|
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
|
||||||
|
if (x_format == "NCHW") {
|
||||||
Tensor t(DT_INT32, {4});
|
Tensor t(DT_INT32, {4});
|
||||||
t.flat<int32>()(0) = 1;
|
t.flat<int32>()(0) = 1;
|
||||||
t.flat<int32>()(1) = -1;
|
t.flat<int32>()(1) = -1;
|
||||||
@ -1455,12 +1457,22 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
t.flat<int32>()(3) = 1;
|
t.flat<int32>()(3) = 1;
|
||||||
t.AsProtoTensorContent(
|
t.AsProtoTensorContent(
|
||||||
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||||
|
} else {
|
||||||
|
Tensor t(DT_INT32, {5});
|
||||||
|
t.flat<int32>()(0) = 1;
|
||||||
|
t.flat<int32>()(1) = -1;
|
||||||
|
t.flat<int32>()(2) = 1;
|
||||||
|
t.flat<int32>()(3) = 1;
|
||||||
|
t.flat<int32>()(4) = 1;
|
||||||
|
t.AsProtoTensorContent(
|
||||||
|
(*new_shape.mutable_attr())["value"].mutable_tensor());
|
||||||
|
}
|
||||||
mutation->AddNode(std::move(new_shape), &status);
|
mutation->AddNode(std::move(new_shape), &status);
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
NodeDef reshaped_scale;
|
NodeDef reshaped_scale;
|
||||||
reshaped_scale.set_name(
|
reshaped_scale.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
|
||||||
reshaped_scale.set_op("Reshape");
|
reshaped_scale.set_op("Reshape");
|
||||||
reshaped_scale.set_device(fused_node.device());
|
reshaped_scale.set_device(fused_node.device());
|
||||||
*reshaped_scale.add_input() = scale;
|
*reshaped_scale.add_input() = scale;
|
||||||
@ -1473,7 +1485,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
|
|
||||||
NodeDef reshaped_offset;
|
NodeDef reshaped_offset;
|
||||||
reshaped_offset.set_name(
|
reshaped_offset.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
|
||||||
reshaped_offset.set_op("Reshape");
|
reshaped_offset.set_op("Reshape");
|
||||||
reshaped_offset.set_device(fused_node.device());
|
reshaped_offset.set_device(fused_node.device());
|
||||||
*reshaped_offset.add_input() = offset;
|
*reshaped_offset.add_input() = offset;
|
||||||
@ -1486,7 +1498,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
|
|
||||||
NodeDef reshaped_mean;
|
NodeDef reshaped_mean;
|
||||||
reshaped_mean.set_name(
|
reshaped_mean.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
|
||||||
reshaped_mean.set_op("Reshape");
|
reshaped_mean.set_op("Reshape");
|
||||||
reshaped_mean.set_device(fused_node.device());
|
reshaped_mean.set_device(fused_node.device());
|
||||||
*reshaped_mean.add_input() = mean;
|
*reshaped_mean.add_input() = mean;
|
||||||
@ -1499,7 +1511,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
|
|||||||
|
|
||||||
NodeDef reshaped_variance;
|
NodeDef reshaped_variance;
|
||||||
reshaped_variance.set_name(
|
reshaped_variance.set_name(
|
||||||
AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
|
AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
|
||||||
reshaped_variance.set_op("Reshape");
|
reshaped_variance.set_op("Reshape");
|
||||||
reshaped_variance.set_device(fused_node.device());
|
reshaped_variance.set_device(fused_node.device());
|
||||||
*reshaped_variance.add_input() = variance;
|
*reshaped_variance.add_input() = variance;
|
||||||
|
@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
// If use_reserved_space is false, we don't have 5th output.
|
// If use_reserved_space is false, we don't have 5th output.
|
||||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||||
bool use_reserved_space) {
|
bool use_reserved_space) {
|
||||||
const Tensor& x = context->input(0);
|
Tensor x = context->input(0);
|
||||||
const Tensor& scale = context->input(1);
|
const Tensor& scale = context->input(1);
|
||||||
const Tensor& offset = context->input(2);
|
const Tensor& offset = context->input(2);
|
||||||
const Tensor& estimated_mean = context->input(3);
|
const Tensor& estimated_mean = context->input(3);
|
||||||
const Tensor& estimated_variance = context->input(4);
|
const Tensor& estimated_variance = context->input(4);
|
||||||
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
|
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
|
||||||
|
|
||||||
OP_REQUIRES(context, x.dims() == 4,
|
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||||
errors::InvalidArgument("input must be 4-dimensional",
|
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||||
x.shape().DebugString()));
|
x.shape().DebugString()));
|
||||||
OP_REQUIRES(context, scale.dims() == 1,
|
OP_REQUIRES(context, scale.dims() == 1,
|
||||||
errors::InvalidArgument("scale must be 1-dimensional",
|
errors::InvalidArgument("scale must be 1-dimensional",
|
||||||
@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
context, estimated_variance.dims() == 1,
|
context, estimated_variance.dims() == 1,
|
||||||
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||||
estimated_variance.shape().DebugString()));
|
estimated_variance.shape().DebugString()));
|
||||||
|
bool use_reshape = (x.dims() == 5);
|
||||||
|
auto x_shape = x.shape();
|
||||||
|
TensorShape dest_shape;
|
||||||
|
if (use_reshape) {
|
||||||
|
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||||
|
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||||
|
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||||
|
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||||
|
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||||
|
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||||
|
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||||
|
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
|
|
||||||
if (has_side_input_) {
|
if (has_side_input_) {
|
||||||
OP_REQUIRES(context, side_input->shape() == x.shape(),
|
OP_REQUIRES(context, side_input->shape() == x.shape(),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Tensor* y = nullptr;
|
Tensor* y = nullptr;
|
||||||
|
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||||
{0}, 0, x.shape(), &y));
|
{0}, 0, alloc_shape, &y));
|
||||||
|
|
||||||
Tensor* batch_mean = nullptr;
|
Tensor* batch_mean = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||||
{3}, 1, scale.shape(), &batch_mean));
|
{3}, 1, scale.shape(), &batch_mean));
|
||||||
@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel {
|
|||||||
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
|
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
|
||||||
tensor_format_, use_reserved_space);
|
tensor_format_, use_reserved_space);
|
||||||
}
|
}
|
||||||
|
if (use_reshape) {
|
||||||
|
OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
|
|
||||||
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
virtual void ComputeWithReservedSpace(OpKernelContext* context,
|
||||||
bool use_reserved_space) {
|
bool use_reserved_space) {
|
||||||
const Tensor& y_backprop = context->input(0);
|
Tensor y_backprop = context->input(0);
|
||||||
const Tensor& x = context->input(1);
|
Tensor x = context->input(1);
|
||||||
const Tensor& scale = context->input(2);
|
const Tensor& scale = context->input(2);
|
||||||
// When is_training=True, batch mean and variance/inverted variance are
|
// When is_training=True, batch mean and variance/inverted variance are
|
||||||
// saved in the forward pass to be reused here. When is_training=False,
|
// saved in the forward pass to be reused here. When is_training=False,
|
||||||
@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
// saves inverted variance.
|
// saves inverted variance.
|
||||||
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
|
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
|
||||||
|
|
||||||
OP_REQUIRES(context, y_backprop.dims() == 4,
|
OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
|
||||||
errors::InvalidArgument("input must be 4-dimensional",
|
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||||
y_backprop.shape().DebugString()));
|
y_backprop.shape().DebugString()));
|
||||||
OP_REQUIRES(context, x.dims() == 4,
|
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
|
||||||
errors::InvalidArgument("input must be 4-dimensional",
|
errors::InvalidArgument("input must be 4 or 5-dimensional",
|
||||||
x.shape().DebugString()));
|
x.shape().DebugString()));
|
||||||
OP_REQUIRES(context, scale.dims() == 1,
|
OP_REQUIRES(context, scale.dims() == 1,
|
||||||
errors::InvalidArgument("scale must be 1-dimensional",
|
errors::InvalidArgument("scale must be 1-dimensional",
|
||||||
@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"saved variance must be 1-dimensional",
|
"saved variance must be 1-dimensional",
|
||||||
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
|
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
|
||||||
|
bool use_reshape = (x.dims() == 5);
|
||||||
|
auto x_shape = x.shape();
|
||||||
|
TensorShape dest_shape;
|
||||||
|
if (use_reshape) {
|
||||||
|
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
|
||||||
|
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
|
||||||
|
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
|
||||||
|
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
|
||||||
|
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
|
||||||
|
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
|
||||||
|
{{in_planes, in_rows * in_cols}}, in_depth);
|
||||||
|
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
|
|
||||||
Tensor* x_backprop = nullptr;
|
Tensor* x_backprop = nullptr;
|
||||||
|
auto alloc_shape = use_reshape ? dest_shape : x_shape;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, x.shape(), &x_backprop));
|
context->allocate_output(0, alloc_shape, &x_backprop));
|
||||||
|
|
||||||
const TensorShape& scale_offset_shape = scale.shape();
|
const TensorShape& scale_offset_shape = scale.shape();
|
||||||
Tensor* scale_backprop = nullptr;
|
Tensor* scale_backprop = nullptr;
|
||||||
@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
|||||||
offset_backprop, use_reserved_space, tensor_format_);
|
offset_backprop, use_reserved_space, tensor_format_);
|
||||||
} else {
|
} else {
|
||||||
// Necessary layout conversion is currently done in python.
|
// Necessary layout conversion is currently done in python.
|
||||||
CHECK(tensor_format_ == FORMAT_NHWC)
|
OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
|
||||||
<< "The implementation of FusedBatchNormGrad with is_training=False "
|
errors::InvalidArgument(
|
||||||
"only support "
|
"The implementation of "
|
||||||
<< "NHWC tensor format for now.";
|
"FusedBatchNormGrad with is_training=False only support "
|
||||||
|
"NHWC tensor format for now."));
|
||||||
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
|
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
|
||||||
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
||||||
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
|
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
|
||||||
offset_backprop);
|
offset_backprop);
|
||||||
}
|
}
|
||||||
|
if (use_reshape) {
|
||||||
|
OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
|
||||||
|
errors::InvalidArgument("Error during tensor copy."));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -221,7 +221,7 @@ REGISTER_OP("FusedBatchNormV3")
|
|||||||
.Attr("U: {float}")
|
.Attr("U: {float}")
|
||||||
.Attr("epsilon: float = 0.0001")
|
.Attr("epsilon: float = 0.0001")
|
||||||
.Attr("exponential_avg_factor: float = 1.0")
|
.Attr("exponential_avg_factor: float = 1.0")
|
||||||
.Attr(GetConvnetDataFormatAttrString())
|
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||||
.Attr("is_training: bool = true")
|
.Attr("is_training: bool = true")
|
||||||
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
|
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ REGISTER_OP("FusedBatchNormGradV3")
|
|||||||
.Attr("T: {half, bfloat16, float}")
|
.Attr("T: {half, bfloat16, float}")
|
||||||
.Attr("U: {float}")
|
.Attr("U: {float}")
|
||||||
.Attr("epsilon: float = 0.0001")
|
.Attr("epsilon: float = 0.0001")
|
||||||
.Attr(GetConvnetDataFormatAttrString())
|
.Attr(GetConvnetDataFormat2D3DAttrString())
|
||||||
.Attr("is_training: bool = true")
|
.Attr("is_training: bool = true")
|
||||||
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
|
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
@ -1275,6 +1275,94 @@ class LayoutOptimizerTest(test.TestCase):
|
|||||||
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
|
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
|
||||||
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
|
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
|
||||||
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
|
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
|
||||||
|
|
||||||
|
@test_util.deprecated_graph_mode_only
|
||||||
|
def testBatchNorm3D(self):
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
random_seed.set_random_seed(0)
|
||||||
|
x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
||||||
|
filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
||||||
|
strides_val = [1, 1, 1, 1, 1]
|
||||||
|
scale = constant_op.constant(0.1, shape=[3])
|
||||||
|
offset = constant_op.constant(0.3, shape=[3])
|
||||||
|
conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
|
||||||
|
y, _, _ = nn.fused_batch_norm(conv3d, scale, offset, data_format='NDHWC')
|
||||||
|
output = array_ops.identity(y)
|
||||||
|
|
||||||
|
with session.Session(config=_get_config(False)) as sess:
|
||||||
|
output_val_ref = sess.run(output)
|
||||||
|
|
||||||
|
with session.Session(config=_get_config()) as sess:
|
||||||
|
metadata = config_pb2.RunMetadata()
|
||||||
|
output_val = sess.run(output, run_metadata=metadata)
|
||||||
|
|
||||||
|
nodes = []
|
||||||
|
num_transposes = 0
|
||||||
|
for node in metadata.cost_graph.node:
|
||||||
|
if _is_transpose(node.name):
|
||||||
|
num_transposes += 1
|
||||||
|
nodes.append(node.name)
|
||||||
|
|
||||||
|
expected_num_transposes = 2
|
||||||
|
self.assertEqual(expected_num_transposes, num_transposes)
|
||||||
|
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
||||||
|
self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormV3-0-0', nodes)
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||||
|
|
||||||
|
@test_util.deprecated_graph_mode_only
|
||||||
|
def testBatchNormGrad3D(self):
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
random_seed.set_random_seed(0)
|
||||||
|
x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
||||||
|
filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
||||||
|
strides_val = [1, 1, 1, 1, 1]
|
||||||
|
scale = constant_op.constant(0.1, shape=[3])
|
||||||
|
offset = constant_op.constant(0.3, shape=[3])
|
||||||
|
mean = constant_op.constant(0.1, shape=[3])
|
||||||
|
variance = constant_op.constant(0.3, shape=[3])
|
||||||
|
conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
|
||||||
|
y, running_mean, running_var, r0, r1, r2 = gen_nn_ops.fused_batch_norm_v3(
|
||||||
|
conv3d,
|
||||||
|
scale,
|
||||||
|
offset,
|
||||||
|
mean,
|
||||||
|
variance,
|
||||||
|
epsilon=1.001e-5,
|
||||||
|
exponential_avg_factor=1.0,
|
||||||
|
data_format='NDHWC',
|
||||||
|
is_training=True,
|
||||||
|
name='batch_norm')
|
||||||
|
dx, dscale, doffset, _, _ = gen_nn_ops.fused_batch_norm_grad_v3(
|
||||||
|
y,
|
||||||
|
x_3d,
|
||||||
|
scale,
|
||||||
|
r0,
|
||||||
|
r1,
|
||||||
|
r2,
|
||||||
|
epsilon=1.001e-5,
|
||||||
|
data_format='NDHWC',
|
||||||
|
is_training=True)
|
||||||
|
output = array_ops.identity(dx)
|
||||||
|
|
||||||
|
with session.Session(config=_get_config(False)) as sess:
|
||||||
|
output_val_ref = sess.run(output)
|
||||||
|
|
||||||
|
with session.Session(config=_get_config()) as sess:
|
||||||
|
metadata = config_pb2.RunMetadata()
|
||||||
|
output_val = sess.run(output, run_metadata=metadata)
|
||||||
|
|
||||||
|
nodes = []
|
||||||
|
num_transposes = 0
|
||||||
|
for node in metadata.cost_graph.node:
|
||||||
|
if _is_transpose(node.name):
|
||||||
|
num_transposes += 1
|
||||||
|
nodes.append(node.name)
|
||||||
|
|
||||||
|
expected_num_transposes = 3
|
||||||
|
self.assertEqual(expected_num_transposes, num_transposes)
|
||||||
|
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
||||||
|
self._assert_trans_ndhwc_to_ncdhw('FusedBatchNormGradV3-1', nodes)
|
||||||
|
self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormGradV3-0-0', nodes)
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
@test_util.deprecated_graph_mode_only
|
||||||
|
@ -330,13 +330,13 @@ class BatchNormalizationBase(Layer):
|
|||||||
# output back to its original shape accordingly.
|
# output back to its original shape accordingly.
|
||||||
if self._USE_V2_BEHAVIOR:
|
if self._USE_V2_BEHAVIOR:
|
||||||
if self.fused is None:
|
if self.fused is None:
|
||||||
self.fused = (ndims == 4)
|
self.fused = ndims in (4, 5)
|
||||||
elif self.fused and ndims != 4:
|
elif self.fused and ndims not in (4, 5):
|
||||||
raise ValueError('Batch normalization layers with fused=True only '
|
raise ValueError('Batch normalization layers with fused=True only '
|
||||||
'support 4D input tensors.')
|
'support 4D or 5D input tensors.')
|
||||||
else:
|
else:
|
||||||
assert self.fused is not None
|
assert self.fused is not None
|
||||||
self.fused = (ndims == 4 and self._fused_can_be_used())
|
self.fused = (ndims in (4, 5) and self._fused_can_be_used())
|
||||||
# TODO(chrisying): fused batch norm is currently not supported for
|
# TODO(chrisying): fused batch norm is currently not supported for
|
||||||
# multi-axis batch norm and by extension virtual batches. In some cases,
|
# multi-axis batch norm and by extension virtual batches. In some cases,
|
||||||
# it might be possible to use fused batch norm but would require reshaping
|
# it might be possible to use fused batch norm but would require reshaping
|
||||||
@ -345,13 +345,22 @@ class BatchNormalizationBase(Layer):
|
|||||||
# common use case (turning 5D w/ virtual batch to NCHW)
|
# common use case (turning 5D w/ virtual batch to NCHW)
|
||||||
|
|
||||||
if self.fused:
|
if self.fused:
|
||||||
if self.axis == [1]:
|
if self.axis == [1] and ndims == 4:
|
||||||
self._data_format = 'NCHW'
|
self._data_format = 'NCHW'
|
||||||
elif self.axis == [3]:
|
elif self.axis == [1] and ndims == 5:
|
||||||
|
self._data_format = 'NCDHW'
|
||||||
|
elif self.axis == [3] and ndims == 4:
|
||||||
self._data_format = 'NHWC'
|
self._data_format = 'NHWC'
|
||||||
|
elif self.axis == [4] and ndims == 5:
|
||||||
|
self._data_format = 'NDHWC'
|
||||||
|
elif ndims == 5:
|
||||||
|
# 5D tensors that can be passed in but should not use fused batch norm
|
||||||
|
# due to unsupported axis.
|
||||||
|
self.fused = False
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unsupported axis, fused batch norm only supports '
|
raise ValueError('Unsupported axis, fused batch norm only supports '
|
||||||
'axis == [1] or axis == [3]')
|
'axis == [1] or axis == [3] for 4D input tensors or '
|
||||||
|
'axis == [1] or axis == [4] for 5D input tensors')
|
||||||
|
|
||||||
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
|
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
|
||||||
for x in axis_to_dim:
|
for x in axis_to_dim:
|
||||||
|
@ -66,6 +66,15 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
|||||||
kwargs={'scale': False,
|
kwargs={'scale': False,
|
||||||
'center': False},
|
'center': False},
|
||||||
input_shape=(3, 3))
|
input_shape=(3, 3))
|
||||||
|
testing_utils.layer_test(
|
||||||
|
keras.layers.BatchNormalization,
|
||||||
|
kwargs={
|
||||||
|
'gamma_initializer': 'ones',
|
||||||
|
'beta_initializer': 'ones',
|
||||||
|
'moving_mean_initializer': 'zeros',
|
||||||
|
'moving_variance_initializer': 'ones'
|
||||||
|
},
|
||||||
|
input_shape=(3, 2, 4, 2))
|
||||||
|
|
||||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
def test_batchnorm_weights(self):
|
def test_batchnorm_weights(self):
|
||||||
@ -319,7 +328,7 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
|||||||
norm = normalization_v2.BatchNormalization(fused=True)
|
norm = normalization_v2.BatchNormalization(fused=True)
|
||||||
self.assertEqual(norm.fused, True)
|
self.assertEqual(norm.fused, True)
|
||||||
inp = keras.layers.Input(shape=(4, 4))
|
inp = keras.layers.Input(shape=(4, 4))
|
||||||
with self.assertRaisesRegex(ValueError, '4D input tensors'):
|
with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
|
||||||
norm(inp)
|
norm(inp)
|
||||||
|
|
||||||
def test_updates_in_wrap_function(self):
|
def test_updates_in_wrap_function(self):
|
||||||
|
@ -43,14 +43,18 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
return math_ops.cast(y, x.dtype)
|
return math_ops.cast(y, x.dtype)
|
||||||
|
|
||||||
def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
|
def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
|
||||||
if data_format not in ['NHWC', 'NCHW']:
|
if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
|
||||||
raise ValueError('data_format must be NCHW or NHWC, '
|
raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
|
||||||
'got %s.' % data_format)
|
'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
|
||||||
if data_format == 'NCHW':
|
if data_format == 'NCHW':
|
||||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||||
|
elif data_format == 'NCDHW':
|
||||||
|
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
|
||||||
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
|
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
|
||||||
if data_format == 'NCHW':
|
if data_format == 'NCHW':
|
||||||
y = array_ops.transpose(y, [0, 3, 1, 2])
|
y = array_ops.transpose(y, [0, 3, 1, 2])
|
||||||
|
elif data_format == 'NCDHW':
|
||||||
|
y = array_ops.transpose(y, [0, 4, 1, 2, 3])
|
||||||
return self.evaluate(y)
|
return self.evaluate(y)
|
||||||
|
|
||||||
def _test_inference(self,
|
def _test_inference(self,
|
||||||
@ -102,17 +106,24 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
|
|
||||||
def _training_ref(self, x, scale, offset, old_mean, old_var,
|
def _training_ref(self, x, scale, offset, old_mean, old_var,
|
||||||
exponential_avg_factor, epsilon, data_format):
|
exponential_avg_factor, epsilon, data_format):
|
||||||
if data_format not in ['NHWC', 'NCHW']:
|
if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
|
||||||
raise ValueError('data_format must be NCHW or NHWC, '
|
raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
|
||||||
'got %s.' % data_format)
|
'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
|
||||||
|
use_4d_tensor = (x.shape.ndims == 4)
|
||||||
if data_format == 'NCHW':
|
if data_format == 'NCHW':
|
||||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||||
|
elif data_format == 'NCDHW':
|
||||||
|
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
|
||||||
|
|
||||||
|
mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3]
|
||||||
batch_mean, batch_var = nn_impl.moments(
|
batch_mean, batch_var = nn_impl.moments(
|
||||||
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
|
math_ops.cast(x, scale.dtype), mean_axis, keep_dims=False)
|
||||||
|
|
||||||
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
|
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
|
||||||
if data_format == 'NCHW':
|
if data_format == 'NCHW':
|
||||||
y = array_ops.transpose(y, [0, 3, 1, 2])
|
y = array_ops.transpose(y, [0, 3, 1, 2])
|
||||||
|
elif data_format == 'NCDHW':
|
||||||
|
y = array_ops.transpose(y, [0, 4, 1, 2, 3])
|
||||||
|
|
||||||
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
||||||
# the denominator in the formula to calculate variance, while
|
# the denominator in the formula to calculate variance, while
|
||||||
@ -377,14 +388,18 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
|
|
||||||
def _runtests(self, x_shape, is_training, gradient_test=False,
|
def _runtests(self, x_shape, is_training, gradient_test=False,
|
||||||
cpu_only=False):
|
cpu_only=False):
|
||||||
|
if len(x_shape) == 4:
|
||||||
|
data_format_list = ['NHWC', 'NCHW']
|
||||||
|
else:
|
||||||
|
data_format_list = ['NCDHW', 'NDHWC']
|
||||||
use_gpu_vals = [False]
|
use_gpu_vals = [False]
|
||||||
if test.is_gpu_available(cuda_only=True) and not cpu_only:
|
if test.is_gpu_available(cuda_only=True) and not cpu_only:
|
||||||
use_gpu_vals += [True]
|
use_gpu_vals += [True]
|
||||||
factors = [1.0, 0.6]
|
factors = [1.0, 0.6]
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
for use_gpu in use_gpu_vals:
|
for use_gpu in use_gpu_vals:
|
||||||
for data_format in ['NHWC', 'NCHW']:
|
for data_format in data_format_list:
|
||||||
if data_format == 'NHWC':
|
if data_format == 'NHWC' or data_format == 'NDHWC':
|
||||||
scale_shape = x_shape[-1:]
|
scale_shape = x_shape[-1:]
|
||||||
else:
|
else:
|
||||||
scale_shape = x_shape[1:2]
|
scale_shape = x_shape[1:2]
|
||||||
@ -444,6 +459,10 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||||
self._runtests(x_shape, False, cpu_only=True)
|
self._runtests(x_shape, False, cpu_only=True)
|
||||||
|
|
||||||
|
def testInferenceShape7(self):
|
||||||
|
x_shape = [1, 2, 6, 1, 3]
|
||||||
|
self._runtests(x_shape, False)
|
||||||
|
|
||||||
def testTrainingShape1(self):
|
def testTrainingShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
self._runtests(x_shape, True)
|
self._runtests(x_shape, True)
|
||||||
@ -465,11 +484,16 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
x_shape = [0, 131, 127, 6]
|
x_shape = [0, 131, 127, 6]
|
||||||
self._runtests(x_shape, True)
|
self._runtests(x_shape, True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
def testTrainingShape6(self):
|
def testTrainingShape6(self):
|
||||||
x_shape = [1, 1, 1, 1]
|
x_shape = [1, 1, 1, 1]
|
||||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||||
self._runtests(x_shape, True, cpu_only=True)
|
self._runtests(x_shape, True, cpu_only=True)
|
||||||
|
|
||||||
|
def testTrainingShape7(self):
|
||||||
|
x_shape = [1, 2, 6, 1, 3]
|
||||||
|
self._runtests(x_shape, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchNormGradInferenceShape1(self):
|
def testBatchNormGradInferenceShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
@ -503,6 +527,11 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
self._runtests(x_shape, is_training=False, gradient_test=True,
|
self._runtests(x_shape, is_training=False, gradient_test=True,
|
||||||
cpu_only=True)
|
cpu_only=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBatchNormGradInferenceShape7(self):
|
||||||
|
x_shape = [1, 2, 6, 1, 3]
|
||||||
|
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchNormGradTrainingShape1(self):
|
def testBatchNormGradTrainingShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
@ -535,42 +564,54 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||||
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
|
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBatchNormGradTrainingShape7(self):
|
||||||
|
x_shape = [1, 2, 6, 1, 3]
|
||||||
|
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||||
|
|
||||||
def _testBatchNormGradGrad(self, config):
|
def _testBatchNormGradGrad(self, config):
|
||||||
shape = config['shape']
|
shape = config['shape']
|
||||||
err_tolerance = config['err_tolerance']
|
err_tolerance = config['err_tolerance']
|
||||||
dtype = config['dtype']
|
dtype = config['dtype']
|
||||||
|
rank = len(shape)
|
||||||
|
if rank == 4:
|
||||||
|
data_format_nhwc, features_nhwc = 'NHWC', shape[3]
|
||||||
|
data_format_nchw, features_nchw = 'NCHW', shape[1]
|
||||||
|
else:
|
||||||
|
data_format_nhwc, features_nhwc = 'NDHWC', shape[4]
|
||||||
|
data_format_nchw, features_nchw = 'NCDHW', shape[1]
|
||||||
for is_training in [True, False]:
|
for is_training in [True, False]:
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True):
|
||||||
self._test_grad_grad(
|
self._test_grad_grad(
|
||||||
shape,
|
shape,
|
||||||
dtype, [shape[3]],
|
dtype, [features_nhwc],
|
||||||
np.float32,
|
np.float32,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format='NHWC',
|
data_format=data_format_nhwc,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
err_tolerance=err_tolerance)
|
err_tolerance=err_tolerance)
|
||||||
self._test_grad_grad(
|
self._test_grad_grad(
|
||||||
shape,
|
shape,
|
||||||
dtype, [shape[1]],
|
dtype, [features_nchw],
|
||||||
np.float32,
|
np.float32,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
data_format='NCHW',
|
data_format=data_format_nchw,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
err_tolerance=err_tolerance)
|
err_tolerance=err_tolerance)
|
||||||
self._test_grad_grad(
|
self._test_grad_grad(
|
||||||
shape,
|
shape,
|
||||||
dtype, [shape[3]],
|
dtype, [features_nhwc],
|
||||||
np.float32,
|
np.float32,
|
||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
data_format='NHWC',
|
data_format=data_format_nhwc,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
err_tolerance=err_tolerance)
|
err_tolerance=err_tolerance)
|
||||||
self._test_grad_grad(
|
self._test_grad_grad(
|
||||||
shape,
|
shape,
|
||||||
dtype, [shape[1]],
|
dtype, [features_nchw],
|
||||||
np.float32,
|
np.float32,
|
||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
data_format='NCHW',
|
data_format=data_format_nchw,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
err_tolerance=err_tolerance)
|
err_tolerance=err_tolerance)
|
||||||
|
|
||||||
@ -610,6 +651,24 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
}
|
}
|
||||||
self._testBatchNormGradGrad(config)
|
self._testBatchNormGradGrad(config)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBatchNormGradGradConfig5(self):
|
||||||
|
config = {
|
||||||
|
'shape': [2, 3, 2, 2, 2],
|
||||||
|
'err_tolerance': 2e-3,
|
||||||
|
'dtype': np.float32,
|
||||||
|
}
|
||||||
|
self._testBatchNormGradGrad(config)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBatchNormGradGradConfig6(self):
|
||||||
|
config = {
|
||||||
|
'shape': [2, 3, 2, 2, 2],
|
||||||
|
'err_tolerance': 3e-3,
|
||||||
|
'dtype': np.float16,
|
||||||
|
}
|
||||||
|
self._testBatchNormGradGrad(config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -897,6 +897,11 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
|||||||
if data_format == b"NCHW":
|
if data_format == b"NCHW":
|
||||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||||
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
|
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
|
||||||
|
elif data_format == b"NCDHW":
|
||||||
|
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
|
||||||
|
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
|
||||||
|
target_data_format = ("NHWC" if data_format in (b"NCHW",
|
||||||
|
b"NHWC") else "NDHWC")
|
||||||
args = {
|
args = {
|
||||||
"y_backprop": grad_y,
|
"y_backprop": grad_y,
|
||||||
"x": x,
|
"x": x,
|
||||||
@ -904,7 +909,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
|||||||
"reserve_space_1": pop_mean,
|
"reserve_space_1": pop_mean,
|
||||||
"reserve_space_2": pop_var,
|
"reserve_space_2": pop_var,
|
||||||
"epsilon": epsilon,
|
"epsilon": epsilon,
|
||||||
"data_format": "NHWC",
|
"data_format": target_data_format,
|
||||||
"is_training": is_training
|
"is_training": is_training
|
||||||
}
|
}
|
||||||
if version == 2:
|
if version == 2:
|
||||||
@ -912,6 +917,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
|||||||
dx, dscale, doffset, _, _ = grad_fun(**args)
|
dx, dscale, doffset, _, _ = grad_fun(**args)
|
||||||
if data_format == b"NCHW":
|
if data_format == b"NCHW":
|
||||||
dx = array_ops.transpose(dx, [0, 3, 1, 2])
|
dx = array_ops.transpose(dx, [0, 3, 1, 2])
|
||||||
|
elif data_format == b"NCDHW":
|
||||||
|
dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
|
||||||
return dx, dscale, doffset, None, None
|
return dx, dscale, doffset, None, None
|
||||||
|
|
||||||
|
|
||||||
@ -941,8 +948,8 @@ def _BatchNormGrad(grad_y,
|
|||||||
"""Returns the gradients for the 3 inputs of BatchNorm.
|
"""Returns the gradients for the 3 inputs of BatchNorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
grad_y: A `Tensor` of 4 dimensions for gradient for y.
|
grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
|
||||||
x: A `Tensor` of 4 dimensions for x.
|
x: A `Tensor` of 4 or 5 dimensions for x.
|
||||||
scale: A `Tensor` of 1 dimension for scaling.
|
scale: A `Tensor` of 1 dimension for scaling.
|
||||||
pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
|
pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
|
||||||
is_training=False.
|
is_training=False.
|
||||||
@ -968,11 +975,19 @@ def _BatchNormGrad(grad_y,
|
|||||||
if data_format == b"NHWC":
|
if data_format == b"NHWC":
|
||||||
keepdims = False
|
keepdims = False
|
||||||
reduce_axis = [0, 1, 2]
|
reduce_axis = [0, 1, 2]
|
||||||
else:
|
elif data_format == b"NDHWC":
|
||||||
|
keepdims = False
|
||||||
|
reduce_axis = [0, 1, 2, 3]
|
||||||
|
elif data_format == b"NCHW":
|
||||||
keepdims = True
|
keepdims = True
|
||||||
reduce_axis = [0, 2, 3]
|
reduce_axis = [0, 2, 3]
|
||||||
shape = [1, array_ops.size(scale), 1, 1]
|
shape = [1, array_ops.size(scale), 1, 1]
|
||||||
scale = array_ops.reshape(scale, shape)
|
scale = array_ops.reshape(scale, shape)
|
||||||
|
else:
|
||||||
|
keepdims = True
|
||||||
|
reduce_axis = [0, 2, 3, 4]
|
||||||
|
shape = [1, array_ops.size(scale), 1, 1, 1]
|
||||||
|
scale = array_ops.reshape(scale, shape)
|
||||||
mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
|
mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
|
||||||
mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
|
mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
|
||||||
var_x = math_ops.reduce_mean(
|
var_x = math_ops.reduce_mean(
|
||||||
@ -987,19 +1002,27 @@ def _BatchNormGrad(grad_y,
|
|||||||
grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
|
grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
|
||||||
grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
|
grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
|
||||||
grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
|
grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
|
||||||
if data_format == b"NCHW":
|
if data_format == b"NCHW" or data_format == b"NCDHW":
|
||||||
grad_scale = array_ops.squeeze(grad_scale)
|
grad_scale = array_ops.squeeze(grad_scale)
|
||||||
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
|
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
|
||||||
return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
|
return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
|
||||||
else:
|
else:
|
||||||
if data_format == b"NHWC":
|
if data_format == b"NHWC":
|
||||||
reduce_axis = [0, 1, 2]
|
reduce_axis = [0, 1, 2]
|
||||||
else:
|
elif data_format == b"NDHWC":
|
||||||
|
reduce_axis = [0, 1, 2, 3]
|
||||||
|
elif data_format == b"NCHW":
|
||||||
reduce_axis = [0, 2, 3]
|
reduce_axis = [0, 2, 3]
|
||||||
shape = [1, array_ops.size(pop_mean), 1, 1]
|
shape = [1, array_ops.size(pop_mean), 1, 1]
|
||||||
pop_mean = array_ops.reshape(pop_mean, shape)
|
pop_mean = array_ops.reshape(pop_mean, shape)
|
||||||
pop_var = array_ops.reshape(pop_var, shape)
|
pop_var = array_ops.reshape(pop_var, shape)
|
||||||
scale = array_ops.reshape(scale, shape)
|
scale = array_ops.reshape(scale, shape)
|
||||||
|
else:
|
||||||
|
reduce_axis = [0, 2, 3, 4]
|
||||||
|
shape = [1, array_ops.size(pop_mean), 1, 1, 1]
|
||||||
|
pop_mean = array_ops.reshape(pop_mean, shape)
|
||||||
|
pop_var = array_ops.reshape(pop_var, shape)
|
||||||
|
scale = array_ops.reshape(scale, shape)
|
||||||
|
|
||||||
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
|
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
|
||||||
var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
|
var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
|
||||||
|
@ -1585,7 +1585,7 @@ def fused_batch_norm(
|
|||||||
(http://arxiv.org/abs/1502.03167).
|
(http://arxiv.org/abs/1502.03167).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input `Tensor` of 4 dimensions.
|
x: Input `Tensor` of 4 or 5 dimensions.
|
||||||
scale: A `Tensor` of 1 dimension for scaling.
|
scale: A `Tensor` of 1 dimension for scaling.
|
||||||
offset: A `Tensor` of 1 dimension for bias.
|
offset: A `Tensor` of 1 dimension for bias.
|
||||||
mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
|
mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
|
||||||
@ -1611,7 +1611,8 @@ def fused_batch_norm(
|
|||||||
Variance must be a `Tensor` of the same shape as scale containing
|
Variance must be a `Tensor` of the same shape as scale containing
|
||||||
the exponential running variance.
|
the exponential running variance.
|
||||||
epsilon: A small float number added to the variance of x.
|
epsilon: A small float number added to the variance of x.
|
||||||
data_format: The data format for x. Either "NHWC" (default) or "NCHW".
|
data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
|
||||||
|
4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
|
||||||
is_training: A bool value to specify if the operation is used for
|
is_training: A bool value to specify if the operation is used for
|
||||||
training or inference.
|
training or inference.
|
||||||
name: A name for this operation (optional).
|
name: A name for this operation (optional).
|
||||||
@ -1622,7 +1623,7 @@ def fused_batch_norm(
|
|||||||
returned.
|
returned.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y: A 4D Tensor for the normalized, scaled, offsetted x.
|
y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
|
||||||
running_mean: A 1D Tensor for the exponential running mean of x.
|
running_mean: A 1D Tensor for the exponential running mean of x.
|
||||||
The output value is (1 - exponential_avg_factor) * mean +
|
The output value is (1 - exponential_avg_factor) * mean +
|
||||||
exponential_avg_factor * batch_mean), where batch_mean
|
exponential_avg_factor * batch_mean), where batch_mean
|
||||||
|
Loading…
Reference in New Issue
Block a user