minor change per Penporn's code review suggestion

This commit is contained in:
xiaohong1031 2020-11-24 21:08:11 -08:00
parent f681dfe5d7
commit 3b7fd3013f

View File

@ -581,7 +581,7 @@ class MklConvOp : public OpKernel {
: TFDataFormatToMklDnn3DDataFormat(data_format_);
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
// NOTE: 'mkl_fmt_tag` will be `format_tag::undef` for ReLU
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
errors::InvalidArgument("Invalid data format"));
@ -873,8 +873,7 @@ class MklConvOp : public OpKernel {
if (!native_format && add_mkl_shape == *output_mkl_shape &&
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
kOutputIndex_Dst, output_tensor,
add_mkl_shape, false) &&
!native_format) {
add_mkl_shape, false)) {
return;
} else {
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
@ -1127,7 +1126,7 @@ class MklConvOp : public OpKernel {
const Tensor& cached_filter_md =
*cached_filter_md_ptensor_.AccessTensor(context);
// Check if the memory descriptor of the cached weights is same as
// Check if the memory descriptor of the cached weights is the same as
// filter_md. If so, we can use the cached weights; otherwise
// return nullptr.
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
@ -1701,17 +1700,18 @@ class MklQuantizedConv2DSumReluOp
// if summand_type is also DT_QUINT8 as the scale_output,
// the scaling factor of 255.0f cancels each other and thus is avoided.
// If it is not then it is DT_INT8 and is scaled appropriately.
if (summand_type == DT_QUINT8)
if (summand_type == DT_QUINT8) {
params.post_op_params.push_back({"sum",
mkldnn::algorithm::undef,
{scale_summand / scale_output},
""});
else
} else {
params.post_op_params.push_back(
{"sum",
mkldnn::algorithm::undef,
{255.0f * scale_summand / (scale_output * 127.0f)},
""});
}
} else {
params.post_op_params.push_back(
{"sum", mkldnn::algorithm::undef, {1.0}, ""});