fix an initialization problem in MklFusedBatchNorm
This commit is contained in:
parent
c5ff38aebf
commit
920f7e6264
@ -1008,7 +1008,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
tf_shape_scale, mkl_shape_saved_mean);
|
tf_shape_scale, mkl_shape_saved_mean);
|
||||||
DCHECK(*saved_mean_tensor);
|
DCHECK(*saved_mean_tensor);
|
||||||
|
|
||||||
// Set NAN mean value in case of empty input tensor
|
// Set 0 mean value in case of empty input tensor
|
||||||
auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data();
|
auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data();
|
||||||
std::fill_n(saved_mean_data, num_elements, static_cast<U>(0));
|
std::fill_n(saved_mean_data, num_elements, static_cast<U>(0));
|
||||||
|
|
||||||
@ -1019,7 +1019,7 @@ class MklFusedBatchNormOp : public OpKernel {
|
|||||||
mkl_shape_saved_variance);
|
mkl_shape_saved_variance);
|
||||||
DCHECK(*saved_variance_tensor);
|
DCHECK(*saved_variance_tensor);
|
||||||
|
|
||||||
// Set NAN variance value in case of empty input tensor
|
// Set 0 variance value in case of empty input tensor
|
||||||
auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data();
|
auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data();
|
||||||
std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
|
std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
|
||||||
|
|
||||||
@ -1346,16 +1346,12 @@ class MklFusedBatchNormGradOp : public OpKernel {
|
|||||||
mkl_shape_p.SetMklTensor(false);
|
mkl_shape_p.SetMklTensor(false);
|
||||||
AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
|
AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
|
||||||
mkl_shape_p);
|
mkl_shape_p);
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
|
std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
|
||||||
static_cast<U>(0));
|
static_cast<U>(0));
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
|
AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
|
||||||
mkl_shape_p);
|
mkl_shape_p);
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
|
std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
|
||||||
static_cast<U>(0));
|
static_cast<U>(0));
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
|
memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
|
||||||
|
Loading…
Reference in New Issue
Block a user