Roll forward of https://github.com/tensorflow/tensorflow/pull/42970 with fix for 5D tensors and unsupported axis that cannot use fused batch norm.

PiperOrigin-RevId: 335071264
Change-Id: Id9ab44bdba870336f03522e1ca76d88b1f305a10
This commit is contained in:
Andy Ly 2020-10-02 11:54:14 -07:00 committed by TensorFlower Gardener
parent eed3ab97e5
commit 27d26a8d86
11 changed files with 379 additions and 83 deletions

View File

@ -1121,8 +1121,17 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
}
Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TensorFormat data_format;
if (!FormatFromString(data_format_str, &data_format)) {
return errors::InvalidArgument("Invalid data format string: ",
data_format_str);
}
const int rank =
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
@ -1131,14 +1140,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
exponential_avg_factor = 1.0f; // default value
}
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TensorFormat data_format;
if (!FormatFromString(data_format_str, &data_format)) {
return errors::InvalidArgument("Invalid data format string: ",
data_format_str);
}
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
// covers scale, offset, and if is_training is false, mean, variance
@ -1191,13 +1194,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
}
Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
ShapeHandle y_backprop;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TensorFormat data_format;
@ -1205,7 +1201,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
return errors::InvalidArgument("Invalid data format string: ",
data_format_str);
}
int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
const int rank =
(data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
ShapeHandle y_backprop;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
TF_RETURN_IF_ERROR(
c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));

View File

@ -670,7 +670,25 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
std::string src_format = context->src_format;
std::string dst_format = context->dst_format;
// Update the format from 4D to 5D layout if necessary.
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
if (allow_5d) {
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
dst_format_3d);
}
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
// Change back to the original layout due to early exit.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -679,6 +697,11 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
TF_RETURN_IF_ERROR(UpdateNode(context, node));
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
// Change back the format from 5D to 4D layout.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return context->graph_view->GetMutationBuilder()->Apply();
}
@ -881,8 +904,26 @@ bool FusedBatchNormGradTransposer::IsTraining(
Status FusedBatchNormGradTransposer::TransposeNode(
TransposeContext* context, utils::MutableNodeView* node) {
DCHECK(IsFusedBatchNormGrad(*node->node()));
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
const auto& shape = output_shape_attr->list().shape(0);
const int rank = shape.dim_size();
std::string src_format = context->src_format;
std::string dst_format = context->dst_format;
// Update the format from 4D to 5D layout if necessary.
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
if (allow_5d) {
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
dst_format_3d);
}
if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
!IsTraining(*node)) {
// Change back to the original layout due to early exit.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return Status::OK();
}
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@ -892,6 +933,11 @@ Status FusedBatchNormGradTransposer::TransposeNode(
TF_RETURN_IF_ERROR(
UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
// Change back the format from 5D to 4D layout.
if (allow_5d) {
context->AssignDeviceAndDataFormats(context->target_device, src_format,
dst_format);
}
return context->graph_view->GetMutationBuilder()->Apply();
}

View File

@ -1438,29 +1438,41 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
if (fused_node.attr().at(kDataFormat).s() == "NCHW") {
string x_format = fused_node.attr().at(kDataFormat).s();
if (x_format == "NCHW" or x_format == "NCDHW") {
// Need to reshape the last 4 inputs
NodeDef new_shape;
const string new_shape_name =
AddPrefixToNodeName("NCHWShape", fused_node.name());
AddPrefixToNodeName(x_format + "Shape", fused_node.name());
new_shape.set_name(new_shape_name);
new_shape.set_op("Const");
new_shape.set_device(fused_node.device());
*new_shape.add_input() = AsControlDependency(scale);
(*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
Tensor t(DT_INT32, {4});
t.flat<int32>()(0) = 1;
t.flat<int32>()(1) = -1;
t.flat<int32>()(2) = 1;
t.flat<int32>()(3) = 1;
t.AsProtoTensorContent(
(*new_shape.mutable_attr())["value"].mutable_tensor());
if (x_format == "NCHW") {
Tensor t(DT_INT32, {4});
t.flat<int32>()(0) = 1;
t.flat<int32>()(1) = -1;
t.flat<int32>()(2) = 1;
t.flat<int32>()(3) = 1;
t.AsProtoTensorContent(
(*new_shape.mutable_attr())["value"].mutable_tensor());
} else {
Tensor t(DT_INT32, {5});
t.flat<int32>()(0) = 1;
t.flat<int32>()(1) = -1;
t.flat<int32>()(2) = 1;
t.flat<int32>()(3) = 1;
t.flat<int32>()(4) = 1;
t.AsProtoTensorContent(
(*new_shape.mutable_attr())["value"].mutable_tensor());
}
mutation->AddNode(std::move(new_shape), &status);
TF_RETURN_IF_ERROR(status);
NodeDef reshaped_scale;
reshaped_scale.set_name(
AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
reshaped_scale.set_op("Reshape");
reshaped_scale.set_device(fused_node.device());
*reshaped_scale.add_input() = scale;
@ -1473,7 +1485,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
NodeDef reshaped_offset;
reshaped_offset.set_name(
AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
reshaped_offset.set_op("Reshape");
reshaped_offset.set_device(fused_node.device());
*reshaped_offset.add_input() = offset;
@ -1486,7 +1498,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
NodeDef reshaped_mean;
reshaped_mean.set_name(
AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
reshaped_mean.set_op("Reshape");
reshaped_mean.set_device(fused_node.device());
*reshaped_mean.add_input() = mean;
@ -1499,7 +1511,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
NodeDef reshaped_variance;
reshaped_variance.set_name(
AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
reshaped_variance.set_op("Reshape");
reshaped_variance.set_device(fused_node.device());
*reshaped_variance.add_input() = variance;

View File

@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel {
// If use_reserved_space is false, we don't have 5th output.
virtual void ComputeWithReservedSpace(OpKernelContext* context,
bool use_reserved_space) {
const Tensor& x = context->input(0);
Tensor x = context->input(0);
const Tensor& scale = context->input(1);
const Tensor& offset = context->input(2);
const Tensor& estimated_mean = context->input(3);
const Tensor& estimated_variance = context->input(4);
const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
OP_REQUIRES(context, x.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
errors::InvalidArgument("input must be 4 or 5-dimensional",
x.shape().DebugString()));
OP_REQUIRES(context, scale.dims() == 1,
errors::InvalidArgument("scale must be 1-dimensional",
@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel {
context, estimated_variance.dims() == 1,
errors::InvalidArgument("estimated_variance must be 1-dimensional",
estimated_variance.shape().DebugString()));
bool use_reshape = (x.dims() == 5);
auto x_shape = x.shape();
TensorShape dest_shape;
if (use_reshape) {
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
{{in_planes, in_rows * in_cols}}, in_depth);
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
errors::InvalidArgument("Error during tensor copy."));
}
if (has_side_input_) {
OP_REQUIRES(context, side_input->shape() == x.shape(),
errors::InvalidArgument(
@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel {
}
Tensor* y = nullptr;
auto alloc_shape = use_reshape ? dest_shape : x_shape;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, x.shape(), &y));
{0}, 0, alloc_shape, &y));
Tensor* batch_mean = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{3}, 1, scale.shape(), &batch_mean));
@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel {
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
tensor_format_, use_reserved_space);
}
if (use_reshape) {
OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
errors::InvalidArgument("Error during tensor copy."));
}
}
private:
@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel {
virtual void ComputeWithReservedSpace(OpKernelContext* context,
bool use_reserved_space) {
const Tensor& y_backprop = context->input(0);
const Tensor& x = context->input(1);
Tensor y_backprop = context->input(0);
Tensor x = context->input(1);
const Tensor& scale = context->input(2);
// When is_training=True, batch mean and variance/inverted variance are
// saved in the forward pass to be reused here. When is_training=False,
@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel {
// saves inverted variance.
const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
OP_REQUIRES(context, y_backprop.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
errors::InvalidArgument("input must be 4 or 5-dimensional",
y_backprop.shape().DebugString()));
OP_REQUIRES(context, x.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
errors::InvalidArgument("input must be 4 or 5-dimensional",
x.shape().DebugString()));
OP_REQUIRES(context, scale.dims() == 1,
errors::InvalidArgument("scale must be 1-dimensional",
@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel {
errors::InvalidArgument(
"saved variance must be 1-dimensional",
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
bool use_reshape = (x.dims() == 5);
auto x_shape = x.shape();
TensorShape dest_shape;
if (use_reshape) {
const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
int64 in_planes = GetTensorDim(x, tensor_format_, '0');
int64 in_rows = GetTensorDim(x, tensor_format_, '1');
int64 in_cols = GetTensorDim(x, tensor_format_, '2');
const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
dest_shape = ShapeFromFormat(tensor_format_, in_batch,
{{in_planes, in_rows * in_cols}}, in_depth);
OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
errors::InvalidArgument("Error during tensor copy."));
OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
errors::InvalidArgument("Error during tensor copy."));
}
Tensor* x_backprop = nullptr;
auto alloc_shape = use_reshape ? dest_shape : x_shape;
OP_REQUIRES_OK(context,
context->allocate_output(0, x.shape(), &x_backprop));
context->allocate_output(0, alloc_shape, &x_backprop));
const TensorShape& scale_offset_shape = scale.shape();
Tensor* scale_backprop = nullptr;
@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel {
offset_backprop, use_reserved_space, tensor_format_);
} else {
// Necessary layout conversion is currently done in python.
CHECK(tensor_format_ == FORMAT_NHWC)
<< "The implementation of FusedBatchNormGrad with is_training=False "
"only support "
<< "NHWC tensor format for now.";
OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
errors::InvalidArgument(
"The implementation of "
"FusedBatchNormGrad with is_training=False only support "
"NHWC tensor format for now."));
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
context, y_backprop, x, scale, saved_mean_or_pop_mean,
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
offset_backprop);
}
if (use_reshape) {
OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
errors::InvalidArgument("Error during tensor copy."));
}
}
private:

View File

@ -221,7 +221,7 @@ REGISTER_OP("FusedBatchNormV3")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")
.Attr(GetConvnetDataFormatAttrString())
.Attr(GetConvnetDataFormat2D3DAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormV3Shape);
@ -308,7 +308,7 @@ REGISTER_OP("FusedBatchNormGradV3")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr(GetConvnetDataFormatAttrString())
.Attr(GetConvnetDataFormat2D3DAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
// --------------------------------------------------------------------------

View File

@ -1275,6 +1275,94 @@ class LayoutOptimizerTest(test.TestCase):
self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
@test_util.deprecated_graph_mode_only
def testBatchNorm3D(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
strides_val = [1, 1, 1, 1, 1]
scale = constant_op.constant(0.1, shape=[3])
offset = constant_op.constant(0.3, shape=[3])
conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
y, _, _ = nn.fused_batch_norm(conv3d, scale, offset, data_format='NDHWC')
output = array_ops.identity(y)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormV3-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only
def testBatchNormGrad3D(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
strides_val = [1, 1, 1, 1, 1]
scale = constant_op.constant(0.1, shape=[3])
offset = constant_op.constant(0.3, shape=[3])
mean = constant_op.constant(0.1, shape=[3])
variance = constant_op.constant(0.3, shape=[3])
conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
y, running_mean, running_var, r0, r1, r2 = gen_nn_ops.fused_batch_norm_v3(
conv3d,
scale,
offset,
mean,
variance,
epsilon=1.001e-5,
exponential_avg_factor=1.0,
data_format='NDHWC',
is_training=True,
name='batch_norm')
dx, dscale, doffset, _, _ = gen_nn_ops.fused_batch_norm_grad_v3(
y,
x_3d,
scale,
r0,
r1,
r2,
epsilon=1.001e-5,
data_format='NDHWC',
is_training=True)
output = array_ops.identity(dx)
with session.Session(config=_get_config(False)) as sess:
output_val_ref = sess.run(output)
with session.Session(config=_get_config()) as sess:
metadata = config_pb2.RunMetadata()
output_val = sess.run(output, run_metadata=metadata)
nodes = []
num_transposes = 0
for node in metadata.cost_graph.node:
if _is_transpose(node.name):
num_transposes += 1
nodes.append(node.name)
expected_num_transposes = 3
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
self._assert_trans_ndhwc_to_ncdhw('FusedBatchNormGradV3-1', nodes)
self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormGradV3-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
@test_util.deprecated_graph_mode_only

View File

@ -330,13 +330,13 @@ class BatchNormalizationBase(Layer):
# output back to its original shape accordingly.
if self._USE_V2_BEHAVIOR:
if self.fused is None:
self.fused = (ndims == 4)
elif self.fused and ndims != 4:
self.fused = ndims in (4, 5)
elif self.fused and ndims not in (4, 5):
raise ValueError('Batch normalization layers with fused=True only '
'support 4D input tensors.')
'support 4D or 5D input tensors.')
else:
assert self.fused is not None
self.fused = (ndims == 4 and self._fused_can_be_used())
self.fused = (ndims in (4, 5) and self._fused_can_be_used())
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping
@ -345,13 +345,22 @@ class BatchNormalizationBase(Layer):
# common use case (turning 5D w/ virtual batch to NCHW)
if self.fused:
if self.axis == [1]:
if self.axis == [1] and ndims == 4:
self._data_format = 'NCHW'
elif self.axis == [3]:
elif self.axis == [1] and ndims == 5:
self._data_format = 'NCDHW'
elif self.axis == [3] and ndims == 4:
self._data_format = 'NHWC'
elif self.axis == [4] and ndims == 5:
self._data_format = 'NDHWC'
elif ndims == 5:
# 5D tensors that can be passed in but should not use fused batch norm
# due to unsupported axis.
self.fused = False
else:
raise ValueError('Unsupported axis, fused batch norm only supports '
'axis == [1] or axis == [3]')
'axis == [1] or axis == [3] for 4D input tensors or '
'axis == [1] or axis == [4] for 5D input tensors')
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
for x in axis_to_dim:

View File

@ -66,6 +66,15 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
kwargs={'scale': False,
'center': False},
input_shape=(3, 3))
testing_utils.layer_test(
keras.layers.BatchNormalization,
kwargs={
'gamma_initializer': 'ones',
'beta_initializer': 'ones',
'moving_mean_initializer': 'zeros',
'moving_variance_initializer': 'ones'
},
input_shape=(3, 2, 4, 2))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_batchnorm_weights(self):
@ -319,7 +328,7 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
norm = normalization_v2.BatchNormalization(fused=True)
self.assertEqual(norm.fused, True)
inp = keras.layers.Input(shape=(4, 4))
with self.assertRaisesRegex(ValueError, '4D input tensors'):
with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
norm(inp)
def test_updates_in_wrap_function(self):

View File

@ -43,14 +43,18 @@ class BatchNormalizationTest(test.TestCase):
return math_ops.cast(y, x.dtype)
def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
if data_format not in ['NHWC', 'NCHW']:
raise ValueError('data_format must be NCHW or NHWC, '
'got %s.' % data_format)
if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
if data_format == 'NCHW':
x = array_ops.transpose(x, [0, 2, 3, 1])
elif data_format == 'NCDHW':
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
if data_format == 'NCHW':
y = array_ops.transpose(y, [0, 3, 1, 2])
elif data_format == 'NCDHW':
y = array_ops.transpose(y, [0, 4, 1, 2, 3])
return self.evaluate(y)
def _test_inference(self,
@ -102,17 +106,24 @@ class BatchNormalizationTest(test.TestCase):
def _training_ref(self, x, scale, offset, old_mean, old_var,
exponential_avg_factor, epsilon, data_format):
if data_format not in ['NHWC', 'NCHW']:
raise ValueError('data_format must be NCHW or NHWC, '
'got %s.' % data_format)
if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
use_4d_tensor = (x.shape.ndims == 4)
if data_format == 'NCHW':
x = array_ops.transpose(x, [0, 2, 3, 1])
elif data_format == 'NCDHW':
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3]
batch_mean, batch_var = nn_impl.moments(
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
math_ops.cast(x, scale.dtype), mean_axis, keep_dims=False)
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
if data_format == 'NCHW':
y = array_ops.transpose(y, [0, 3, 1, 2])
elif data_format == 'NCDHW':
y = array_ops.transpose(y, [0, 4, 1, 2, 3])
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
@ -377,14 +388,18 @@ class BatchNormalizationTest(test.TestCase):
def _runtests(self, x_shape, is_training, gradient_test=False,
cpu_only=False):
if len(x_shape) == 4:
data_format_list = ['NHWC', 'NCHW']
else:
data_format_list = ['NCDHW', 'NDHWC']
use_gpu_vals = [False]
if test.is_gpu_available(cuda_only=True) and not cpu_only:
use_gpu_vals += [True]
factors = [1.0, 0.6]
for dtype in [np.float16, np.float32]:
for use_gpu in use_gpu_vals:
for data_format in ['NHWC', 'NCHW']:
if data_format == 'NHWC':
for data_format in data_format_list:
if data_format == 'NHWC' or data_format == 'NDHWC':
scale_shape = x_shape[-1:]
else:
scale_shape = x_shape[1:2]
@ -444,6 +459,10 @@ class BatchNormalizationTest(test.TestCase):
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, False, cpu_only=True)
def testInferenceShape7(self):
x_shape = [1, 2, 6, 1, 3]
self._runtests(x_shape, False)
def testTrainingShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, True)
@ -465,11 +484,16 @@ class BatchNormalizationTest(test.TestCase):
x_shape = [0, 131, 127, 6]
self._runtests(x_shape, True)
@test_util.run_deprecated_v1
def testTrainingShape6(self):
x_shape = [1, 1, 1, 1]
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, True, cpu_only=True)
def testTrainingShape7(self):
x_shape = [1, 2, 6, 1, 3]
self._runtests(x_shape, True)
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape1(self):
x_shape = [1, 1, 6, 1]
@ -503,6 +527,11 @@ class BatchNormalizationTest(test.TestCase):
self._runtests(x_shape, is_training=False, gradient_test=True,
cpu_only=True)
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape7(self):
x_shape = [1, 2, 6, 1, 3]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape1(self):
x_shape = [1, 1, 6, 1]
@ -535,42 +564,54 @@ class BatchNormalizationTest(test.TestCase):
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape7(self):
x_shape = [1, 2, 6, 1, 3]
self._runtests(x_shape, is_training=True, gradient_test=True)
def _testBatchNormGradGrad(self, config):
shape = config['shape']
err_tolerance = config['err_tolerance']
dtype = config['dtype']
rank = len(shape)
if rank == 4:
data_format_nhwc, features_nhwc = 'NHWC', shape[3]
data_format_nchw, features_nchw = 'NCHW', shape[1]
else:
data_format_nhwc, features_nhwc = 'NDHWC', shape[4]
data_format_nchw, features_nchw = 'NCDHW', shape[1]
for is_training in [True, False]:
if test.is_gpu_available(cuda_only=True):
self._test_grad_grad(
shape,
dtype, [shape[3]],
dtype, [features_nhwc],
np.float32,
use_gpu=True,
data_format='NHWC',
data_format=data_format_nhwc,
is_training=is_training,
err_tolerance=err_tolerance)
self._test_grad_grad(
shape,
dtype, [shape[1]],
dtype, [features_nchw],
np.float32,
use_gpu=True,
data_format='NCHW',
data_format=data_format_nchw,
is_training=is_training,
err_tolerance=err_tolerance)
self._test_grad_grad(
shape,
dtype, [shape[3]],
dtype, [features_nhwc],
np.float32,
use_gpu=False,
data_format='NHWC',
data_format=data_format_nhwc,
is_training=is_training,
err_tolerance=err_tolerance)
self._test_grad_grad(
shape,
dtype, [shape[1]],
dtype, [features_nchw],
np.float32,
use_gpu=False,
data_format='NCHW',
data_format=data_format_nchw,
is_training=is_training,
err_tolerance=err_tolerance)
@ -610,6 +651,24 @@ class BatchNormalizationTest(test.TestCase):
}
self._testBatchNormGradGrad(config)
@test_util.run_deprecated_v1
def testBatchNormGradGradConfig5(self):
config = {
'shape': [2, 3, 2, 2, 2],
'err_tolerance': 2e-3,
'dtype': np.float32,
}
self._testBatchNormGradGrad(config)
@test_util.run_deprecated_v1
def testBatchNormGradGradConfig6(self):
config = {
'shape': [2, 3, 2, 2, 2],
'err_tolerance': 3e-3,
'dtype': np.float16,
}
self._testBatchNormGradGrad(config)
if __name__ == '__main__':
test.main()

View File

@ -897,6 +897,11 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
if data_format == b"NCHW":
x = array_ops.transpose(x, [0, 2, 3, 1])
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
elif data_format == b"NCDHW":
x = array_ops.transpose(x, [0, 2, 3, 4, 1])
grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
target_data_format = ("NHWC" if data_format in (b"NCHW",
b"NHWC") else "NDHWC")
args = {
"y_backprop": grad_y,
"x": x,
@ -904,7 +909,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
"reserve_space_1": pop_mean,
"reserve_space_2": pop_var,
"epsilon": epsilon,
"data_format": "NHWC",
"data_format": target_data_format,
"is_training": is_training
}
if version == 2:
@ -912,6 +917,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
dx, dscale, doffset, _, _ = grad_fun(**args)
if data_format == b"NCHW":
dx = array_ops.transpose(dx, [0, 3, 1, 2])
elif data_format == b"NCDHW":
dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
return dx, dscale, doffset, None, None
@ -941,8 +948,8 @@ def _BatchNormGrad(grad_y,
"""Returns the gradients for the 3 inputs of BatchNorm.
Args:
grad_y: A `Tensor` of 4 dimensions for gradient for y.
x: A `Tensor` of 4 dimensions for x.
grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
x: A `Tensor` of 4 or 5 dimensions for x.
scale: A `Tensor` of 1 dimension for scaling.
pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
is_training=False.
@ -968,11 +975,19 @@ def _BatchNormGrad(grad_y,
if data_format == b"NHWC":
keepdims = False
reduce_axis = [0, 1, 2]
else:
elif data_format == b"NDHWC":
keepdims = False
reduce_axis = [0, 1, 2, 3]
elif data_format == b"NCHW":
keepdims = True
reduce_axis = [0, 2, 3]
shape = [1, array_ops.size(scale), 1, 1]
scale = array_ops.reshape(scale, shape)
else:
keepdims = True
reduce_axis = [0, 2, 3, 4]
shape = [1, array_ops.size(scale), 1, 1, 1]
scale = array_ops.reshape(scale, shape)
mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
var_x = math_ops.reduce_mean(
@ -987,19 +1002,27 @@ def _BatchNormGrad(grad_y,
grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
if data_format == b"NCHW":
if data_format == b"NCHW" or data_format == b"NCDHW":
grad_scale = array_ops.squeeze(grad_scale)
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
else:
if data_format == b"NHWC":
reduce_axis = [0, 1, 2]
else:
elif data_format == b"NDHWC":
reduce_axis = [0, 1, 2, 3]
elif data_format == b"NCHW":
reduce_axis = [0, 2, 3]
shape = [1, array_ops.size(pop_mean), 1, 1]
pop_mean = array_ops.reshape(pop_mean, shape)
pop_var = array_ops.reshape(pop_var, shape)
scale = array_ops.reshape(scale, shape)
else:
reduce_axis = [0, 2, 3, 4]
shape = [1, array_ops.size(pop_mean), 1, 1, 1]
pop_mean = array_ops.reshape(pop_mean, shape)
pop_var = array_ops.reshape(pop_var, shape)
scale = array_ops.reshape(scale, shape)
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
var_rsqrt = math_ops.rsqrt(pop_var + epsilon)

View File

@ -1585,7 +1585,7 @@ def fused_batch_norm(
(http://arxiv.org/abs/1502.03167).
Args:
x: Input `Tensor` of 4 dimensions.
x: Input `Tensor` of 4 or 5 dimensions.
scale: A `Tensor` of 1 dimension for scaling.
offset: A `Tensor` of 1 dimension for bias.
mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
@ -1611,7 +1611,8 @@ def fused_batch_norm(
Variance must be a `Tensor` of the same shape as scale containing
the exponential running variance.
epsilon: A small float number added to the variance of x.
data_format: The data format for x. Either "NHWC" (default) or "NCHW".
data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
is_training: A bool value to specify if the operation is used for
training or inference.
name: A name for this operation (optional).
@ -1622,7 +1623,7 @@ def fused_batch_norm(
returned.
Returns:
y: A 4D Tensor for the normalized, scaled, offsetted x.
y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
running_mean: A 1D Tensor for the exponential running mean of x.
The output value is (1 - exponential_avg_factor) * mean +
exponential_avg_factor * batch_mean), where batch_mean