Support computing exponential running mean and variance in fused_batch_norm.

Simplify unit test code for fused_batch_norm and its gradient.

PiperOrigin-RevId: 297451723
Change-Id: I3ad1db022848cec3cf45a5a560a8f25af75781d4
This commit is contained in:
A. Unique TensorFlower 2020-02-26 14:27:38 -08:00 committed by TensorFlower Gardener
parent f0ffc49ac2
commit 84f2ec1d60
14 changed files with 388 additions and 361 deletions

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.compiler.tests import test_utils
from tensorflow.compiler.tests import xla_test
from tensorflow.python.compat import compat
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
@ -34,10 +35,18 @@ DATA_FORMATS = (
("_data_format_NCHW", "NCHW"),
)
DATA_FORMATS_AND_AVG_FACTORS = (
("_data_format_NHWC_no_averaging", "NHWC", 1.0),
("_data_format_NHWC_averaging", "NHWC", 0.6),
("_data_format_NCHW_no_averaging", "NCHW", 1.0),
("_data_format_NCHW_averaging", "NCHW", 0.6),
)
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
def _reference_training(self, x, scale, offset, epsilon, data_format):
def _reference_training(self, x, scale, offset, old_mean, old_var, epsilon,
exponential_avg_factor, data_format):
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format)
x_square = x * x
@ -49,6 +58,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
factor = element_count / max(element_count - 1, 1)
corrected_var = var * factor
normalized = (x - mean) / np.sqrt(var + epsilon)
if exponential_avg_factor != 1.0:
mean = (1.0 -
exponential_avg_factor) * old_mean + exponential_avg_factor * mean
corrected_var = (1.0 - exponential_avg_factor
) * old_var + exponential_avg_factor * corrected_var
return (normalized * scale + offset), mean, var, corrected_var
def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format):
@ -81,9 +95,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
epsilon = 0.001
exponential_avg_factor = 1.0
data_format_src = "NHWC"
y_ref, mean_ref, var_ref, _ = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src)
x_val, scale_val, offset_val, None, None, epsilon,
exponential_avg_factor, data_format_src)
with self.session() as sess, self.test_scope():
# To avoid constant folding
@ -114,7 +130,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
})
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
def _testLearning(self, use_gradient_checker, data_format):
def _testLearning(self, use_gradient_checker, data_format,
exponential_avg_factor):
if not compat.forward_compatible(2020, 3,
6) and exponential_avg_factor != 1.0:
self.skipTest("running average not available.")
channel = 3
x_shape = [2, 2, 6, channel]
scale_shape = [channel]
@ -122,13 +142,14 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
mean_val = np.random.random_sample(scale_shape).astype(np.float32)
var_val = np.random.random_sample(scale_shape).astype(np.float32)
var_val_corr = np.random.random_sample(scale_shape).astype(np.float32)
epsilon = 0.001
data_format_src = "NHWC"
# When in training mode, fused_batchnorm applies an implicit Bessel's
# correction. So we have to use the corrected variance here, as well.
y_ref, mean_ref, _, var_ref_corr = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format_src)
x_val, scale_val, offset_val, mean_val, var_val_corr, epsilon,
exponential_avg_factor, data_format_src)
with self.session() as sess, self.test_scope():
# To avoid constant folding
@ -142,15 +163,38 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
if exponential_avg_factor == 1.0:
old_mean = None
old_var = None
else:
old_mean = array_ops.placeholder(
np.float32, shape=scale_shape, name="old_mean")
old_var = array_ops.placeholder(
np.float32, shape=scale_shape, name="old_var")
y, mean, var = nn.fused_batch_norm(
t_val,
scale,
offset,
mean=None,
variance=None,
mean=old_mean,
variance=old_var,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=True)
if exponential_avg_factor == 1.0:
feed_dict = {
t_val: x_val_converted,
scale: scale_val,
offset: offset_val,
}
else:
feed_dict = {
t_val: x_val_converted,
scale: scale_val,
offset: offset_val,
old_mean: mean_val,
old_var: var_val_corr
}
# Check gradient.
if use_gradient_checker:
err = gradient_checker.compute_gradient_error(
@ -158,29 +202,22 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
x_val_converted.shape,
y,
x_val_converted.shape,
extra_feed_dict={
t_val: x_val_converted,
scale: scale_val,
offset: offset_val
})
extra_feed_dict=feed_dict)
self.assertLess(err, 1e-3)
y_val, mean_val, var_val = sess.run([y, mean, var], {
t_val: x_val_converted,
scale: scale_val,
offset: offset_val
})
self.assertAllClose(mean_val, mean_ref, atol=1e-3)
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
self.assertAllClose(var_val, var_ref_corr, atol=1e-3)
y_tf, mean_tf, var_tf = sess.run([y, mean, var], feed_dict)
self.assertAllClose(y_tf, y_ref_converted, atol=1e-3)
self.assertAllClose(mean_tf, mean_ref, atol=1e-3)
self.assertAllClose(var_tf, var_ref_corr, atol=1e-3)
@parameterized.named_parameters(*DATA_FORMATS)
def testLearning(self, data_format):
self._testLearning(False, data_format)
@parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS)
def testLearning(self, data_format, exponential_avg_factor):
self._testLearning(False, data_format, exponential_avg_factor)
@parameterized.named_parameters(*DATA_FORMATS)
def testLearningWithGradientChecker(self, data_format):
self._testLearning(True, data_format)
@parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS)
def testLearningWithGradientChecker(self, data_format,
exponential_avg_factor):
self._testLearning(True, data_format, exponential_avg_factor)
@parameterized.named_parameters(*DATA_FORMATS)
def testGradientTraining(self, data_format):

View File

@ -37,6 +37,8 @@ class FusedBatchNormOp : public XlaOpKernel {
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("exponential_avg_factor", &exponential_avg_factor_));
string data_format_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
OP_REQUIRES(
@ -105,7 +107,6 @@ class FusedBatchNormOp : public XlaOpKernel {
ctx->SetOutput(0, converted);
}
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
xla::XlaOp variance = xla::GetTupleElement(output, 2);
// Apply Bessel's correction.
int total_input_size = ctx->InputShape(0).num_elements();
@ -121,7 +122,7 @@ class FusedBatchNormOp : public XlaOpKernel {
if (input_shape.num_elements() == 0) {
auto status_or_output_shape = b->GetShape(corrected);
OP_REQUIRES_OK(ctx, status_or_output_shape.status());
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
ctx->SetOutput(
kVarianceOutputIndex,
xla::Broadcast(
@ -130,7 +131,27 @@ class FusedBatchNormOp : public XlaOpKernel {
status_or_output_shape.ValueOrDie().dimensions())));
} else {
ctx->SetOutput(2, corrected);
if (exponential_avg_factor_ == 1.0f) {
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
ctx->SetOutput(2, corrected);
} else {
xla::XlaOp old_mean = ctx->Input(3);
xla::XlaOp alpha =
xla::ScalarLike(old_mean, 1.0f - exponential_avg_factor_);
xla::XlaOp beta = xla::ScalarLike(old_mean, exponential_avg_factor_);
// new_running_mean = alpha * old_mean + beta * batch_mean.
xla::XlaOp new_running_mean =
xla::Add(xla::Mul(old_mean, alpha),
xla::Mul(xla::GetTupleElement(output, 1), beta));
ctx->SetOutput(1, new_running_mean);
xla::XlaOp old_variance = ctx->Input(4);
xla::XlaOp new_running_variance = xla::Add(
xla::Mul(old_variance, alpha), xla::Mul(corrected, beta));
// new_running_variance = alpha * old_variance + beta *
// batch_variance.
ctx->SetOutput(2, new_running_variance);
}
}
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
@ -175,6 +196,7 @@ class FusedBatchNormOp : public XlaOpKernel {
float epsilon_;
TensorFormat data_format_;
bool is_training_;
float exponential_avg_factor_;
bool add_side_input_;
bool apply_relu_;
bool is_on_gpu_;

View File

@ -787,6 +787,7 @@ cc_library(
"//tensorflow/core/util:padding",
"//tensorflow/core/util:tensor_format",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)

View File

@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <unordered_set>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -1083,7 +1083,11 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
int number_inputs = (is_training) ? 3 : 5;
float exponential_avg_factor;
if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
exponential_avg_factor = 1.0f; // default value
}
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
string data_format_str;
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
TensorFormat data_format;
@ -1176,16 +1180,8 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
c->set_output(0, x_backprop);
c->set_output(1, c->Vector(channel_dim));
c->set_output(2, c->Vector(channel_dim));
// Set the correct shapes for reserve_spaces
// so that gradients can be performed when
// the op is in a symbolic condition.
if (is_training) {
c->set_output(3, c->Vector(0));
c->set_output(4, c->Vector(0));
} else {
c->set_output(3, c->Vector(channel_dim));
c->set_output(4, c->Vector(channel_dim));
}
c->set_output(3, c->Vector(0));
c->set_output(4, c->Vector(0));
return Status::OK();
}
@ -2326,7 +2322,7 @@ Status SparseReduceShapeFn(InferenceContext* c) {
auto axes_vec = axes_tensor->flat<int32>();
int64 ndims = shape_vec.size();
std::unordered_set<int64> axes;
absl::flat_hash_set<int64> axes;
for (int i = 0; i < axes_vec.size(); i++) {
axes.insert((axes_vec(i) + ndims) % ndims);
}

View File

@ -528,9 +528,9 @@ TEST(CommonShapeFnsTest, FusedBatchNormExTest) {
.Finalize(&op.node_def));
// Channels are not multiple of 4.
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[0];[0]");
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[2];[2]");
INFER_OK(op, "[2,2,2,4];[4];[4];[0];[0]",
INFER_OK(op, "[2,2,2,4];[4];[4];[4];[4]",
"[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?");
}

View File

@ -1923,7 +1923,7 @@ tf_cc_test(
],
)
tf_cc_test(
tf_cuda_cc_test(
name = "fused_batch_norm_op_test",
size = "small",
srcs = ["fused_batch_norm_op_test.cc"],

View File

@ -929,6 +929,28 @@ struct FusedBatchNorm<GPUDevice, T, U, is_training> {
workspace_allocator.reset(
new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
}
if (!batch_mean->SharesBufferWith(estimated_mean) &&
exponential_avg_factor != 1.0f) {
OP_REQUIRES(
context,
stream
->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr,
estimated_mean.NumElements() * sizeof(U))
.ok(),
errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
"from device"));
}
if (!batch_var->SharesBufferWith(estimated_variance) &&
exponential_avg_factor != 1.0f) {
OP_REQUIRES(
context,
stream
->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr,
estimated_variance.NumElements() * sizeof(U))
.ok(),
errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
"from device"));
}
bool cudnn_launch_status =
stream
->ThenBatchNormalizationForward(
@ -1263,11 +1285,11 @@ class FusedBatchNormOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, x.shape(), &y));
Tensor* batch_mean = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(1, scale.shape(), &batch_mean));
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{3}, 1, scale.shape(), &batch_mean));
Tensor* batch_var = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(2, scale.shape(), &batch_var));
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{4}, 2, scale.shape(), &batch_var));
Tensor* saved_mean = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(3, scale.shape(), &saved_mean));
@ -1400,12 +1422,9 @@ class FusedBatchNormGradOpBase : public OpKernel {
Tensor* placeholder_1 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(3, TensorShape({0}), &placeholder_1));
functor::SetZeroFunctor<Device, float> f;
f(context->eigen_device<Device>(), placeholder_1->flat<U>());
Tensor* placeholder_2 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(4, TensorShape({0}), &placeholder_2));
f(context->eigen_device<Device>(), placeholder_2->flat<U>());
// If input is empty, set gradients w.r.t scale/offset to zero.
if (x.shape().num_elements() == 0) {

View File

@ -181,7 +181,8 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
ShapeInferenceTestOp op("FusedBatchNorm");
auto set_op = [&op](bool is_training, string data_format) {
auto set_op = [&op](bool is_training, float exponential_avg_factor,
string data_format) {
TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
@ -190,10 +191,11 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
.Input(FakeInput(DT_FLOAT))
.Attr("data_format", data_format)
.Attr("is_training", is_training)
.Attr("exponential_avg_factor", exponential_avg_factor)
.Finalize(&op.node_def));
};
set_op(true, "NHWC");
set_op(true, 1.0, "NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -207,7 +209,21 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
set_op(true, "NCHW");
set_op(true, 0.5, "NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
// Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
"[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
set_op(true, 1.0, "NCHW");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -221,7 +237,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
set_op(false, "NHWC");
set_op(false, 1.0, "NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -239,7 +255,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
set_op(false, "NCHW");
set_op(false, 1.0, "NCHW");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
@ -271,22 +287,6 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
.Finalize(&op.node_def));
};
set_op("NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
// Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
set_op("NCHW");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
@ -302,6 +302,22 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
"[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
set_op("NHWC");
// Test rank errors.
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
// Channel dim of first input is merged with the single dim in other 4 inputs.
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
}
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {

View File

@ -5015,7 +5015,7 @@ cuda_py_test(
size = "large",
srcs = ["ops/nn_fused_batchnorm_test.py"],
python_version = "PY3",
shard_count = 16,
shard_count = 24,
deps = [
":array_ops",
":client_testlib",

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@ -59,6 +60,7 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC'):
np.random.seed(1)
x_val = np.random.random_sample(x_shape).astype(x_dtype)
@ -81,6 +83,7 @@ class BatchNormalizationTest(test.TestCase):
mean=mean,
variance=var,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=False)
y_val = self.evaluate(y)
@ -92,17 +95,37 @@ class BatchNormalizationTest(test.TestCase):
atol = 2e-3 if x_dtype == np.float16 else 1e-3
self.assertAllClose(y_ref, y_val, atol=atol)
def _training_ref(self, x, scale, offset, epsilon, data_format):
def _running_mean(self, old_mean, new_val, factor):
if factor == 1.0:
return new_val
else:
return (1.0 - factor) * old_mean + factor * new_val
def _training_ref(self, x, scale, offset, old_mean, old_var,
exponential_avg_factor, epsilon, data_format):
if data_format not in ['NHWC', 'NCHW']:
raise ValueError('data_format must be NCHW or NHWC, '
'got %s.' % data_format)
if data_format == 'NCHW':
x = array_ops.transpose(x, [0, 2, 3, 1])
mean, var = nn_impl.moments(
batch_mean, batch_var = nn_impl.moments(
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
if data_format == 'NCHW':
y = array_ops.transpose(y, [0, 3, 1, 2])
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
sample_size = math_ops.cast(
array_ops.size(x) / array_ops.size(scale), scale.dtype)
batch_var_corrected = batch_var * sample_size / (
math_ops.maximum(sample_size - 1.0, 1.0))
mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor)
var = self._running_mean(old_var, batch_var_corrected,
exponential_avg_factor)
return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
def _test_training(self,
@ -111,11 +134,19 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC'):
np.random.seed(1)
x_val = np.random.random_sample(x_shape).astype(x_dtype)
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
if exponential_avg_factor == 1.0:
old_mean_val = None
old_var_val = None
else:
old_mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
old_var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
with self.cached_session(use_gpu=use_gpu) as sess:
x = constant_op.constant(x_val, name='x')
scale = constant_op.constant(scale_val, name='scale')
@ -125,20 +156,20 @@ class BatchNormalizationTest(test.TestCase):
x,
scale,
offset,
mean=old_mean_val,
variance=old_var_val,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=True)
y_val, mean_val, var_val = self.evaluate([y, mean, var])
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
data_format)
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset,
old_mean_val, old_var_val,
exponential_avg_factor,
epsilon, data_format)
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
self.assertAllClose(y_ref, y_val, atol=y_atol)
self.assertAllClose(mean_ref, mean_val, atol=1e-3)
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
# the denominator in the formula to calculate variance, while
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
sample_size = x_val.size / scale_val.size
var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
self.assertAllClose(var_ref, var_val, atol=1e-3)
def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
@ -184,6 +215,7 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC',
is_training=True):
np.random.seed(1)
@ -195,7 +227,7 @@ class BatchNormalizationTest(test.TestCase):
x = constant_op.constant(x_val, name='x')
scale = constant_op.constant(scale_val, name='scale')
offset = constant_op.constant(offset_val, name='offset')
if is_training:
if is_training and exponential_avg_factor == 1.0:
pop_mean = None
pop_var = None
else:
@ -207,6 +239,7 @@ class BatchNormalizationTest(test.TestCase):
offset,
mean=pop_mean,
variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training)
if x_dtype != np.float16:
@ -224,6 +257,7 @@ class BatchNormalizationTest(test.TestCase):
mean=pop_mean,
variance=pop_var,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor,
is_training=is_training)
err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
x_shape)
@ -244,6 +278,7 @@ class BatchNormalizationTest(test.TestCase):
scale_shape,
scale_dtype,
use_gpu=True,
exponential_avg_factor=1.0,
data_format='NHWC',
is_training=True,
err_tolerance=1e-3):
@ -258,7 +293,7 @@ class BatchNormalizationTest(test.TestCase):
grad_y = constant_op.constant(grad_y_val, name='grad_y')
scale = constant_op.constant(scale_val, name='scale')
offset = constant_op.constant(offset_val, name='offset')
if is_training:
if is_training and exponential_avg_factor == 1.0:
pop_mean = None
pop_var = None
else:
@ -270,6 +305,7 @@ class BatchNormalizationTest(test.TestCase):
offset,
mean=pop_mean,
variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training)
grad_x, grad_scale, grad_offset = gradients_impl.gradients(
@ -311,6 +347,7 @@ class BatchNormalizationTest(test.TestCase):
offset,
mean=pop_mean,
variance=pop_var,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training)
grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
@ -339,282 +376,146 @@ class BatchNormalizationTest(test.TestCase):
self.assertLess(err_grad_x_2, err_tolerance)
self.assertLess(err_grad_scale, err_tolerance)
def _runtests(self, x_shape, is_training, gradient_test=False):
use_gpu_vals = [False]
if test.is_gpu_available(cuda_only=True):
use_gpu_vals += [True]
factors = [
1.0,
]
if compat.forward_compatible(2020, 3, 6):
factors += [
0.6,
]
for dtype in [np.float16, np.float32]:
for use_gpu in use_gpu_vals:
for data_format in ['NHWC', 'NCHW']:
if data_format == 'NHWC':
scale_shape = x_shape[-1:]
else:
scale_shape = x_shape[1:2]
for exponential_avg_factor in factors:
if gradient_test:
self._test_gradient(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
is_training=is_training,
exponential_avg_factor=exponential_avg_factor)
else:
if is_training:
self._test_training(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor)
else:
self._test_inference(
x_shape,
dtype,
scale_shape,
np.float32,
use_gpu=use_gpu,
data_format=data_format,
exponential_avg_factor=exponential_avg_factor)
def testInferenceShape1(self):
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
self._test_inference(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
self._runtests(x_shape, False)
def testInferenceShape2(self):
x_shape = [1, 1, 6, 2]
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
self._runtests(x_shape, False)
def testInferenceShape3(self):
x_shape = [1, 2, 1, 6]
if test.is_gpu_available(cuda_only=True):
for dtype in [np.float16, np.float32]:
self._test_inference(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
self._runtests(x_shape, False)
def testInferenceShape4(self):
x_shape = [27, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
self._runtests(x_shape, False)
def testInferenceShape5(self):
x_shape = [0, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
self._runtests(x_shape, False)
def testTrainingShape1(self):
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
self._test_training(
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
self._runtests(x_shape, True)
def testTrainingShape2(self):
x_shape = [1, 1, 6, 2]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
self._runtests(x_shape, True)
def testTrainingShape3(self):
x_shape = [1, 2, 1, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NCHW')
self._runtests(x_shape, True)
def testTrainingShape4(self):
x_shape = [27, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
self._runtests(x_shape, True)
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
def testTrainingShape5(self):
x_shape = [0, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=True,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape,
dtype, [131],
np.float32,
use_gpu=False,
data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
self._runtests(x_shape, True)
def testBatchNormGradInferenceShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradShape1(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [1],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
def testBatchNormGradInferenceShape2(self):
x_shape = [1, 1, 6, 2]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradShape2(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 2]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
def testBatchNormGradInferenceShape3(self):
x_shape = [1, 2, 1, 6]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradShape3(self):
for is_training in [True, False]:
x_shape = [1, 2, 1, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [2],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
@test_util.run_deprecated_v1
def testBatchNormGradShape4(self):
for is_training in [True, False]:
x_shape = [5, 7, 11, 4]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
def testBatchNormGradInferenceShape4(self):
x_shape = [5, 7, 11, 4]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test never passed for XLA')
def testBatchNormGradShape5(self):
for is_training in [True, False]:
x_shape = [0, 7, 11, 4]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=False,
data_format='NCHW',
is_training=is_training)
def testBatchNormGradInferenceShape5(self):
x_shape = [0, 7, 11, 4]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape2(self):
x_shape = [1, 1, 6, 2]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape3(self):
x_shape = [1, 2, 1, 6]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape4(self):
x_shape = [5, 7, 11, 4]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
@test_util.disable_xla('This test never passed for XLA')
def testBatchNormGradTrainingShape5(self):
x_shape = [0, 7, 11, 4]
self._runtests(x_shape, is_training=True, gradient_test=True)
def _testBatchNormGradGrad(self, config):
shape = config['shape']

View File

@ -883,7 +883,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
}
if version == 2:
args["reserve_space_3"] = op.outputs[5]
return grad_fun(**args)
dx, dscale, doffset, _, _ = grad_fun(**args)
else:
pop_mean = op.inputs[3]
pop_var = op.inputs[4]
@ -905,7 +905,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
dx, dscale, doffset, _, _ = grad_fun(**args)
if data_format == b"NCHW":
dx = array_ops.transpose(dx, [0, 3, 1, 2])
return dx, dscale, doffset, None, None
return dx, dscale, doffset, None, None
@ops.RegisterGradient("FusedBatchNorm")
def _FusedBatchNormGrad(op, *grad):

View File

@ -1432,7 +1432,8 @@ def fused_batch_norm(
epsilon=0.001,
data_format="NHWC",
is_training=True,
name=None):
name=None,
exponential_avg_factor=1.0):
r"""Batch normalization.
@ -1444,22 +1445,49 @@ def fused_batch_norm(
x: Input `Tensor` of 4 dimensions.
scale: A `Tensor` of 1 dimension for scaling.
offset: A `Tensor` of 1 dimension for bias.
mean: A `Tensor` of 1 dimension for population mean used for inference.
variance: A `Tensor` of 1 dimension for population variance
used for inference.
mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
of this argument depends on the value of is_training and
exponential_avg_factor as follows:
is_training==False (inference):
Mean must be a `Tensor` of the same shape as scale containing the
estimated population mean computed during training.
is_training==True and exponential_avg_factor == 1.0:
Mean must be None.
is_training==True and exponential_avg_factor != 1.0:
Mean must be a `Tensor` of the same shape as scale containing the
exponential running mean.
variance: A `Tensor` of 1 dimension for population variance. The shape and
meaning of this argument depends on the value of is_training and
exponential_avg_factor as follows:
is_training==False (inference):
Variance must be a `Tensor` of the same shape as scale containing
the estimated population variance computed during training.
is_training==True and exponential_avg_factor == 1.0:
Variance must be None.
is_training==True and exponential_avg_factor != 1.0:
Variance must be a `Tensor` of the same shape as scale containing
the exponential running variance.
epsilon: A small float number added to the variance of x.
data_format: The data format for x. Either "NHWC" (default) or "NCHW".
is_training: A bool value to specify if the operation is used for
training or inference.
name: A name for this operation (optional).
exponential_avg_factor: A float number (usually between 0 and 1) used
for controlling the decay of the running
population average of mean and variance.
If set to 1.0, the current batch average is
returned.
Returns:
y: A 4D Tensor for the normalized, scaled, offsetted x.
batch_mean: A 1D Tensor for the mean of x.
batch_var: A 1D Tensor for the variance of x.
Raises:
ValueError: If mean or variance is not None when is_training is True.
running_mean: A 1D Tensor for the exponential running mean of x.
The output value is (1 - exponential_avg_factor) * mean +
exponential_avg_factor * batch_mean), where batch_mean
is the mean of the current batch in x.
running_var: A 1D Tensor for the exponential running variance
The output value is (1 - exponential_avg_factor) * variance +
exponential_avg_factor * batch_variance), where batch_variance
is the variance of the current batch in x.
References:
Batch Normalization - Accelerating Deep Network Training by Reducing
@ -1467,24 +1495,44 @@ def fused_batch_norm(
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
"""
if is_training and exponential_avg_factor == 1.0:
if (mean is not None) or (variance is not None):
raise ValueError("Both 'mean' and 'variance' must be None when "
"is_training is True and "
"exponential_avg_factor == 1.0.")
else:
if (mean is None) or (variance is None):
raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
"is_training is False or "
"exponential_avg_factor != 1.0.")
x = ops.convert_to_tensor(x, name="input")
scale = ops.convert_to_tensor(scale, name="scale")
offset = ops.convert_to_tensor(offset, name="offset")
if is_training:
if (mean is not None) or (variance is not None):
raise ValueError("Both 'mean' and 'variance' must be None "
"if is_training is True.")
if mean is None:
mean = constant_op.constant([])
if variance is None:
variance = constant_op.constant([])
# Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
# prevent exception (see cudnn.h).
min_epsilon = 1.001e-5
epsilon = epsilon if epsilon > min_epsilon else min_epsilon
if compat.forward_compatible(2019, 6, 6):
y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
if compat.forward_compatible(2020, 3, 6):
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
x,
scale,
offset,
mean,
variance,
epsilon=epsilon,
exponential_avg_factor=exponential_avg_factor,
data_format=data_format,
is_training=is_training,
name=name)
return y, running_mean, running_var
else:
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
x,
scale,
offset,
@ -1494,23 +1542,7 @@ def fused_batch_norm(
data_format=data_format,
is_training=is_training,
name=name)
return y, batch_mean, batch_var
if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16:
fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2
else:
fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access
y, batch_mean, batch_var, _, _ = fused_batch_norm_func(
x,
scale,
offset,
mean,
variance,
epsilon=epsilon,
data_format=data_format,
is_training=is_training,
name=name)
return y, batch_mean, batch_var
return y, running_mean, running_var
@tf_export(v1=["nn.batch_norm_with_global_normalization"])

View File

@ -3607,8 +3607,10 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
void* batch_mean_opaque;
void* batch_var_opaque;
if (!batch_mean->is_null() && !batch_var->is_null()) {
stream->ThenMemZero(batch_mean, batch_mean->size());
stream->ThenMemZero(batch_var, batch_var->size());
if (exponential_average_factor == 1.0) {
stream->ThenMemZero(batch_mean, batch_mean->size());
stream->ThenMemZero(batch_var, batch_var->size());
}
batch_mean_opaque = batch_mean->opaque();
batch_var_opaque = batch_var->opaque();
} else {

View File

@ -210,7 +210,7 @@ tf_module {
}
member_method {
name: "fused_batch_norm"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], "
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\', \'exponential_avg_factor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\', \'1.0\'], "
}
member_method {
name: "in_top_k"