diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc index 2d7bb65430e..0a37620700d 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc @@ -76,11 +76,10 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, int64 weights_cols; if (conv_node.op() == "Conv2D") { weights_cols = weights.shape().dim_size(3); - } - else if (conv_node.op() == "DepthwiseConv2dNative") { - weights_cols = weights.shape().dim_size(2) * weights.shape().dim_size(3); - } - else { + } else if (conv_node.op() == "DepthwiseConv2dNative") { + weights_cols = + weights.shape().dim_size(2) * weights.shape().dim_size(3); + } else { weights_cols = weights.shape().dim_size(1); } if ((mul_values.shape().dims() != 1) || @@ -96,7 +95,8 @@ Status FoldBatchNorms(const GraphDef& input_graph_def, auto scaled_weights_vector = scaled_weights.flat(); for (int64 row = 0; row < weights_vector.dimension(0); ++row) { scaled_weights_vector(row) = - weights_vector(row) * mul_values.flat()(row % weights_cols); + weights_vector(row) * + mul_values.flat()(row % weights_cols); } // Construct the new nodes. diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc index 2b5326799e6..885fbd59b77 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc @@ -104,10 +104,10 @@ class FoldBatchNormsTest : public ::testing::Test { Output weights_op = Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); - Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op, weights_op, - {1, 1, 1, 1}, "VALID"); + Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op, + weights_op, {1, 1, 1, 1}, "VALID"); - Tensor mul_values_data(DT_FLOAT, TensorShape({4})); + Tensor mul_values_data(DT_FLOAT, TensorShape({4})); test::FillValues(&mul_values_data, {2.0f, 3.0f, 4.0f, 5.0f}); Output mul_values_op = Const(root.WithOpName("mul_values"), Input::Initializer(mul_values_data)); @@ -136,7 +136,7 @@ class FoldBatchNormsTest : public ::testing::Test { for (const NodeDef& node : fused_graph_def.node()) { EXPECT_NE("Mul", node.op()); } - } + } void TestFoldBatchNormsConv2DShared() { auto root = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc index 413361b616c..8c67bd23b56 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc @@ -32,9 +32,9 @@ Status ErrorIfNotVector(const Tensor& input, const string& input_name, int expected_width) { if ((input.shape().dims() != 1) || (input.shape().dim_size(0) != expected_width)) { - return errors::InvalidArgument( - input_name, - " input to batch norm has bad shape: ", input.shape().DebugString()); + return errors::InvalidArgument(input_name, + " input to batch norm has bad shape: ", + input.shape().DebugString()); } return Status::OK(); } @@ -119,11 +119,9 @@ Status FuseScaleOffsetToConvWeights(const std::vector& scale_values, int64 weights_cols; if (conv_node.op() == "Conv2D") { weights_cols = weights.shape().dim_size(3); - } - else if (conv_node.op() == "DepthwiseConv2dNative") { + } else if (conv_node.op() == "DepthwiseConv2dNative") { weights_cols = weights.shape().dim_size(2) * weights.shape().dim_size(3); - } - else { + } else { weights_cols = weights.shape().dim_size(1); } CHECK_EQ(weights_cols, scale_values.size()); @@ -134,7 +132,7 @@ Status FuseScaleOffsetToConvWeights(const std::vector& scale_values, auto scaled_weights_vector = scaled_weights.flat(); for (int64 row = 0; row < weights_vector.dimension(0); ++row) { scaled_weights_vector(row) = - weights_vector(row) * scale_values[row % weights_cols]; + weights_vector(row) * scale_values[row % weights_cols]; } // Figure out the remaining bias to add on. Tensor bias_offset(DT_FLOAT, {weights_cols}); @@ -193,7 +191,7 @@ Status FuseBatchNormWithConv(const NodeMatch& match, } Status FuseBatchNormWithBatchToSpace(const NodeMatch& match, - std::vector* new_nodes) { + std::vector* new_nodes) { // Calculate the scale and offset values to apply. std::vector scale_values; std::vector offset_values; @@ -208,9 +206,8 @@ Status FuseBatchNormWithBatchToSpace(const NodeMatch& match, const NodeDef& conv_node = conv_node_match.node; string biasadd_name = conv_node.name() + "/biasadd"; - TF_RETURN_IF_ERROR( - FuseScaleOffsetToConvWeights(scale_values, offset_values, conv_node_match, - biasadd_name , new_nodes)); + TF_RETURN_IF_ERROR(FuseScaleOffsetToConvWeights( + scale_values, offset_values, conv_node_match, biasadd_name, new_nodes)); NodeDef new_batch_to_space_node = batch_to_space_node; // reuse batch_norm node name diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 45637cf9d1d..925f37745c8 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -138,8 +138,8 @@ class FoldOldBatchNormsTest : public ::testing::Test { Output weights_op = Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); - Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), - input_op, weights_op, {1, 1, 1, 1}, "VALID"); + Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op, + weights_op, {1, 1, 1, 1}, "VALID"); Tensor mean_data(DT_FLOAT, TensorShape({4})); test::FillValues(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f}); @@ -164,7 +164,6 @@ class FoldOldBatchNormsTest : public ::testing::Test { GraphDef original_graph_def; TF_ASSERT_OK(root.ToGraphDef(&original_graph_def)); - NodeDef batch_norm_node; batch_norm_node.set_op("BatchNormWithGlobalNormalization"); batch_norm_node.set_name("output"); @@ -198,7 +197,7 @@ class FoldOldBatchNormsTest : public ::testing::Test { for (const NodeDef& node : fused_graph_def.node()) { EXPECT_NE("BatchNormWithGlobalNormalization", node.op()); } - } + } void TestFoldFusedBatchNorms() { auto root = tensorflow::Scope::NewRootScope(); @@ -294,8 +293,8 @@ class FoldOldBatchNormsTest : public ::testing::Test { Output weights_op = Const(root.WithOpName("weights_op"), Input::Initializer(weights_data)); - Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), - input_op, weights_op, {1, 1, 1, 1}, "VALID"); + Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op, + weights_op, {1, 1, 1, 1}, "VALID"); Tensor mean_data(DT_FLOAT, TensorShape({4})); test::FillValues(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f}); @@ -477,16 +476,17 @@ void TestFoldFusedBatchNormsWithBatchToSpace() { Tensor block_shape_data(DT_INT32, TensorShape({2})); test::FillValues(&block_shape_data, {1, 2}); - Output block_shape_op = - Const(root.WithOpName("block_shape_op"), Input::Initializer(block_shape_data)); + Output block_shape_op = Const(root.WithOpName("block_shape_op"), + Input::Initializer(block_shape_data)); Tensor crops_data(DT_INT32, TensorShape({2, 2})); test::FillValues(&crops_data, {0, 0, 0, 1}); Output crops_op = Const(root.WithOpName("crops_op"), Input::Initializer(crops_data)); - Output batch_to_space_op = BatchToSpaceND(root.WithOpName("batch_to_space_op"), - conv_op, block_shape_op, crops_data); + Output batch_to_space_op = + BatchToSpaceND(root.WithOpName("batch_to_space_op"), conv_op, + block_shape_op, crops_data); Tensor mean_data(DT_FLOAT, TensorShape({2})); test::FillValues(&mean_data, {10.0f, 20.0f}); @@ -495,8 +495,8 @@ void TestFoldFusedBatchNormsWithBatchToSpace() { Tensor variance_data(DT_FLOAT, TensorShape({2})); test::FillValues(&variance_data, {0.25f, 0.5f}); - Output variance_op = Const(root.WithOpName("variance_op"), - Input::Initializer(variance_data)); + Output variance_op = + Const(root.WithOpName("variance_op"), Input::Initializer(variance_data)); Tensor beta_data(DT_FLOAT, TensorShape({2})); test::FillValues(&beta_data, {0.1f, 0.6f}); @@ -570,7 +570,8 @@ TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNormsAfterDepthwiseConv2dNative) { TestFoldOldBatchNormsAfterDepthwiseConv2dNative(); } -TEST_F(FoldOldBatchNormsTest, TestFoldFusedBatchNormsAfterDepthwiseConv2dNative) { +TEST_F(FoldOldBatchNormsTest, + TestFoldFusedBatchNormsAfterDepthwiseConv2dNative) { TestFoldFusedBatchNormsAfterDepthwiseConv2dNative(); }