Support computing exponential running mean and variance in fused_batch_norm.
Simplify unit test code for fused_batch_norm and its gradient. PiperOrigin-RevId: 297451723 Change-Id: I3ad1db022848cec3cf45a5a560a8f25af75781d4
This commit is contained in:
parent
f0ffc49ac2
commit
84f2ec1d60
@ -23,6 +23,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import test_utils
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -34,10 +35,18 @@ DATA_FORMATS = (
|
||||
("_data_format_NCHW", "NCHW"),
|
||||
)
|
||||
|
||||
DATA_FORMATS_AND_AVG_FACTORS = (
|
||||
("_data_format_NHWC_no_averaging", "NHWC", 1.0),
|
||||
("_data_format_NHWC_averaging", "NHWC", 0.6),
|
||||
("_data_format_NCHW_no_averaging", "NCHW", 1.0),
|
||||
("_data_format_NCHW_averaging", "NCHW", 0.6),
|
||||
)
|
||||
|
||||
|
||||
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _reference_training(self, x, scale, offset, epsilon, data_format):
|
||||
def _reference_training(self, x, scale, offset, old_mean, old_var, epsilon,
|
||||
exponential_avg_factor, data_format):
|
||||
if data_format != "NHWC":
|
||||
raise ValueError("data_format must be NHWC, got %s." % data_format)
|
||||
x_square = x * x
|
||||
@ -49,6 +58,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
factor = element_count / max(element_count - 1, 1)
|
||||
corrected_var = var * factor
|
||||
normalized = (x - mean) / np.sqrt(var + epsilon)
|
||||
if exponential_avg_factor != 1.0:
|
||||
mean = (1.0 -
|
||||
exponential_avg_factor) * old_mean + exponential_avg_factor * mean
|
||||
corrected_var = (1.0 - exponential_avg_factor
|
||||
) * old_var + exponential_avg_factor * corrected_var
|
||||
return (normalized * scale + offset), mean, var, corrected_var
|
||||
|
||||
def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format):
|
||||
@ -81,9 +95,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
epsilon = 0.001
|
||||
exponential_avg_factor = 1.0
|
||||
data_format_src = "NHWC"
|
||||
y_ref, mean_ref, var_ref, _ = self._reference_training(
|
||||
x_val, scale_val, offset_val, epsilon, data_format_src)
|
||||
x_val, scale_val, offset_val, None, None, epsilon,
|
||||
exponential_avg_factor, data_format_src)
|
||||
|
||||
with self.session() as sess, self.test_scope():
|
||||
# To avoid constant folding
|
||||
@ -114,7 +130,11 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
})
|
||||
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
|
||||
|
||||
def _testLearning(self, use_gradient_checker, data_format):
|
||||
def _testLearning(self, use_gradient_checker, data_format,
|
||||
exponential_avg_factor):
|
||||
if not compat.forward_compatible(2020, 3,
|
||||
6) and exponential_avg_factor != 1.0:
|
||||
self.skipTest("running average not available.")
|
||||
channel = 3
|
||||
x_shape = [2, 2, 6, channel]
|
||||
scale_shape = [channel]
|
||||
@ -122,13 +142,14 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
mean_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
var_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
var_val_corr = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
epsilon = 0.001
|
||||
data_format_src = "NHWC"
|
||||
# When in training mode, fused_batchnorm applies an implicit Bessel's
|
||||
# correction. So we have to use the corrected variance here, as well.
|
||||
y_ref, mean_ref, _, var_ref_corr = self._reference_training(
|
||||
x_val, scale_val, offset_val, epsilon, data_format_src)
|
||||
x_val, scale_val, offset_val, mean_val, var_val_corr, epsilon,
|
||||
exponential_avg_factor, data_format_src)
|
||||
|
||||
with self.session() as sess, self.test_scope():
|
||||
# To avoid constant folding
|
||||
@ -142,15 +163,38 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
|
||||
offset = array_ops.placeholder(
|
||||
np.float32, shape=scale_shape, name="offset")
|
||||
if exponential_avg_factor == 1.0:
|
||||
old_mean = None
|
||||
old_var = None
|
||||
else:
|
||||
old_mean = array_ops.placeholder(
|
||||
np.float32, shape=scale_shape, name="old_mean")
|
||||
old_var = array_ops.placeholder(
|
||||
np.float32, shape=scale_shape, name="old_var")
|
||||
y, mean, var = nn.fused_batch_norm(
|
||||
t_val,
|
||||
scale,
|
||||
offset,
|
||||
mean=None,
|
||||
variance=None,
|
||||
mean=old_mean,
|
||||
variance=old_var,
|
||||
epsilon=epsilon,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=True)
|
||||
if exponential_avg_factor == 1.0:
|
||||
feed_dict = {
|
||||
t_val: x_val_converted,
|
||||
scale: scale_val,
|
||||
offset: offset_val,
|
||||
}
|
||||
else:
|
||||
feed_dict = {
|
||||
t_val: x_val_converted,
|
||||
scale: scale_val,
|
||||
offset: offset_val,
|
||||
old_mean: mean_val,
|
||||
old_var: var_val_corr
|
||||
}
|
||||
# Check gradient.
|
||||
if use_gradient_checker:
|
||||
err = gradient_checker.compute_gradient_error(
|
||||
@ -158,29 +202,22 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
x_val_converted.shape,
|
||||
y,
|
||||
x_val_converted.shape,
|
||||
extra_feed_dict={
|
||||
t_val: x_val_converted,
|
||||
scale: scale_val,
|
||||
offset: offset_val
|
||||
})
|
||||
extra_feed_dict=feed_dict)
|
||||
self.assertLess(err, 1e-3)
|
||||
|
||||
y_val, mean_val, var_val = sess.run([y, mean, var], {
|
||||
t_val: x_val_converted,
|
||||
scale: scale_val,
|
||||
offset: offset_val
|
||||
})
|
||||
self.assertAllClose(mean_val, mean_ref, atol=1e-3)
|
||||
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
|
||||
self.assertAllClose(var_val, var_ref_corr, atol=1e-3)
|
||||
y_tf, mean_tf, var_tf = sess.run([y, mean, var], feed_dict)
|
||||
self.assertAllClose(y_tf, y_ref_converted, atol=1e-3)
|
||||
self.assertAllClose(mean_tf, mean_ref, atol=1e-3)
|
||||
self.assertAllClose(var_tf, var_ref_corr, atol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(*DATA_FORMATS)
|
||||
def testLearning(self, data_format):
|
||||
self._testLearning(False, data_format)
|
||||
@parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS)
|
||||
def testLearning(self, data_format, exponential_avg_factor):
|
||||
self._testLearning(False, data_format, exponential_avg_factor)
|
||||
|
||||
@parameterized.named_parameters(*DATA_FORMATS)
|
||||
def testLearningWithGradientChecker(self, data_format):
|
||||
self._testLearning(True, data_format)
|
||||
@parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS)
|
||||
def testLearningWithGradientChecker(self, data_format,
|
||||
exponential_avg_factor):
|
||||
self._testLearning(True, data_format, exponential_avg_factor)
|
||||
|
||||
@parameterized.named_parameters(*DATA_FORMATS)
|
||||
def testGradientTraining(self, data_format):
|
||||
|
@ -37,6 +37,8 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("exponential_avg_factor", &exponential_avg_factor_));
|
||||
string data_format_str;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
|
||||
OP_REQUIRES(
|
||||
@ -105,7 +107,6 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
ctx->SetOutput(0, converted);
|
||||
}
|
||||
|
||||
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
|
||||
xla::XlaOp variance = xla::GetTupleElement(output, 2);
|
||||
// Apply Bessel's correction.
|
||||
int total_input_size = ctx->InputShape(0).num_elements();
|
||||
@ -121,7 +122,7 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
if (input_shape.num_elements() == 0) {
|
||||
auto status_or_output_shape = b->GetShape(corrected);
|
||||
OP_REQUIRES_OK(ctx, status_or_output_shape.status());
|
||||
|
||||
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
|
||||
ctx->SetOutput(
|
||||
kVarianceOutputIndex,
|
||||
xla::Broadcast(
|
||||
@ -130,7 +131,27 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
status_or_output_shape.ValueOrDie().dimensions())));
|
||||
|
||||
} else {
|
||||
ctx->SetOutput(2, corrected);
|
||||
if (exponential_avg_factor_ == 1.0f) {
|
||||
ctx->SetOutput(1, xla::GetTupleElement(output, 1));
|
||||
ctx->SetOutput(2, corrected);
|
||||
} else {
|
||||
xla::XlaOp old_mean = ctx->Input(3);
|
||||
xla::XlaOp alpha =
|
||||
xla::ScalarLike(old_mean, 1.0f - exponential_avg_factor_);
|
||||
xla::XlaOp beta = xla::ScalarLike(old_mean, exponential_avg_factor_);
|
||||
// new_running_mean = alpha * old_mean + beta * batch_mean.
|
||||
xla::XlaOp new_running_mean =
|
||||
xla::Add(xla::Mul(old_mean, alpha),
|
||||
xla::Mul(xla::GetTupleElement(output, 1), beta));
|
||||
ctx->SetOutput(1, new_running_mean);
|
||||
|
||||
xla::XlaOp old_variance = ctx->Input(4);
|
||||
xla::XlaOp new_running_variance = xla::Add(
|
||||
xla::Mul(old_variance, alpha), xla::Mul(corrected, beta));
|
||||
// new_running_variance = alpha * old_variance + beta *
|
||||
// batch_variance.
|
||||
ctx->SetOutput(2, new_running_variance);
|
||||
}
|
||||
}
|
||||
|
||||
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
|
||||
@ -175,6 +196,7 @@ class FusedBatchNormOp : public XlaOpKernel {
|
||||
float epsilon_;
|
||||
TensorFormat data_format_;
|
||||
bool is_training_;
|
||||
float exponential_avg_factor_;
|
||||
bool add_side_input_;
|
||||
bool apply_relu_;
|
||||
bool is_on_gpu_;
|
||||
|
@ -787,6 +787,7 @@ cc_library(
|
||||
"//tensorflow/core/util:padding",
|
||||
"//tensorflow/core/util:tensor_format",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
@ -1083,7 +1083,11 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
||||
|
||||
bool is_training;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||
int number_inputs = (is_training) ? 3 : 5;
|
||||
float exponential_avg_factor;
|
||||
if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
|
||||
exponential_avg_factor = 1.0f; // default value
|
||||
}
|
||||
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
||||
string data_format_str;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||
TensorFormat data_format;
|
||||
@ -1176,16 +1180,8 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, x_backprop);
|
||||
c->set_output(1, c->Vector(channel_dim));
|
||||
c->set_output(2, c->Vector(channel_dim));
|
||||
// Set the correct shapes for reserve_spaces
|
||||
// so that gradients can be performed when
|
||||
// the op is in a symbolic condition.
|
||||
if (is_training) {
|
||||
c->set_output(3, c->Vector(0));
|
||||
c->set_output(4, c->Vector(0));
|
||||
} else {
|
||||
c->set_output(3, c->Vector(channel_dim));
|
||||
c->set_output(4, c->Vector(channel_dim));
|
||||
}
|
||||
c->set_output(3, c->Vector(0));
|
||||
c->set_output(4, c->Vector(0));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -2326,7 +2322,7 @@ Status SparseReduceShapeFn(InferenceContext* c) {
|
||||
auto axes_vec = axes_tensor->flat<int32>();
|
||||
|
||||
int64 ndims = shape_vec.size();
|
||||
std::unordered_set<int64> axes;
|
||||
absl::flat_hash_set<int64> axes;
|
||||
for (int i = 0; i < axes_vec.size(); i++) {
|
||||
axes.insert((axes_vec(i) + ndims) % ndims);
|
||||
}
|
||||
|
@ -528,9 +528,9 @@ TEST(CommonShapeFnsTest, FusedBatchNormExTest) {
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
// Channels are not multiple of 4.
|
||||
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[0];[0]");
|
||||
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[2];[2]");
|
||||
|
||||
INFER_OK(op, "[2,2,2,4];[4];[4];[0];[0]",
|
||||
INFER_OK(op, "[2,2,2,4];[4];[4];[4];[4]",
|
||||
"[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?");
|
||||
}
|
||||
|
||||
|
@ -1923,7 +1923,7 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
tf_cuda_cc_test(
|
||||
name = "fused_batch_norm_op_test",
|
||||
size = "small",
|
||||
srcs = ["fused_batch_norm_op_test.cc"],
|
||||
|
@ -929,6 +929,28 @@ struct FusedBatchNorm<GPUDevice, T, U, is_training> {
|
||||
workspace_allocator.reset(
|
||||
new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
|
||||
}
|
||||
if (!batch_mean->SharesBufferWith(estimated_mean) &&
|
||||
exponential_avg_factor != 1.0f) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
stream
|
||||
->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr,
|
||||
estimated_mean.NumElements() * sizeof(U))
|
||||
.ok(),
|
||||
errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
|
||||
"from device"));
|
||||
}
|
||||
if (!batch_var->SharesBufferWith(estimated_variance) &&
|
||||
exponential_avg_factor != 1.0f) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
stream
|
||||
->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr,
|
||||
estimated_variance.NumElements() * sizeof(U))
|
||||
.ok(),
|
||||
errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
|
||||
"from device"));
|
||||
}
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenBatchNormalizationForward(
|
||||
@ -1263,11 +1285,11 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{0}, 0, x.shape(), &y));
|
||||
Tensor* batch_mean = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, scale.shape(), &batch_mean));
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{3}, 1, scale.shape(), &batch_mean));
|
||||
Tensor* batch_var = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(2, scale.shape(), &batch_var));
|
||||
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
|
||||
{4}, 2, scale.shape(), &batch_var));
|
||||
Tensor* saved_mean = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(3, scale.shape(), &saved_mean));
|
||||
@ -1400,12 +1422,9 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
Tensor* placeholder_1 = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(3, TensorShape({0}), &placeholder_1));
|
||||
functor::SetZeroFunctor<Device, float> f;
|
||||
f(context->eigen_device<Device>(), placeholder_1->flat<U>());
|
||||
Tensor* placeholder_2 = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(4, TensorShape({0}), &placeholder_2));
|
||||
f(context->eigen_device<Device>(), placeholder_2->flat<U>());
|
||||
|
||||
// If input is empty, set gradients w.r.t scale/offset to zero.
|
||||
if (x.shape().num_elements() == 0) {
|
||||
|
@ -181,7 +181,8 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
|
||||
TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
ShapeInferenceTestOp op("FusedBatchNorm");
|
||||
|
||||
auto set_op = [&op](bool is_training, string data_format) {
|
||||
auto set_op = [&op](bool is_training, float exponential_avg_factor,
|
||||
string data_format) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
@ -190,10 +191,11 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("data_format", data_format)
|
||||
.Attr("is_training", is_training)
|
||||
.Attr("exponential_avg_factor", exponential_avg_factor)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
set_op(true, "NHWC");
|
||||
set_op(true, 1.0, "NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -207,7 +209,21 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
|
||||
|
||||
set_op(true, "NCHW");
|
||||
set_op(true, 0.5, "NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
|
||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
|
||||
INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
|
||||
INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
|
||||
"[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
|
||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
|
||||
|
||||
set_op(true, 1.0, "NCHW");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -221,7 +237,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
|
||||
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
|
||||
|
||||
set_op(false, "NHWC");
|
||||
set_op(false, 1.0, "NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -239,7 +255,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
|
||||
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
|
||||
|
||||
set_op(false, "NCHW");
|
||||
set_op(false, 1.0, "NCHW");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
@ -271,22 +287,6 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
set_op("NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
|
||||
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
|
||||
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
|
||||
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
|
||||
|
||||
set_op("NCHW");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
@ -302,6 +302,22 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
|
||||
INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
|
||||
"[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
|
||||
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
|
||||
|
||||
set_op("NHWC");
|
||||
// Test rank errors.
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
|
||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
|
||||
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
|
||||
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
|
||||
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
|
||||
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
|
||||
}
|
||||
|
||||
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
|
||||
|
@ -5015,7 +5015,7 @@ cuda_py_test(
|
||||
size = "large",
|
||||
srcs = ["ops/nn_fused_batchnorm_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 16,
|
||||
shard_count = 24,
|
||||
deps = [
|
||||
":array_ops",
|
||||
":client_testlib",
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -59,6 +60,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC'):
|
||||
np.random.seed(1)
|
||||
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
||||
@ -81,6 +83,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
mean=mean,
|
||||
variance=var,
|
||||
epsilon=epsilon,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=False)
|
||||
y_val = self.evaluate(y)
|
||||
@ -92,17 +95,37 @@ class BatchNormalizationTest(test.TestCase):
|
||||
atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
||||
self.assertAllClose(y_ref, y_val, atol=atol)
|
||||
|
||||
def _training_ref(self, x, scale, offset, epsilon, data_format):
|
||||
def _running_mean(self, old_mean, new_val, factor):
|
||||
if factor == 1.0:
|
||||
return new_val
|
||||
else:
|
||||
return (1.0 - factor) * old_mean + factor * new_val
|
||||
|
||||
def _training_ref(self, x, scale, offset, old_mean, old_var,
|
||||
exponential_avg_factor, epsilon, data_format):
|
||||
if data_format not in ['NHWC', 'NCHW']:
|
||||
raise ValueError('data_format must be NCHW or NHWC, '
|
||||
'got %s.' % data_format)
|
||||
if data_format == 'NCHW':
|
||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||
mean, var = nn_impl.moments(
|
||||
batch_mean, batch_var = nn_impl.moments(
|
||||
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
|
||||
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
|
||||
|
||||
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
|
||||
if data_format == 'NCHW':
|
||||
y = array_ops.transpose(y, [0, 3, 1, 2])
|
||||
|
||||
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
||||
# the denominator in the formula to calculate variance, while
|
||||
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
|
||||
sample_size = math_ops.cast(
|
||||
array_ops.size(x) / array_ops.size(scale), scale.dtype)
|
||||
batch_var_corrected = batch_var * sample_size / (
|
||||
math_ops.maximum(sample_size - 1.0, 1.0))
|
||||
|
||||
mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor)
|
||||
var = self._running_mean(old_var, batch_var_corrected,
|
||||
exponential_avg_factor)
|
||||
return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
|
||||
|
||||
def _test_training(self,
|
||||
@ -111,11 +134,19 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC'):
|
||||
np.random.seed(1)
|
||||
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
||||
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
if exponential_avg_factor == 1.0:
|
||||
old_mean_val = None
|
||||
old_var_val = None
|
||||
else:
|
||||
old_mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
old_var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||
|
||||
with self.cached_session(use_gpu=use_gpu) as sess:
|
||||
x = constant_op.constant(x_val, name='x')
|
||||
scale = constant_op.constant(scale_val, name='scale')
|
||||
@ -125,20 +156,20 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean=old_mean_val,
|
||||
variance=old_var_val,
|
||||
epsilon=epsilon,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=True)
|
||||
y_val, mean_val, var_val = self.evaluate([y, mean, var])
|
||||
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
|
||||
data_format)
|
||||
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset,
|
||||
old_mean_val, old_var_val,
|
||||
exponential_avg_factor,
|
||||
epsilon, data_format)
|
||||
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
||||
self.assertAllClose(y_ref, y_val, atol=y_atol)
|
||||
self.assertAllClose(mean_ref, mean_val, atol=1e-3)
|
||||
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
||||
# the denominator in the formula to calculate variance, while
|
||||
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
|
||||
sample_size = x_val.size / scale_val.size
|
||||
var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
|
||||
self.assertAllClose(var_ref, var_val, atol=1e-3)
|
||||
|
||||
def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
|
||||
@ -184,6 +215,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC',
|
||||
is_training=True):
|
||||
np.random.seed(1)
|
||||
@ -195,7 +227,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x = constant_op.constant(x_val, name='x')
|
||||
scale = constant_op.constant(scale_val, name='scale')
|
||||
offset = constant_op.constant(offset_val, name='offset')
|
||||
if is_training:
|
||||
if is_training and exponential_avg_factor == 1.0:
|
||||
pop_mean = None
|
||||
pop_var = None
|
||||
else:
|
||||
@ -207,6 +239,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
offset,
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
if x_dtype != np.float16:
|
||||
@ -224,6 +257,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
data_format=data_format,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
is_training=is_training)
|
||||
err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
|
||||
x_shape)
|
||||
@ -244,6 +278,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
scale_shape,
|
||||
scale_dtype,
|
||||
use_gpu=True,
|
||||
exponential_avg_factor=1.0,
|
||||
data_format='NHWC',
|
||||
is_training=True,
|
||||
err_tolerance=1e-3):
|
||||
@ -258,7 +293,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
grad_y = constant_op.constant(grad_y_val, name='grad_y')
|
||||
scale = constant_op.constant(scale_val, name='scale')
|
||||
offset = constant_op.constant(offset_val, name='offset')
|
||||
if is_training:
|
||||
if is_training and exponential_avg_factor == 1.0:
|
||||
pop_mean = None
|
||||
pop_var = None
|
||||
else:
|
||||
@ -270,6 +305,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
offset,
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
grad_x, grad_scale, grad_offset = gradients_impl.gradients(
|
||||
@ -311,6 +347,7 @@ class BatchNormalizationTest(test.TestCase):
|
||||
offset,
|
||||
mean=pop_mean,
|
||||
variance=pop_var,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training)
|
||||
grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
|
||||
@ -339,282 +376,146 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self.assertLess(err_grad_x_2, err_tolerance)
|
||||
self.assertLess(err_grad_scale, err_tolerance)
|
||||
|
||||
def _runtests(self, x_shape, is_training, gradient_test=False):
|
||||
use_gpu_vals = [False]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
use_gpu_vals += [True]
|
||||
factors = [
|
||||
1.0,
|
||||
]
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
factors += [
|
||||
0.6,
|
||||
]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
for use_gpu in use_gpu_vals:
|
||||
for data_format in ['NHWC', 'NCHW']:
|
||||
if data_format == 'NHWC':
|
||||
scale_shape = x_shape[-1:]
|
||||
else:
|
||||
scale_shape = x_shape[1:2]
|
||||
for exponential_avg_factor in factors:
|
||||
if gradient_test:
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype,
|
||||
scale_shape,
|
||||
np.float32,
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
else:
|
||||
if is_training:
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype,
|
||||
scale_shape,
|
||||
np.float32,
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
else:
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype,
|
||||
scale_shape,
|
||||
np.float32,
|
||||
use_gpu=use_gpu,
|
||||
data_format=data_format,
|
||||
exponential_avg_factor=exponential_avg_factor)
|
||||
|
||||
def testInferenceShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
|
||||
self._runtests(x_shape, False)
|
||||
|
||||
def testInferenceShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._runtests(x_shape, False)
|
||||
|
||||
def testInferenceShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
for dtype in [np.float16, np.float32]:
|
||||
self._test_inference(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._runtests(x_shape, False)
|
||||
|
||||
def testInferenceShape4(self):
|
||||
x_shape = [27, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._runtests(x_shape, False)
|
||||
|
||||
def testInferenceShape5(self):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._runtests(x_shape, False)
|
||||
|
||||
def testTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
def testTrainingShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
def testTrainingShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NCHW')
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
def testTrainingShape4(self):
|
||||
x_shape = [27, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
||||
def testTrainingShape5(self):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape,
|
||||
dtype, [131],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
def testBatchNormGradInferenceShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradShape1(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [1],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
def testBatchNormGradInferenceShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradShape2(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 1, 6, 2]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
def testBatchNormGradInferenceShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradShape3(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 2, 1, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [2],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradShape4(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [5, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
def testBatchNormGradInferenceShape4(self):
|
||||
x_shape = [5, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test never passed for XLA')
|
||||
def testBatchNormGradShape5(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [0, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
def testBatchNormGradInferenceShape5(self):
|
||||
x_shape = [0, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape2(self):
|
||||
x_shape = [1, 1, 6, 2]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape3(self):
|
||||
x_shape = [1, 2, 1, 6]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape4(self):
|
||||
x_shape = [5, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test never passed for XLA')
|
||||
def testBatchNormGradTrainingShape5(self):
|
||||
x_shape = [0, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
def _testBatchNormGradGrad(self, config):
|
||||
shape = config['shape']
|
||||
|
@ -883,7 +883,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
||||
}
|
||||
if version == 2:
|
||||
args["reserve_space_3"] = op.outputs[5]
|
||||
return grad_fun(**args)
|
||||
dx, dscale, doffset, _, _ = grad_fun(**args)
|
||||
else:
|
||||
pop_mean = op.inputs[3]
|
||||
pop_var = op.inputs[4]
|
||||
@ -905,7 +905,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad):
|
||||
dx, dscale, doffset, _, _ = grad_fun(**args)
|
||||
if data_format == b"NCHW":
|
||||
dx = array_ops.transpose(dx, [0, 3, 1, 2])
|
||||
return dx, dscale, doffset, None, None
|
||||
return dx, dscale, doffset, None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("FusedBatchNorm")
|
||||
def _FusedBatchNormGrad(op, *grad):
|
||||
|
@ -1432,7 +1432,8 @@ def fused_batch_norm(
|
||||
epsilon=0.001,
|
||||
data_format="NHWC",
|
||||
is_training=True,
|
||||
name=None):
|
||||
name=None,
|
||||
exponential_avg_factor=1.0):
|
||||
r"""Batch normalization.
|
||||
|
||||
|
||||
@ -1444,22 +1445,49 @@ def fused_batch_norm(
|
||||
x: Input `Tensor` of 4 dimensions.
|
||||
scale: A `Tensor` of 1 dimension for scaling.
|
||||
offset: A `Tensor` of 1 dimension for bias.
|
||||
mean: A `Tensor` of 1 dimension for population mean used for inference.
|
||||
variance: A `Tensor` of 1 dimension for population variance
|
||||
used for inference.
|
||||
mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
|
||||
of this argument depends on the value of is_training and
|
||||
exponential_avg_factor as follows:
|
||||
is_training==False (inference):
|
||||
Mean must be a `Tensor` of the same shape as scale containing the
|
||||
estimated population mean computed during training.
|
||||
is_training==True and exponential_avg_factor == 1.0:
|
||||
Mean must be None.
|
||||
is_training==True and exponential_avg_factor != 1.0:
|
||||
Mean must be a `Tensor` of the same shape as scale containing the
|
||||
exponential running mean.
|
||||
variance: A `Tensor` of 1 dimension for population variance. The shape and
|
||||
meaning of this argument depends on the value of is_training and
|
||||
exponential_avg_factor as follows:
|
||||
is_training==False (inference):
|
||||
Variance must be a `Tensor` of the same shape as scale containing
|
||||
the estimated population variance computed during training.
|
||||
is_training==True and exponential_avg_factor == 1.0:
|
||||
Variance must be None.
|
||||
is_training==True and exponential_avg_factor != 1.0:
|
||||
Variance must be a `Tensor` of the same shape as scale containing
|
||||
the exponential running variance.
|
||||
epsilon: A small float number added to the variance of x.
|
||||
data_format: The data format for x. Either "NHWC" (default) or "NCHW".
|
||||
is_training: A bool value to specify if the operation is used for
|
||||
training or inference.
|
||||
name: A name for this operation (optional).
|
||||
exponential_avg_factor: A float number (usually between 0 and 1) used
|
||||
for controlling the decay of the running
|
||||
population average of mean and variance.
|
||||
If set to 1.0, the current batch average is
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
y: A 4D Tensor for the normalized, scaled, offsetted x.
|
||||
batch_mean: A 1D Tensor for the mean of x.
|
||||
batch_var: A 1D Tensor for the variance of x.
|
||||
|
||||
Raises:
|
||||
ValueError: If mean or variance is not None when is_training is True.
|
||||
running_mean: A 1D Tensor for the exponential running mean of x.
|
||||
The output value is (1 - exponential_avg_factor) * mean +
|
||||
exponential_avg_factor * batch_mean), where batch_mean
|
||||
is the mean of the current batch in x.
|
||||
running_var: A 1D Tensor for the exponential running variance
|
||||
The output value is (1 - exponential_avg_factor) * variance +
|
||||
exponential_avg_factor * batch_variance), where batch_variance
|
||||
is the variance of the current batch in x.
|
||||
|
||||
References:
|
||||
Batch Normalization - Accelerating Deep Network Training by Reducing
|
||||
@ -1467,24 +1495,44 @@ def fused_batch_norm(
|
||||
[Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
|
||||
([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
|
||||
"""
|
||||
if is_training and exponential_avg_factor == 1.0:
|
||||
if (mean is not None) or (variance is not None):
|
||||
raise ValueError("Both 'mean' and 'variance' must be None when "
|
||||
"is_training is True and "
|
||||
"exponential_avg_factor == 1.0.")
|
||||
else:
|
||||
if (mean is None) or (variance is None):
|
||||
raise ValueError("Both 'mean' and 'variance' must be a 1D tensor when "
|
||||
"is_training is False or "
|
||||
"exponential_avg_factor != 1.0.")
|
||||
x = ops.convert_to_tensor(x, name="input")
|
||||
scale = ops.convert_to_tensor(scale, name="scale")
|
||||
offset = ops.convert_to_tensor(offset, name="offset")
|
||||
if is_training:
|
||||
if (mean is not None) or (variance is not None):
|
||||
raise ValueError("Both 'mean' and 'variance' must be None "
|
||||
"if is_training is True.")
|
||||
if mean is None:
|
||||
mean = constant_op.constant([])
|
||||
if variance is None:
|
||||
variance = constant_op.constant([])
|
||||
|
||||
# Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
|
||||
# prevent exception (see cudnn.h).
|
||||
min_epsilon = 1.001e-5
|
||||
epsilon = epsilon if epsilon > min_epsilon else min_epsilon
|
||||
|
||||
if compat.forward_compatible(2019, 6, 6):
|
||||
y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
||||
if compat.forward_compatible(2020, 3, 6):
|
||||
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean,
|
||||
variance,
|
||||
epsilon=epsilon,
|
||||
exponential_avg_factor=exponential_avg_factor,
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
name=name)
|
||||
return y, running_mean, running_var
|
||||
else:
|
||||
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
@ -1494,23 +1542,7 @@ def fused_batch_norm(
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
name=name)
|
||||
return y, batch_mean, batch_var
|
||||
|
||||
if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16:
|
||||
fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2
|
||||
else:
|
||||
fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access
|
||||
y, batch_mean, batch_var, _, _ = fused_batch_norm_func(
|
||||
x,
|
||||
scale,
|
||||
offset,
|
||||
mean,
|
||||
variance,
|
||||
epsilon=epsilon,
|
||||
data_format=data_format,
|
||||
is_training=is_training,
|
||||
name=name)
|
||||
return y, batch_mean, batch_var
|
||||
return y, running_mean, running_var
|
||||
|
||||
|
||||
@tf_export(v1=["nn.batch_norm_with_global_normalization"])
|
||||
|
@ -3607,8 +3607,10 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
|
||||
void* batch_mean_opaque;
|
||||
void* batch_var_opaque;
|
||||
if (!batch_mean->is_null() && !batch_var->is_null()) {
|
||||
stream->ThenMemZero(batch_mean, batch_mean->size());
|
||||
stream->ThenMemZero(batch_var, batch_var->size());
|
||||
if (exponential_average_factor == 1.0) {
|
||||
stream->ThenMemZero(batch_mean, batch_mean->size());
|
||||
stream->ThenMemZero(batch_var, batch_var->size());
|
||||
}
|
||||
batch_mean_opaque = batch_mean->opaque();
|
||||
batch_var_opaque = batch_var->opaque();
|
||||
} else {
|
||||
|
@ -210,7 +210,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "fused_batch_norm"
|
||||
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\', \'exponential_avg_factor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\', \'1.0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "in_top_k"
|
||||
|
Loading…
Reference in New Issue
Block a user