diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index c54418ba648..2d30d41c7a6 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1121,8 +1121,17 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
 }
 
 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
+  string data_format_str;
+  TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
+  TensorFormat data_format;
+  if (!FormatFromString(data_format_str, &data_format)) {
+    return errors::InvalidArgument("Invalid data format string: ",
+                                   data_format_str);
+  }
+  const int rank =
+      (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
   ShapeHandle x;
-  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
 
   bool is_training;
   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
@@ -1131,14 +1140,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
     exponential_avg_factor = 1.0f;  // default value
   }
   int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
-  string data_format_str;
-  TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
-  TensorFormat data_format;
-  if (!FormatFromString(data_format_str, &data_format)) {
-    return errors::InvalidArgument("Invalid data format string: ",
-                                   data_format_str);
-  }
-  int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
+
+  int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
 
   // covers scale, offset, and if is_training is false, mean, variance
@@ -1191,13 +1194,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
 }
 
 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
-  ShapeHandle y_backprop;
-  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
-  ShapeHandle x;
-  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
-
-  bool is_training;
-  TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
   string data_format_str;
   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
   TensorFormat data_format;
@@ -1205,7 +1201,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
     return errors::InvalidArgument("Invalid data format string: ",
                                    data_format_str);
   }
-  int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
+  const int rank =
+      (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4;
+  ShapeHandle y_backprop;
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
+  ShapeHandle x;
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
+
+  bool is_training;
+  TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
+
+  int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
   DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
   TF_RETURN_IF_ERROR(
       c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index 10253f187c0..3c466edc69b 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -670,7 +670,25 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
 Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
     TransposeContext* context, utils::MutableNodeView* node) {
   DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
-  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
+  const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
+  const auto& shape = output_shape_attr->list().shape(0);
+  const int rank = shape.dim_size();
+  std::string src_format = context->src_format;
+  std::string dst_format = context->dst_format;
+  // Update the format from 4D to 5D layout if necessary.
+  bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
+  if (allow_5d) {
+    std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
+    std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
+    context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
+                                        dst_format_3d);
+  }
+  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
+    // Change back to the original layout due to early exit.
+    if (allow_5d) {
+      context->AssignDeviceAndDataFormats(context->target_device, src_format,
+                                          dst_format);
+    }
     return Status::OK();
   }
   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@@ -679,6 +697,11 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
   TF_RETURN_IF_ERROR(UpdateNode(context, node));
   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
+  // Change back the format from 5D to 4D layout.
+  if (allow_5d) {
+    context->AssignDeviceAndDataFormats(context->target_device, src_format,
+                                        dst_format);
+  }
   return context->graph_view->GetMutationBuilder()->Apply();
 }
 
@@ -881,8 +904,26 @@ bool FusedBatchNormGradTransposer::IsTraining(
 Status FusedBatchNormGradTransposer::TransposeNode(
     TransposeContext* context, utils::MutableNodeView* node) {
   DCHECK(IsFusedBatchNormGrad(*node->node()));
-  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
+  const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
+  const auto& shape = output_shape_attr->list().shape(0);
+  const int rank = shape.dim_size();
+  std::string src_format = context->src_format;
+  std::string dst_format = context->dst_format;
+  // Update the format from 4D to 5D layout if necessary.
+  bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
+  if (allow_5d) {
+    std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
+    std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
+    context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
+                                        dst_format_3d);
+  }
+  if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) ||
       !IsTraining(*node)) {
+    // Change back to the original layout due to early exit.
+    if (allow_5d) {
+      context->AssignDeviceAndDataFormats(context->target_device, src_format,
+                                          dst_format);
+    }
     return Status::OK();
   }
   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
@@ -892,6 +933,11 @@ Status FusedBatchNormGradTransposer::TransposeNode(
   TF_RETURN_IF_ERROR(
       UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
+  // Change back the format from 5D to 4D layout.
+  if (allow_5d) {
+    context->AssignDeviceAndDataFormats(context->target_device, src_format,
+                                        dst_format);
+  }
   return context->graph_view->GetMutationBuilder()->Apply();
 }
 
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 115428ff5ef..db528da2f6d 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -1438,29 +1438,41 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
   Status status;
 
-  if (fused_node.attr().at(kDataFormat).s() == "NCHW") {
+  string x_format = fused_node.attr().at(kDataFormat).s();
+  if (x_format == "NCHW" or x_format == "NCDHW") {
     // Need to reshape the last 4 inputs
     NodeDef new_shape;
     const string new_shape_name =
-        AddPrefixToNodeName("NCHWShape", fused_node.name());
+        AddPrefixToNodeName(x_format + "Shape", fused_node.name());
     new_shape.set_name(new_shape_name);
     new_shape.set_op("Const");
     new_shape.set_device(fused_node.device());
     *new_shape.add_input() = AsControlDependency(scale);
     (*new_shape.mutable_attr())["dtype"].set_type(DT_INT32);
-    Tensor t(DT_INT32, {4});
-    t.flat<int32>()(0) = 1;
-    t.flat<int32>()(1) = -1;
-    t.flat<int32>()(2) = 1;
-    t.flat<int32>()(3) = 1;
-    t.AsProtoTensorContent(
-        (*new_shape.mutable_attr())["value"].mutable_tensor());
+    if (x_format == "NCHW") {
+      Tensor t(DT_INT32, {4});
+      t.flat<int32>()(0) = 1;
+      t.flat<int32>()(1) = -1;
+      t.flat<int32>()(2) = 1;
+      t.flat<int32>()(3) = 1;
+      t.AsProtoTensorContent(
+          (*new_shape.mutable_attr())["value"].mutable_tensor());
+    } else {
+      Tensor t(DT_INT32, {5});
+      t.flat<int32>()(0) = 1;
+      t.flat<int32>()(1) = -1;
+      t.flat<int32>()(2) = 1;
+      t.flat<int32>()(3) = 1;
+      t.flat<int32>()(4) = 1;
+      t.AsProtoTensorContent(
+          (*new_shape.mutable_attr())["value"].mutable_tensor());
+    }
     mutation->AddNode(std::move(new_shape), &status);
     TF_RETURN_IF_ERROR(status);
 
     NodeDef reshaped_scale;
     reshaped_scale.set_name(
-        AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
+        AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name()));
     reshaped_scale.set_op("Reshape");
     reshaped_scale.set_device(fused_node.device());
     *reshaped_scale.add_input() = scale;
@@ -1473,7 +1485,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
 
     NodeDef reshaped_offset;
     reshaped_offset.set_name(
-        AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
+        AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name()));
     reshaped_offset.set_op("Reshape");
     reshaped_offset.set_device(fused_node.device());
     *reshaped_offset.add_input() = offset;
@@ -1486,7 +1498,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
 
     NodeDef reshaped_mean;
     reshaped_mean.set_name(
-        AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
+        AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name()));
     reshaped_mean.set_op("Reshape");
     reshaped_mean.set_device(fused_node.device());
     *reshaped_mean.add_input() = mean;
@@ -1499,7 +1511,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
 
     NodeDef reshaped_variance;
     reshaped_variance.set_name(
-        AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
+        AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name()));
     reshaped_variance.set_op("Reshape");
     reshaped_variance.set_device(fused_node.device());
     *reshaped_variance.add_input() = variance;
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 00ac9be6dcd..d8e58093b07 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel {
   // If use_reserved_space is false, we don't have 5th output.
   virtual void ComputeWithReservedSpace(OpKernelContext* context,
                                         bool use_reserved_space) {
-    const Tensor& x = context->input(0);
+    Tensor x = context->input(0);
     const Tensor& scale = context->input(1);
     const Tensor& offset = context->input(2);
     const Tensor& estimated_mean = context->input(3);
     const Tensor& estimated_variance = context->input(4);
     const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
 
-    OP_REQUIRES(context, x.dims() == 4,
-                errors::InvalidArgument("input must be 4-dimensional",
+    OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
+                errors::InvalidArgument("input must be 4 or 5-dimensional",
                                         x.shape().DebugString()));
     OP_REQUIRES(context, scale.dims() == 1,
                 errors::InvalidArgument("scale must be 1-dimensional",
@@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel {
         context, estimated_variance.dims() == 1,
         errors::InvalidArgument("estimated_variance must be 1-dimensional",
                                 estimated_variance.shape().DebugString()));
+    bool use_reshape = (x.dims() == 5);
+    auto x_shape = x.shape();
+    TensorShape dest_shape;
+    if (use_reshape) {
+      const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
+      int64 in_planes = GetTensorDim(x, tensor_format_, '0');
+      int64 in_rows = GetTensorDim(x, tensor_format_, '1');
+      int64 in_cols = GetTensorDim(x, tensor_format_, '2');
+      const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
+      dest_shape = ShapeFromFormat(tensor_format_, in_batch,
+                                   {{in_planes, in_rows * in_cols}}, in_depth);
+      OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
+                  errors::InvalidArgument("Error during tensor copy."));
+    }
+
     if (has_side_input_) {
       OP_REQUIRES(context, side_input->shape() == x.shape(),
                   errors::InvalidArgument(
@@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel {
     }
 
     Tensor* y = nullptr;
+    auto alloc_shape = use_reshape ? dest_shape : x_shape;
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
-                                {0}, 0, x.shape(), &y));
+                                {0}, 0, alloc_shape, &y));
+
     Tensor* batch_mean = nullptr;
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
                                 {3}, 1, scale.shape(), &batch_mean));
@@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel {
           batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
           tensor_format_, use_reserved_space);
     }
+    if (use_reshape) {
+      OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
+                  errors::InvalidArgument("Error during tensor copy."));
+    }
   }
 
  private:
@@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel {
 
   virtual void ComputeWithReservedSpace(OpKernelContext* context,
                                         bool use_reserved_space) {
-    const Tensor& y_backprop = context->input(0);
-    const Tensor& x = context->input(1);
+    Tensor y_backprop = context->input(0);
+    Tensor x = context->input(1);
     const Tensor& scale = context->input(2);
     // When is_training=True, batch mean and variance/inverted variance are
     // saved in the forward pass to be reused here. When is_training=False,
@@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel {
     // saves inverted variance.
     const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
 
-    OP_REQUIRES(context, y_backprop.dims() == 4,
-                errors::InvalidArgument("input must be 4-dimensional",
+    OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5,
+                errors::InvalidArgument("input must be 4 or 5-dimensional",
                                         y_backprop.shape().DebugString()));
-    OP_REQUIRES(context, x.dims() == 4,
-                errors::InvalidArgument("input must be 4-dimensional",
+    OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5,
+                errors::InvalidArgument("input must be 4 or 5-dimensional",
                                         x.shape().DebugString()));
     OP_REQUIRES(context, scale.dims() == 1,
                 errors::InvalidArgument("scale must be 1-dimensional",
@@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel {
                 errors::InvalidArgument(
                     "saved variance must be 1-dimensional",
                     saved_maybe_inv_var_or_pop_var.shape().DebugString()));
+    bool use_reshape = (x.dims() == 5);
+    auto x_shape = x.shape();
+    TensorShape dest_shape;
+    if (use_reshape) {
+      const int64 in_batch = GetTensorDim(x, tensor_format_, 'N');
+      int64 in_planes = GetTensorDim(x, tensor_format_, '0');
+      int64 in_rows = GetTensorDim(x, tensor_format_, '1');
+      int64 in_cols = GetTensorDim(x, tensor_format_, '2');
+      const int64 in_depth = GetTensorDim(x, tensor_format_, 'C');
+      dest_shape = ShapeFromFormat(tensor_format_, in_batch,
+                                   {{in_planes, in_rows * in_cols}}, in_depth);
+      OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
+                  errors::InvalidArgument("Error during tensor copy."));
+      OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
+                  errors::InvalidArgument("Error during tensor copy."));
+    }
 
     Tensor* x_backprop = nullptr;
+    auto alloc_shape = use_reshape ? dest_shape : x_shape;
     OP_REQUIRES_OK(context,
-                   context->allocate_output(0, x.shape(), &x_backprop));
+                   context->allocate_output(0, alloc_shape, &x_backprop));
 
     const TensorShape& scale_offset_shape = scale.shape();
     Tensor* scale_backprop = nullptr;
@@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel {
           offset_backprop, use_reserved_space, tensor_format_);
     } else {
       // Necessary layout conversion is currently done in python.
-      CHECK(tensor_format_ == FORMAT_NHWC)
-          << "The implementation of FusedBatchNormGrad with is_training=False "
-             "only support "
-          << "NHWC tensor format for now.";
+      OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
+                  errors::InvalidArgument(
+                      "The implementation of "
+                      "FusedBatchNormGrad with is_training=False only support "
+                      "NHWC tensor format for now."));
       functor::FusedBatchNormFreezeGrad<Device, T, U>()(
           context, y_backprop, x, scale, saved_mean_or_pop_mean,
           saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
           offset_backprop);
     }
+    if (use_reshape) {
+      OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
+                  errors::InvalidArgument("Error during tensor copy."));
+    }
   }
 
  private:
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 2b6330db4aa..759bf0f0ddf 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -221,7 +221,7 @@ REGISTER_OP("FusedBatchNormV3")
     .Attr("U: {float}")
     .Attr("epsilon: float = 0.0001")
     .Attr("exponential_avg_factor: float = 1.0")
-    .Attr(GetConvnetDataFormatAttrString())
+    .Attr(GetConvnetDataFormat2D3DAttrString())
     .Attr("is_training: bool = true")
     .SetShapeFn(shape_inference::FusedBatchNormV3Shape);
 
@@ -308,7 +308,7 @@ REGISTER_OP("FusedBatchNormGradV3")
     .Attr("T: {half, bfloat16, float}")
     .Attr("U: {float}")
     .Attr("epsilon: float = 0.0001")
-    .Attr(GetConvnetDataFormatAttrString())
+    .Attr(GetConvnetDataFormat2D3DAttrString())
     .Attr("is_training: bool = true")
     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
 // --------------------------------------------------------------------------
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index c80ab536588..263b05047da 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1275,6 +1275,94 @@ class LayoutOptimizerTest(test.TestCase):
       self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes)
       self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes)
       self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes)
+
+  @test_util.deprecated_graph_mode_only
+  def testBatchNorm3D(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
+      filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
+      strides_val = [1, 1, 1, 1, 1]
+      scale = constant_op.constant(0.1, shape=[3])
+      offset = constant_op.constant(0.3, shape=[3])
+      conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
+      y, _, _ = nn.fused_batch_norm(conv3d, scale, offset, data_format='NDHWC')
+      output = array_ops.identity(y)
+
+      with session.Session(config=_get_config(False)) as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      expected_num_transposes = 2
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
+      self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormV3-0-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+  @test_util.deprecated_graph_mode_only
+  def testBatchNormGrad3D(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
+      filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
+      strides_val = [1, 1, 1, 1, 1]
+      scale = constant_op.constant(0.1, shape=[3])
+      offset = constant_op.constant(0.3, shape=[3])
+      mean = constant_op.constant(0.1, shape=[3])
+      variance = constant_op.constant(0.3, shape=[3])
+      conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME')
+      y, running_mean, running_var, r0, r1, r2 = gen_nn_ops.fused_batch_norm_v3(
+          conv3d,
+          scale,
+          offset,
+          mean,
+          variance,
+          epsilon=1.001e-5,
+          exponential_avg_factor=1.0,
+          data_format='NDHWC',
+          is_training=True,
+          name='batch_norm')
+      dx, dscale, doffset, _, _ = gen_nn_ops.fused_batch_norm_grad_v3(
+          y,
+          x_3d,
+          scale,
+          r0,
+          r1,
+          r2,
+          epsilon=1.001e-5,
+          data_format='NDHWC',
+          is_training=True)
+      output = array_ops.identity(dx)
+
+      with session.Session(config=_get_config(False)) as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      expected_num_transposes = 3
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
+      self._assert_trans_ndhwc_to_ncdhw('FusedBatchNormGradV3-1', nodes)
+      self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormGradV3-0-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
   @test_util.deprecated_graph_mode_only
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index dc6eda6dcc3..2809cbb0108 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -330,13 +330,13 @@ class BatchNormalizationBase(Layer):
       # output back to its original shape accordingly.
       if self._USE_V2_BEHAVIOR:
         if self.fused is None:
-          self.fused = (ndims == 4)
-        elif self.fused and ndims != 4:
+          self.fused = ndims in (4, 5)
+        elif self.fused and ndims not in (4, 5):
           raise ValueError('Batch normalization layers with fused=True only '
-                           'support 4D input tensors.')
+                           'support 4D or 5D input tensors.')
       else:
         assert self.fused is not None
-        self.fused = (ndims == 4 and self._fused_can_be_used())
+        self.fused = (ndims in (4, 5) and self._fused_can_be_used())
       # TODO(chrisying): fused batch norm is currently not supported for
       # multi-axis batch norm and by extension virtual batches. In some cases,
       # it might be possible to use fused batch norm but would require reshaping
@@ -345,13 +345,22 @@ class BatchNormalizationBase(Layer):
       # common use case (turning 5D w/ virtual batch to NCHW)
 
     if self.fused:
-      if self.axis == [1]:
+      if self.axis == [1] and ndims == 4:
         self._data_format = 'NCHW'
-      elif self.axis == [3]:
+      elif self.axis == [1] and ndims == 5:
+        self._data_format = 'NCDHW'
+      elif self.axis == [3] and ndims == 4:
         self._data_format = 'NHWC'
+      elif self.axis == [4] and ndims == 5:
+        self._data_format = 'NDHWC'
+      elif ndims == 5:
+        # 5D tensors that can be passed in but should not use fused batch norm
+        # due to unsupported axis.
+        self.fused = False
       else:
         raise ValueError('Unsupported axis, fused batch norm only supports '
-                         'axis == [1] or axis == [3]')
+                         'axis == [1] or axis == [3] for 4D input tensors or '
+                         'axis == [1] or axis == [4] for 5D input tensors')
 
     axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
     for x in axis_to_dim:
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index f89a615bee5..79ecc3c3fe1 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -66,6 +66,15 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
         kwargs={'scale': False,
                 'center': False},
         input_shape=(3, 3))
+    testing_utils.layer_test(
+        keras.layers.BatchNormalization,
+        kwargs={
+            'gamma_initializer': 'ones',
+            'beta_initializer': 'ones',
+            'moving_mean_initializer': 'zeros',
+            'moving_variance_initializer': 'ones'
+        },
+        input_shape=(3, 2, 4, 2))
 
   @combinations.generate(combinations.combine(mode=['graph', 'eager']))
   def test_batchnorm_weights(self):
@@ -319,7 +328,7 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
     norm = normalization_v2.BatchNormalization(fused=True)
     self.assertEqual(norm.fused, True)
     inp = keras.layers.Input(shape=(4, 4))
-    with self.assertRaisesRegex(ValueError, '4D input tensors'):
+    with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
       norm(inp)
 
   def test_updates_in_wrap_function(self):
diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py
index 1742a919216..0421829bff3 100644
--- a/tensorflow/python/ops/nn_fused_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py
@@ -43,14 +43,18 @@ class BatchNormalizationTest(test.TestCase):
     return math_ops.cast(y, x.dtype)
 
   def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format):
-    if data_format not in ['NHWC', 'NCHW']:
-      raise ValueError('data_format must be NCHW or NHWC, '
-                       'got %s.' % data_format)
+    if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
+      raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
+                       'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
     if data_format == 'NCHW':
       x = array_ops.transpose(x, [0, 2, 3, 1])
+    elif data_format == 'NCDHW':
+      x = array_ops.transpose(x, [0, 2, 3, 4, 1])
     y = self._batch_norm(x, mean, var, offset, scale, epsilon)
     if data_format == 'NCHW':
       y = array_ops.transpose(y, [0, 3, 1, 2])
+    elif data_format == 'NCDHW':
+      y = array_ops.transpose(y, [0, 4, 1, 2, 3])
     return self.evaluate(y)
 
   def _test_inference(self,
@@ -102,17 +106,24 @@ class BatchNormalizationTest(test.TestCase):
 
   def _training_ref(self, x, scale, offset, old_mean, old_var,
                     exponential_avg_factor, epsilon, data_format):
-    if data_format not in ['NHWC', 'NCHW']:
-      raise ValueError('data_format must be NCHW or NHWC, '
-                       'got %s.' % data_format)
+    if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
+      raise ValueError('data_format must be NCHW or NHWC for 4D tensors or'
+                       'NCDHW or NDHWC for 5D tensors, got %s.' % data_format)
+    use_4d_tensor = (x.shape.ndims == 4)
     if data_format == 'NCHW':
       x = array_ops.transpose(x, [0, 2, 3, 1])
+    elif data_format == 'NCDHW':
+      x = array_ops.transpose(x, [0, 2, 3, 4, 1])
+
+    mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3]
     batch_mean, batch_var = nn_impl.moments(
-        math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
+        math_ops.cast(x, scale.dtype), mean_axis, keep_dims=False)
 
     y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
     if data_format == 'NCHW':
       y = array_ops.transpose(y, [0, 3, 1, 2])
+    elif data_format == 'NCDHW':
+      y = array_ops.transpose(y, [0, 4, 1, 2, 3])
 
     # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
     # the denominator in the formula to calculate variance, while
@@ -377,14 +388,18 @@ class BatchNormalizationTest(test.TestCase):
 
   def _runtests(self, x_shape, is_training, gradient_test=False,
                 cpu_only=False):
+    if len(x_shape) == 4:
+      data_format_list = ['NHWC', 'NCHW']
+    else:
+      data_format_list = ['NCDHW', 'NDHWC']
     use_gpu_vals = [False]
     if test.is_gpu_available(cuda_only=True) and not cpu_only:
       use_gpu_vals += [True]
     factors = [1.0, 0.6]
     for dtype in [np.float16, np.float32]:
       for use_gpu in use_gpu_vals:
-        for data_format in ['NHWC', 'NCHW']:
-          if data_format == 'NHWC':
+        for data_format in data_format_list:
+          if data_format == 'NHWC' or data_format == 'NDHWC':
             scale_shape = x_shape[-1:]
           else:
             scale_shape = x_shape[1:2]
@@ -444,6 +459,10 @@ class BatchNormalizationTest(test.TestCase):
     # GPU kernel doesn't properly handle case where non-channel dimensions are 1
     self._runtests(x_shape, False, cpu_only=True)
 
+  def testInferenceShape7(self):
+    x_shape = [1, 2, 6, 1, 3]
+    self._runtests(x_shape, False)
+
   def testTrainingShape1(self):
     x_shape = [1, 1, 6, 1]
     self._runtests(x_shape, True)
@@ -465,11 +484,16 @@ class BatchNormalizationTest(test.TestCase):
     x_shape = [0, 131, 127, 6]
     self._runtests(x_shape, True)
 
+  @test_util.run_deprecated_v1
   def testTrainingShape6(self):
     x_shape = [1, 1, 1, 1]
     # GPU kernel doesn't properly handle case where non-channel dimensions are 1
     self._runtests(x_shape, True, cpu_only=True)
 
+  def testTrainingShape7(self):
+    x_shape = [1, 2, 6, 1, 3]
+    self._runtests(x_shape, True)
+
   @test_util.run_deprecated_v1
   def testBatchNormGradInferenceShape1(self):
     x_shape = [1, 1, 6, 1]
@@ -503,6 +527,11 @@ class BatchNormalizationTest(test.TestCase):
     self._runtests(x_shape, is_training=False, gradient_test=True,
                    cpu_only=True)
 
+  @test_util.run_deprecated_v1
+  def testBatchNormGradInferenceShape7(self):
+    x_shape = [1, 2, 6, 1, 3]
+    self._runtests(x_shape, is_training=False, gradient_test=True)
+
   @test_util.run_deprecated_v1
   def testBatchNormGradTrainingShape1(self):
     x_shape = [1, 1, 6, 1]
@@ -535,42 +564,54 @@ class BatchNormalizationTest(test.TestCase):
     # GPU kernel doesn't properly handle case where non-channel dimensions are 1
     self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
 
+  @test_util.run_deprecated_v1
+  def testBatchNormGradTrainingShape7(self):
+    x_shape = [1, 2, 6, 1, 3]
+    self._runtests(x_shape, is_training=True, gradient_test=True)
+
   def _testBatchNormGradGrad(self, config):
     shape = config['shape']
     err_tolerance = config['err_tolerance']
     dtype = config['dtype']
+    rank = len(shape)
+    if rank == 4:
+      data_format_nhwc, features_nhwc = 'NHWC', shape[3]
+      data_format_nchw, features_nchw = 'NCHW', shape[1]
+    else:
+      data_format_nhwc, features_nhwc = 'NDHWC', shape[4]
+      data_format_nchw, features_nchw = 'NCDHW', shape[1]
     for is_training in [True, False]:
       if test.is_gpu_available(cuda_only=True):
         self._test_grad_grad(
             shape,
-            dtype, [shape[3]],
+            dtype, [features_nhwc],
             np.float32,
             use_gpu=True,
-            data_format='NHWC',
+            data_format=data_format_nhwc,
             is_training=is_training,
             err_tolerance=err_tolerance)
         self._test_grad_grad(
             shape,
-            dtype, [shape[1]],
+            dtype, [features_nchw],
             np.float32,
             use_gpu=True,
-            data_format='NCHW',
+            data_format=data_format_nchw,
             is_training=is_training,
             err_tolerance=err_tolerance)
       self._test_grad_grad(
           shape,
-          dtype, [shape[3]],
+          dtype, [features_nhwc],
           np.float32,
           use_gpu=False,
-          data_format='NHWC',
+          data_format=data_format_nhwc,
           is_training=is_training,
           err_tolerance=err_tolerance)
       self._test_grad_grad(
           shape,
-          dtype, [shape[1]],
+          dtype, [features_nchw],
           np.float32,
           use_gpu=False,
-          data_format='NCHW',
+          data_format=data_format_nchw,
           is_training=is_training,
           err_tolerance=err_tolerance)
 
@@ -610,6 +651,24 @@ class BatchNormalizationTest(test.TestCase):
     }
     self._testBatchNormGradGrad(config)
 
+  @test_util.run_deprecated_v1
+  def testBatchNormGradGradConfig5(self):
+    config = {
+        'shape': [2, 3, 2, 2, 2],
+        'err_tolerance': 2e-3,
+        'dtype': np.float32,
+    }
+    self._testBatchNormGradGrad(config)
+
+  @test_util.run_deprecated_v1
+  def testBatchNormGradGradConfig6(self):
+    config = {
+        'shape': [2, 3, 2, 2, 2],
+        'err_tolerance': 3e-3,
+        'dtype': np.float16,
+    }
+    self._testBatchNormGradGrad(config)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 58dd1852cc5..a02e31f80a5 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -897,6 +897,11 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
     if data_format == b"NCHW":
       x = array_ops.transpose(x, [0, 2, 3, 1])
       grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
+    elif data_format == b"NCDHW":
+      x = array_ops.transpose(x, [0, 2, 3, 4, 1])
+      grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
+    target_data_format = ("NHWC" if data_format in (b"NCHW",
+                                                    b"NHWC") else "NDHWC")
     args = {
         "y_backprop": grad_y,
         "x": x,
@@ -904,7 +909,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
         "reserve_space_1": pop_mean,
         "reserve_space_2": pop_var,
         "epsilon": epsilon,
-        "data_format": "NHWC",
+        "data_format": target_data_format,
         "is_training": is_training
     }
     if version == 2:
@@ -912,6 +917,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
     dx, dscale, doffset, _, _ = grad_fun(**args)
     if data_format == b"NCHW":
       dx = array_ops.transpose(dx, [0, 3, 1, 2])
+    elif data_format == b"NCDHW":
+      dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
   return dx, dscale, doffset, None, None
 
 
@@ -941,8 +948,8 @@ def _BatchNormGrad(grad_y,
   """Returns the gradients for the 3 inputs of BatchNorm.
 
   Args:
-    grad_y: A `Tensor` of 4 dimensions for gradient for y.
-    x: A `Tensor` of 4 dimensions for x.
+    grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
+    x: A `Tensor` of 4 or 5 dimensions for x.
     scale: A `Tensor` of 1 dimension for scaling.
     pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
       is_training=False.
@@ -968,11 +975,19 @@ def _BatchNormGrad(grad_y,
     if data_format == b"NHWC":
       keepdims = False
       reduce_axis = [0, 1, 2]
-    else:
+    elif data_format == b"NDHWC":
+      keepdims = False
+      reduce_axis = [0, 1, 2, 3]
+    elif data_format == b"NCHW":
       keepdims = True
       reduce_axis = [0, 2, 3]
       shape = [1, array_ops.size(scale), 1, 1]
       scale = array_ops.reshape(scale, shape)
+    else:
+      keepdims = True
+      reduce_axis = [0, 2, 3, 4]
+      shape = [1, array_ops.size(scale), 1, 1, 1]
+      scale = array_ops.reshape(scale, shape)
     mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
     mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
     var_x = math_ops.reduce_mean(
@@ -987,19 +1002,27 @@ def _BatchNormGrad(grad_y,
         grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
     grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
         grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
-    if data_format == b"NCHW":
+    if data_format == b"NCHW" or data_format == b"NCDHW":
       grad_scale = array_ops.squeeze(grad_scale)
     grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
     return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
   else:
     if data_format == b"NHWC":
       reduce_axis = [0, 1, 2]
-    else:
+    elif data_format == b"NDHWC":
+      reduce_axis = [0, 1, 2, 3]
+    elif data_format == b"NCHW":
       reduce_axis = [0, 2, 3]
       shape = [1, array_ops.size(pop_mean), 1, 1]
       pop_mean = array_ops.reshape(pop_mean, shape)
       pop_var = array_ops.reshape(pop_var, shape)
       scale = array_ops.reshape(scale, shape)
+    else:
+      reduce_axis = [0, 2, 3, 4]
+      shape = [1, array_ops.size(pop_mean), 1, 1, 1]
+      pop_mean = array_ops.reshape(pop_mean, shape)
+      pop_var = array_ops.reshape(pop_var, shape)
+      scale = array_ops.reshape(scale, shape)
 
     grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
     var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 89174b29336..d22fbf3fa4e 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -1585,7 +1585,7 @@ def fused_batch_norm(
   (http://arxiv.org/abs/1502.03167).
 
   Args:
-    x: Input `Tensor` of 4 dimensions.
+    x: Input `Tensor` of 4 or 5 dimensions.
     scale: A `Tensor` of 1 dimension for scaling.
     offset: A `Tensor` of 1 dimension for bias.
     mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
@@ -1611,7 +1611,8 @@ def fused_batch_norm(
             Variance must be a `Tensor` of the same shape as scale containing
             the exponential running variance.
     epsilon: A small float number added to the variance of x.
-    data_format: The data format for x. Either "NHWC" (default) or "NCHW".
+    data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
+                 4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
     is_training: A bool value to specify if the operation is used for
                  training or inference.
     name: A name for this operation (optional).
@@ -1622,7 +1623,7 @@ def fused_batch_norm(
                             returned.
 
   Returns:
-    y: A 4D Tensor for the normalized, scaled, offsetted x.
+    y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
     running_mean: A 1D Tensor for the exponential running mean of x.
                   The output value is (1 - exponential_avg_factor) * mean +
                   exponential_avg_factor * batch_mean), where batch_mean