minor change per Penporn's code review suggestion
This commit is contained in:
parent
f681dfe5d7
commit
3b7fd3013f
@ -581,7 +581,7 @@ class MklConvOp : public OpKernel {
|
||||
: TFDataFormatToMklDnn3DDataFormat(data_format_);
|
||||
|
||||
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
||||
// NOTE: 'mkl_fmt_tag` will be `format_tag::undef` for ReLU
|
||||
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
|
||||
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
|
||||
@ -873,8 +873,7 @@ class MklConvOp : public OpKernel {
|
||||
if (!native_format && add_mkl_shape == *output_mkl_shape &&
|
||||
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
|
||||
kOutputIndex_Dst, output_tensor,
|
||||
add_mkl_shape, false) &&
|
||||
!native_format) {
|
||||
add_mkl_shape, false)) {
|
||||
return;
|
||||
} else {
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||
@ -1127,7 +1126,7 @@ class MklConvOp : public OpKernel {
|
||||
const Tensor& cached_filter_md =
|
||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// Check if the memory descriptor of the cached weights is the same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||
@ -1701,17 +1700,18 @@ class MklQuantizedConv2DSumReluOp
|
||||
// if summand_type is also DT_QUINT8 as the scale_output,
|
||||
// the scaling factor of 255.0f cancels each other and thus is avoided.
|
||||
// If it is not then it is DT_INT8 and is scaled appropriately.
|
||||
if (summand_type == DT_QUINT8)
|
||||
if (summand_type == DT_QUINT8) {
|
||||
params.post_op_params.push_back({"sum",
|
||||
mkldnn::algorithm::undef,
|
||||
{scale_summand / scale_output},
|
||||
""});
|
||||
else
|
||||
} else {
|
||||
params.post_op_params.push_back(
|
||||
{"sum",
|
||||
mkldnn::algorithm::undef,
|
||||
{255.0f * scale_summand / (scale_output * 127.0f)},
|
||||
""});
|
||||
}
|
||||
} else {
|
||||
params.post_op_params.push_back(
|
||||
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
|
||||
|
Loading…
x
Reference in New Issue
Block a user