Automated rollback of commit 80acd88cd43f09a1a2980792e3955f2ce5147bfd

PiperOrigin-RevId: 295811620
Change-Id: I39a1f7f7dadee2c7ea231df16da1ab5516c8f1fa
This commit is contained in:
A. Unique TensorFlower 2020-02-18 13:41:56 -08:00 committed by TensorFlower Gardener
parent ac2c05a1d5
commit 0488a18af4
5 changed files with 264 additions and 105 deletions

View File

@ -81,7 +81,7 @@ Status ParseActivationMode(OpKernelConstruction* context,
} }
// Functor used by FusedBatchNormOp to do the computations. // Functor used by FusedBatchNormOp to do the computations.
template <typename Device, typename T, typename U> template <typename Device, typename T, typename U, bool is_training>
struct FusedBatchNorm; struct FusedBatchNorm;
// Functor used by FusedBatchNormGradOp to do the computations when // Functor used by FusedBatchNormGradOp to do the computations when
// is_training=True. // is_training=True.
@ -89,17 +89,155 @@ template <typename Device, typename T, typename U>
struct FusedBatchNormGrad; struct FusedBatchNormGrad;
template <typename T, typename U> template <typename T, typename U>
struct FusedBatchNorm<CPUDevice, T, U> { struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ true> {
void operator()(OpKernelContext* context, const Tensor& x_input,
const Tensor& scale_input, const Tensor& offset_input,
const Tensor& running_mean_input,
const Tensor& running_variance_input,
const Tensor* side_input, U epsilon, U exponential_avg_factor,
FusedBatchNormActivationMode activation_mode,
Tensor* y_output, Tensor* running_mean_output,
Tensor* running_var_output, Tensor* saved_batch_mean_output,
Tensor* saved_batch_var_output, TensorFormat tensor_format,
bool use_reserved_space) {
OP_REQUIRES(context, side_input == nullptr,
errors::Internal(
"The CPU implementation of FusedBatchNorm does not support "
"side input."));
OP_REQUIRES(context,
activation_mode == FusedBatchNormActivationMode::kIdentity,
errors::Internal("The CPU implementation of FusedBatchNorm "
"does not support activations."));
if (use_reserved_space) {
Tensor* dummy_reserve_space = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(5, {}, &dummy_reserve_space));
// Initialize the memory, to avoid sanitizer alerts.
dummy_reserve_space->flat<U>()(0) = U();
}
Tensor transformed_x;
Tensor transformed_y;
if (tensor_format == FORMAT_NCHW) {
const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NHWC, in_batch,
in_rows, in_cols, in_depths),
&transformed_x));
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NHWC, in_batch,
in_rows, in_cols, in_depths),
&transformed_y));
// Perform NCHW to NHWC
std::vector<int32> perm = {0, 2, 3, 1};
OP_REQUIRES_OK(
context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
x_input, perm, &transformed_x));
} else {
transformed_x = x_input;
transformed_y = *y_output;
}
typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
typename TTypes<U>::ConstVec old_mean(running_mean_input.vec<U>());
typename TTypes<U>::ConstVec old_variance(running_variance_input.vec<U>());
typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
typename TTypes<U>::Vec new_mean(running_mean_output->vec<U>());
typename TTypes<U>::Vec new_variance(running_var_output->vec<U>());
typename TTypes<U>::Vec saved_batch_mean(saved_batch_mean_output->vec<U>());
typename TTypes<U>::Vec saved_batch_var(saved_batch_var_output->vec<U>());
const CPUDevice& d = context->eigen_device<CPUDevice>();
const int depth = x.dimension(3);
const int size = x.size();
const int rest_size = size / depth;
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
Eigen::array<int, 1> reduce_dims({0});
Eigen::array<int, 2> bcast_spec({rest_size, 1});
#else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
bcast_spec.set(0, rest_size);
#endif
auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
// This adjustment is for Bessel's correction
U rest_size_adjust =
static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
Eigen::Tensor<U, 1, Eigen::RowMajor> batch_mean(depth);
Eigen::Tensor<U, 1, Eigen::RowMajor> batch_variance(depth);
batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
auto x_centered = x_rest_by_depth -
batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
batch_variance.device(d) =
x_centered.square().sum(reduce_dims) * rest_size_inv;
auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale)
.eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();
y.reshape(rest_by_depth).device(d) = x_shifted;
if (exponential_avg_factor == U(1.0)) {
saved_batch_var.device(d) = batch_variance;
saved_batch_mean.device(d) = batch_mean;
new_variance.device(d) = batch_variance * rest_size_adjust;
new_mean.device(d) = batch_mean;
} else {
U one_minus_factor = U(1) - exponential_avg_factor;
saved_batch_var.device(d) = batch_variance;
saved_batch_mean.device(d) = batch_mean;
new_variance.device(d) =
one_minus_factor * old_variance +
(exponential_avg_factor * rest_size_adjust) * batch_variance;
new_mean.device(d) =
one_minus_factor * old_mean + exponential_avg_factor * batch_mean;
}
if (tensor_format == FORMAT_NCHW) {
// Perform NHWC to NCHW
const std::vector<int32> perm = {0, 3, 1, 2};
const Status s = ::tensorflow::DoTranspose(
context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
if (!s.ok()) {
context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
}
}
}
};
template <typename T, typename U>
struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
void operator()(OpKernelContext* context, const Tensor& x_input, void operator()(OpKernelContext* context, const Tensor& x_input,
const Tensor& scale_input, const Tensor& offset_input, const Tensor& scale_input, const Tensor& offset_input,
const Tensor& estimated_mean_input, const Tensor& estimated_mean_input,
const Tensor& estimated_variance_input, const Tensor& estimated_variance_input,
const Tensor* side_input, U epsilon, const Tensor* side_input, U epsilon, U exponential_avg_factor,
FusedBatchNormActivationMode activation_mode, FusedBatchNormActivationMode activation_mode,
Tensor* y_output, Tensor* batch_mean_output, Tensor* y_output, Tensor* batch_mean_output,
Tensor* batch_var_output, Tensor* saved_mean_output, Tensor* batch_var_output, Tensor* saved_mean_output,
Tensor* saved_var_output, TensorFormat tensor_format, Tensor* saved_var_output, TensorFormat tensor_format,
bool use_reserved_space, bool is_training) { bool use_reserved_space) {
OP_REQUIRES(context, side_input == nullptr, OP_REQUIRES(context, side_input == nullptr,
errors::Internal( errors::Internal(
"The CPU implementation of FusedBatchNorm does not support " "The CPU implementation of FusedBatchNorm does not support "
@ -150,9 +288,7 @@ struct FusedBatchNorm<CPUDevice, T, U> {
estimated_variance_input.vec<U>()); estimated_variance_input.vec<U>());
typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>()); typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>()); typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>()); typename TTypes<U>::Vec batch_variance(batch_var_output->vec<U>());
typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>());
typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>());
const CPUDevice& d = context->eigen_device<CPUDevice>(); const CPUDevice& d = context->eigen_device<CPUDevice>();
@ -168,80 +304,36 @@ struct FusedBatchNorm<CPUDevice, T, U> {
#else #else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth); one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec; Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
bcast_spec.set(0, rest_size); bcast_spec.set(0, rest_size);
#endif #endif
auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; auto x_centered =
U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); x_rest_by_depth -
// This adjustment is for Bessel's correction estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
U rest_size_adjust = auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one); .eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();
Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth); y.reshape(rest_by_depth).device(d) = x_shifted;
Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth); batch_mean.device(d) = estimated_mean;
BlockingCounter barrier(1); batch_variance.device(d) = estimated_variance;
std::atomic<uint8> task_counter;
auto on_done = [&]() { if (tensor_format == FORMAT_NCHW) {
uint8 count = --task_counter; // Perform NHWC to NCHW
if (count == 0) { const std::vector<int32> perm = {0, 3, 1, 2};
if (tensor_format == FORMAT_NCHW) { const Status s = ::tensorflow::DoTranspose(
// Perform NHWC to NCHW context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
const std::vector<int32> perm = {0, 3, 1, 2}; if (!s.ok()) {
const Status s = context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
transformed_y, perm, y_output);
if (!s.ok()) {
context->SetStatus(
errors::InvalidArgument("Transpose failed: ", s));
}
}
barrier.DecrementCount();
} }
};
if (is_training) {
// TODO(b/137108598): Extend kernel to allow use of exponential averaging.
mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
auto x_centered =
x_rest_by_depth - mean.reshape(one_by_depth).broadcast(bcast_spec);
variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv;
auto scaling_factor = ((variance + epsilon).rsqrt() * scale)
.eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();
task_counter = 5;
y.reshape(rest_by_depth).device(d, on_done) = x_shifted;
batch_var.device(d, on_done) = variance * rest_size_adjust;
saved_var.device(d, on_done) = variance;
batch_mean.device(d, on_done) = mean;
saved_mean.device(d, on_done) = mean;
} else { // is_training == false
auto x_centered =
x_rest_by_depth -
estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
.eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();
task_counter = 3;
y.reshape(rest_by_depth).device(d, on_done) = x_shifted;
mean.device(d, on_done) = estimated_mean;
variance.device(d, on_done) = estimated_variance;
} }
barrier.Wait();
} }
}; };
@ -662,17 +754,17 @@ class CudnnBatchNormAllocatorInOutput : public ScratchAllocator {
bool output_allocated = false; bool output_allocated = false;
}; };
template <typename T, typename U> template <typename T, typename U, bool is_training>
struct FusedBatchNorm<GPUDevice, T, U> { struct FusedBatchNorm<GPUDevice, T, U, is_training> {
void operator()(OpKernelContext* context, const Tensor& x, void operator()(OpKernelContext* context, const Tensor& x,
const Tensor& scale, const Tensor& offset, const Tensor& scale, const Tensor& offset,
const Tensor& estimated_mean, const Tensor& estimated_mean,
const Tensor& estimated_variance, const Tensor* side_input, const Tensor& estimated_variance, const Tensor* side_input,
U epsilon, FusedBatchNormActivationMode activation_mode, U epsilon, U exponential_avg_factor,
Tensor* y, Tensor* batch_mean, Tensor* batch_var, FusedBatchNormActivationMode activation_mode, Tensor* y,
Tensor* saved_mean, Tensor* saved_inv_var, Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
TensorFormat tensor_format, bool use_reserved_space, Tensor* saved_inv_var, TensorFormat tensor_format,
bool is_training) { bool use_reserved_space) {
auto* stream = context->op_device_context()->stream(); auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available")); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
@ -837,15 +929,13 @@ struct FusedBatchNorm<GPUDevice, T, U> {
workspace_allocator.reset( workspace_allocator.reset(
new functor::CudnnBatchNormAllocatorInTemp<uint8>(context)); new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
} }
// TODO(b/137108598): Extend kernel to allow use of exponential averaging.
const double exponential_average_factor = 1.0;
bool cudnn_launch_status = bool cudnn_launch_status =
stream stream
->ThenBatchNormalizationForward( ->ThenBatchNormalizationForward(
x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr, x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
estimated_variance_ptr, side_input_ptr, x_desc, estimated_variance_ptr, side_input_ptr, x_desc,
scale_offset_desc, static_cast<double>(epsilon), scale_offset_desc, static_cast<double>(epsilon),
exponential_average_factor, static_cast<double>(exponential_avg_factor),
AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr, AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr,
&batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr, &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
is_training, reserve_space_allocator.get(), is_training, reserve_space_allocator.get(),
@ -1075,6 +1165,10 @@ class FusedBatchNormOpBase : public OpKernel {
float epsilon; float epsilon;
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
epsilon_ = U(epsilon); epsilon_ = U(epsilon);
float exponential_avg_factor;
OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
&exponential_avg_factor));
exponential_avg_factor_ = U(exponential_avg_factor);
string tensor_format; string tensor_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
@ -1165,17 +1259,6 @@ class FusedBatchNormOpBase : public OpKernel {
"channel dimension to be a multiple of 4.")); "channel dimension to be a multiple of 4."));
} }
if (is_training_) {
OP_REQUIRES(
context, estimated_mean.dim_size(0) == 0,
errors::InvalidArgument("estimated_mean must be empty for training",
estimated_mean.shape().DebugString()));
OP_REQUIRES(context, estimated_variance.dim_size(0) == 0,
errors::InvalidArgument(
"estimated_variance must be empty for training",
estimated_variance.shape().DebugString()));
}
Tensor* y = nullptr; Tensor* y = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, x.shape(), &y)); {0}, 0, x.shape(), &y));
@ -1192,15 +1275,24 @@ class FusedBatchNormOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(), OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
&saved_maybe_inv_var)); &saved_maybe_inv_var));
functor::FusedBatchNorm<Device, T, U>()( if (is_training_) {
context, x, scale, offset, estimated_mean, estimated_variance, functor::FusedBatchNorm<Device, T, U, true>()(
side_input, epsilon_, activation_mode_, y, batch_mean, batch_var, context, x, scale, offset, estimated_mean, estimated_variance,
saved_mean, saved_maybe_inv_var, tensor_format_, use_reserved_space, side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
is_training_); batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
tensor_format_, use_reserved_space);
} else {
functor::FusedBatchNorm<Device, T, U, false>()(
context, x, scale, offset, estimated_mean, estimated_variance,
side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
tensor_format_, use_reserved_space);
}
} }
private: private:
U epsilon_; U epsilon_;
U exponential_avg_factor_;
TensorFormat tensor_format_; TensorFormat tensor_format_;
bool is_training_; bool is_training_;
bool has_side_input_; bool has_side_input_;

View File

@ -40,6 +40,7 @@ TEST_F(FusedBatchNormOpTest, Training) {
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Attr("exponential_avg_factor", 1.0)
.Attr("epsilon", 0.001) .Attr("epsilon", 0.001)
.Attr("is_training", true) .Attr("is_training", true)
.Finalize(node_def())); .Finalize(node_def()));
@ -67,6 +68,41 @@ TEST_F(FusedBatchNormOpTest, Training) {
test::ExpectTensorNear<float>(expected_variance, *GetOutput(2), 0.01); test::ExpectTensorNear<float>(expected_variance, *GetOutput(2), 0.01);
} }
TEST_F(FusedBatchNormOpTest, TrainingRunningMean) {
TF_EXPECT_OK(NodeDefBuilder("batch_norm_op", "FusedBatchNorm")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("exponential_avg_factor", 0.5)
.Attr("epsilon", 0.001)
.Attr("is_training", true)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<float>(TensorShape({1, 1, 6, 2}),
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
AddInputFromArray<float>(TensorShape({2}), {4.0, 4.0});
AddInputFromArray<float>(TensorShape({2}), {2.0, 2.0});
AddInputFromArray<float>(TensorShape({2}), {6.0, 6.0});
AddInputFromArray<float>(TensorShape({2}), {16.0, 16.0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2}));
test::FillValues<float>(&expected, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83,
3.17, 3.17, 5.51, 5.51, 7.86, 7.86});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01);
Tensor expected_mean(allocator(), DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&expected_mean, {8, 8});
test::ExpectTensorNear<float>(expected_mean, *GetOutput(1), 0.01);
Tensor expected_variance(allocator(), DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&expected_variance, {15.00, 15.00});
test::ExpectTensorNear<float>(expected_variance, *GetOutput(2), 0.01);
}
TEST_F(FusedBatchNormOpTest, Inference) { TEST_F(FusedBatchNormOpTest, Inference) {
TF_EXPECT_OK(NodeDefBuilder("batch_norm_op", "FusedBatchNorm") TF_EXPECT_OK(NodeDefBuilder("batch_norm_op", "FusedBatchNorm")
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
@ -93,6 +129,33 @@ TEST_F(FusedBatchNormOpTest, Inference) {
test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01); test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01);
} }
TEST_F(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) {
TF_EXPECT_OK(NodeDefBuilder("batch_norm_op", "FusedBatchNorm")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("exponential_avg_factor", 0.5)
.Attr("epsilon", 0.001)
.Attr("is_training", false)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<float>(TensorShape({1, 1, 6, 2}),
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
AddInputFromArray<float>(TensorShape({2}), {4.0, 4.0});
AddInputFromArray<float>(TensorShape({2}), {2.0, 2.0});
AddInputFromArray<float>(TensorShape({2}), {10, 10});
AddInputFromArray<float>(TensorShape({2}), {11.67f, 11.67f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2}));
test::FillValues<float>(&expected, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83,
3.17, 3.17, 5.51, 5.51, 7.86, 7.86});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.01);
}
class FusedBatchNormGradOpTest : public OpsTestBase {}; class FusedBatchNormGradOpTest : public OpsTestBase {};
TEST_F(FusedBatchNormGradOpTest, Simple) { TEST_F(FusedBatchNormGradOpTest, Simple) {

View File

@ -179,6 +179,7 @@ REGISTER_OP("FusedBatchNorm")
.Output("reserve_space_2: T") .Output("reserve_space_2: T")
.Attr("T: {float}") .Attr("T: {float}")
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")
.Attr(GetConvnetDataFormatAttrString()) .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true") .Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape); .SetShapeFn(shape_inference::FusedBatchNormShape);
@ -197,6 +198,7 @@ REGISTER_OP("FusedBatchNormV2")
.Attr("T: {half, bfloat16, float}") .Attr("T: {half, bfloat16, float}")
.Attr("U: {float}") .Attr("U: {float}")
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")
.Attr(GetConvnetDataFormatAttrString()) .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true") .Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape); .SetShapeFn(shape_inference::FusedBatchNormShape);
@ -216,6 +218,7 @@ REGISTER_OP("FusedBatchNormV3")
.Attr("T: {half, bfloat16, float}") .Attr("T: {half, bfloat16, float}")
.Attr("U: {float}") .Attr("U: {float}")
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")
.Attr(GetConvnetDataFormatAttrString()) .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true") .Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormV3Shape); .SetShapeFn(shape_inference::FusedBatchNormV3Shape);
@ -236,6 +239,7 @@ REGISTER_OP("_FusedBatchNormEx")
.Attr("T: {half, float}") .Attr("T: {half, float}")
.Attr("U: {float}") .Attr("U: {float}")
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
.Attr("exponential_avg_factor: float = 1.0")
.Attr("num_side_inputs: int >= 0 = 0") .Attr("num_side_inputs: int >= 0 = 0")
.Attr("activation_mode: string = \"Identity\"") .Attr("activation_mode: string = \"Identity\"")
.Attr(GetConvnetDataFormatAttrString()) .Attr(GetConvnetDataFormatAttrString())

View File

@ -1590,7 +1590,7 @@ tf_module {
} }
member_method { member_method {
name: "FusedBatchNorm" name: "FusedBatchNorm"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "FusedBatchNormGrad" name: "FusedBatchNormGrad"
@ -1606,11 +1606,11 @@ tf_module {
} }
member_method { member_method {
name: "FusedBatchNormV2" name: "FusedBatchNormV2"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "FusedBatchNormV3" name: "FusedBatchNormV3"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "FusedPadConv2D" name: "FusedPadConv2D"

View File

@ -1590,7 +1590,7 @@ tf_module {
} }
member_method { member_method {
name: "FusedBatchNorm" name: "FusedBatchNorm"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "FusedBatchNormGrad" name: "FusedBatchNormGrad"
@ -1606,11 +1606,11 @@ tf_module {
} }
member_method { member_method {
name: "FusedBatchNormV2" name: "FusedBatchNormV2"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "FusedBatchNormV3" name: "FusedBatchNormV3"
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'NHWC\', \'True\', \'None\'], " argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'exponential_avg_factor\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1\', \'NHWC\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "FusedPadConv2D" name: "FusedPadConv2D"