From 0488a18af4ba1f630d06b685a301f6d94622aad4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Feb 2020 13:41:56 -0800 Subject: [PATCH] Automated rollback of commit 80acd88cd43f09a1a2980792e3955f2ce5147bfd PiperOrigin-RevId: 295811620 Change-Id: I39a1f7f7dadee2c7ea231df16da1ab5516c8f1fa --- .../core/kernels/fused_batch_norm_op.cc | 290 ++++++++++++------ .../core/kernels/fused_batch_norm_op_test.cc | 63 ++++ tensorflow/core/ops/nn_ops.cc | 4 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 6 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 6 +- 5 files changed, 264 insertions(+), 105 deletions(-) diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index cc0ce9b7922..afe3e621fcf 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -81,7 +81,7 @@ Status ParseActivationMode(OpKernelConstruction* context, } // Functor used by FusedBatchNormOp to do the computations. -template +template struct FusedBatchNorm; // Functor used by FusedBatchNormGradOp to do the computations when // is_training=True. @@ -89,17 +89,155 @@ template struct FusedBatchNormGrad; template -struct FusedBatchNorm { +struct FusedBatchNorm { + void operator()(OpKernelContext* context, const Tensor& x_input, + const Tensor& scale_input, const Tensor& offset_input, + const Tensor& running_mean_input, + const Tensor& running_variance_input, + const Tensor* side_input, U epsilon, U exponential_avg_factor, + FusedBatchNormActivationMode activation_mode, + Tensor* y_output, Tensor* running_mean_output, + Tensor* running_var_output, Tensor* saved_batch_mean_output, + Tensor* saved_batch_var_output, TensorFormat tensor_format, + bool use_reserved_space) { + OP_REQUIRES(context, side_input == nullptr, + errors::Internal( + "The CPU implementation of FusedBatchNorm does not support " + "side input.")); + OP_REQUIRES(context, + activation_mode == FusedBatchNormActivationMode::kIdentity, + errors::Internal("The CPU implementation of FusedBatchNorm " + "does not support activations.")); + + if (use_reserved_space) { + Tensor* dummy_reserve_space = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(5, {}, &dummy_reserve_space)); + // Initialize the memory, to avoid sanitizer alerts. + dummy_reserve_space->flat()(0) = U(); + } + Tensor transformed_x; + Tensor transformed_y; + if (tensor_format == FORMAT_NCHW) { + const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N'); + const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H'); + const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W'); + const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C'); + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(FORMAT_NHWC, in_batch, + in_rows, in_cols, in_depths), + &transformed_x)); + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(FORMAT_NHWC, in_batch, + in_rows, in_cols, in_depths), + &transformed_y)); + // Perform NCHW to NHWC + std::vector perm = {0, 2, 3, 1}; + OP_REQUIRES_OK( + context, ::tensorflow::DoTranspose(context->eigen_device(), + x_input, perm, &transformed_x)); + } else { + transformed_x = x_input; + transformed_y = *y_output; + } + typename TTypes::Tensor x(transformed_x.tensor()); + typename TTypes::ConstVec scale(scale_input.vec()); + typename TTypes::ConstVec offset(offset_input.vec()); + typename TTypes::ConstVec old_mean(running_mean_input.vec()); + typename TTypes::ConstVec old_variance(running_variance_input.vec()); + typename TTypes::Tensor y(transformed_y.tensor()); + typename TTypes::Vec new_mean(running_mean_output->vec()); + typename TTypes::Vec new_variance(running_var_output->vec()); + typename TTypes::Vec saved_batch_mean(saved_batch_mean_output->vec()); + typename TTypes::Vec saved_batch_var(saved_batch_var_output->vec()); + + const CPUDevice& d = context->eigen_device(); + + const int depth = x.dimension(3); + const int size = x.size(); + const int rest_size = size / depth; + Eigen::DSizes rest_by_depth(rest_size, depth); + +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes one_by_depth(1, depth); + Eigen::array reduce_dims({0}); + Eigen::array bcast_spec({rest_size, 1}); +#else + Eigen::IndexList, Eigen::Index> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList> reduce_dims; + Eigen::IndexList> bcast_spec; + bcast_spec.set(0, rest_size); +#endif + + auto x_rest_by_depth = x.reshape(rest_by_depth).template cast(); + const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; + U rest_size_inv = static_cast(1.0f / static_cast(rest_size)); + // This adjustment is for Bessel's correction + U rest_size_adjust = + static_cast(rest_size) / static_cast(rest_size_minus_one); + + Eigen::Tensor batch_mean(depth); + Eigen::Tensor batch_variance(depth); + + batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); + auto x_centered = x_rest_by_depth - + batch_mean.reshape(one_by_depth).broadcast(bcast_spec); + + batch_variance.device(d) = + x_centered.square().sum(reduce_dims) * rest_size_inv; + auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale) + .eval() + .reshape(one_by_depth) + .broadcast(bcast_spec); + auto x_scaled = x_centered * scaling_factor; + auto x_shifted = + (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) + .template cast(); + + y.reshape(rest_by_depth).device(d) = x_shifted; + if (exponential_avg_factor == U(1.0)) { + saved_batch_var.device(d) = batch_variance; + saved_batch_mean.device(d) = batch_mean; + new_variance.device(d) = batch_variance * rest_size_adjust; + new_mean.device(d) = batch_mean; + } else { + U one_minus_factor = U(1) - exponential_avg_factor; + saved_batch_var.device(d) = batch_variance; + saved_batch_mean.device(d) = batch_mean; + new_variance.device(d) = + one_minus_factor * old_variance + + (exponential_avg_factor * rest_size_adjust) * batch_variance; + new_mean.device(d) = + one_minus_factor * old_mean + exponential_avg_factor * batch_mean; + } + + if (tensor_format == FORMAT_NCHW) { + // Perform NHWC to NCHW + const std::vector perm = {0, 3, 1, 2}; + const Status s = ::tensorflow::DoTranspose( + context->eigen_device(), transformed_y, perm, y_output); + if (!s.ok()) { + context->SetStatus(errors::InvalidArgument("Transpose failed: ", s)); + } + } + } +}; + +template +struct FusedBatchNorm { void operator()(OpKernelContext* context, const Tensor& x_input, const Tensor& scale_input, const Tensor& offset_input, const Tensor& estimated_mean_input, const Tensor& estimated_variance_input, - const Tensor* side_input, U epsilon, + const Tensor* side_input, U epsilon, U exponential_avg_factor, FusedBatchNormActivationMode activation_mode, Tensor* y_output, Tensor* batch_mean_output, Tensor* batch_var_output, Tensor* saved_mean_output, Tensor* saved_var_output, TensorFormat tensor_format, - bool use_reserved_space, bool is_training) { + bool use_reserved_space) { OP_REQUIRES(context, side_input == nullptr, errors::Internal( "The CPU implementation of FusedBatchNorm does not support " @@ -150,9 +288,7 @@ struct FusedBatchNorm { estimated_variance_input.vec()); typename TTypes::Tensor y(transformed_y.tensor()); typename TTypes::Vec batch_mean(batch_mean_output->vec()); - typename TTypes::Vec batch_var(batch_var_output->vec()); - typename TTypes::Vec saved_mean(saved_mean_output->vec()); - typename TTypes::Vec saved_var(saved_var_output->vec()); + typename TTypes::Vec batch_variance(batch_var_output->vec()); const CPUDevice& d = context->eigen_device(); @@ -168,80 +304,36 @@ struct FusedBatchNorm { #else Eigen::IndexList, Eigen::Index> one_by_depth; one_by_depth.set(1, depth); - Eigen::IndexList> reduce_dims; Eigen::IndexList> bcast_spec; bcast_spec.set(0, rest_size); #endif auto x_rest_by_depth = x.reshape(rest_by_depth).template cast(); - const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; - U rest_size_inv = static_cast(1.0f / static_cast(rest_size)); - // This adjustment is for Bessel's correction - U rest_size_adjust = - static_cast(rest_size) / static_cast(rest_size_minus_one); + auto x_centered = + x_rest_by_depth - + estimated_mean.reshape(one_by_depth).broadcast(bcast_spec); + auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale) + .eval() + .reshape(one_by_depth) + .broadcast(bcast_spec); + auto x_scaled = x_centered * scaling_factor; + auto x_shifted = + (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) + .template cast(); - Eigen::Tensor mean(depth); - Eigen::Tensor variance(depth); - BlockingCounter barrier(1); - std::atomic task_counter; - auto on_done = [&]() { - uint8 count = --task_counter; - if (count == 0) { - if (tensor_format == FORMAT_NCHW) { - // Perform NHWC to NCHW - const std::vector perm = {0, 3, 1, 2}; - const Status s = - ::tensorflow::DoTranspose(context->eigen_device(), - transformed_y, perm, y_output); - if (!s.ok()) { - context->SetStatus( - errors::InvalidArgument("Transpose failed: ", s)); - } - } - barrier.DecrementCount(); + y.reshape(rest_by_depth).device(d) = x_shifted; + batch_mean.device(d) = estimated_mean; + batch_variance.device(d) = estimated_variance; + + if (tensor_format == FORMAT_NCHW) { + // Perform NHWC to NCHW + const std::vector perm = {0, 3, 1, 2}; + const Status s = ::tensorflow::DoTranspose( + context->eigen_device(), transformed_y, perm, y_output); + if (!s.ok()) { + context->SetStatus(errors::InvalidArgument("Transpose failed: ", s)); } - }; - if (is_training) { - // TODO(b/137108598): Extend kernel to allow use of exponential averaging. - mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); - auto x_centered = - x_rest_by_depth - mean.reshape(one_by_depth).broadcast(bcast_spec); - - variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv; - auto scaling_factor = ((variance + epsilon).rsqrt() * scale) - .eval() - .reshape(one_by_depth) - .broadcast(bcast_spec); - auto x_scaled = x_centered * scaling_factor; - auto x_shifted = - (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) - .template cast(); - - task_counter = 5; - y.reshape(rest_by_depth).device(d, on_done) = x_shifted; - batch_var.device(d, on_done) = variance * rest_size_adjust; - saved_var.device(d, on_done) = variance; - batch_mean.device(d, on_done) = mean; - saved_mean.device(d, on_done) = mean; - } else { // is_training == false - auto x_centered = - x_rest_by_depth - - estimated_mean.reshape(one_by_depth).broadcast(bcast_spec); - auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale) - .eval() - .reshape(one_by_depth) - .broadcast(bcast_spec); - auto x_scaled = x_centered * scaling_factor; - auto x_shifted = - (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) - .template cast(); - - task_counter = 3; - y.reshape(rest_by_depth).device(d, on_done) = x_shifted; - mean.device(d, on_done) = estimated_mean; - variance.device(d, on_done) = estimated_variance; } - barrier.Wait(); } }; @@ -662,17 +754,17 @@ class CudnnBatchNormAllocatorInOutput : public ScratchAllocator { bool output_allocated = false; }; -template -struct FusedBatchNorm { +template +struct FusedBatchNorm { void operator()(OpKernelContext* context, const Tensor& x, const Tensor& scale, const Tensor& offset, const Tensor& estimated_mean, const Tensor& estimated_variance, const Tensor* side_input, - U epsilon, FusedBatchNormActivationMode activation_mode, - Tensor* y, Tensor* batch_mean, Tensor* batch_var, - Tensor* saved_mean, Tensor* saved_inv_var, - TensorFormat tensor_format, bool use_reserved_space, - bool is_training) { + U epsilon, U exponential_avg_factor, + FusedBatchNormActivationMode activation_mode, Tensor* y, + Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean, + Tensor* saved_inv_var, TensorFormat tensor_format, + bool use_reserved_space) { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available")); @@ -837,15 +929,13 @@ struct FusedBatchNorm { workspace_allocator.reset( new functor::CudnnBatchNormAllocatorInTemp(context)); } - // TODO(b/137108598): Extend kernel to allow use of exponential averaging. - const double exponential_average_factor = 1.0; bool cudnn_launch_status = stream ->ThenBatchNormalizationForward( x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr, estimated_variance_ptr, side_input_ptr, x_desc, scale_offset_desc, static_cast(epsilon), - exponential_average_factor, + static_cast(exponential_avg_factor), AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr, &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr, is_training, reserve_space_allocator.get(), @@ -1075,6 +1165,10 @@ class FusedBatchNormOpBase : public OpKernel { float epsilon; OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); epsilon_ = U(epsilon); + float exponential_avg_factor; + OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor", + &exponential_avg_factor)); + exponential_avg_factor_ = U(exponential_avg_factor); string tensor_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), @@ -1165,17 +1259,6 @@ class FusedBatchNormOpBase : public OpKernel { "channel dimension to be a multiple of 4.")); } - if (is_training_) { - OP_REQUIRES( - context, estimated_mean.dim_size(0) == 0, - errors::InvalidArgument("estimated_mean must be empty for training", - estimated_mean.shape().DebugString())); - OP_REQUIRES(context, estimated_variance.dim_size(0) == 0, - errors::InvalidArgument( - "estimated_variance must be empty for training", - estimated_variance.shape().DebugString())); - } - Tensor* y = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, x.shape(), &y)); @@ -1192,15 +1275,24 @@ class FusedBatchNormOpBase : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(), &saved_maybe_inv_var)); - functor::FusedBatchNorm()( - context, x, scale, offset, estimated_mean, estimated_variance, - side_input, epsilon_, activation_mode_, y, batch_mean, batch_var, - saved_mean, saved_maybe_inv_var, tensor_format_, use_reserved_space, - is_training_); + if (is_training_) { + functor::FusedBatchNorm()( + context, x, scale, offset, estimated_mean, estimated_variance, + side_input, epsilon_, exponential_avg_factor_, activation_mode_, y, + batch_mean, batch_var, saved_mean, saved_maybe_inv_var, + tensor_format_, use_reserved_space); + } else { + functor::FusedBatchNorm()( + context, x, scale, offset, estimated_mean, estimated_variance, + side_input, epsilon_, exponential_avg_factor_, activation_mode_, y, + batch_mean, batch_var, saved_mean, saved_maybe_inv_var, + tensor_format_, use_reserved_space); + } } private: U epsilon_; + U exponential_avg_factor_; TensorFormat tensor_format_; bool is_training_; bool has_side_input_; diff --git a/tensorflow/core/kernels/fused_batch_norm_op_test.cc b/tensorflow/core/kernels/fused_batch_norm_op_test.cc index 7da57143b77..734fb294135 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op_test.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op_test.cc @@ -40,6 +40,7 @@ TEST_F(FusedBatchNormOpTest, Training) { .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT)) + .Attr("exponential_avg_factor", 1.0) .Attr("epsilon", 0.001) .Attr("is_training", true) .Finalize(node_def())); @@ -67,6 +68,41 @@ TEST_F(FusedBatchNormOpTest, Training) { test::ExpectTensorNear(expected_variance, *GetOutput(2), 0.01); } +TEST_F(FusedBatchNormOpTest, TrainingRunningMean) { + TF_EXPECT_OK(NodeDefBuilder("batch_norm_op", "FusedBatchNorm") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("exponential_avg_factor", 0.5) + .Attr("epsilon", 0.001) + .Attr("is_training", true) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 1, 6, 2}), + {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); + AddInputFromArray(TensorShape({2}), {4.0, 4.0}); + AddInputFromArray(TensorShape({2}), {2.0, 2.0}); + AddInputFromArray(TensorShape({2}), {6.0, 6.0}); + AddInputFromArray(TensorShape({2}), {16.0, 16.0}); + + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues(&expected, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, + 3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); + test::ExpectTensorNear(expected, *GetOutput(0), 0.01); + + Tensor expected_mean(allocator(), DT_FLOAT, TensorShape({2})); + test::FillValues(&expected_mean, {8, 8}); + test::ExpectTensorNear(expected_mean, *GetOutput(1), 0.01); + + Tensor expected_variance(allocator(), DT_FLOAT, TensorShape({2})); + test::FillValues(&expected_variance, {15.00, 15.00}); + test::ExpectTensorNear(expected_variance, *GetOutput(2), 0.01); +} + TEST_F(FusedBatchNormOpTest, Inference) { TF_EXPECT_OK(NodeDefBuilder("batch_norm_op", "FusedBatchNorm") .Input(FakeInput(DT_FLOAT)) @@ -93,6 +129,33 @@ TEST_F(FusedBatchNormOpTest, Inference) { test::ExpectTensorNear(expected, *GetOutput(0), 0.01); } +TEST_F(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) { + TF_EXPECT_OK(NodeDefBuilder("batch_norm_op", "FusedBatchNorm") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("exponential_avg_factor", 0.5) + .Attr("epsilon", 0.001) + .Attr("is_training", false) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 1, 6, 2}), + {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); + AddInputFromArray(TensorShape({2}), {4.0, 4.0}); + AddInputFromArray(TensorShape({2}), {2.0, 2.0}); + AddInputFromArray(TensorShape({2}), {10, 10}); + AddInputFromArray(TensorShape({2}), {11.67f, 11.67f}); + + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2})); + test::FillValues(&expected, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, + 3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); + test::ExpectTensorNear(expected, *GetOutput(0), 0.01); +} + class FusedBatchNormGradOpTest : public OpsTestBase {}; TEST_F(FusedBatchNormGradOpTest, Simple) { diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 82adb489f94..84f25347a86 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -179,6 +179,7 @@ REGISTER_OP("FusedBatchNorm") .Output("reserve_space_2: T") .Attr("T: {float}") .Attr("epsilon: float = 0.0001") + .Attr("exponential_avg_factor: float = 1.0") .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape); @@ -197,6 +198,7 @@ REGISTER_OP("FusedBatchNormV2") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") + .Attr("exponential_avg_factor: float = 1.0") .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape); @@ -216,6 +218,7 @@ REGISTER_OP("FusedBatchNormV3") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") + .Attr("exponential_avg_factor: float = 1.0") .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormV3Shape); @@ -236,6 +239,7 @@ REGISTER_OP("_FusedBatchNormEx") .Attr("T: {half, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") + .Attr("exponential_avg_factor: float = 1.0") .Attr("num_side_inputs: int >= 0 = 0") .Attr("activation_mode: string = \"Identity\"") .Attr(GetConvnetDataFormatAttrString()) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index cf8bf14e42d..853f67c12de 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1590,7 +1590,7 @@ tf_module { } member_method { name: "FusedBatchNorm" - argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], " } member_method { name: "FusedBatchNormGrad" @@ -1606,11 +1606,11 @@ tf_module { } member_method { name: "FusedBatchNormV2" - argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], " } member_method { name: "FusedBatchNormV3" - argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], " } member_method { name: "FusedPadConv2D" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index cf8bf14e42d..853f67c12de 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1590,7 +1590,7 @@ tf_module { } member_method { name: "FusedBatchNorm" - argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], " } member_method { name: "FusedBatchNormGrad" @@ -1606,11 +1606,11 @@ tf_module { } member_method { name: "FusedBatchNormV2" - argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], " } member_method { name: "FusedBatchNormV3" - argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " + argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], " } member_method { name: "FusedPadConv2D"