Support computing exponential running mean and variance in fused_batch_norm.
Simplify unit test code for fused_batch_norm and its gradient. PiperOrigin-RevId: 296546311 Change-Id: Ieb78b4038a39dd2bcde16302541e733f4e8604bd
This commit is contained in:
parent
bdf0ea41b9
commit
ce9564d430
@ -100,10 +100,10 @@ TEST(FusedBatchnormReserveSpaceTest, Test) {
|
|||||||
Output offset =
|
Output offset =
|
||||||
Const(root.WithOpName("offset"), Input::Initializer(offset_data));
|
Const(root.WithOpName("offset"), Input::Initializer(offset_data));
|
||||||
|
|
||||||
Tensor mean_data(DT_FLOAT, TensorShape({10}));
|
Tensor mean_data(DT_FLOAT, TensorShape({0}));
|
||||||
Output mean = Const(root.WithOpName("mean"), Input::Initializer(mean_data));
|
Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data));
|
||||||
|
|
||||||
Tensor variance_data(DT_FLOAT, TensorShape({10}));
|
Tensor variance_data(DT_FLOAT, TensorShape({0}));
|
||||||
Output variance =
|
Output variance =
|
||||||
Const(root.WithOpName("variance"), Input::Initializer(variance_data));
|
Const(root.WithOpName("variance"), Input::Initializer(variance_data));
|
||||||
|
|
||||||
|
@ -787,7 +787,6 @@ cc_library(
|
|||||||
"//tensorflow/core/util:padding",
|
"//tensorflow/core/util:padding",
|
||||||
"//tensorflow/core/util:tensor_format",
|
"//tensorflow/core/util:tensor_format",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
@ -1083,11 +1083,7 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
|
|||||||
|
|
||||||
bool is_training;
|
bool is_training;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
|
||||||
float exponential_avg_factor;
|
int number_inputs = (is_training) ? 3 : 5;
|
||||||
if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
|
|
||||||
exponential_avg_factor = 1.0f; // default value
|
|
||||||
}
|
|
||||||
int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
|
|
||||||
string data_format_str;
|
string data_format_str;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
|
||||||
TensorFormat data_format;
|
TensorFormat data_format;
|
||||||
@ -1183,8 +1179,13 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
|
|||||||
// Set the correct shapes for reserve_spaces
|
// Set the correct shapes for reserve_spaces
|
||||||
// so that gradients can be performed when
|
// so that gradients can be performed when
|
||||||
// the op is in a symbolic condition.
|
// the op is in a symbolic condition.
|
||||||
|
if (is_training) {
|
||||||
|
c->set_output(3, c->Vector(0));
|
||||||
|
c->set_output(4, c->Vector(0));
|
||||||
|
} else {
|
||||||
c->set_output(3, c->Vector(channel_dim));
|
c->set_output(3, c->Vector(channel_dim));
|
||||||
c->set_output(4, c->Vector(channel_dim));
|
c->set_output(4, c->Vector(channel_dim));
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2325,7 +2326,7 @@ Status SparseReduceShapeFn(InferenceContext* c) {
|
|||||||
auto axes_vec = axes_tensor->flat<int32>();
|
auto axes_vec = axes_tensor->flat<int32>();
|
||||||
|
|
||||||
int64 ndims = shape_vec.size();
|
int64 ndims = shape_vec.size();
|
||||||
absl::flat_hash_set<int64> axes;
|
std::unordered_set<int64> axes;
|
||||||
for (int i = 0; i < axes_vec.size(); i++) {
|
for (int i = 0; i < axes_vec.size(); i++) {
|
||||||
axes.insert((axes_vec(i) + ndims) % ndims);
|
axes.insert((axes_vec(i) + ndims) % ndims);
|
||||||
}
|
}
|
||||||
|
@ -528,9 +528,9 @@ TEST(CommonShapeFnsTest, FusedBatchNormExTest) {
|
|||||||
.Finalize(&op.node_def));
|
.Finalize(&op.node_def));
|
||||||
|
|
||||||
// Channels are not multiple of 4.
|
// Channels are not multiple of 4.
|
||||||
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[2];[2]");
|
INFER_ERROR("must be divisible by 4", op, "[2,2,2,2];[2];[2];[0];[0]");
|
||||||
|
|
||||||
INFER_OK(op, "[2,2,2,4];[4];[4];[4];[4]",
|
INFER_OK(op, "[2,2,2,4];[4];[4];[0];[0]",
|
||||||
"[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?");
|
"[d0_0,d0_1,d0_2,d0_3];[d0_3];[d0_3];[d0_3];[d0_3];?");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -859,8 +859,8 @@ class VirtualSchedulerTest : public ::testing::Test {
|
|||||||
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
|
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
|
||||||
auto offset =
|
auto offset =
|
||||||
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
|
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
|
||||||
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT);
|
auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
|
||||||
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT);
|
auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
|
||||||
|
|
||||||
auto batch_norm = ops::FusedBatchNorm(
|
auto batch_norm = ops::FusedBatchNorm(
|
||||||
s.WithOpName("bn"), x, scale, offset, mean, var,
|
s.WithOpName("bn"), x, scale, offset, mean, var,
|
||||||
@ -2146,8 +2146,8 @@ versions {
|
|||||||
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
|
ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
|
||||||
auto offset =
|
auto offset =
|
||||||
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
|
ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
|
||||||
auto mean = ops::RandomUniform(s.WithOpName("mean"), {depth_in_}, DT_FLOAT);
|
auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
|
||||||
auto var = ops::RandomUniform(s.WithOpName("var"), {depth_in_}, DT_FLOAT);
|
auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
|
||||||
|
|
||||||
auto batch_norm = ops::FusedBatchNorm(
|
auto batch_norm = ops::FusedBatchNorm(
|
||||||
s.WithOpName("bn"), x, scale, offset, mean, var,
|
s.WithOpName("bn"), x, scale, offset, mean, var,
|
||||||
|
@ -497,8 +497,8 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
|||||||
Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
|
Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
|
||||||
Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
|
Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
|
||||||
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
|
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
|
||||||
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {16});
|
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
|
||||||
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {16});
|
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
|
||||||
Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1},
|
Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1},
|
||||||
"SAME", ops::Conv2D::DataFormat("NHWC"));
|
"SAME", ops::Conv2D::DataFormat("NHWC"));
|
||||||
auto fbn1_op =
|
auto fbn1_op =
|
||||||
|
@ -181,8 +181,7 @@ TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
|
|||||||
TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
||||||
ShapeInferenceTestOp op("FusedBatchNorm");
|
ShapeInferenceTestOp op("FusedBatchNorm");
|
||||||
|
|
||||||
auto set_op = [&op](bool is_training, float exponential_avg_factor,
|
auto set_op = [&op](bool is_training, string data_format) {
|
||||||
string data_format) {
|
|
||||||
TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
|
TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
@ -191,11 +190,10 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
|||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
.Attr("data_format", data_format)
|
.Attr("data_format", data_format)
|
||||||
.Attr("is_training", is_training)
|
.Attr("is_training", is_training)
|
||||||
.Attr("exponential_avg_factor", exponential_avg_factor)
|
|
||||||
.Finalize(&op.node_def));
|
.Finalize(&op.node_def));
|
||||||
};
|
};
|
||||||
|
|
||||||
set_op(true, 1.0, "NHWC");
|
set_op(true, "NHWC");
|
||||||
// Test rank errors.
|
// Test rank errors.
|
||||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||||
@ -209,21 +207,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
|||||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
|
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
|
||||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
|
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
|
||||||
|
|
||||||
set_op(true, 0.5, "NHWC");
|
set_op(true, "NCHW");
|
||||||
// Test rank errors.
|
|
||||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
|
|
||||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
|
||||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
|
|
||||||
INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
|
|
||||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
|
|
||||||
INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
|
|
||||||
"[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
|
|
||||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
|
|
||||||
"[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
|
|
||||||
|
|
||||||
set_op(true, 1.0, "NCHW");
|
|
||||||
// Test rank errors.
|
// Test rank errors.
|
||||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||||
@ -237,7 +221,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
|||||||
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
|
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
|
||||||
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
|
"[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
|
||||||
|
|
||||||
set_op(false, 1.0, "NHWC");
|
set_op(false, "NHWC");
|
||||||
// Test rank errors.
|
// Test rank errors.
|
||||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||||
@ -255,7 +239,7 @@ TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
|
|||||||
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
|
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
|
||||||
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
|
"[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
|
||||||
|
|
||||||
set_op(false, 1.0, "NCHW");
|
set_op(false, "NCHW");
|
||||||
// Test rank errors.
|
// Test rank errors.
|
||||||
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
|
||||||
@ -295,14 +279,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
|
|||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
||||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
|
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
|
||||||
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
|
INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
|
||||||
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
|
INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
|
||||||
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[d4_0];[d4_0]");
|
INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
|
||||||
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
|
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
|
||||||
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
|
"[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
|
||||||
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];["
|
"[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
|
||||||
"d0_3|d2_0|d3_0|d4_0]");
|
|
||||||
|
|
||||||
set_op("NCHW");
|
set_op("NCHW");
|
||||||
// Test rank errors.
|
// Test rank errors.
|
||||||
@ -312,14 +295,13 @@ TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
|
|||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
|
||||||
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
|
||||||
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
// Channel dim of first input is merged with the single dim in other 4 inputs.
|
||||||
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
|
INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
|
||||||
INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]");
|
INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[0];[0]");
|
||||||
INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[d3_0];[d3_0]");
|
INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[0];[0]");
|
||||||
INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[d4_0];[d4_0]");
|
INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[0];[0]");
|
||||||
INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
|
INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
|
||||||
"[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
|
"[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
|
||||||
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];["
|
"[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
|
||||||
"d0_1|d2_0|d3_0|d4_0]");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
|
TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
|
||||||
|
@ -4999,7 +4999,7 @@ cuda_py_test(
|
|||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["ops/nn_fused_batchnorm_test.py"],
|
srcs = ["ops/nn_fused_batchnorm_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 24,
|
shard_count = 16,
|
||||||
deps = [
|
deps = [
|
||||||
":array_ops",
|
":array_ops",
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
|
@ -568,13 +568,11 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
|||||||
x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1])
|
x = constant_op.constant(1, dtype=dtypes.float32, shape=[0, 1, 1, 1])
|
||||||
scale = constant_op.constant([1], dtype=dtypes.float32)
|
scale = constant_op.constant([1], dtype=dtypes.float32)
|
||||||
offset = constant_op.constant([1], dtype=dtypes.float32)
|
offset = constant_op.constant([1], dtype=dtypes.float32)
|
||||||
mean = constant_op.constant([1], dtype=dtypes.float32)
|
|
||||||
variance = constant_op.constant([1], dtype=dtypes.float32)
|
|
||||||
|
|
||||||
# Calling fused_batch_norm with an empty input should output a NaN in the
|
# Calling fused_batch_norm with an empty input should output a NaN in the
|
||||||
# latter four outputs without triggering the check_numerics callback
|
# latter four outputs without triggering the check_numerics callback
|
||||||
batch_norm_res = gen_nn_ops._fused_batch_norm(
|
batch_norm_res = gen_nn_ops._fused_batch_norm(
|
||||||
x=x, scale=scale, offset=offset, mean=mean, variance=variance)
|
x=x, scale=scale, offset=offset, mean=[], variance=[])
|
||||||
|
|
||||||
_, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res)
|
_, batch_mean, batch_variance, _, _ = self.evaluate(batch_norm_res)
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.compat import compat
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -60,7 +59,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
scale_shape,
|
scale_shape,
|
||||||
scale_dtype,
|
scale_dtype,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
exponential_avg_factor=1.0,
|
|
||||||
data_format='NHWC'):
|
data_format='NHWC'):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
||||||
@ -83,7 +81,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
mean=mean,
|
mean=mean,
|
||||||
variance=var,
|
variance=var,
|
||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
exponential_avg_factor=exponential_avg_factor,
|
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
is_training=False)
|
is_training=False)
|
||||||
y_val = self.evaluate(y)
|
y_val = self.evaluate(y)
|
||||||
@ -95,37 +92,17 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
||||||
self.assertAllClose(y_ref, y_val, atol=atol)
|
self.assertAllClose(y_ref, y_val, atol=atol)
|
||||||
|
|
||||||
def _running_mean(self, old_mean, new_val, factor):
|
def _training_ref(self, x, scale, offset, epsilon, data_format):
|
||||||
if factor == 1.0:
|
|
||||||
return new_val
|
|
||||||
else:
|
|
||||||
return (1.0 - factor) * old_mean + factor * new_val
|
|
||||||
|
|
||||||
def _training_ref(self, x, scale, offset, old_mean, old_var,
|
|
||||||
exponential_avg_factor, epsilon, data_format):
|
|
||||||
if data_format not in ['NHWC', 'NCHW']:
|
if data_format not in ['NHWC', 'NCHW']:
|
||||||
raise ValueError('data_format must be NCHW or NHWC, '
|
raise ValueError('data_format must be NCHW or NHWC, '
|
||||||
'got %s.' % data_format)
|
'got %s.' % data_format)
|
||||||
if data_format == 'NCHW':
|
if data_format == 'NCHW':
|
||||||
x = array_ops.transpose(x, [0, 2, 3, 1])
|
x = array_ops.transpose(x, [0, 2, 3, 1])
|
||||||
batch_mean, batch_var = nn_impl.moments(
|
mean, var = nn_impl.moments(
|
||||||
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
|
math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False)
|
||||||
|
y = self._batch_norm(x, mean, var, offset, scale, epsilon)
|
||||||
y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon)
|
|
||||||
if data_format == 'NCHW':
|
if data_format == 'NCHW':
|
||||||
y = array_ops.transpose(y, [0, 3, 1, 2])
|
y = array_ops.transpose(y, [0, 3, 1, 2])
|
||||||
|
|
||||||
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
|
||||||
# the denominator in the formula to calculate variance, while
|
|
||||||
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
|
|
||||||
sample_size = math_ops.cast(
|
|
||||||
array_ops.size(x) / array_ops.size(scale), scale.dtype)
|
|
||||||
batch_var_corrected = batch_var * sample_size / (
|
|
||||||
math_ops.maximum(sample_size - 1.0, 1.0))
|
|
||||||
|
|
||||||
mean = self._running_mean(old_mean, batch_mean, exponential_avg_factor)
|
|
||||||
var = self._running_mean(old_var, batch_var_corrected,
|
|
||||||
exponential_avg_factor)
|
|
||||||
return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
|
return self.evaluate(y), self.evaluate(mean), self.evaluate(var)
|
||||||
|
|
||||||
def _test_training(self,
|
def _test_training(self,
|
||||||
@ -134,15 +111,11 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
scale_shape,
|
scale_shape,
|
||||||
scale_dtype,
|
scale_dtype,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
exponential_avg_factor=1.0,
|
|
||||||
data_format='NHWC'):
|
data_format='NHWC'):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
x_val = np.random.random_sample(x_shape).astype(x_dtype)
|
||||||
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||||
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||||
old_mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
|
||||||
old_var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
|
|
||||||
|
|
||||||
with self.cached_session(use_gpu=use_gpu) as sess:
|
with self.cached_session(use_gpu=use_gpu) as sess:
|
||||||
x = constant_op.constant(x_val, name='x')
|
x = constant_op.constant(x_val, name='x')
|
||||||
scale = constant_op.constant(scale_val, name='scale')
|
scale = constant_op.constant(scale_val, name='scale')
|
||||||
@ -152,20 +125,20 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
x,
|
x,
|
||||||
scale,
|
scale,
|
||||||
offset,
|
offset,
|
||||||
mean=old_mean_val,
|
|
||||||
variance=old_var_val,
|
|
||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
exponential_avg_factor=exponential_avg_factor,
|
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
is_training=True)
|
is_training=True)
|
||||||
y_val, mean_val, var_val = self.evaluate([y, mean, var])
|
y_val, mean_val, var_val = self.evaluate([y, mean, var])
|
||||||
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset,
|
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
|
||||||
old_mean_val, old_var_val,
|
data_format)
|
||||||
exponential_avg_factor,
|
|
||||||
epsilon, data_format)
|
|
||||||
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
|
||||||
self.assertAllClose(y_ref, y_val, atol=y_atol)
|
self.assertAllClose(y_ref, y_val, atol=y_atol)
|
||||||
self.assertAllClose(mean_ref, mean_val, atol=1e-3)
|
self.assertAllClose(mean_ref, mean_val, atol=1e-3)
|
||||||
|
# This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as
|
||||||
|
# the denominator in the formula to calculate variance, while
|
||||||
|
# tf.compat.v1.nn.fused_batch_norm has Bessel's correction built in.
|
||||||
|
sample_size = x_val.size / scale_val.size
|
||||||
|
var_ref = var_ref * sample_size / (max(sample_size - 1.0, 1.0))
|
||||||
self.assertAllClose(var_ref, var_val, atol=1e-3)
|
self.assertAllClose(var_ref, var_val, atol=1e-3)
|
||||||
|
|
||||||
def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
|
def _compute_gradient_error_float16(self, x, x32, x_shape, y, y32, y_shape):
|
||||||
@ -211,7 +184,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
scale_shape,
|
scale_shape,
|
||||||
scale_dtype,
|
scale_dtype,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
exponential_avg_factor=1.0,
|
|
||||||
data_format='NHWC',
|
data_format='NHWC',
|
||||||
is_training=True):
|
is_training=True):
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
@ -223,6 +195,10 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
x = constant_op.constant(x_val, name='x')
|
x = constant_op.constant(x_val, name='x')
|
||||||
scale = constant_op.constant(scale_val, name='scale')
|
scale = constant_op.constant(scale_val, name='scale')
|
||||||
offset = constant_op.constant(offset_val, name='offset')
|
offset = constant_op.constant(offset_val, name='offset')
|
||||||
|
if is_training:
|
||||||
|
pop_mean = None
|
||||||
|
pop_var = None
|
||||||
|
else:
|
||||||
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||||
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||||
y, _, _ = nn_impl.fused_batch_norm(
|
y, _, _ = nn_impl.fused_batch_norm(
|
||||||
@ -231,7 +207,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
offset,
|
offset,
|
||||||
mean=pop_mean,
|
mean=pop_mean,
|
||||||
variance=pop_var,
|
variance=pop_var,
|
||||||
exponential_avg_factor=exponential_avg_factor,
|
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
if x_dtype != np.float16:
|
if x_dtype != np.float16:
|
||||||
@ -249,7 +224,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
mean=pop_mean,
|
mean=pop_mean,
|
||||||
variance=pop_var,
|
variance=pop_var,
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
exponential_avg_factor=exponential_avg_factor,
|
|
||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
|
err_x = self._compute_gradient_error_float16(x, x32, x_shape, y, y32,
|
||||||
x_shape)
|
x_shape)
|
||||||
@ -270,7 +244,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
scale_shape,
|
scale_shape,
|
||||||
scale_dtype,
|
scale_dtype,
|
||||||
use_gpu=True,
|
use_gpu=True,
|
||||||
exponential_avg_factor=1.0,
|
|
||||||
data_format='NHWC',
|
data_format='NHWC',
|
||||||
is_training=True,
|
is_training=True,
|
||||||
err_tolerance=1e-3):
|
err_tolerance=1e-3):
|
||||||
@ -285,6 +258,10 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
grad_y = constant_op.constant(grad_y_val, name='grad_y')
|
grad_y = constant_op.constant(grad_y_val, name='grad_y')
|
||||||
scale = constant_op.constant(scale_val, name='scale')
|
scale = constant_op.constant(scale_val, name='scale')
|
||||||
offset = constant_op.constant(offset_val, name='offset')
|
offset = constant_op.constant(offset_val, name='offset')
|
||||||
|
if is_training:
|
||||||
|
pop_mean = None
|
||||||
|
pop_var = None
|
||||||
|
else:
|
||||||
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
pop_mean = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||||
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
pop_var = np.random.random_sample(scale_shape).astype(scale_dtype)
|
||||||
y, _, _ = nn_impl.fused_batch_norm(
|
y, _, _ = nn_impl.fused_batch_norm(
|
||||||
@ -293,7 +270,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
offset,
|
offset,
|
||||||
mean=pop_mean,
|
mean=pop_mean,
|
||||||
variance=pop_var,
|
variance=pop_var,
|
||||||
exponential_avg_factor=exponential_avg_factor,
|
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
grad_x, grad_scale, grad_offset = gradients_impl.gradients(
|
grad_x, grad_scale, grad_offset = gradients_impl.gradients(
|
||||||
@ -335,7 +311,6 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
offset,
|
offset,
|
||||||
mean=pop_mean,
|
mean=pop_mean,
|
||||||
variance=pop_var,
|
variance=pop_var,
|
||||||
exponential_avg_factor=exponential_avg_factor,
|
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
is_training=is_training)
|
is_training=is_training)
|
||||||
grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
|
grad_x32, grad_scale32, grad_offset32 = gradients_impl.gradients(
|
||||||
@ -364,142 +339,282 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
self.assertLess(err_grad_x_2, err_tolerance)
|
self.assertLess(err_grad_x_2, err_tolerance)
|
||||||
self.assertLess(err_grad_scale, err_tolerance)
|
self.assertLess(err_grad_scale, err_tolerance)
|
||||||
|
|
||||||
def _runtests(self, x_shape, is_training, gradient_test=False):
|
|
||||||
use_gpu_vals = [False]
|
|
||||||
if test.is_gpu_available(cuda_only=True):
|
|
||||||
use_gpu_vals += [True]
|
|
||||||
factors = [1.0,]
|
|
||||||
if compat.forward_compatible(2020, 2, 29):
|
|
||||||
factors += [0.6,]
|
|
||||||
for dtype in [np.float16, np.float32]:
|
|
||||||
for use_gpu in use_gpu_vals:
|
|
||||||
for data_format in ['NHWC', 'NCHW']:
|
|
||||||
if data_format == 'NHWC':
|
|
||||||
scale_shape = x_shape[-1:]
|
|
||||||
else:
|
|
||||||
scale_shape = x_shape[1:2]
|
|
||||||
for exponential_avg_factor in factors:
|
|
||||||
if gradient_test:
|
|
||||||
self._test_gradient(
|
|
||||||
x_shape,
|
|
||||||
dtype,
|
|
||||||
scale_shape,
|
|
||||||
np.float32,
|
|
||||||
use_gpu=use_gpu,
|
|
||||||
data_format=data_format,
|
|
||||||
is_training=is_training,
|
|
||||||
exponential_avg_factor=exponential_avg_factor)
|
|
||||||
else:
|
|
||||||
if is_training:
|
|
||||||
self._test_training(
|
|
||||||
x_shape,
|
|
||||||
dtype,
|
|
||||||
scale_shape,
|
|
||||||
np.float32,
|
|
||||||
use_gpu=use_gpu,
|
|
||||||
data_format=data_format,
|
|
||||||
exponential_avg_factor=exponential_avg_factor)
|
|
||||||
else:
|
|
||||||
self._test_inference(
|
|
||||||
x_shape,
|
|
||||||
dtype,
|
|
||||||
scale_shape,
|
|
||||||
np.float32,
|
|
||||||
use_gpu=use_gpu,
|
|
||||||
data_format=data_format,
|
|
||||||
exponential_avg_factor=exponential_avg_factor)
|
|
||||||
|
|
||||||
def testInferenceShape1(self):
|
def testInferenceShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
self._runtests(x_shape, False)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
|
||||||
|
|
||||||
def testInferenceShape2(self):
|
def testInferenceShape2(self):
|
||||||
x_shape = [1, 1, 6, 2]
|
x_shape = [1, 1, 6, 2]
|
||||||
self._runtests(x_shape, False)
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
for dtype in [np.float16, np.float32]:
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testInferenceShape3(self):
|
def testInferenceShape3(self):
|
||||||
x_shape = [1, 2, 1, 6]
|
x_shape = [1, 2, 1, 6]
|
||||||
self._runtests(x_shape, False)
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
for dtype in [np.float16, np.float32]:
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
|
||||||
def testInferenceShape4(self):
|
def testInferenceShape4(self):
|
||||||
x_shape = [27, 131, 127, 6]
|
x_shape = [27, 131, 127, 6]
|
||||||
self._runtests(x_shape, False)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testInferenceShape5(self):
|
def testInferenceShape5(self):
|
||||||
x_shape = [0, 131, 127, 6]
|
x_shape = [0, 131, 127, 6]
|
||||||
self._runtests(x_shape, False)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_inference(
|
||||||
|
x_shape,
|
||||||
|
dtype, [131],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NCHW')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape,
|
||||||
|
dtype, [131],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NCHW')
|
||||||
|
self._test_inference(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testTrainingShape1(self):
|
def testTrainingShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
self._runtests(x_shape, True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [1], np.float32, use_gpu=False, data_format='NCHW')
|
||||||
|
|
||||||
def testTrainingShape2(self):
|
def testTrainingShape2(self):
|
||||||
x_shape = [1, 1, 6, 2]
|
x_shape = [1, 1, 6, 2]
|
||||||
self._runtests(x_shape, True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testTrainingShape3(self):
|
def testTrainingShape3(self):
|
||||||
x_shape = [1, 2, 1, 6]
|
x_shape = [1, 2, 1, 6]
|
||||||
self._runtests(x_shape, True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [2], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [2], np.float32, use_gpu=False, data_format='NCHW')
|
||||||
|
|
||||||
def testTrainingShape4(self):
|
def testTrainingShape4(self):
|
||||||
x_shape = [27, 131, 127, 6]
|
x_shape = [27, 131, 127, 6]
|
||||||
self._runtests(x_shape, True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [131], np.float32, use_gpu=False, data_format='NCHW')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
@test_util.disable_xla('b/141236973: Empty inputs wrong on CPU.')
|
||||||
def testTrainingShape5(self):
|
def testTrainingShape5(self):
|
||||||
x_shape = [0, 131, 127, 6]
|
x_shape = [0, 131, 127, 6]
|
||||||
self._runtests(x_shape, True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_training(
|
||||||
|
x_shape,
|
||||||
|
dtype, [131],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NCHW')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||||
|
self._test_training(
|
||||||
|
x_shape,
|
||||||
|
dtype, [131],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NCHW')
|
||||||
|
self._test_training(
|
||||||
|
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||||
|
|
||||||
def testBatchNormGradInferenceShape1(self):
|
@test_util.run_deprecated_v1
|
||||||
|
def testBatchNormGradShape1(self):
|
||||||
|
for is_training in [True, False]:
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [1],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [1],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [1],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [1],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchNormGradInferenceShape2(self):
|
def testBatchNormGradShape2(self):
|
||||||
|
for is_training in [True, False]:
|
||||||
x_shape = [1, 1, 6, 2]
|
x_shape = [1, 1, 6, 2]
|
||||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [2],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [2],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchNormGradInferenceShape3(self):
|
def testBatchNormGradShape3(self):
|
||||||
|
for is_training in [True, False]:
|
||||||
x_shape = [1, 2, 1, 6]
|
x_shape = [1, 2, 1, 6]
|
||||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [2],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [2],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchNormGradInferenceShape4(self):
|
def testBatchNormGradShape4(self):
|
||||||
|
for is_training in [True, False]:
|
||||||
x_shape = [5, 7, 11, 4]
|
x_shape = [5, 7, 11, 4]
|
||||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [7],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [4],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [4],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NHWC',
|
||||||
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
|
x_shape,
|
||||||
|
dtype, [7],
|
||||||
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test never passed for XLA')
|
@test_util.disable_xla('This test never passed for XLA')
|
||||||
def testBatchNormGradInferenceShape5(self):
|
def testBatchNormGradShape5(self):
|
||||||
|
for is_training in [True, False]:
|
||||||
x_shape = [0, 7, 11, 4]
|
x_shape = [0, 7, 11, 4]
|
||||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
for dtype in [np.float16, np.float32]:
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
@test_util.run_deprecated_v1
|
self._test_gradient(
|
||||||
def testBatchNormGradTrainingShape1(self):
|
x_shape,
|
||||||
x_shape = [1, 1, 6, 1]
|
dtype, [7],
|
||||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
np.float32,
|
||||||
|
use_gpu=True,
|
||||||
@test_util.run_deprecated_v1
|
data_format='NCHW',
|
||||||
def testBatchNormGradTrainingShape2(self):
|
is_training=is_training)
|
||||||
x_shape = [1, 1, 6, 2]
|
self._test_gradient(
|
||||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
x_shape,
|
||||||
|
dtype, [4],
|
||||||
@test_util.run_deprecated_v1
|
np.float32,
|
||||||
def testBatchNormGradTrainingShape3(self):
|
use_gpu=True,
|
||||||
x_shape = [1, 2, 1, 6]
|
data_format='NHWC',
|
||||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
is_training=is_training)
|
||||||
|
self._test_gradient(
|
||||||
@test_util.run_deprecated_v1
|
x_shape,
|
||||||
def testBatchNormGradTrainingShape4(self):
|
dtype, [4],
|
||||||
x_shape = [5, 7, 11, 4]
|
np.float32,
|
||||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
use_gpu=False,
|
||||||
|
data_format='NHWC',
|
||||||
@test_util.run_deprecated_v1
|
is_training=is_training)
|
||||||
@test_util.disable_xla('This test never passed for XLA')
|
self._test_gradient(
|
||||||
def testBatchNormGradTrainingShape5(self):
|
x_shape,
|
||||||
x_shape = [0, 7, 11, 4]
|
dtype, [7],
|
||||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
data_format='NCHW',
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
def _testBatchNormGradGrad(self, config):
|
def _testBatchNormGradGrad(self, config):
|
||||||
shape = config['shape']
|
shape = config['shape']
|
||||||
|
@ -1432,8 +1432,7 @@ def fused_batch_norm(
|
|||||||
epsilon=0.001,
|
epsilon=0.001,
|
||||||
data_format="NHWC",
|
data_format="NHWC",
|
||||||
is_training=True,
|
is_training=True,
|
||||||
name=None,
|
name=None):
|
||||||
exponential_avg_factor=1.0):
|
|
||||||
r"""Batch normalization.
|
r"""Batch normalization.
|
||||||
|
|
||||||
|
|
||||||
@ -1453,22 +1452,14 @@ def fused_batch_norm(
|
|||||||
is_training: A bool value to specify if the operation is used for
|
is_training: A bool value to specify if the operation is used for
|
||||||
training or inference.
|
training or inference.
|
||||||
name: A name for this operation (optional).
|
name: A name for this operation (optional).
|
||||||
exponential_avg_factor: A float number (usually between 0 and 1) used
|
|
||||||
for controlling the decay of the running
|
|
||||||
population average of mean and variance.
|
|
||||||
If set to 1.0, the current batch average is
|
|
||||||
returned.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y: A 4D Tensor for the normalized, scaled, offsetted x.
|
y: A 4D Tensor for the normalized, scaled, offsetted x.
|
||||||
running_mean: A 1D Tensor for the exponential running mean of x.
|
batch_mean: A 1D Tensor for the mean of x.
|
||||||
The output value is (1 - exponential_avg_factor) * mean +
|
batch_var: A 1D Tensor for the variance of x.
|
||||||
exponential_avg_factor * batch_mean), where batch_mean
|
|
||||||
is the mean of the current batch in x.
|
Raises:
|
||||||
running_var: A 1D Tensor for the exponential running
|
ValueError: If mean or variance is not None when is_training is True.
|
||||||
The output value is (1 - exponential_avg_factor) * variance +
|
|
||||||
exponential_avg_factor * batch_variance), where batch_variance
|
|
||||||
is the variance within of the current batch in x.
|
|
||||||
|
|
||||||
References:
|
References:
|
||||||
Batch Normalization - Accelerating Deep Network Training by Reducing
|
Batch Normalization - Accelerating Deep Network Training by Reducing
|
||||||
@ -1479,31 +1470,37 @@ def fused_batch_norm(
|
|||||||
x = ops.convert_to_tensor(x, name="input")
|
x = ops.convert_to_tensor(x, name="input")
|
||||||
scale = ops.convert_to_tensor(scale, name="scale")
|
scale = ops.convert_to_tensor(scale, name="scale")
|
||||||
offset = ops.convert_to_tensor(offset, name="offset")
|
offset = ops.convert_to_tensor(offset, name="offset")
|
||||||
|
if is_training:
|
||||||
|
if (mean is not None) or (variance is not None):
|
||||||
|
raise ValueError("Both 'mean' and 'variance' must be None "
|
||||||
|
"if is_training is True.")
|
||||||
if mean is None:
|
if mean is None:
|
||||||
mean = array_ops.zeros_like(scale)
|
mean = constant_op.constant([])
|
||||||
if variance is None:
|
if variance is None:
|
||||||
variance = array_ops.zeros_like(scale)
|
variance = constant_op.constant([])
|
||||||
|
|
||||||
# Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
|
# Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
|
||||||
# prevent exception (see cudnn.h).
|
# prevent exception (see cudnn.h).
|
||||||
min_epsilon = 1.001e-5
|
min_epsilon = 1.001e-5
|
||||||
epsilon = epsilon if epsilon > min_epsilon else min_epsilon
|
epsilon = epsilon if epsilon > min_epsilon else min_epsilon
|
||||||
|
|
||||||
if compat.forward_compatible(2020, 2, 29):
|
if compat.forward_compatible(2019, 6, 6):
|
||||||
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
y, batch_mean, batch_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
||||||
x,
|
x,
|
||||||
scale,
|
scale,
|
||||||
offset,
|
offset,
|
||||||
mean,
|
mean,
|
||||||
variance,
|
variance,
|
||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
exponential_avg_factor=exponential_avg_factor,
|
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
name=name)
|
name=name)
|
||||||
return y, running_mean, running_var
|
return y, batch_mean, batch_var
|
||||||
|
|
||||||
|
if x.dtype == dtypes.float16 or x.dtype == dtypes.bfloat16:
|
||||||
|
fused_batch_norm_func = gen_nn_ops.fused_batch_norm_v2
|
||||||
else:
|
else:
|
||||||
y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
|
fused_batch_norm_func = gen_nn_ops._fused_batch_norm # pylint: disable=protected-access
|
||||||
|
y, batch_mean, batch_var, _, _ = fused_batch_norm_func(
|
||||||
x,
|
x,
|
||||||
scale,
|
scale,
|
||||||
offset,
|
offset,
|
||||||
@ -1513,7 +1510,7 @@ def fused_batch_norm(
|
|||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
name=name)
|
name=name)
|
||||||
return y, running_mean, running_var
|
return y, batch_mean, batch_var
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["nn.batch_norm_with_global_normalization"])
|
@tf_export(v1=["nn.batch_norm_with_global_normalization"])
|
||||||
|
@ -210,7 +210,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "fused_batch_norm"
|
name: "fused_batch_norm"
|
||||||
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\', \'exponential_avg_factor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\', \'1.0\'], "
|
argspec: "args=[\'x\', \'scale\', \'offset\', \'mean\', \'variance\', \'epsilon\', \'data_format\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.001\', \'NHWC\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "in_top_k"
|
name: "in_top_k"
|
||||||
|
Loading…
Reference in New Issue
Block a user