minor clang fix

This commit is contained in:
xiaohong1031 2020-11-12 23:19:54 -08:00
parent 3f3b46333b
commit 264c259dca

View File

@ -291,10 +291,9 @@ class MklConvFwdPrimitive : public MklPrimitive {
// Create convolution primitive and add it to net // Create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) { if (!convFwdDims.bias_dims.empty()) {
context_.bias_mem.reset( context_.bias_mem.reset(new memory(
new memory({{convFwdDims.bias_dims}, MklDnnType<Tbias>(), {{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
memory::format_tag::x}, cpu_engine_, DummyData));
cpu_engine_, DummyData));
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back( context_.fwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem}, {{MKLDNN_ARG_SRC, *context_.src_mem},
@ -465,15 +464,17 @@ class MklConvOp : public OpKernel {
OP_REQUIRES(context, dilations_.size() == 5, OP_REQUIRES(context, dilations_.size() == 5,
errors::InvalidArgument("Dilation rates field must " errors::InvalidArgument("Dilation rates field must "
"specify 5 dimensions")); "specify 5 dimensions"));
OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 && OP_REQUIRES(context,
GetTensorDim(dilations_, data_format_, 'C') == 1), (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
GetTensorDim(dilations_, data_format_, 'C') == 1),
errors::InvalidArgument( errors::InvalidArgument(
"Current implementation does not yet support " "Current implementation does not yet support "
"dilations rates in the batch and depth dimensions.")); "dilations rates in the batch and depth dimensions."));
OP_REQUIRES( OP_REQUIRES(
context, (GetTensorDim(dilations_, data_format_, '0') > 0 && context,
GetTensorDim(dilations_, data_format_, '1') > 0 && (GetTensorDim(dilations_, data_format_, '0') > 0 &&
GetTensorDim(dilations_, data_format_, '2') > 0), GetTensorDim(dilations_, data_format_, '1') > 0 &&
GetTensorDim(dilations_, data_format_, '2') > 0),
errors::InvalidArgument("Dilated rates should be larger than 0.")); errors::InvalidArgument("Dilated rates should be larger than 0."));
} }
} }
@ -868,14 +869,14 @@ class MklConvOp : public OpKernel {
kInputIndex_Add, kOutputIndex_Dst, output_tf_shape, kInputIndex_Add, kOutputIndex_Dst, output_tf_shape,
output_tensor)) { output_tensor)) {
return; return;
} }
} }
// Check if reorder is needed // Check if reorder is needed
if (add_mkl_shape == *output_mkl_shape && if (add_mkl_shape == *output_mkl_shape &&
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
kOutputIndex_Dst, output_tensor, kOutputIndex_Dst, output_tensor,
add_mkl_shape, false) && add_mkl_shape, false) &&
!native_format) { !native_format) {
return; return;
} else { } else {
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
@ -1535,7 +1536,7 @@ class MklQuantizedConv2DOp
} }
auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())}, auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
MklDnnType<Tbias>(), memory::format_tag::x); MklDnnType<Tbias>(), memory::format_tag::x);
void* bias_buf = static_cast<void*>( void* bias_buf = static_cast<void*>(
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
if (!input_bias_) { if (!input_bias_) {
@ -1749,11 +1750,12 @@ class MklQuantizedConv2DSumReluOp
summand_mkl_shape.SetElemType(MklDnnType<Toutput>()); summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
} }
// TODO(intel-tf): Support cases when summand cannot be forwarded. // TODO(intel-tf): Support cases when summand cannot be forwarded.
OP_REQUIRES(context, ForwardMklTensorInToOutWithMklShape( OP_REQUIRES(
context, summand_idx, 0, output_tensor, context,
summand_mkl_shape, false), ForwardMklTensorInToOutWithMklShape(
errors::InvalidArgument( context, summand_idx, 0, output_tensor, summand_mkl_shape, false),
"Summand cannot be forwarded in the current fusion.")); errors::InvalidArgument(
"Summand cannot be forwarded in the current fusion."));
return; return;
} }
MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,