Support computing exponential running mean and variance in fused_batch_norm.

Simplify unit test code for fused_batch_norm and its gradient.

PiperOrigin-RevId: 296546311
Change-Id: Ieb78b4038a39dd2bcde16302541e733f4e8604bd
This commit is contained in:
A. Unique TensorFlower 2020-02-21 17:52:10 -08:00 committed by TensorFlower Gardener
parent bdf0ea41b9
commit ce9564d430
12 changed files with 329 additions and 237 deletions

View File

@ -100,10 +100,10 @@ TEST(FusedBatchnormReserveSpaceTest, Test) {
Output offset =
Const(root.WithOpName("offset"), Input::Initializer(offset_data));
Tensor mean_data(DT_FLOAT, TensorShape({10}));
Output mean = Const(root.WithOpName("mean"), Input::Initializer(mean_data));
Tensor mean_data(DT_FLOAT, TensorShape({0}));
Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data));
Tensor variance_data(DT_FLOAT, TensorShape({10}));
Tensor variance_data(DT_FLOAT, TensorShape({0}));
Output variance =
Const(root.WithOpName("variance"), Input::Initializer(variance_data));

View File

@ -787,7 +787,6 @@ cc_library(
"//tensorflow/core/util:padding",
"//tensorflow/core/util:tensor_format",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)

View File

@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include <unordered_set>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -1083,11 +1083,7 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
float exponential_avg_factor;
if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
exponential_avg_factor = 1.0f; // default value
}
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
int number_inputs = (is_training) ? 3 : 5;
string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TensorFormat data_format;
@ -1183,8 +1179,13 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
// Set the correct shapes for reserve_spaces
// so that gradients can be performed when
// the op is in a symbolic condition.
c->set_output(3, c->Vector(channel_dim));
c->set_output(4, c->Vector(channel_dim));
if (is_training) {
c->set_output(3, c->Vector(0));
c->set_output(4, c->Vector(0));
} else {
c->set_output(3, c->Vector(channel_dim));
c->set_output(4, c->Vector(channel_dim));
}
return Status::OK();
}
@ -2325,7 +2326,7 @@ Status SparseReduceShapeFn(InferenceContext* c) {
auto axes_vec = axes_tensor->flat<int32>();
int64 ndims = shape_vec.size();
absl::flat_hash_set<int64> axes;
std::unordered_set<int64> axes;
for (int i = 0; i < axes_vec.size(); i++) {
axes.insert((axes_vec(i) + ndims) % ndims);
}

View File

@ -528,9 +528,9 @@ TEST(CommonShapeFnsTest, FusedBatchNormExTest) {
.Finalize(&op.node_def));
// Channels are not multiple of 4.
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[2];[2]");
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[0];[0]");
INFER_OK(op, "[2,2,2,4];[4];[4];[4];[4]",
INFER_OK(op, "[2,2,2,4];[4];[4];[0];[0]",
"[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?");
}

View File

@ -859,8 +859,8 @@ class VirtualSchedulerTest : public ::testing::Test {
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
auto offset =
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT);
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT);
auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
auto batch_norm = ops::FusedBatchNorm(
s.WithOpName("bn"), x, scale, offset, mean, var,
@ -2146,8 +2146,8 @@ versions {
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
auto offset =
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT);
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT);
auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
auto batch_norm = ops::FusedBatchNorm(
s.WithOpName("bn"), x, scale, offset, mean, var,

View File

@ -497,8 +497,8 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {16});
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {16});
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1},
"SAME", ops::Conv2D::DataFormat("NHWC"));
auto fbn1_op =

View File

@ -181,8 +181,7 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
ShapeInferenceTestOp op("FusedBatchNorm");
auto set_op = [&op](bool is_training, float exponential_avg_factor,
string data_format) {
auto set_op = [&op](bool is_training, string data_format) {
TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
@ -191,11 +190,10 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
.Input(FakeInput(DT_FLOAT))
.Attr("data_format", data_format)
.Attr("is_training", is_training)
.Attr("exponential_avg_factor", exponential_avg_factor)
.Finalize(&op.node_def));
};
set_op(true, 1.0, "NHWC");
set_op(true, "NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -209,21 +207,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
set_op(true, 0.5, "NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
// Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
"[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
set_op(true, 1.0, "NCHW");
set_op(true, "NCHW");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -237,7 +221,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
set_op(false, 1.0, "NHWC");
set_op(false, "NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -255,7 +239,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
set_op(false, 1.0, "NCHW");
set_op(false, "NCHW");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -295,14 +279,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
// Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[d4_0];[d4_0]");
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];["
"d0_3|d2_0|d3_0|d4_0]");
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
set_op("NCHW");
// Test rank errors.
@ -312,14 +295,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
// Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]");
INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[d3_0];[d3_0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[d4_0];[d4_0]");
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[0];[0]");
INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[0];[0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[0];[0]");
INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
"[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];["
"d0_1|d2_0|d3_0|d4_0]");
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
}
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {

View File

@ -4999,7 +4999,7 @@ cuda_py_test(
size = "large",
srcs = ["ops/nn_fused_batchnorm_test.py"],
python_version = "PY3",
shard_count = 24,
shard_count = 16,
deps = [
":array_ops",
":client_testlib",

View File

@ -568,13 +568,11 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1])
scale = constant_op.constant([1], dtype=dtypes.float32)
offset = constant_op.constant([1], dtype=dtypes.float32)
mean = constant_op.constant([1], dtype=dtypes.float32)
variance = constant_op.constant([1], dtype=dtypes.float32)
# Calling fused_batch_norm with an empty input should output a NaN in the
# latter four outputs without triggering the check_numerics callback
batch_norm_res = gen_nn_ops._fused_batch_norm(
x=x, scale=scale, offset=offset, mean=mean, variance=variance)
x=x, scale=scale, offset=offset, mean=[], variance=[])
_, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res)

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@ -60,7 +59,6 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC'):
np.random.seed(1)
x_val = np.random.random_sample(x_shape).astype(x_dtype)
@ -83,7 +81,6 @@ class BatchNormalizationTest(test.TestCase):
mean=mean,
variance=var,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=False)
y_val = self.evaluate(y)
@ -95,37 +92,17 @@ class BatchNormalizationTest(test.TestCase):
atol = 2e-3 if x_dtype == np.float16 else 1e-3
self.assertAllClose(y_ref, y_val, atol=atol)
def _running_mean(self, old_mean, new_val, factor):
if factor == 1.0:
return new_val
else:
return (1.0 - factor) * old_mean + factor * new_val
def _training_ref(self, x, scale, offset, old_mean, old_var,
exponential_avg_factor, epsilon, data_format):
def _training_ref(self, x, scale, offset, epsilon, data_format):
if data_format not in ['NHWC', 'NCHW']:
raise ValueError('data_format must be NCHW or NHWC, '
'got %s.' % data_format)
if data_format == 'NCHW':
x = array_ops.transpose(x, [0, 2, 3, 1])
batch_mean, batch_var = nn_impl.moments(
mean, var = nn_impl.moments(
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
if data_format == 'NCHW':
y = array_ops.transpose(y, [0, 3, 1, 2])
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
sample_size = math_ops.cast(
array_ops.size(x) / array_ops.size(scale), scale.dtype)
batch_var_corrected = batch_var * sample_size / (
math_ops.maximum(sample_size - 1.0, 1.0))
mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor)
var = self._running_mean(old_var, batch_var_corrected,
exponential_avg_factor)
return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
def _test_training(self,
@ -134,15 +111,11 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC'):
np.random.seed(1)
x_val = np.random.random_sample(x_shape).astype(x_dtype)
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
old_mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
old_var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
with self.cached_session(use_gpu=use_gpu) as sess:
x = constant_op.constant(x_val, name='x')
scale = constant_op.constant(scale_val, name='scale')
@ -152,20 +125,20 @@ class BatchNormalizationTest(test.TestCase):
x,
scale,
offset,
mean=old_mean_val,
variance=old_var_val,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=True)
y_val, mean_val, var_val = self.evaluate([y, mean, var])
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset,
old_mean_val, old_var_val,
exponential_avg_factor,
epsilon, data_format)
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
data_format)
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
self.assertAllClose(y_ref, y_val, atol=y_atol)
self.assertAllClose(mean_ref, mean_val, atol=1e-3)
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
sample_size = x_val.size / scale_val.size
var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
self.assertAllClose(var_ref, var_val, atol=1e-3)
def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
@ -211,7 +184,6 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC',
is_training=True):
np.random.seed(1)
@ -223,15 +195,18 @@ class BatchNormalizationTest(test.TestCase):
x = constant_op.constant(x_val, name='x')
scale = constant_op.constant(scale_val, name='scale')
offset = constant_op.constant(offset_val, name='offset')
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
if is_training:
pop_mean = None
pop_var = None
else:
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
y, _, _ = nn_impl.fused_batch_norm(
x,
scale,
offset,
mean=pop_mean,
variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training)
if x_dtype != np.float16:
@ -249,7 +224,6 @@ class BatchNormalizationTest(test.TestCase):
mean=pop_mean,
variance=pop_var,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor,
is_training=is_training)
err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
x_shape)
@ -270,7 +244,6 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC',
is_training=True,
err_tolerance=1e-3):
@ -285,15 +258,18 @@ class BatchNormalizationTest(test.TestCase):
grad_y = constant_op.constant(grad_y_val, name='grad_y')
scale = constant_op.constant(scale_val, name='scale')
offset = constant_op.constant(offset_val, name='offset')
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
if is_training:
pop_mean = None
pop_var = None
else:
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
y, _, _ = nn_impl.fused_batch_norm(
x,
scale,
offset,
mean=pop_mean,
variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training)
grad_x, grad_scale, grad_offset = gradients_impl.gradients(
@ -335,7 +311,6 @@ class BatchNormalizationTest(test.TestCase):
offset,
mean=pop_mean,
variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training)
grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
@ -364,142 +339,282 @@ class BatchNormalizationTest(test.TestCase):
self.assertLess(err_grad_x_2, err_tolerance)
self.assertLess(err_grad_scale, err_tolerance)
def _runtests(self, x_shape, is_training, gradient_test=False):
use_gpu_vals = [False]
if test.is_gpu_available(cuda_only=True):
use_gpu_vals += [True]
factors = [1.0,]
if compat.forward_compatible(2020, 2, 29):
factors += [0.6,]
for dtype in [np.float16, np.float32]:
for use_gpu in use_gpu_vals:
for data_format in ['NHWC', 'NCHW']:
if data_format == 'NHWC':
scale_shape = x_shape[-1:]
else:
scale_shape = x_shape[1:2]
for exponential_avg_factor in factors:
if gradient_test:
self._test_gradient(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
is_training=is_training,
exponential_avg_factor=exponential_avg_factor)
else:
if is_training:
self._test_training(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor)
else:
self._test_inference(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor)
def testInferenceShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, False)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
def testInferenceShape2(self):
x_shape = [1, 1, 6, 2]
self._runtests(x_shape, False)
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape3(self):
x_shape = [1, 2, 1, 6]
self._runtests(x_shape, False)
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
def testInferenceShape4(self):
x_shape = [27, 131, 127, 6]
self._runtests(x_shape, False)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape5(self):
x_shape = [0, 131, 127, 6]
self._runtests(x_shape, False)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, True)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
def testTrainingShape2(self):
x_shape = [1, 1, 6, 2]
self._runtests(x_shape, True)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape3(self):
x_shape = [1, 2, 1, 6]
self._runtests(x_shape, True)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NCHW')
def testTrainingShape4(self):
x_shape = [27, 131, 127, 6]
self._runtests(x_shape, True)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
def testTrainingShape5(self):
x_shape = [0, 131, 127, 6]
self._runtests(x_shape, True)
def testBatchNormGradInferenceShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, is_training=False, gradient_test=True)
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape2(self):
x_shape = [1, 1, 6, 2]
self._runtests(x_shape, is_training=False, gradient_test=True)
def testBatchNormGradShape1(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape3(self):
x_shape = [1, 2, 1, 6]
self._runtests(x_shape, is_training=False, gradient_test=True)
def testBatchNormGradShape2(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 2]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape4(self):
x_shape = [5, 7, 11, 4]
self._runtests(x_shape, is_training=False, gradient_test=True)
def testBatchNormGradShape3(self):
for is_training in [True, False]:
x_shape = [1, 2, 1, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
@test_util.run_deprecated_v1
def testBatchNormGradShape4(self):
for is_training in [True, False]:
x_shape = [5, 7, 11, 4]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test never passed for XLA')
def testBatchNormGradInferenceShape5(self):
x_shape = [0, 7, 11, 4]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape2(self):
x_shape = [1, 1, 6, 2]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape3(self):
x_shape = [1, 2, 1, 6]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape4(self):
x_shape = [5, 7, 11, 4]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test never passed for XLA')
def testBatchNormGradTrainingShape5(self):
x_shape = [0, 7, 11, 4]
self._runtests(x_shape, is_training=True, gradient_test=True)
def testBatchNormGradShape5(self):
for is_training in [True, False]:
x_shape = [0, 7, 11, 4]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
def _testBatchNormGradGrad(self, config):
shape = config['shape']

View File

@ -1432,8 +1432,7 @@ def fused_batch_norm(
epsilon=0.001,
data_format="NHWC",
is_training=True,
name=None,
exponential_avg_factor=1.0):
name=None):
r"""Batch normalization.
@ -1453,22 +1452,14 @@ def fused_batch_norm(
is_training: A bool value to specify if the operation is used for
training or inference.
name: A name for this operation (optional).
exponential_avg_factor: A float number (usually between 0 and 1) used
for controlling the decay of the running
population average of mean and variance.
If set to 1.0, the current batch average is
returned.
Returns:
y: A 4D Tensor for the normalized, scaled, offsetted x.
running_mean: A 1D Tensor for the exponential running mean of x.
The output value is (1 - exponential_avg_factor) * mean +
exponential_avg_factor * batch_mean), where batch_mean
is the mean of the current batch in x.
running_var: A 1D Tensor for the exponential running
The output value is (1 - exponential_avg_factor) * variance +
exponential_avg_factor * batch_variance), where batch_variance
is the variance within of the current batch in x.
batch_mean: A 1D Tensor for the mean of x.
batch_var: A 1D Tensor for the variance of x.
Raises:
ValueError: If mean or variance is not None when is_training is True.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing
@ -1479,41 +1470,47 @@ def fused_batch_norm(
x = ops.convert_to_tensor(x, name="input")
scale = ops.convert_to_tensor(scale, name="scale")
offset = ops.convert_to_tensor(offset, name="offset")
if is_training:
if (mean is not None) or (variance is not None):
raise ValueError("Both 'mean' and 'variance' must be None "
"if is_training is True.")
if mean is None:
mean = array_ops.zeros_like(scale)
mean = constant_op.constant([])
if variance is None:
variance = array_ops.zeros_like(scale)
variance = constant_op.constant([])
# Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
# prevent exception (see cudnn.h).
min_epsilon = 1.001e-5
epsilon = epsilon if epsilon > min_epsilon else min_epsilon
if compat.forward_compatible(2020, 2, 29):
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
if compat.forward_compatible(2019, 6, 6):
y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
x,
scale,
offset,
mean,
variance,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training,
name=name)
return y, running_mean, running_var
return y, batch_mean, batch_var
if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16:
fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2
else:
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
x,
scale,
offset,
mean,
variance,
epsilon=epsilon,
data_format=data_format,
is_training=is_training,
name=name)
return y, running_mean, running_var
fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access
y, batch_mean, batch_var, _, _ = fused_batch_norm_func(
x,
scale,
offset,
mean,
variance,
epsilon=epsilon,
data_format=data_format,
is_training=is_training,
name=name)
return y, batch_mean, batch_var
@tf_export(v1=["nn.batch_norm_with_global_normalization"])

View File

@ -210,7 +210,7 @@ tf_module {
}
member_method {
name: "fused_batch_norm"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\', \'exponential_avg_factor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\', \'1.0\'], "
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], "
}
member_method {
name: "in_top_k"