Support computing exponential running mean and variance in fused_batch_norm.

Simplify unit test code for fused_batch_norm and its gradient.

PiperOrigin-RevId: 296546311
Change-Id: Ieb78b4038a39dd2bcde16302541e733f4e8604bd
This commit is contained in:
A. Unique TensorFlower 2020-02-21 17:52:10 -08:00 committed by TensorFlower Gardener
parent bdf0ea41b9
commit ce9564d430
12 changed files with 329 additions and 237 deletions

View File

@ -100,10 +100,10 @@ TEST(FusedBatchnormReserveSpaceTest, Test) {
Output offset = Output offset =
Const(root.WithOpName("offset"), Input::Initializer(offset_data)); Const(root.WithOpName("offset"), Input::Initializer(offset_data));
Tensor mean_data(DT_FLOAT, TensorShape({10})); Tensor mean_data(DT_FLOAT, TensorShape({0}));
Output mean = Const(root.WithOpName("mean"), Input::Initializer(mean_data)); Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data));
Tensor variance_data(DT_FLOAT, TensorShape({10})); Tensor variance_data(DT_FLOAT, TensorShape({0}));
Output variance = Output variance =
Const(root.WithOpName("variance"), Input::Initializer(variance_data)); Const(root.WithOpName("variance"), Input::Initializer(variance_data));

View File

@ -787,7 +787,6 @@ cc_library(
"//tensorflow/core/util:padding", "//tensorflow/core/util:padding",
"//tensorflow/core/util:tensor_format", "//tensorflow/core/util:tensor_format",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h" #include <unordered_set>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -1083,11 +1083,7 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
bool is_training; bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
float exponential_avg_factor; int number_inputs = (is_training) ? 3 : 5;
if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
exponential_avg_factor = 1.0f; // default value
}
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
string data_format_str; string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TensorFormat data_format; TensorFormat data_format;
@ -1183,8 +1179,13 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
// Set the correct shapes for reserve_spaces // Set the correct shapes for reserve_spaces
// so that gradients can be performed when // so that gradients can be performed when
// the op is in a symbolic condition. // the op is in a symbolic condition.
if (is_training) {
c->set_output(3, c->Vector(0));
c->set_output(4, c->Vector(0));
} else {
c->set_output(3, c->Vector(channel_dim)); c->set_output(3, c->Vector(channel_dim));
c->set_output(4, c->Vector(channel_dim)); c->set_output(4, c->Vector(channel_dim));
}
return Status::OK(); return Status::OK();
} }
@ -2325,7 +2326,7 @@ Status SparseReduceShapeFn(InferenceContext* c) {
auto axes_vec = axes_tensor->flat<int32>(); auto axes_vec = axes_tensor->flat<int32>();
int64 ndims = shape_vec.size(); int64 ndims = shape_vec.size();
absl::flat_hash_set<int64> axes; std::unordered_set<int64> axes;
for (int i = 0; i < axes_vec.size(); i++) { for (int i = 0; i < axes_vec.size(); i++) {
axes.insert((axes_vec(i) + ndims) % ndims); axes.insert((axes_vec(i) + ndims) % ndims);
} }

View File

@ -528,9 +528,9 @@ TEST(CommonShapeFnsTest, FusedBatchNormExTest) {
.Finalize(&op.node_def)); .Finalize(&op.node_def));
// Channels are not multiple of 4. // Channels are not multiple of 4.
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[2];[2]"); INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[0];[0]");
INFER_OK(op, "[2,2,2,4];[4];[4];[4];[4]", INFER_OK(op, "[2,2,2,4];[4];[4];[0];[0]",
"[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?"); "[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?");
} }

View File

@ -859,8 +859,8 @@ class VirtualSchedulerTest : public ::testing::Test {
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
auto offset = auto offset =
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT); auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT); auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
auto batch_norm = ops::FusedBatchNorm( auto batch_norm = ops::FusedBatchNorm(
s.WithOpName("bn"), x, scale, offset, mean, var, s.WithOpName("bn"), x, scale, offset, mean, var,
@ -2146,8 +2146,8 @@ versions {
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
auto offset = auto offset =
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT); auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT); auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
auto batch_norm = ops::FusedBatchNorm( auto batch_norm = ops::FusedBatchNorm(
s.WithOpName("bn"), x, scale, offset, mean, var, s.WithOpName("bn"), x, scale, offset, mean, var,

View File

@ -497,8 +497,8 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16}); Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16}); Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16}); Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {16}); Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {16}); Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1}, Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1},
"SAME", ops::Conv2D::DataFormat("NHWC")); "SAME", ops::Conv2D::DataFormat("NHWC"));
auto fbn1_op = auto fbn1_op =

View File

@ -181,8 +181,7 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
TEST(NNOpsTest, FusedBatchNorm_ShapeFn) { TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
ShapeInferenceTestOp op("FusedBatchNorm"); ShapeInferenceTestOp op("FusedBatchNorm");
auto set_op = [&op](bool is_training, float exponential_avg_factor, auto set_op = [&op](bool is_training, string data_format) {
string data_format) {
TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm") TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
@ -191,11 +190,10 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Attr("data_format", data_format) .Attr("data_format", data_format)
.Attr("is_training", is_training) .Attr("is_training", is_training)
.Attr("exponential_avg_factor", exponential_avg_factor)
.Finalize(&op.node_def)); .Finalize(&op.node_def));
}; };
set_op(true, 1.0, "NHWC"); set_op(true, "NHWC");
// Test rank errors. // Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -209,21 +207,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];" "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]"); "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
set_op(true, 0.5, "NHWC"); set_op(true, "NCHW");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
// Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
"[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
set_op(true, 1.0, "NCHW");
// Test rank errors. // Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -237,7 +221,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];" "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]"); "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
set_op(false, 1.0, "NHWC"); set_op(false, "NHWC");
// Test rank errors. // Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -255,7 +239,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];" "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]"); "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
set_op(false, 1.0, "NCHW"); set_op(false, "NCHW");
// Test rank errors. // Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?"); INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -295,14 +279,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
// Channel dim of first input is merged with the single dim in other 4 inputs. // Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]"); INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]"); INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]"); INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[d4_0];[d4_0]"); INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]", INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];" "[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[" "[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
"d0_3|d2_0|d3_0|d4_0]");
set_op("NCHW"); set_op("NCHW");
// Test rank errors. // Test rank errors.
@ -312,14 +295,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]"); INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
// Channel dim of first input is merged with the single dim in other 4 inputs. // Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]"); INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]"); INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[0];[0]");
INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[d3_0];[d3_0]"); INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[0];[0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[d4_0];[d4_0]"); INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[0];[0]");
INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]", INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
"[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];" "[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[" "[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
"d0_1|d2_0|d3_0|d4_0]");
} }
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) { TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {

View File

@ -4999,7 +4999,7 @@ cuda_py_test(
size = "large", size = "large",
srcs = ["ops/nn_fused_batchnorm_test.py"], srcs = ["ops/nn_fused_batchnorm_test.py"],
python_version = "PY3", python_version = "PY3",
shard_count = 24, shard_count = 16,
deps = [ deps = [
":array_ops", ":array_ops",
":client_testlib", ":client_testlib",

View File

@ -568,13 +568,11 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1]) x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1])
scale = constant_op.constant([1], dtype=dtypes.float32) scale = constant_op.constant([1], dtype=dtypes.float32)
offset = constant_op.constant([1], dtype=dtypes.float32) offset = constant_op.constant([1], dtype=dtypes.float32)
mean = constant_op.constant([1], dtype=dtypes.float32)
variance = constant_op.constant([1], dtype=dtypes.float32)
# Calling fused_batch_norm with an empty input should output a NaN in the # Calling fused_batch_norm with an empty input should output a NaN in the
# latter four outputs without triggering the check_numerics callback # latter four outputs without triggering the check_numerics callback
batch_norm_res = gen_nn_ops._fused_batch_norm( batch_norm_res = gen_nn_ops._fused_batch_norm(
x=x, scale=scale, offset=offset, mean=mean, variance=variance) x=x, scale=scale, offset=offset, mean=[], variance=[])
_, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res) _, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res)

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -60,7 +59,6 @@ class BatchNormalizationTest(test.TestCase):
scale_shape, scale_shape,
scale_dtype, scale_dtype,
use_gpu=True, use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC'): data_format='NHWC'):
np.random.seed(1) np.random.seed(1)
x_val = np.random.random_sample(x_shape).astype(x_dtype) x_val = np.random.random_sample(x_shape).astype(x_dtype)
@ -83,7 +81,6 @@ class BatchNormalizationTest(test.TestCase):
mean=mean, mean=mean,
variance=var, variance=var,
epsilon=epsilon, epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format, data_format=data_format,
is_training=False) is_training=False)
y_val = self.evaluate(y) y_val = self.evaluate(y)
@ -95,37 +92,17 @@ class BatchNormalizationTest(test.TestCase):
atol = 2e-3 if x_dtype == np.float16 else 1e-3 atol = 2e-3 if x_dtype == np.float16 else 1e-3
self.assertAllClose(y_ref, y_val, atol=atol) self.assertAllClose(y_ref, y_val, atol=atol)
def _running_mean(self, old_mean, new_val, factor): def _training_ref(self, x, scale, offset, epsilon, data_format):
if factor == 1.0:
return new_val
else:
return (1.0 - factor) * old_mean + factor * new_val
def _training_ref(self, x, scale, offset, old_mean, old_var,
exponential_avg_factor, epsilon, data_format):
if data_format not in ['NHWC', 'NCHW']: if data_format not in ['NHWC', 'NCHW']:
raise ValueError('data_format must be NCHW or NHWC, ' raise ValueError('data_format must be NCHW or NHWC, '
'got %s.' % data_format) 'got %s.' % data_format)
if data_format == 'NCHW': if data_format == 'NCHW':
x = array_ops.transpose(x, [0, 2, 3, 1]) x = array_ops.transpose(x, [0, 2, 3, 1])
batch_mean, batch_var = nn_impl.moments( mean, var = nn_impl.moments(
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False) math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
if data_format == 'NCHW': if data_format == 'NCHW':
y = array_ops.transpose(y, [0, 3, 1, 2]) y = array_ops.transpose(y, [0, 3, 1, 2])
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
sample_size = math_ops.cast(
array_ops.size(x) / array_ops.size(scale), scale.dtype)
batch_var_corrected = batch_var * sample_size / (
math_ops.maximum(sample_size - 1.0, 1.0))
mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor)
var = self._running_mean(old_var, batch_var_corrected,
exponential_avg_factor)
return self.evaluate(y), self.evaluate(mean), self.evaluate(var) return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
def _test_training(self, def _test_training(self,
@ -134,15 +111,11 @@ class BatchNormalizationTest(test.TestCase):
scale_shape, scale_shape,
scale_dtype, scale_dtype,
use_gpu=True, use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC'): data_format='NHWC'):
np.random.seed(1) np.random.seed(1)
x_val = np.random.random_sample(x_shape).astype(x_dtype) x_val = np.random.random_sample(x_shape).astype(x_dtype)
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype) scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype) offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
old_mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
old_var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
with self.cached_session(use_gpu=use_gpu) as sess: with self.cached_session(use_gpu=use_gpu) as sess:
x = constant_op.constant(x_val, name='x') x = constant_op.constant(x_val, name='x')
scale = constant_op.constant(scale_val, name='scale') scale = constant_op.constant(scale_val, name='scale')
@ -152,20 +125,20 @@ class BatchNormalizationTest(test.TestCase):
x, x,
scale, scale,
offset, offset,
mean=old_mean_val,
variance=old_var_val,
epsilon=epsilon, epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format, data_format=data_format,
is_training=True) is_training=True)
y_val, mean_val, var_val = self.evaluate([y, mean, var]) y_val, mean_val, var_val = self.evaluate([y, mean, var])
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
old_mean_val, old_var_val, data_format)
exponential_avg_factor,
epsilon, data_format)
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3 y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
self.assertAllClose(y_ref, y_val, atol=y_atol) self.assertAllClose(y_ref, y_val, atol=y_atol)
self.assertAllClose(mean_ref, mean_val, atol=1e-3) self.assertAllClose(mean_ref, mean_val, atol=1e-3)
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
sample_size = x_val.size / scale_val.size
var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
self.assertAllClose(var_ref, var_val, atol=1e-3) self.assertAllClose(var_ref, var_val, atol=1e-3)
def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape): def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
@ -211,7 +184,6 @@ class BatchNormalizationTest(test.TestCase):
scale_shape, scale_shape,
scale_dtype, scale_dtype,
use_gpu=True, use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC', data_format='NHWC',
is_training=True): is_training=True):
np.random.seed(1) np.random.seed(1)
@ -223,6 +195,10 @@ class BatchNormalizationTest(test.TestCase):
x = constant_op.constant(x_val, name='x') x = constant_op.constant(x_val, name='x')
scale = constant_op.constant(scale_val, name='scale') scale = constant_op.constant(scale_val, name='scale')
offset = constant_op.constant(offset_val, name='offset') offset = constant_op.constant(offset_val, name='offset')
if is_training:
pop_mean = None
pop_var = None
else:
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype) pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype) pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
y, _, _ = nn_impl.fused_batch_norm( y, _, _ = nn_impl.fused_batch_norm(
@ -231,7 +207,6 @@ class BatchNormalizationTest(test.TestCase):
offset, offset,
mean=pop_mean, mean=pop_mean,
variance=pop_var, variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format, data_format=data_format,
is_training=is_training) is_training=is_training)
if x_dtype != np.float16: if x_dtype != np.float16:
@ -249,7 +224,6 @@ class BatchNormalizationTest(test.TestCase):
mean=pop_mean, mean=pop_mean,
variance=pop_var, variance=pop_var,
data_format=data_format, data_format=data_format,
exponential_avg_factor=exponential_avg_factor,
is_training=is_training) is_training=is_training)
err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32, err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
x_shape) x_shape)
@ -270,7 +244,6 @@ class BatchNormalizationTest(test.TestCase):
scale_shape, scale_shape,
scale_dtype, scale_dtype,
use_gpu=True, use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC', data_format='NHWC',
is_training=True, is_training=True,
err_tolerance=1e-3): err_tolerance=1e-3):
@ -285,6 +258,10 @@ class BatchNormalizationTest(test.TestCase):
grad_y = constant_op.constant(grad_y_val, name='grad_y') grad_y = constant_op.constant(grad_y_val, name='grad_y')
scale = constant_op.constant(scale_val, name='scale') scale = constant_op.constant(scale_val, name='scale')
offset = constant_op.constant(offset_val, name='offset') offset = constant_op.constant(offset_val, name='offset')
if is_training:
pop_mean = None
pop_var = None
else:
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype) pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype) pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
y, _, _ = nn_impl.fused_batch_norm( y, _, _ = nn_impl.fused_batch_norm(
@ -293,7 +270,6 @@ class BatchNormalizationTest(test.TestCase):
offset, offset,
mean=pop_mean, mean=pop_mean,
variance=pop_var, variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format, data_format=data_format,
is_training=is_training) is_training=is_training)
grad_x, grad_scale, grad_offset = gradients_impl.gradients( grad_x, grad_scale, grad_offset = gradients_impl.gradients(
@ -335,7 +311,6 @@ class BatchNormalizationTest(test.TestCase):
offset, offset,
mean=pop_mean, mean=pop_mean,
variance=pop_var, variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format, data_format=data_format,
is_training=is_training) is_training=is_training)
grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients( grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
@ -364,142 +339,282 @@ class BatchNormalizationTest(test.TestCase):
self.assertLess(err_grad_x_2, err_tolerance) self.assertLess(err_grad_x_2, err_tolerance)
self.assertLess(err_grad_scale, err_tolerance) self.assertLess(err_grad_scale, err_tolerance)
def _runtests(self, x_shape, is_training, gradient_test=False):
use_gpu_vals = [False]
if test.is_gpu_available(cuda_only=True):
use_gpu_vals += [True]
factors = [1.0,]
if compat.forward_compatible(2020, 2, 29):
factors += [0.6,]
for dtype in [np.float16, np.float32]:
for use_gpu in use_gpu_vals:
for data_format in ['NHWC', 'NCHW']:
if data_format == 'NHWC':
scale_shape = x_shape[-1:]
else:
scale_shape = x_shape[1:2]
for exponential_avg_factor in factors:
if gradient_test:
self._test_gradient(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
is_training=is_training,
exponential_avg_factor=exponential_avg_factor)
else:
if is_training:
self._test_training(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor)
else:
self._test_inference(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor)
def testInferenceShape1(self): def testInferenceShape1(self):
x_shape = [1, 1, 6, 1] x_shape = [1, 1, 6, 1]
self._runtests(x_shape, False) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
def testInferenceShape2(self): def testInferenceShape2(self):
x_shape = [1, 1, 6, 2] x_shape = [1, 1, 6, 2]
self._runtests(x_shape, False) if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape3(self): def testInferenceShape3(self):
x_shape = [1, 2, 1, 6] x_shape = [1, 2, 1, 6]
self._runtests(x_shape, False) if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
def testInferenceShape4(self): def testInferenceShape4(self):
x_shape = [27, 131, 127, 6] x_shape = [27, 131, 127, 6]
self._runtests(x_shape, False) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape5(self): def testInferenceShape5(self):
x_shape = [0, 131, 127, 6] x_shape = [0, 131, 127, 6]
self._runtests(x_shape, False) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape1(self): def testTrainingShape1(self):
x_shape = [1, 1, 6, 1] x_shape = [1, 1, 6, 1]
self._runtests(x_shape, True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
def testTrainingShape2(self): def testTrainingShape2(self):
x_shape = [1, 1, 6, 2] x_shape = [1, 1, 6, 2]
self._runtests(x_shape, True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape3(self): def testTrainingShape3(self):
x_shape = [1, 2, 1, 6] x_shape = [1, 2, 1, 6]
self._runtests(x_shape, True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NCHW')
def testTrainingShape4(self): def testTrainingShape4(self):
x_shape = [27, 131, 127, 6] x_shape = [27, 131, 127, 6]
self._runtests(x_shape, True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.') @test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
def testTrainingShape5(self): def testTrainingShape5(self):
x_shape = [0, 131, 127, 6] x_shape = [0, 131, 127, 6]
self._runtests(x_shape, True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testBatchNormGradInferenceShape1(self): @test_util.run_deprecated_v1
def testBatchNormGradShape1(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 1] x_shape = [1, 1, 6, 1]
self._runtests(x_shape, is_training=False, gradient_test=True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testBatchNormGradInferenceShape2(self): def testBatchNormGradShape2(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 2] x_shape = [1, 1, 6, 2]
self._runtests(x_shape, is_training=False, gradient_test=True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testBatchNormGradInferenceShape3(self): def testBatchNormGradShape3(self):
for is_training in [True, False]:
x_shape = [1, 2, 1, 6] x_shape = [1, 2, 1, 6]
self._runtests(x_shape, is_training=False, gradient_test=True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testBatchNormGradInferenceShape4(self): def testBatchNormGradShape4(self):
for is_training in [True, False]:
x_shape = [5, 7, 11, 4] x_shape = [5, 7, 11, 4]
self._runtests(x_shape, is_training=False, gradient_test=True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test never passed for XLA') @test_util.disable_xla('This test never passed for XLA')
def testBatchNormGradInferenceShape5(self): def testBatchNormGradShape5(self):
for is_training in [True, False]:
x_shape = [0, 7, 11, 4] x_shape = [0, 7, 11, 4]
self._runtests(x_shape, is_training=False, gradient_test=True) for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
@test_util.run_deprecated_v1 self._test_gradient(
def testBatchNormGradTrainingShape1(self): x_shape,
x_shape = [1, 1, 6, 1] dtype, [7],
self._runtests(x_shape, is_training=True, gradient_test=True) np.float32,
use_gpu=True,
@test_util.run_deprecated_v1 data_format='NCHW',
def testBatchNormGradTrainingShape2(self): is_training=is_training)
x_shape = [1, 1, 6, 2] self._test_gradient(
self._runtests(x_shape, is_training=True, gradient_test=True) x_shape,
dtype, [4],
@test_util.run_deprecated_v1 np.float32,
def testBatchNormGradTrainingShape3(self): use_gpu=True,
x_shape = [1, 2, 1, 6] data_format='NHWC',
self._runtests(x_shape, is_training=True, gradient_test=True) is_training=is_training)
self._test_gradient(
@test_util.run_deprecated_v1 x_shape,
def testBatchNormGradTrainingShape4(self): dtype, [4],
x_shape = [5, 7, 11, 4] np.float32,
self._runtests(x_shape, is_training=True, gradient_test=True) use_gpu=False,
data_format='NHWC',
@test_util.run_deprecated_v1 is_training=is_training)
@test_util.disable_xla('This test never passed for XLA') self._test_gradient(
def testBatchNormGradTrainingShape5(self): x_shape,
x_shape = [0, 7, 11, 4] dtype, [7],
self._runtests(x_shape, is_training=True, gradient_test=True) np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
def _testBatchNormGradGrad(self, config): def _testBatchNormGradGrad(self, config):
shape = config['shape'] shape = config['shape']

View File

@ -1432,8 +1432,7 @@ def fused_batch_norm(
epsilon=0.001, epsilon=0.001,
data_format="NHWC", data_format="NHWC",
is_training=True, is_training=True,
name=None, name=None):
exponential_avg_factor=1.0):
r"""Batch normalization. r"""Batch normalization.
@ -1453,22 +1452,14 @@ def fused_batch_norm(
is_training: A bool value to specify if the operation is used for is_training: A bool value to specify if the operation is used for
training or inference. training or inference.
name: A name for this operation (optional). name: A name for this operation (optional).
exponential_avg_factor: A float number (usually between 0 and 1) used
for controlling the decay of the running
population average of mean and variance.
If set to 1.0, the current batch average is
returned.
Returns: Returns:
y: A 4D Tensor for the normalized, scaled, offsetted x. y: A 4D Tensor for the normalized, scaled, offsetted x.
running_mean: A 1D Tensor for the exponential running mean of x. batch_mean: A 1D Tensor for the mean of x.
The output value is (1 - exponential_avg_factor) * mean + batch_var: A 1D Tensor for the variance of x.
exponential_avg_factor * batch_mean), where batch_mean
is the mean of the current batch in x. Raises:
running_var: A 1D Tensor for the exponential running ValueError: If mean or variance is not None when is_training is True.
The output value is (1 - exponential_avg_factor) * variance +
exponential_avg_factor * batch_variance), where batch_variance
is the variance within of the current batch in x.
References: References:
Batch Normalization - Accelerating Deep Network Training by Reducing Batch Normalization - Accelerating Deep Network Training by Reducing
@ -1479,31 +1470,37 @@ def fused_batch_norm(
x = ops.convert_to_tensor(x, name="input") x = ops.convert_to_tensor(x, name="input")
scale = ops.convert_to_tensor(scale, name="scale") scale = ops.convert_to_tensor(scale, name="scale")
offset = ops.convert_to_tensor(offset, name="offset") offset = ops.convert_to_tensor(offset, name="offset")
if is_training:
if (mean is not None) or (variance is not None):
raise ValueError("Both 'mean' and 'variance' must be None "
"if is_training is True.")
if mean is None: if mean is None:
mean = array_ops.zeros_like(scale) mean = constant_op.constant([])
if variance is None: if variance is None:
variance = array_ops.zeros_like(scale) variance = constant_op.constant([])
# Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
# prevent exception (see cudnn.h). # prevent exception (see cudnn.h).
min_epsilon = 1.001e-5 min_epsilon = 1.001e-5
epsilon = epsilon if epsilon > min_epsilon else min_epsilon epsilon = epsilon if epsilon > min_epsilon else min_epsilon
if compat.forward_compatible(2020, 2, 29): if compat.forward_compatible(2019, 6, 6):
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3( y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
x, x,
scale, scale,
offset, offset,
mean, mean,
variance, variance,
epsilon=epsilon, epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format, data_format=data_format,
is_training=is_training, is_training=is_training,
name=name) name=name)
return y, running_mean, running_var return y, batch_mean, batch_var
if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16:
fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2
else: else:
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3( fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access
y, batch_mean, batch_var, _, _ = fused_batch_norm_func(
x, x,
scale, scale,
offset, offset,
@ -1513,7 +1510,7 @@ def fused_batch_norm(
data_format=data_format, data_format=data_format,
is_training=is_training, is_training=is_training,
name=name) name=name)
return y, running_mean, running_var return y, batch_mean, batch_var
@tf_export(v1=["nn.batch_norm_with_global_normalization"]) @tf_export(v1=["nn.batch_norm_with_global_normalization"])

View File

@ -210,7 +210,7 @@ tf_module {
} }
member_method { member_method {
name: "fused_batch_norm" name: "fused_batch_norm"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\', \'exponential_avg_factor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\', \'1.0\'], " argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "in_top_k" name: "in_top_k"