Merge pull request #34052 from Intel-tensorflow:memory-fix

PiperOrigin-RevId: 279039959
Change-Id: Ic4a51086582cb18fe555079d5517887a5590eabe
This commit is contained in:
TensorFlower Gardener 2019-11-07 02:15:25 -08:00
commit 6eac303dd8
2 changed files with 23 additions and 8 deletions

View File

@ -155,10 +155,10 @@ class MklAddNOp : public OpKernel {
auto cpu_engine = engine(engine::cpu, 0); auto cpu_engine = engine(engine::cpu, 0);
std::vector<float> coeff(num_inputs, 1.0); std::vector<float> coeff(num_inputs, 1.0);
std::vector<memory::primitive_desc> srcs_pd; std::vector<memory::primitive_desc> srcs_pd;
std::vector<MklDnnData<T>> srcs(num_inputs, MklDnnData<T>(&cpu_engine));
std::vector<primitive::at> inputs; std::vector<primitive::at> inputs;
MklDnnData<T> dst(&cpu_engine); MklDnnData<T> dst(&cpu_engine);
MklDnnData<T> src(&cpu_engine);
bool has_mkl_input = false; bool has_mkl_input = false;
int mkl_input_index = FindMKLInputIndex(ctx); int mkl_input_index = FindMKLInputIndex(ctx);
memory::format mkl_data_format; memory::format mkl_data_format;
@ -178,7 +178,6 @@ class MklAddNOp : public OpKernel {
MklDnnShape src_mkl_shape; MklDnnShape src_mkl_shape;
GetMklShape(ctx, src_idx, &src_mkl_shape); GetMklShape(ctx, src_idx, &src_mkl_shape);
memory::desc md({}, memory::data_undef, memory::format_undef); memory::desc md({}, memory::data_undef, memory::format_undef);
src = MklDnnData<T>(&cpu_engine);
const Tensor& src_tensor = MklGetInput(ctx, src_idx); const Tensor& src_tensor = MklGetInput(ctx, src_idx);
if (src_mkl_shape.IsMklTensor()) { if (src_mkl_shape.IsMklTensor()) {
@ -203,8 +202,8 @@ class MklAddNOp : public OpKernel {
} }
} }
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine)); srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
src.SetUsrMem(md, &src_tensor); srcs[src_idx].SetUsrMem(md, &src_tensor);
inputs.push_back(src.GetOpMem()); inputs.push_back(srcs[src_idx].GetOpMem());
} }
auto sum_pd = sum::primitive_desc(coeff, srcs_pd); auto sum_pd = sum::primitive_desc(coeff, srcs_pd);

View File

@ -1620,10 +1620,23 @@ class MklQuantizedConv2DOp
Tbias, x, this->cpu_engine_); Tbias, x, this->cpu_engine_);
void* bias_buf = static_cast<void*>( void* bias_buf = static_cast<void*>(
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
if (!input_bias_) {
input_bias_ = input_bias_ =
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf); new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
scaled_bias_ = new MEMORY_CONSTRUCTOR_WITHOUT_DATA( } else {
conv_fwd_pd->PRIMITIVE_DESC_BIAS, this->cpu_engine_); input_bias_->set_data_handle(bias_buf);
}
if (!scaled_bias_buf_)
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
conv_fwd_pd->bias_primitive_desc(),
&scaled_bias_buf_);
if (!scaled_bias_) {
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
scaled_bias_buf_);
} else {
scaled_bias_->set_data_handle(scaled_bias_buf_);
}
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_, input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
bias_attr); bias_attr);
@ -1646,6 +1659,9 @@ class MklQuantizedConv2DOp
memory* input_bias_ = nullptr; memory* input_bias_ = nullptr;
memory* scaled_bias_ = nullptr; memory* scaled_bias_ = nullptr;
Tensor scaled_bias_tensor_;
void* scaled_bias_buf_ = nullptr;
private: private:
std::vector<float> scales_; std::vector<float> scales_;
mutex bias_cache_mu_; mutex bias_cache_mu_;