diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc index c6b783f2ad6..e4c21b2e6a9 100644 --- a/tensorflow/core/kernels/mkl_aggregate_ops.cc +++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc @@ -155,10 +155,10 @@ class MklAddNOp : public OpKernel { auto cpu_engine = engine(engine::cpu, 0); std::vector coeff(num_inputs, 1.0); std::vector srcs_pd; + std::vector> srcs(num_inputs, MklDnnData(&cpu_engine)); std::vector inputs; MklDnnData dst(&cpu_engine); - MklDnnData src(&cpu_engine); bool has_mkl_input = false; int mkl_input_index = FindMKLInputIndex(ctx); memory::format mkl_data_format; @@ -178,7 +178,6 @@ class MklAddNOp : public OpKernel { MklDnnShape src_mkl_shape; GetMklShape(ctx, src_idx, &src_mkl_shape); memory::desc md({}, memory::data_undef, memory::format_undef); - src = MklDnnData(&cpu_engine); const Tensor& src_tensor = MklGetInput(ctx, src_idx); if (src_mkl_shape.IsMklTensor()) { @@ -203,8 +202,8 @@ class MklAddNOp : public OpKernel { } } srcs_pd.push_back(memory::primitive_desc(md, cpu_engine)); - src.SetUsrMem(md, &src_tensor); - inputs.push_back(src.GetOpMem()); + srcs[src_idx].SetUsrMem(md, &src_tensor); + inputs.push_back(srcs[src_idx].GetOpMem()); } auto sum_pd = sum::primitive_desc(coeff, srcs_pd); diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 4a7f5140ffd..f5a037f8f29 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -1620,10 +1620,23 @@ class MklQuantizedConv2DOp Tbias, x, this->cpu_engine_); void* bias_buf = static_cast( const_cast(bias_tensor.flat().data())); - input_bias_ = - new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf); - scaled_bias_ = new MEMORY_CONSTRUCTOR_WITHOUT_DATA( - conv_fwd_pd->PRIMITIVE_DESC_BIAS, this->cpu_engine_); + if (!input_bias_) { + input_bias_ = + new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf); + } else { + input_bias_->set_data_handle(bias_buf); + } + + if (!scaled_bias_buf_) + AllocTmpBuffer(context, &scaled_bias_tensor_, + conv_fwd_pd->bias_primitive_desc(), + &scaled_bias_buf_); + if (!scaled_bias_) { + scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, + scaled_bias_buf_); + } else { + scaled_bias_->set_data_handle(scaled_bias_buf_); + } auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_, bias_attr); @@ -1646,6 +1659,9 @@ class MklQuantizedConv2DOp memory* input_bias_ = nullptr; memory* scaled_bias_ = nullptr; + Tensor scaled_bias_tensor_; + void* scaled_bias_buf_ = nullptr; + private: std::vector scales_; mutex bias_cache_mu_;