Merge pull request #34052 from Intel-tensorflow:memory-fix
PiperOrigin-RevId: 279039959 Change-Id: Ic4a51086582cb18fe555079d5517887a5590eabe
This commit is contained in:
commit
6eac303dd8
@ -155,10 +155,10 @@ class MklAddNOp : public OpKernel {
|
|||||||
auto cpu_engine = engine(engine::cpu, 0);
|
auto cpu_engine = engine(engine::cpu, 0);
|
||||||
std::vector<float> coeff(num_inputs, 1.0);
|
std::vector<float> coeff(num_inputs, 1.0);
|
||||||
std::vector<memory::primitive_desc> srcs_pd;
|
std::vector<memory::primitive_desc> srcs_pd;
|
||||||
|
std::vector<MklDnnData<T>> srcs(num_inputs, MklDnnData<T>(&cpu_engine));
|
||||||
std::vector<primitive::at> inputs;
|
std::vector<primitive::at> inputs;
|
||||||
|
|
||||||
MklDnnData<T> dst(&cpu_engine);
|
MklDnnData<T> dst(&cpu_engine);
|
||||||
MklDnnData<T> src(&cpu_engine);
|
|
||||||
bool has_mkl_input = false;
|
bool has_mkl_input = false;
|
||||||
int mkl_input_index = FindMKLInputIndex(ctx);
|
int mkl_input_index = FindMKLInputIndex(ctx);
|
||||||
memory::format mkl_data_format;
|
memory::format mkl_data_format;
|
||||||
@ -178,7 +178,6 @@ class MklAddNOp : public OpKernel {
|
|||||||
MklDnnShape src_mkl_shape;
|
MklDnnShape src_mkl_shape;
|
||||||
GetMklShape(ctx, src_idx, &src_mkl_shape);
|
GetMklShape(ctx, src_idx, &src_mkl_shape);
|
||||||
memory::desc md({}, memory::data_undef, memory::format_undef);
|
memory::desc md({}, memory::data_undef, memory::format_undef);
|
||||||
src = MklDnnData<T>(&cpu_engine);
|
|
||||||
const Tensor& src_tensor = MklGetInput(ctx, src_idx);
|
const Tensor& src_tensor = MklGetInput(ctx, src_idx);
|
||||||
|
|
||||||
if (src_mkl_shape.IsMklTensor()) {
|
if (src_mkl_shape.IsMklTensor()) {
|
||||||
@ -203,8 +202,8 @@ class MklAddNOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
|
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
|
||||||
src.SetUsrMem(md, &src_tensor);
|
srcs[src_idx].SetUsrMem(md, &src_tensor);
|
||||||
inputs.push_back(src.GetOpMem());
|
inputs.push_back(srcs[src_idx].GetOpMem());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sum_pd = sum::primitive_desc(coeff, srcs_pd);
|
auto sum_pd = sum::primitive_desc(coeff, srcs_pd);
|
||||||
|
@ -1620,10 +1620,23 @@ class MklQuantizedConv2DOp
|
|||||||
Tbias, x, this->cpu_engine_);
|
Tbias, x, this->cpu_engine_);
|
||||||
void* bias_buf = static_cast<void*>(
|
void* bias_buf = static_cast<void*>(
|
||||||
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
||||||
|
if (!input_bias_) {
|
||||||
input_bias_ =
|
input_bias_ =
|
||||||
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
|
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
|
||||||
scaled_bias_ = new MEMORY_CONSTRUCTOR_WITHOUT_DATA(
|
} else {
|
||||||
conv_fwd_pd->PRIMITIVE_DESC_BIAS, this->cpu_engine_);
|
input_bias_->set_data_handle(bias_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!scaled_bias_buf_)
|
||||||
|
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
|
||||||
|
conv_fwd_pd->bias_primitive_desc(),
|
||||||
|
&scaled_bias_buf_);
|
||||||
|
if (!scaled_bias_) {
|
||||||
|
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
|
||||||
|
scaled_bias_buf_);
|
||||||
|
} else {
|
||||||
|
scaled_bias_->set_data_handle(scaled_bias_buf_);
|
||||||
|
}
|
||||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||||
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
||||||
bias_attr);
|
bias_attr);
|
||||||
@ -1646,6 +1659,9 @@ class MklQuantizedConv2DOp
|
|||||||
memory* input_bias_ = nullptr;
|
memory* input_bias_ = nullptr;
|
||||||
memory* scaled_bias_ = nullptr;
|
memory* scaled_bias_ = nullptr;
|
||||||
|
|
||||||
|
Tensor scaled_bias_tensor_;
|
||||||
|
void* scaled_bias_buf_ = nullptr;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<float> scales_;
|
std::vector<float> scales_;
|
||||||
mutex bias_cache_mu_;
|
mutex bias_cache_mu_;
|
||||||
|
Loading…
Reference in New Issue
Block a user