Merge pull request #37901 from Intel-tensorflow:fusedBN_fix

PiperOrigin-RevId: 304433358
Change-Id: I398a935331f46d0d9c39a9a249a2ce4a93822bad
This commit is contained in:
TensorFlower Gardener 2020-04-02 10:58:12 -07:00
commit f2bde78c4c

View File

@ -1008,7 +1008,7 @@ class MklFusedBatchNormOp : public OpKernel {
tf_shape_scale, mkl_shape_saved_mean); tf_shape_scale, mkl_shape_saved_mean);
DCHECK(*saved_mean_tensor); DCHECK(*saved_mean_tensor);
// Set NAN mean value in case of empty input tensor // Set 0 mean value in case of empty input tensor
auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data(); auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data();
std::fill_n(saved_mean_data, num_elements, static_cast<U>(0)); std::fill_n(saved_mean_data, num_elements, static_cast<U>(0));
@ -1019,7 +1019,7 @@ class MklFusedBatchNormOp : public OpKernel {
mkl_shape_saved_variance); mkl_shape_saved_variance);
DCHECK(*saved_variance_tensor); DCHECK(*saved_variance_tensor);
// Set NAN variance value in case of empty input tensor // Set 0 variance value in case of empty input tensor
auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data(); auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data();
std::fill_n(saved_variance_data, num_elements, static_cast<U>(0)); std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
@ -1346,16 +1346,12 @@ class MklFusedBatchNormGradOp : public OpKernel {
mkl_shape_p.SetMklTensor(false); mkl_shape_p.SetMklTensor(false);
AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
mkl_shape_p); mkl_shape_p);
#ifndef ENABLE_MKLDNN_V1
std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(), std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
static_cast<U>(0)); static_cast<U>(0));
#endif // !ENABLE_MKLDNN_V1
AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
mkl_shape_p); mkl_shape_p);
#ifndef ENABLE_MKLDNN_V1
std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(), std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
static_cast<U>(0)); static_cast<U>(0));
#endif // !ENABLE_MKLDNN_V1
} }
memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }