Support computing exponential running mean and variance in fused_batch_norm.
Simplify unit test code for fused_batch_norm and its gradient. PiperOrigin-RevId: 296546311 Change-Id: Ieb78b4038a39dd2bcde16302541e733f4e8604bd
This commit is contained in:
parent
bdf0ea41b9
commit
ce9564d430
@ -100,10 +100,10 @@ TEST(FusedBatchnormReserveSpaceTest, Test) {
|
||||
Output offset =
|
||||
Const(root.WithOpName("offset"), Input::Initializer(offset_data));
|
||||
|
||||
Tensor mean_data(DT_FLOAT, TensorShape({10}));
|
||||
Output mean = Const(root.WithOpName("mean"), Input::Initializer(mean_data));
|
||||
Tensor mean_data(DT_FLOAT, TensorShape({0}));
|
||||
Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data));
|
||||
|
||||
Tensor variance_data(DT_FLOAT, TensorShape({10}));
|
||||
Tensor variance_data(DT_FLOAT, TensorShape({0}));
|
||||
Output variance =
|
||||
Const(root.WithOpName("variance"), Input::Initializer(variance_data));
|
||||
|
||||
|
@ -787,7 +787,6 @@ cc_library(
|
||||
"//tensorflow/core/util:padding",
|
||||
"//tensorflow/core/util:tensor_format",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
@ -1083,11 +1083,7 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
float exponential_avg_factor;
|
||||
if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
|
||||
exponential_avg_factor = 1.0f; // default value
|
||||
}
|
||||
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
||||
int number_inputs = (is_training) ? 3 : 5;
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
@ -1183,8 +1179,13 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||
// Set the correct shapes for reserve_spaces
|
||||
// so that gradients can be performed when
|
||||
// the op is in a symbolic condition.
|
||||
c->set_output(3, c->Vector(channel_dim));
|
||||
c->set_output(4, c->Vector(channel_dim));
|
||||
if (is_training) {
|
||||
c->set_output(3, c->Vector(0));
|
||||
c->set_output(4, c->Vector(0));
|
||||
} else {
|
||||
c->set_output(3, c->Vector(channel_dim));
|
||||
c->set_output(4, c->Vector(channel_dim));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -2325,7 +2326,7 @@ Status SparseReduceShapeFn(InferenceContext* c) {
|
||||
auto axes_vec = axes_tensor->flat<int32>();
|
||||
|
||||
int64 ndims = shape_vec.size();
|
||||
absl::flat_hash_set<int64> axes;
|
||||
std::unordered_set<int64> axes;
|
||||
for (int i = 0; i < axes_vec.size(); i++) {
|
||||
axes.insert((axes_vec(i) + ndims) % ndims);
|
||||
}
|
||||
|
@ -528,9 +528,9 @@ TEST(CommonShapeFnsTest, FusedBatchNormExTest) {
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
// Channels are not multiple of 4.
|
||||
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[2];[2]");
|
||||
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[0];[0]");
|
||||
|
||||
INFER_OK(op, "[2,2,2,4];[4];[4];[4];[4]",
|
||||
INFER_OK(op, "[2,2,2,4];[4];[4];[0];[0]",
|
||||
"[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?");
|
||||
}
|
||||
|
||||
|
@ -859,8 +859,8 @@ class VirtualSchedulerTest : public ::testing::Test {
|
||||
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
|
||||
auto offset =
|
||||
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
|
||||
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT);
|
||||
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT);
|
||||
auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
|
||||
auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
|
||||
|
||||
auto batch_norm = ops::FusedBatchNorm(
|
||||
s.WithOpName("bn"), x, scale, offset, mean, var,
|
||||
@ -2146,8 +2146,8 @@ versions {
|
||||
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
|
||||
auto offset =
|
||||
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
|
||||
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT);
|
||||
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT);
|
||||
auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
|
||||
auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
|
||||
|
||||
auto batch_norm = ops::FusedBatchNorm(
|
||||
s.WithOpName("bn"), x, scale, offset, mean, var,
|
||||
|
@ -497,8 +497,8 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
|
||||
Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
|
||||
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
|
||||
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {16});
|
||||
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {16});
|
||||
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
|
||||
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
|
||||
Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1},
|
||||
"SAME", ops::Conv2D::DataFormat("NHWC"));
|
||||
auto fbn1_op =
|
||||
|
@ -181,8 +181,7 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
|
||||
TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
ShapeInferenceTestOp op("FusedBatchNorm");
|
||||
|
||||
auto set_op = [&op](bool is_training, float exponential_avg_factor,
|
||||
string data_format) {
|
||||
auto set_op = [&op](bool is_training, string data_format) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
@ -191,11 +190,10 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("data_format", data_format)
|
||||
.Attr("is_training", is_training)
|
||||
.Attr("exponential_avg_factor", exponential_avg_factor)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
set_op(true, 1.0, "NHWC");
|
||||
set_op(true, "NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -209,21 +207,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
|
||||
|
||||
set_op(true, 0.5, "NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
|
||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
|
||||
INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
|
||||
INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
|
||||
"[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
|
||||
|
||||
set_op(true, 1.0, "NCHW");
|
||||
set_op(true, "NCHW");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -237,7 +221,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
|
||||
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
|
||||
|
||||
set_op(false, 1.0, "NHWC");
|
||||
set_op(false, "NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -255,7 +239,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
|
||||
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
|
||||
|
||||
set_op(false, 1.0, "NCHW");
|
||||
set_op(false, "NCHW");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -295,14 +279,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
|
||||
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
|
||||
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[d4_0];[d4_0]");
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
|
||||
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
|
||||
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
|
||||
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];["
|
||||
"d0_3|d2_0|d3_0|d4_0]");
|
||||
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
|
||||
|
||||
set_op("NCHW");
|
||||
// Test rank errors.
|
||||
@ -312,14 +295,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]");
|
||||
INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[d3_0];[d3_0]");
|
||||
INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[d4_0];[d4_0]");
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[0];[0]");
|
||||
INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
|
||||
"[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
|
||||
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];["
|
||||
"d0_1|d2_0|d3_0|d4_0]");
|
||||
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
|
||||
}
|
||||
|
||||
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
|
||||
|
@ -4999,7 +4999,7 @@ cuda_py_test(
|
||||
size = "large",
|
||||
srcs = ["ops/nn_fused_batchnorm_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 24,
|
||||
shard_count = 16,
|
||||
deps = [
|
||||
":array_ops",
|
||||
":client_testlib",
|
||||
|
@ -568,13 +568,11 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
||||
x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1])
|
||||
scale = constant_op.constant([1], dtype=dtypes.float32)
|
||||
offset = constant_op.constant([1], dtype=dtypes.float32)
|
||||
mean = constant_op.constant([1], dtype=dtypes.float32)
|
||||
variance = constant_op.constant([1], dtype=dtypes.float32)
|
||||
|
||||
# Calling fused_batch_norm with an empty input should output a NaN in the
|
||||
# latter four outputs without triggering the check_numerics callback
|
||||
batch_norm_res = gen_nn_ops._fused_batch_norm(
|
||||
x=x, scale=scale, offset=offset, mean=mean, variance=variance)
|
||||
x=x, scale=scale, offset=offset, mean=[], variance=[])
|
||||
|
||||
_, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res)
|
||||
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -60,7 +59,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC'):
|
||||
np.random.seed(1)
|
||||
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
||||
@ -83,7 +81,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||
mean=mean,
|
||||
variance=var,
|
||||
epsilon=epsilon,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=False)
|
||||
y_val = self.evaluate(y)
|
||||
@ -95,37 +92,17 @@ class BatchNormalizationTest(test.TestCase):
|
||||
atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
||||
self.assertAllClose(y_ref, y_val, atol=atol)
|
||||
|
||||
def _running_mean(self, old_mean, new_val, factor):
|
||||
if factor == 1.0:
|
||||
return new_val
|
||||
else:
|
||||
return (1.0 - factor) * old_mean + factor * new_val
|
||||
|
||||
def _training_ref(self, x, scale, offset, old_mean, old_var,
|
||||
exponential_avg_factor, epsilon, data_format):
|
||||
def _training_ref(self, x, scale, offset, epsilon, data_format):
|
||||
if data_format not in ['NHWC', 'NCHW']:
|
||||
raise ValueError('data_format must be NCHW or NHWC, '
|
||||
'got %s.' % data_format)
|
||||
if data_format == 'NCHW':
|
||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||
batch_mean, batch_var = nn_impl.moments(
|
||||
mean, var = nn_impl.moments(
|
||||
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
|
||||
|
||||
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
|
||||
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
|
||||
if data_format == 'NCHW':
|
||||
y = array_ops.transpose(y, [0, 3, 1, 2])
|
||||
|
||||
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
||||
# the denominator in the formula to calculate variance, while
|
||||
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
|
||||
sample_size = math_ops.cast(
|
||||
array_ops.size(x) / array_ops.size(scale), scale.dtype)
|
||||
batch_var_corrected = batch_var * sample_size / (
|
||||
math_ops.maximum(sample_size - 1.0, 1.0))
|
||||
|
||||
mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor)
|
||||
var = self._running_mean(old_var, batch_var_corrected,
|
||||
exponential_avg_factor)
|
||||
return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
|
||||
|
||||
def _test_training(self,
|
||||
@ -134,15 +111,11 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC'):
|
||||
np.random.seed(1)
|
||||
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
||||
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
old_mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
old_var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
|
||||
with self.cached_session(use_gpu=use_gpu) as sess:
|
||||
x = constant_op.constant(x_val, name='x')
|
||||
scale = constant_op.constant(scale_val, name='scale')
|
||||
@ -152,20 +125,20 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean=old_mean_val,
|
||||
variance=old_var_val,
|
||||
epsilon=epsilon,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=True)
|
||||
y_val, mean_val, var_val = self.evaluate([y, mean, var])
|
||||
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset,
|
||||
old_mean_val, old_var_val,
|
||||
exponential_avg_factor,
|
||||
epsilon, data_format)
|
||||
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
|
||||
data_format)
|
||||
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
||||
self.assertAllClose(y_ref, y_val, atol=y_atol)
|
||||
self.assertAllClose(mean_ref, mean_val, atol=1e-3)
|
||||
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
||||
# the denominator in the formula to calculate variance, while
|
||||
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
|
||||
sample_size = x_val.size / scale_val.size
|
||||
var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
|
||||
self.assertAllClose(var_ref, var_val, atol=1e-3)
|
||||
|
||||
def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
|
||||
@ -211,7 +184,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC',
|
||||
is_training=True):
|
||||
np.random.seed(1)
|
||||
@ -223,15 +195,18 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x = constant_op.constant(x_val, name='x')
|
||||
scale = constant_op.constant(scale_val, name='scale')
|
||||
offset = constant_op.constant(offset_val, name='offset')
|
||||
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
if is_training:
|
||||
pop_mean = None
|
||||
pop_var = None
|
||||
else:
|
||||
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
y, _, _ = nn_impl.fused_batch_norm(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
if x_dtype != np.float16:
|
||||
@ -249,7 +224,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
data_format=data_format,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
is_training=is_training)
|
||||
err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
|
||||
x_shape)
|
||||
@ -270,7 +244,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC',
|
||||
is_training=True,
|
||||
err_tolerance=1e-3):
|
||||
@ -285,15 +258,18 @@ class BatchNormalizationTest(test.TestCase):
|
||||
grad_y = constant_op.constant(grad_y_val, name='grad_y')
|
||||
scale = constant_op.constant(scale_val, name='scale')
|
||||
offset = constant_op.constant(offset_val, name='offset')
|
||||
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
if is_training:
|
||||
pop_mean = None
|
||||
pop_var = None
|
||||
else:
|
||||
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
y, _, _ = nn_impl.fused_batch_norm(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
grad_x, grad_scale, grad_offset = gradients_impl.gradients(
|
||||
@ -335,7 +311,6 @@ class BatchNormalizationTest(test.TestCase):
|
||||
offset,
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
|
||||
@ -364,142 +339,282 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self.assertLess(err_grad_x_2, err_tolerance)
|
||||
self.assertLess(err_grad_scale, err_tolerance)
|
||||
|
||||
def _runtests(self, x_shape, is_training, gradient_test=False):
|
||||
use_gpu_vals = [False]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
use_gpu_vals += [True]
|
||||
factors = [1.0,]
|
||||
if compat.forward_compatible(2020, 2, 29):
|
||||
factors += [0.6,]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
for use_gpu in use_gpu_vals:
|
||||
for data_format in ['NHWC', 'NCHW']:
|
||||
if data_format == 'NHWC':
|
||||
scale_shape = x_shape[-1:]
|
||||
else:
|
||||
scale_shape = x_shape[1:2]
|
||||
for exponential_avg_factor in factors:
|
||||
if gradient_test:
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype,
|
||||
scale_shape,
|
||||
np.float32,
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
else:
|
||||
if is_training:
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype,
|
||||
scale_shape,
|
||||
np.float32,
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
else:
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype,
|
||||
scale_shape,
|
||||
np.float32,
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
|
||||
def testInferenceShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, False)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
|
||||
|
||||
def testInferenceShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
self._runtests(x_shape, False)
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testInferenceShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
self._runtests(x_shape, False)
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||
|
||||
def testInferenceShape4(self):
|
||||
x_shape = [27, 131, 127, 6]
|
||||
self._runtests(x_shape, False)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testInferenceShape5(self):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
self._runtests(x_shape, False)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, True)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
|
||||
|
||||
def testTrainingShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
self._runtests(x_shape, True)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTrainingShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
self._runtests(x_shape, True)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NCHW')
|
||||
|
||||
def testTrainingShape4(self):
|
||||
x_shape = [27, 131, 127, 6]
|
||||
self._runtests(x_shape, True)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
||||
def testTrainingShape5(self):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
def testBatchNormGradInferenceShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradInferenceShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
def testBatchNormGradShape1(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradInferenceShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
def testBatchNormGradShape2(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 1, 6, 2]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradInferenceShape4(self):
|
||||
x_shape = [5, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
def testBatchNormGradShape3(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 2, 1, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradShape4(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [5, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test never passed for XLA')
|
||||
def testBatchNormGradInferenceShape5(self):
|
||||
x_shape = [0, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape4(self):
|
||||
x_shape = [5, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test never passed for XLA')
|
||||
def testBatchNormGradTrainingShape5(self):
|
||||
x_shape = [0, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
def testBatchNormGradShape5(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [0, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
|
||||
def _testBatchNormGradGrad(self, config):
|
||||
shape = config['shape']
|
||||
|
@ -1432,8 +1432,7 @@ def fused_batch_norm(
|
||||
epsilon=0.001,
|
||||
data_format="NHWC",
|
||||
is_training=True,
|
||||
name=None,
|
||||
exponential_avg_factor=1.0):
|
||||
name=None):
|
||||
r"""Batch normalization.
|
||||
|
||||
|
||||
@ -1453,22 +1452,14 @@ def fused_batch_norm(
|
||||
is_training: A bool value to specify if the operation is used for
|
||||
training or inference.
|
||||
name: A name for this operation (optional).
|
||||
exponential_avg_factor: A float number (usually between 0 and 1) used
|
||||
for controlling the decay of the running
|
||||
population average of mean and variance.
|
||||
If set to 1.0, the current batch average is
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
y: A 4D Tensor for the normalized, scaled, offsetted x.
|
||||
running_mean: A 1D Tensor for the exponential running mean of x.
|
||||
The output value is (1 - exponential_avg_factor) * mean +
|
||||
exponential_avg_factor * batch_mean), where batch_mean
|
||||
is the mean of the current batch in x.
|
||||
running_var: A 1D Tensor for the exponential running
|
||||
The output value is (1 - exponential_avg_factor) * variance +
|
||||
exponential_avg_factor * batch_variance), where batch_variance
|
||||
is the variance within of the current batch in x.
|
||||
batch_mean: A 1D Tensor for the mean of x.
|
||||
batch_var: A 1D Tensor for the variance of x.
|
||||
|
||||
Raises:
|
||||
ValueError: If mean or variance is not None when is_training is True.
|
||||
|
||||
References:
|
||||
Batch Normalization - Accelerating Deep Network Training by Reducing
|
||||
@ -1479,41 +1470,47 @@ def fused_batch_norm(
|
||||
x = ops.convert_to_tensor(x, name="input")
|
||||
scale = ops.convert_to_tensor(scale, name="scale")
|
||||
offset = ops.convert_to_tensor(offset, name="offset")
|
||||
if is_training:
|
||||
if (mean is not None) or (variance is not None):
|
||||
raise ValueError("Both 'mean' and 'variance' must be None "
|
||||
"if is_training is True.")
|
||||
if mean is None:
|
||||
mean = array_ops.zeros_like(scale)
|
||||
mean = constant_op.constant([])
|
||||
if variance is None:
|
||||
variance = array_ops.zeros_like(scale)
|
||||
|
||||
variance = constant_op.constant([])
|
||||
# Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
|
||||
# prevent exception (see cudnn.h).
|
||||
min_epsilon = 1.001e-5
|
||||
epsilon = epsilon if epsilon > min_epsilon else min_epsilon
|
||||
|
||||
if compat.forward_compatible(2020, 2, 29):
|
||||
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
||||
if compat.forward_compatible(2019, 6, 6):
|
||||
y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean,
|
||||
variance,
|
||||
epsilon=epsilon,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
name=name)
|
||||
return y, running_mean, running_var
|
||||
return y, batch_mean, batch_var
|
||||
|
||||
if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16:
|
||||
fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2
|
||||
else:
|
||||
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean,
|
||||
variance,
|
||||
epsilon=epsilon,
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
name=name)
|
||||
return y, running_mean, running_var
|
||||
fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access
|
||||
y, batch_mean, batch_var, _, _ = fused_batch_norm_func(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean,
|
||||
variance,
|
||||
epsilon=epsilon,
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
name=name)
|
||||
return y, batch_mean, batch_var
|
||||
|
||||
|
||||
@tf_export(v1=["nn.batch_norm_with_global_normalization"])
|
||||
|
@ -210,7 +210,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "fused_batch_norm"
|
||||
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\', \'exponential_avg_factor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\', \'1.0\'], "
|
||||
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "in_top_k"
|
||||
|
Loading…
Reference in New Issue
Block a user