diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index c54418ba648..2d30d41c7a6 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1121,8 +1121,17 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { } Status FusedBatchNormShape(shape_inference::InferenceContext* c) { + string data_format_str; + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + TensorFormat data_format; + if (!FormatFromString(data_format_str, &data_format)) { + return errors::InvalidArgument("Invalid data format string: ", + data_format_str); + } + const int rank = + (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4; ShapeHandle x; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x)); bool is_training; TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); @@ -1131,14 +1140,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) { exponential_avg_factor = 1.0f; // default value } int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5; - string data_format_str; - TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); - TensorFormat data_format; - if (!FormatFromString(data_format_str, &data_format)) { - return errors::InvalidArgument("Invalid data format string: ", - data_format_str); - } - int channel_dim_index = GetTensorFeatureDimIndex(4, data_format); + + int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format); DimensionHandle channel_dim = c->Dim(x, channel_dim_index); // covers scale, offset, and if is_training is false, mean, variance @@ -1191,13 +1194,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { } Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { - ShapeHandle y_backprop; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); - ShapeHandle x; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); - - bool is_training; - TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; @@ -1205,7 +1201,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { return errors::InvalidArgument("Invalid data format string: ", data_format_str); } - int channel_dim_index = GetTensorFeatureDimIndex(4, data_format); + const int rank = + (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4; + ShapeHandle y_backprop; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop)); + ShapeHandle x; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x)); + + bool is_training; + TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); + + int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format); DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index); TF_RETURN_IF_ERROR( c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim)); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index 10253f187c0..3c466edc69b 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -670,7 +670,25 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context, Status DefaultLayoutSensitiveOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutSensitiveOp(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) { + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + std::string src_format = context->src_format; + std::string dst_format = context->dst_format; + // Update the format from 4D to 5D layout if necessary. + bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW"); + if (allow_5d) { + std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; + std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; + context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, + dst_format_3d); + } + if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) { + // Change back to the original layout due to early exit. + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -679,6 +697,11 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode( TF_RETURN_IF_ERROR(UpdateNode(context, node)); TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + // Change back the format from 5D to 4D layout. + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } return context->graph_view->GetMutationBuilder()->Apply(); } @@ -881,8 +904,26 @@ bool FusedBatchNormGradTransposer::IsTraining( Status FusedBatchNormGradTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsFusedBatchNormGrad(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + std::string src_format = context->src_format; + std::string dst_format = context->dst_format; + // Update the format from 4D to 5D layout if necessary. + bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW"); + if (allow_5d) { + std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; + std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; + context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, + dst_format_3d); + } + if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) || !IsTraining(*node)) { + // Change back to the original layout due to early exit. + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -892,6 +933,11 @@ Status FusedBatchNormGradTransposer::TransposeNode( TF_RETURN_IF_ERROR( UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); + // Change back the format from 5D to 4D layout. + if (allow_5d) { + context->AssignDeviceAndDataFormats(context->target_device, src_format, + dst_format); + } return context->graph_view->GetMutationBuilder()->Apply(); } diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 115428ff5ef..db528da2f6d 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -1438,29 +1438,41 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); Status status; - if (fused_node.attr().at(kDataFormat).s() == "NCHW") { + string x_format = fused_node.attr().at(kDataFormat).s(); + if (x_format == "NCHW" or x_format == "NCDHW") { // Need to reshape the last 4 inputs NodeDef new_shape; const string new_shape_name = - AddPrefixToNodeName("NCHWShape", fused_node.name()); + AddPrefixToNodeName(x_format + "Shape", fused_node.name()); new_shape.set_name(new_shape_name); new_shape.set_op("Const"); new_shape.set_device(fused_node.device()); *new_shape.add_input() = AsControlDependency(scale); (*new_shape.mutable_attr())["dtype"].set_type(DT_INT32); - Tensor t(DT_INT32, {4}); - t.flat<int32>()(0) = 1; - t.flat<int32>()(1) = -1; - t.flat<int32>()(2) = 1; - t.flat<int32>()(3) = 1; - t.AsProtoTensorContent( - (*new_shape.mutable_attr())["value"].mutable_tensor()); + if (x_format == "NCHW") { + Tensor t(DT_INT32, {4}); + t.flat<int32>()(0) = 1; + t.flat<int32>()(1) = -1; + t.flat<int32>()(2) = 1; + t.flat<int32>()(3) = 1; + t.AsProtoTensorContent( + (*new_shape.mutable_attr())["value"].mutable_tensor()); + } else { + Tensor t(DT_INT32, {5}); + t.flat<int32>()(0) = 1; + t.flat<int32>()(1) = -1; + t.flat<int32>()(2) = 1; + t.flat<int32>()(3) = 1; + t.flat<int32>()(4) = 1; + t.AsProtoTensorContent( + (*new_shape.mutable_attr())["value"].mutable_tensor()); + } mutation->AddNode(std::move(new_shape), &status); TF_RETURN_IF_ERROR(status); NodeDef reshaped_scale; reshaped_scale.set_name( - AddPrefixToNodeName("NCHWShapedScale", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name())); reshaped_scale.set_op("Reshape"); reshaped_scale.set_device(fused_node.device()); *reshaped_scale.add_input() = scale; @@ -1473,7 +1485,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { NodeDef reshaped_offset; reshaped_offset.set_name( - AddPrefixToNodeName("NCHWShapedOffset", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name())); reshaped_offset.set_op("Reshape"); reshaped_offset.set_device(fused_node.device()); *reshaped_offset.add_input() = offset; @@ -1486,7 +1498,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { NodeDef reshaped_mean; reshaped_mean.set_name( - AddPrefixToNodeName("NCHWShapedMean", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name())); reshaped_mean.set_op("Reshape"); reshaped_mean.set_device(fused_node.device()); *reshaped_mean.add_input() = mean; @@ -1499,7 +1511,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { NodeDef reshaped_variance; reshaped_variance.set_name( - AddPrefixToNodeName("NCHWShapedVariance", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name())); reshaped_variance.set_op("Reshape"); reshaped_variance.set_device(fused_node.device()); *reshaped_variance.add_input() = variance; diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 00ac9be6dcd..d8e58093b07 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel { // If use_reserved_space is false, we don't have 5th output. virtual void ComputeWithReservedSpace(OpKernelContext* context, bool use_reserved_space) { - const Tensor& x = context->input(0); + Tensor x = context->input(0); const Tensor& scale = context->input(1); const Tensor& offset = context->input(2); const Tensor& estimated_mean = context->input(3); const Tensor& estimated_variance = context->input(4); const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr; - OP_REQUIRES(context, x.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", + OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5, + errors::InvalidArgument("input must be 4 or 5-dimensional", x.shape().DebugString())); OP_REQUIRES(context, scale.dims() == 1, errors::InvalidArgument("scale must be 1-dimensional", @@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel { context, estimated_variance.dims() == 1, errors::InvalidArgument("estimated_variance must be 1-dimensional", estimated_variance.shape().DebugString())); + bool use_reshape = (x.dims() == 5); + auto x_shape = x.shape(); + TensorShape dest_shape; + if (use_reshape) { + const int64 in_batch = GetTensorDim(x, tensor_format_, 'N'); + int64 in_planes = GetTensorDim(x, tensor_format_, '0'); + int64 in_rows = GetTensorDim(x, tensor_format_, '1'); + int64 in_cols = GetTensorDim(x, tensor_format_, '2'); + const int64 in_depth = GetTensorDim(x, tensor_format_, 'C'); + dest_shape = ShapeFromFormat(tensor_format_, in_batch, + {{in_planes, in_rows * in_cols}}, in_depth); + OP_REQUIRES(context, x.CopyFrom(x, dest_shape), + errors::InvalidArgument("Error during tensor copy.")); + } + if (has_side_input_) { OP_REQUIRES(context, side_input->shape() == x.shape(), errors::InvalidArgument( @@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel { } Tensor* y = nullptr; + auto alloc_shape = use_reshape ? dest_shape : x_shape; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {0}, 0, x.shape(), &y)); + {0}, 0, alloc_shape, &y)); + Tensor* batch_mean = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {3}, 1, scale.shape(), &batch_mean)); @@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel { batch_mean, batch_var, saved_mean, saved_maybe_inv_var, tensor_format_, use_reserved_space); } + if (use_reshape) { + OP_REQUIRES(context, y->CopyFrom(*y, x_shape), + errors::InvalidArgument("Error during tensor copy.")); + } } private: @@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel { virtual void ComputeWithReservedSpace(OpKernelContext* context, bool use_reserved_space) { - const Tensor& y_backprop = context->input(0); - const Tensor& x = context->input(1); + Tensor y_backprop = context->input(0); + Tensor x = context->input(1); const Tensor& scale = context->input(2); // When is_training=True, batch mean and variance/inverted variance are // saved in the forward pass to be reused here. When is_training=False, @@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel { // saves inverted variance. const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4); - OP_REQUIRES(context, y_backprop.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", + OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5, + errors::InvalidArgument("input must be 4 or 5-dimensional", y_backprop.shape().DebugString())); - OP_REQUIRES(context, x.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", + OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5, + errors::InvalidArgument("input must be 4 or 5-dimensional", x.shape().DebugString())); OP_REQUIRES(context, scale.dims() == 1, errors::InvalidArgument("scale must be 1-dimensional", @@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel { errors::InvalidArgument( "saved variance must be 1-dimensional", saved_maybe_inv_var_or_pop_var.shape().DebugString())); + bool use_reshape = (x.dims() == 5); + auto x_shape = x.shape(); + TensorShape dest_shape; + if (use_reshape) { + const int64 in_batch = GetTensorDim(x, tensor_format_, 'N'); + int64 in_planes = GetTensorDim(x, tensor_format_, '0'); + int64 in_rows = GetTensorDim(x, tensor_format_, '1'); + int64 in_cols = GetTensorDim(x, tensor_format_, '2'); + const int64 in_depth = GetTensorDim(x, tensor_format_, 'C'); + dest_shape = ShapeFromFormat(tensor_format_, in_batch, + {{in_planes, in_rows * in_cols}}, in_depth); + OP_REQUIRES(context, x.CopyFrom(x, dest_shape), + errors::InvalidArgument("Error during tensor copy.")); + OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape), + errors::InvalidArgument("Error during tensor copy.")); + } Tensor* x_backprop = nullptr; + auto alloc_shape = use_reshape ? dest_shape : x_shape; OP_REQUIRES_OK(context, - context->allocate_output(0, x.shape(), &x_backprop)); + context->allocate_output(0, alloc_shape, &x_backprop)); const TensorShape& scale_offset_shape = scale.shape(); Tensor* scale_backprop = nullptr; @@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel { offset_backprop, use_reserved_space, tensor_format_); } else { // Necessary layout conversion is currently done in python. - CHECK(tensor_format_ == FORMAT_NHWC) - << "The implementation of FusedBatchNormGrad with is_training=False " - "only support " - << "NHWC tensor format for now."; + OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "The implementation of " + "FusedBatchNormGrad with is_training=False only support " + "NHWC tensor format for now.")); functor::FusedBatchNormFreezeGrad<Device, T, U>()( context, y_backprop, x, scale, saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, offset_backprop); } + if (use_reshape) { + OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape), + errors::InvalidArgument("Error during tensor copy.")); + } } private: diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 2b6330db4aa..759bf0f0ddf 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -221,7 +221,7 @@ REGISTER_OP("FusedBatchNormV3") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") .Attr("exponential_avg_factor: float = 1.0") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormV3Shape); @@ -308,7 +308,7 @@ REGISTER_OP("FusedBatchNormGradV3") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormGradShape); // -------------------------------------------------------------------------- diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index c80ab536588..263b05047da 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -1275,6 +1275,94 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes) self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes) self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes) + + @test_util.deprecated_graph_mode_only + def testBatchNorm3D(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0) + filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0) + strides_val = [1, 1, 1, 1, 1] + scale = constant_op.constant(0.1, shape=[3]) + offset = constant_op.constant(0.3, shape=[3]) + conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME') + y, _, _ = nn.fused_batch_norm(conv3d, scale, offset, data_format='NDHWC') + output = array_ops.identity(y) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormV3-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + @test_util.deprecated_graph_mode_only + def testBatchNormGrad3D(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0) + filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0) + strides_val = [1, 1, 1, 1, 1] + scale = constant_op.constant(0.1, shape=[3]) + offset = constant_op.constant(0.3, shape=[3]) + mean = constant_op.constant(0.1, shape=[3]) + variance = constant_op.constant(0.3, shape=[3]) + conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME') + y, running_mean, running_var, r0, r1, r2 = gen_nn_ops.fused_batch_norm_v3( + conv3d, + scale, + offset, + mean, + variance, + epsilon=1.001e-5, + exponential_avg_factor=1.0, + data_format='NDHWC', + is_training=True, + name='batch_norm') + dx, dscale, doffset, _, _ = gen_nn_ops.fused_batch_norm_grad_v3( + y, + x_3d, + scale, + r0, + r1, + r2, + epsilon=1.001e-5, + data_format='NDHWC', + is_training=True) + output = array_ops.identity(dx) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 3 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ndhwc_to_ncdhw('FusedBatchNormGradV3-1', nodes) + self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormGradV3-0-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) @test_util.deprecated_graph_mode_only diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index dc6eda6dcc3..2809cbb0108 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -330,13 +330,13 @@ class BatchNormalizationBase(Layer): # output back to its original shape accordingly. if self._USE_V2_BEHAVIOR: if self.fused is None: - self.fused = (ndims == 4) - elif self.fused and ndims != 4: + self.fused = ndims in (4, 5) + elif self.fused and ndims not in (4, 5): raise ValueError('Batch normalization layers with fused=True only ' - 'support 4D input tensors.') + 'support 4D or 5D input tensors.') else: assert self.fused is not None - self.fused = (ndims == 4 and self._fused_can_be_used()) + self.fused = (ndims in (4, 5) and self._fused_can_be_used()) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping @@ -345,13 +345,22 @@ class BatchNormalizationBase(Layer): # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: - if self.axis == [1]: + if self.axis == [1] and ndims == 4: self._data_format = 'NCHW' - elif self.axis == [3]: + elif self.axis == [1] and ndims == 5: + self._data_format = 'NCDHW' + elif self.axis == [3] and ndims == 4: self._data_format = 'NHWC' + elif self.axis == [4] and ndims == 5: + self._data_format = 'NDHWC' + elif ndims == 5: + # 5D tensors that can be passed in but should not use fused batch norm + # due to unsupported axis. + self.fused = False else: raise ValueError('Unsupported axis, fused batch norm only supports ' - 'axis == [1] or axis == [3]') + 'axis == [1] or axis == [3] for 4D input tensors or ' + 'axis == [1] or axis == [4] for 5D input tensors') axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} for x in axis_to_dim: diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index f89a615bee5..79ecc3c3fe1 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -66,6 +66,15 @@ class BatchNormalizationTest(keras_parameterized.TestCase): kwargs={'scale': False, 'center': False}, input_shape=(3, 3)) + testing_utils.layer_test( + keras.layers.BatchNormalization, + kwargs={ + 'gamma_initializer': 'ones', + 'beta_initializer': 'ones', + 'moving_mean_initializer': 'zeros', + 'moving_variance_initializer': 'ones' + }, + input_shape=(3, 2, 4, 2)) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_batchnorm_weights(self): @@ -319,7 +328,7 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase): norm = normalization_v2.BatchNormalization(fused=True) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(4, 4)) - with self.assertRaisesRegex(ValueError, '4D input tensors'): + with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'): norm(inp) def test_updates_in_wrap_function(self): diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index 1742a919216..0421829bff3 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -43,14 +43,18 @@ class BatchNormalizationTest(test.TestCase): return math_ops.cast(y, x.dtype) def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format): - if data_format not in ['NHWC', 'NCHW']: - raise ValueError('data_format must be NCHW or NHWC, ' - 'got %s.' % data_format) + if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']: + raise ValueError('data_format must be NCHW or NHWC for 4D tensors or' + 'NCDHW or NDHWC for 5D tensors, got %s.' % data_format) if data_format == 'NCHW': x = array_ops.transpose(x, [0, 2, 3, 1]) + elif data_format == 'NCDHW': + x = array_ops.transpose(x, [0, 2, 3, 4, 1]) y = self._batch_norm(x, mean, var, offset, scale, epsilon) if data_format == 'NCHW': y = array_ops.transpose(y, [0, 3, 1, 2]) + elif data_format == 'NCDHW': + y = array_ops.transpose(y, [0, 4, 1, 2, 3]) return self.evaluate(y) def _test_inference(self, @@ -102,17 +106,24 @@ class BatchNormalizationTest(test.TestCase): def _training_ref(self, x, scale, offset, old_mean, old_var, exponential_avg_factor, epsilon, data_format): - if data_format not in ['NHWC', 'NCHW']: - raise ValueError('data_format must be NCHW or NHWC, ' - 'got %s.' % data_format) + if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']: + raise ValueError('data_format must be NCHW or NHWC for 4D tensors or' + 'NCDHW or NDHWC for 5D tensors, got %s.' % data_format) + use_4d_tensor = (x.shape.ndims == 4) if data_format == 'NCHW': x = array_ops.transpose(x, [0, 2, 3, 1]) + elif data_format == 'NCDHW': + x = array_ops.transpose(x, [0, 2, 3, 4, 1]) + + mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3] batch_mean, batch_var = nn_impl.moments( - math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False) + math_ops.cast(x, scale.dtype), mean_axis, keep_dims=False) y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon) if data_format == 'NCHW': y = array_ops.transpose(y, [0, 3, 1, 2]) + elif data_format == 'NCDHW': + y = array_ops.transpose(y, [0, 4, 1, 2, 3]) # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as # the denominator in the formula to calculate variance, while @@ -377,14 +388,18 @@ class BatchNormalizationTest(test.TestCase): def _runtests(self, x_shape, is_training, gradient_test=False, cpu_only=False): + if len(x_shape) == 4: + data_format_list = ['NHWC', 'NCHW'] + else: + data_format_list = ['NCDHW', 'NDHWC'] use_gpu_vals = [False] if test.is_gpu_available(cuda_only=True) and not cpu_only: use_gpu_vals += [True] factors = [1.0, 0.6] for dtype in [np.float16, np.float32]: for use_gpu in use_gpu_vals: - for data_format in ['NHWC', 'NCHW']: - if data_format == 'NHWC': + for data_format in data_format_list: + if data_format == 'NHWC' or data_format == 'NDHWC': scale_shape = x_shape[-1:] else: scale_shape = x_shape[1:2] @@ -444,6 +459,10 @@ class BatchNormalizationTest(test.TestCase): # GPU kernel doesn't properly handle case where non-channel dimensions are 1 self._runtests(x_shape, False, cpu_only=True) + def testInferenceShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, False) + def testTrainingShape1(self): x_shape = [1, 1, 6, 1] self._runtests(x_shape, True) @@ -465,11 +484,16 @@ class BatchNormalizationTest(test.TestCase): x_shape = [0, 131, 127, 6] self._runtests(x_shape, True) + @test_util.run_deprecated_v1 def testTrainingShape6(self): x_shape = [1, 1, 1, 1] # GPU kernel doesn't properly handle case where non-channel dimensions are 1 self._runtests(x_shape, True, cpu_only=True) + def testTrainingShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, True) + @test_util.run_deprecated_v1 def testBatchNormGradInferenceShape1(self): x_shape = [1, 1, 6, 1] @@ -503,6 +527,11 @@ class BatchNormalizationTest(test.TestCase): self._runtests(x_shape, is_training=False, gradient_test=True, cpu_only=True) + @test_util.run_deprecated_v1 + def testBatchNormGradInferenceShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, is_training=False, gradient_test=True) + @test_util.run_deprecated_v1 def testBatchNormGradTrainingShape1(self): x_shape = [1, 1, 6, 1] @@ -535,42 +564,54 @@ class BatchNormalizationTest(test.TestCase): # GPU kernel doesn't properly handle case where non-channel dimensions are 1 self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True) + @test_util.run_deprecated_v1 + def testBatchNormGradTrainingShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, is_training=True, gradient_test=True) + def _testBatchNormGradGrad(self, config): shape = config['shape'] err_tolerance = config['err_tolerance'] dtype = config['dtype'] + rank = len(shape) + if rank == 4: + data_format_nhwc, features_nhwc = 'NHWC', shape[3] + data_format_nchw, features_nchw = 'NCHW', shape[1] + else: + data_format_nhwc, features_nhwc = 'NDHWC', shape[4] + data_format_nchw, features_nchw = 'NCDHW', shape[1] for is_training in [True, False]: if test.is_gpu_available(cuda_only=True): self._test_grad_grad( shape, - dtype, [shape[3]], + dtype, [features_nhwc], np.float32, use_gpu=True, - data_format='NHWC', + data_format=data_format_nhwc, is_training=is_training, err_tolerance=err_tolerance) self._test_grad_grad( shape, - dtype, [shape[1]], + dtype, [features_nchw], np.float32, use_gpu=True, - data_format='NCHW', + data_format=data_format_nchw, is_training=is_training, err_tolerance=err_tolerance) self._test_grad_grad( shape, - dtype, [shape[3]], + dtype, [features_nhwc], np.float32, use_gpu=False, - data_format='NHWC', + data_format=data_format_nhwc, is_training=is_training, err_tolerance=err_tolerance) self._test_grad_grad( shape, - dtype, [shape[1]], + dtype, [features_nchw], np.float32, use_gpu=False, - data_format='NCHW', + data_format=data_format_nchw, is_training=is_training, err_tolerance=err_tolerance) @@ -610,6 +651,24 @@ class BatchNormalizationTest(test.TestCase): } self._testBatchNormGradGrad(config) + @test_util.run_deprecated_v1 + def testBatchNormGradGradConfig5(self): + config = { + 'shape': [2, 3, 2, 2, 2], + 'err_tolerance': 2e-3, + 'dtype': np.float32, + } + self._testBatchNormGradGrad(config) + + @test_util.run_deprecated_v1 + def testBatchNormGradGradConfig6(self): + config = { + 'shape': [2, 3, 2, 2, 2], + 'err_tolerance': 3e-3, + 'dtype': np.float16, + } + self._testBatchNormGradGrad(config) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 58dd1852cc5..a02e31f80a5 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -897,6 +897,11 @@ def _BaseFusedBatchNormGrad(op, version, *grad): if data_format == b"NCHW": x = array_ops.transpose(x, [0, 2, 3, 1]) grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1]) + elif data_format == b"NCDHW": + x = array_ops.transpose(x, [0, 2, 3, 4, 1]) + grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1]) + target_data_format = ("NHWC" if data_format in (b"NCHW", + b"NHWC") else "NDHWC") args = { "y_backprop": grad_y, "x": x, @@ -904,7 +909,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad): "reserve_space_1": pop_mean, "reserve_space_2": pop_var, "epsilon": epsilon, - "data_format": "NHWC", + "data_format": target_data_format, "is_training": is_training } if version == 2: @@ -912,6 +917,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad): dx, dscale, doffset, _, _ = grad_fun(**args) if data_format == b"NCHW": dx = array_ops.transpose(dx, [0, 3, 1, 2]) + elif data_format == b"NCDHW": + dx = array_ops.transpose(dx, [0, 4, 1, 2, 3]) return dx, dscale, doffset, None, None @@ -941,8 +948,8 @@ def _BatchNormGrad(grad_y, """Returns the gradients for the 3 inputs of BatchNorm. Args: - grad_y: A `Tensor` of 4 dimensions for gradient for y. - x: A `Tensor` of 4 dimensions for x. + grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y. + x: A `Tensor` of 4 or 5 dimensions for x. scale: A `Tensor` of 1 dimension for scaling. pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when is_training=False. @@ -968,11 +975,19 @@ def _BatchNormGrad(grad_y, if data_format == b"NHWC": keepdims = False reduce_axis = [0, 1, 2] - else: + elif data_format == b"NDHWC": + keepdims = False + reduce_axis = [0, 1, 2, 3] + elif data_format == b"NCHW": keepdims = True reduce_axis = [0, 2, 3] shape = [1, array_ops.size(scale), 1, 1] scale = array_ops.reshape(scale, shape) + else: + keepdims = True + reduce_axis = [0, 2, 3, 4] + shape = [1, array_ops.size(scale), 1, 1, 1] + scale = array_ops.reshape(scale, shape) mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) var_x = math_ops.reduce_mean( @@ -987,19 +1002,27 @@ def _BatchNormGrad(grad_y, grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) - if data_format == b"NCHW": + if data_format == b"NCHW" or data_format == b"NCDHW": grad_scale = array_ops.squeeze(grad_scale) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset else: if data_format == b"NHWC": reduce_axis = [0, 1, 2] - else: + elif data_format == b"NDHWC": + reduce_axis = [0, 1, 2, 3] + elif data_format == b"NCHW": reduce_axis = [0, 2, 3] shape = [1, array_ops.size(pop_mean), 1, 1] pop_mean = array_ops.reshape(pop_mean, shape) pop_var = array_ops.reshape(pop_var, shape) scale = array_ops.reshape(scale, shape) + else: + reduce_axis = [0, 2, 3, 4] + shape = [1, array_ops.size(pop_mean), 1, 1, 1] + pop_mean = array_ops.reshape(pop_mean, shape) + pop_var = array_ops.reshape(pop_var, shape) + scale = array_ops.reshape(scale, shape) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) var_rsqrt = math_ops.rsqrt(pop_var + epsilon) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 89174b29336..d22fbf3fa4e 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -1585,7 +1585,7 @@ def fused_batch_norm( (http://arxiv.org/abs/1502.03167). Args: - x: Input `Tensor` of 4 dimensions. + x: Input `Tensor` of 4 or 5 dimensions. scale: A `Tensor` of 1 dimension for scaling. offset: A `Tensor` of 1 dimension for bias. mean: A `Tensor` of 1 dimension for population mean. The shape and meaning @@ -1611,7 +1611,8 @@ def fused_batch_norm( Variance must be a `Tensor` of the same shape as scale containing the exponential running variance. epsilon: A small float number added to the variance of x. - data_format: The data format for x. Either "NHWC" (default) or "NCHW". + data_format: The data format for x. Support "NHWC" (default) or "NCHW" for + 4D tenors and "NDHWC" or "NCDHW" for 5D tensors. is_training: A bool value to specify if the operation is used for training or inference. name: A name for this operation (optional). @@ -1622,7 +1623,7 @@ def fused_batch_norm( returned. Returns: - y: A 4D Tensor for the normalized, scaled, offsetted x. + y: A 4D or 5D Tensor for the normalized, scaled, offsetted x. running_mean: A 1D Tensor for the exponential running mean of x. The output value is (1 - exponential_avg_factor) * mean + exponential_avg_factor * batch_mean), where batch_mean