minor clang fix
This commit is contained in:
parent
3f3b46333b
commit
264c259dca
@ -291,10 +291,9 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
|
|
||||||
// Create convolution primitive and add it to net
|
// Create convolution primitive and add it to net
|
||||||
if (!convFwdDims.bias_dims.empty()) {
|
if (!convFwdDims.bias_dims.empty()) {
|
||||||
context_.bias_mem.reset(
|
context_.bias_mem.reset(new memory(
|
||||||
new memory({{convFwdDims.bias_dims}, MklDnnType<Tbias>(),
|
{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
|
||||||
memory::format_tag::x},
|
cpu_engine_, DummyData));
|
||||||
cpu_engine_, DummyData));
|
|
||||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||||
context_.fwd_primitives_args.push_back(
|
context_.fwd_primitives_args.push_back(
|
||||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
@ -465,15 +464,17 @@ class MklConvOp : public OpKernel {
|
|||||||
OP_REQUIRES(context, dilations_.size() == 5,
|
OP_REQUIRES(context, dilations_.size() == 5,
|
||||||
errors::InvalidArgument("Dilation rates field must "
|
errors::InvalidArgument("Dilation rates field must "
|
||||||
"specify 5 dimensions"));
|
"specify 5 dimensions"));
|
||||||
OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
|
OP_REQUIRES(context,
|
||||||
GetTensorDim(dilations_, data_format_, 'C') == 1),
|
(GetTensorDim(dilations_, data_format_, 'N') == 1 &&
|
||||||
|
GetTensorDim(dilations_, data_format_, 'C') == 1),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Current implementation does not yet support "
|
"Current implementation does not yet support "
|
||||||
"dilations rates in the batch and depth dimensions."));
|
"dilations rates in the batch and depth dimensions."));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, (GetTensorDim(dilations_, data_format_, '0') > 0 &&
|
context,
|
||||||
GetTensorDim(dilations_, data_format_, '1') > 0 &&
|
(GetTensorDim(dilations_, data_format_, '0') > 0 &&
|
||||||
GetTensorDim(dilations_, data_format_, '2') > 0),
|
GetTensorDim(dilations_, data_format_, '1') > 0 &&
|
||||||
|
GetTensorDim(dilations_, data_format_, '2') > 0),
|
||||||
errors::InvalidArgument("Dilated rates should be larger than 0."));
|
errors::InvalidArgument("Dilated rates should be larger than 0."));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -868,14 +869,14 @@ class MklConvOp : public OpKernel {
|
|||||||
kInputIndex_Add, kOutputIndex_Dst, output_tf_shape,
|
kInputIndex_Add, kOutputIndex_Dst, output_tf_shape,
|
||||||
output_tensor)) {
|
output_tensor)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Check if reorder is needed
|
// Check if reorder is needed
|
||||||
if (add_mkl_shape == *output_mkl_shape &&
|
if (add_mkl_shape == *output_mkl_shape &&
|
||||||
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
|
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
|
||||||
kOutputIndex_Dst, output_tensor,
|
kOutputIndex_Dst, output_tensor,
|
||||||
add_mkl_shape, false) &&
|
add_mkl_shape, false) &&
|
||||||
!native_format) {
|
!native_format) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||||
@ -1535,7 +1536,7 @@ class MklQuantizedConv2DOp
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
|
auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
|
||||||
MklDnnType<Tbias>(), memory::format_tag::x);
|
MklDnnType<Tbias>(), memory::format_tag::x);
|
||||||
void* bias_buf = static_cast<void*>(
|
void* bias_buf = static_cast<void*>(
|
||||||
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
||||||
if (!input_bias_) {
|
if (!input_bias_) {
|
||||||
@ -1749,11 +1750,12 @@ class MklQuantizedConv2DSumReluOp
|
|||||||
summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
||||||
}
|
}
|
||||||
// TODO(intel-tf): Support cases when summand cannot be forwarded.
|
// TODO(intel-tf): Support cases when summand cannot be forwarded.
|
||||||
OP_REQUIRES(context, ForwardMklTensorInToOutWithMklShape(
|
OP_REQUIRES(
|
||||||
context, summand_idx, 0, output_tensor,
|
context,
|
||||||
summand_mkl_shape, false),
|
ForwardMklTensorInToOutWithMklShape(
|
||||||
errors::InvalidArgument(
|
context, summand_idx, 0, output_tensor, summand_mkl_shape, false),
|
||||||
"Summand cannot be forwarded in the current fusion."));
|
errors::InvalidArgument(
|
||||||
|
"Summand cannot be forwarded in the current fusion."));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
|
MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user