Minor follow-up fixes for PR #37176.

PiperOrigin-RevId: 298389754
Change-Id: I7a32fea4fd8eb6d10e3c5dd32e341b1e9ad2f808
This commit is contained in:
Penporn Koanantakool 2020-03-02 10:41:52 -08:00 committed by TensorFlower Gardener
parent edcafbcefc
commit c877732daf
2 changed files with 11 additions and 5 deletions

View File

@ -881,6 +881,9 @@ class MklFusedBatchNormOp : public OpKernel {
auto batch_variance_data = batch_variance_tensor->flat<U>().data();
auto est_mean_data = est_mean_tensor.flat<U>().data();
auto est_variance_data = est_variance_tensor.flat<U>().data();
// TODO(intel-tf): Merge the `is_training && exponential_avg_factor == 1`
// case with the `else` (`!is_training`) case if possible.
if (is_training_) {
if (exponential_avg_factor_ == U(1.0)) {
for (int k = 0; k < depth_; k++) {

View File

@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/nn_ops.h"
@ -222,7 +224,7 @@ class FusedBatchNormOpTest : public OpsTestBase {
}
};
TYPED_TEST_CASE_P(FusedBatchNormOpTest);
TYPED_TEST_SUITE_P(FusedBatchNormOpTest);
TYPED_TEST_P(FusedBatchNormOpTest, Training) {
const float exponential_avg_factor = 1.0;
@ -248,12 +250,13 @@ TYPED_TEST_P(FusedBatchNormOpTest, InferenceIgnoreAvgFactor) {
this->VerifyFusedBatchNorm(exponential_avg_factor, is_training);
}
REGISTER_TYPED_TEST_CASE_P(FusedBatchNormOpTest, Training, TrainingRunningMean,
Inference, InferenceIgnoreAvgFactor);
REGISTER_TYPED_TEST_SUITE_P(FusedBatchNormOpTest, Training, TrainingRunningMean,
Inference, InferenceIgnoreAvgFactor);
using FusedBatchNormDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_CASE_P(Test, FusedBatchNormOpTest,
FusedBatchNormDataTypes);
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedBatchNormOpTest,
FusedBatchNormDataTypes);
} // namespace tensorflow
#endif // INTEL_MKL