Merge pull request #42849 from Intel-tensorflow:dnn0x_clean_quantize
PiperOrigin-RevId: 331178405 Change-Id: Id163829724709e5e9aa6eead0625287b40079860
This commit is contained in:
commit
7d35f9e0ee
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/meta_support.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::primitive_attr;
|
||||
@ -51,7 +50,7 @@ class MklDequantizeOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
try {
|
||||
// Using CPU device
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
auto cpu_engine = engine(engine::kind::cpu, 0);
|
||||
|
||||
// Get the inputs
|
||||
const Tensor& src_tensor = MklGetInput(ctx, kSrcIndex);
|
||||
@ -82,10 +81,10 @@ class MklDequantizeOp : public OpKernel {
|
||||
// construct input TF layout. For TF layout, although input shape
|
||||
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
|
||||
// layout
|
||||
auto src_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
|
||||
auto src_md = src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(),
|
||||
memory::format_tag::nhwc);
|
||||
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.SetUsrMemDataHandle(&src_tensor, reorder_stream);
|
||||
@ -93,14 +92,6 @@ class MklDequantizeOp : public OpKernel {
|
||||
Tensor* output_tensor = nullptr;
|
||||
MklDnnShape output_mkl_shape;
|
||||
TensorShape output_tf_shape;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::desc dst_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? memory::desc(src_dims, MklDnnType<float>(),
|
||||
static_cast<MEMORY_FORMAT>(src_md.data.format))
|
||||
: memory::desc(src_dims, MklDnnType<float>(),
|
||||
MEMORY_FORMAT::nhwc);
|
||||
#else
|
||||
memory::desc dst_md = memory::desc();
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
dst_md = memory::desc(src_mkl_shape.GetMklLayout().data);
|
||||
@ -108,10 +99,9 @@ class MklDequantizeOp : public OpKernel {
|
||||
// same .data field but different type.
|
||||
dst_md.data.data_type = memory::convert_to_c(MklDnnType<float>());
|
||||
} else {
|
||||
dst_md =
|
||||
memory::desc(src_dims, MklDnnType<float>(), MEMORY_FORMAT::nhwc);
|
||||
dst_md = memory::desc(src_dims, MklDnnType<float>(),
|
||||
memory::format_tag::nhwc);
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// If input is MKL shape, output is also MKL shape.
|
||||
// If input is TF shape, output is also TF shape.
|
||||
@ -122,8 +112,7 @@ class MklDequantizeOp : public OpKernel {
|
||||
output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(),
|
||||
src_mkl_shape.GetSizesAsMklDnnDims(),
|
||||
src_mkl_shape.GetTfDataFormat());
|
||||
output_tf_shape.AddDim(GET_MEMORY_SIZE_FROM_MD(dst_md, cpu_engine) /
|
||||
sizeof(float));
|
||||
output_tf_shape.AddDim(dst_md.get_size() / sizeof(float));
|
||||
} else {
|
||||
output_mkl_shape.SetMklTensor(false);
|
||||
output_tf_shape = MklDnnDimsToTFShape(output_dims);
|
||||
@ -155,29 +144,17 @@ class MklDequantizeOp : public OpKernel {
|
||||
scales.push_back(scale_factor);
|
||||
primitive_attr attr;
|
||||
attr.set_output_scales(0, scales);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// MKL-DNN 1.0 does not provide set_int_output_round_mode() API.
|
||||
// Also it does not define round_nearest (enum).
|
||||
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
std::vector<primitive> net;
|
||||
|
||||
// Create reorder primitive and then execute.
|
||||
auto reorder_pd = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(src.GetUsrMem()),
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(dst.GetUsrMem()), cpu_engine,
|
||||
attr);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto reorder_pd =
|
||||
ReorderPd(cpu_engine, src.GetUsrMem()->get_desc(), cpu_engine,
|
||||
dst.GetUsrMem()->get_desc(), attr);
|
||||
net.push_back(reorder(reorder_pd));
|
||||
std::vector<std::unordered_map<int, memory>> reorder_net_args;
|
||||
reorder_net_args.push_back({{MKLDNN_ARG_FROM, *src.GetUsrMem()},
|
||||
{ MKLDNN_ARG_TO,
|
||||
*dst.GetUsrMem() }});
|
||||
{MKLDNN_ARG_TO, *dst.GetUsrMem()}});
|
||||
execute_primitives(net, reorder_stream, reorder_net_args);
|
||||
#else
|
||||
net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem()));
|
||||
reorder_stream->submit(net);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::primitive_attr;
|
||||
@ -77,7 +76,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklReorderWithScalePrimitive(
|
||||
const MklReorderWithScaleFwdParams& fwdParams)
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
: MklPrimitive(engine(engine::kind::cpu, 0)) {
|
||||
// Create reorder primitive
|
||||
Setup(fwdParams);
|
||||
}
|
||||
@ -95,11 +94,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
context_.src_mem->set_data_handle(src_data);
|
||||
context_.dst_mem->set_data_handle(dst_data);
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
reorder_stream->submit(context_.net);
|
||||
#else
|
||||
context_.reorder_prim->execute(*reorder_stream, context_.prim_args);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
// After execution, set data handle back.
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
@ -119,11 +114,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
// Stream and primitive vector
|
||||
std::shared_ptr<mkldnn::stream> reorder_stream;
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
std::vector<mkldnn::primitive> net;
|
||||
#else
|
||||
std::unordered_map<int, mkldnn::memory> prim_args;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
ReorderContext()
|
||||
: src_mem(nullptr),
|
||||
@ -135,10 +126,10 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
// Reorder primitive setup
|
||||
void Setup(const MklReorderWithScaleFwdParams& fwdParams) {
|
||||
// Create memory descriptors for reorder data with specified format
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD(
|
||||
fwdParams.src_md, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD(
|
||||
fwdParams.dst_md, cpu_engine_, DummyData));
|
||||
context_.src_mem.reset(
|
||||
new memory(fwdParams.src_md, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(fwdParams.dst_md, cpu_engine_, DummyData));
|
||||
|
||||
// Check if there is any fusion as post-ops
|
||||
auto const& post_op_params = fwdParams.post_op_params;
|
||||
@ -150,21 +141,14 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
scales.push_back(post_op_params.param[0]);
|
||||
post_ops_attr.set_output_scales(0, scales);
|
||||
|
||||
context_.reorder_pd.reset(new REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.src_mem),
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.dst_mem), cpu_engine_,
|
||||
post_ops_attr));
|
||||
context_.reorder_pd.reset(
|
||||
new ReorderPd(cpu_engine_, context_.src_mem->get_desc(), cpu_engine_,
|
||||
context_.dst_mem->get_desc(), post_ops_attr));
|
||||
|
||||
// Create reorder primitive
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.reorder_prim.reset(new reorder(
|
||||
*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
|
||||
context_.net.push_back(*context_.reorder_prim);
|
||||
#else
|
||||
// Create reorder primitive
|
||||
context_.reorder_prim.reset(new reorder(*context_.reorder_pd));
|
||||
context_.prim_args.insert({MKLDNN_ARG_FROM, *context_.src_mem});
|
||||
context_.prim_args.insert({MKLDNN_ARG_TO, *context_.dst_mem});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
};
|
||||
|
||||
@ -278,7 +262,7 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
"Scalar calculation in MKL is supported only for"
|
||||
"MIN_FIRST mode for now."));
|
||||
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
auto cpu_engine = engine(engine::kind::cpu, 0);
|
||||
const Tensor& input = ctx->input(0);
|
||||
const unsigned int src_idx = 0;
|
||||
const Tensor& src_tensor = MklGetInput(ctx, src_idx);
|
||||
@ -344,7 +328,7 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
max_range = std::max(input_max_range, min_range + epsilon);
|
||||
// Clamping the max_range to zero since max_range can also be negative.
|
||||
max_range = std::max(0.0f, max_range);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
auto cpu_engine = engine(engine::kind::cpu, 0);
|
||||
const Tensor& src_tensor = MklGetInput(ctx, src_idx);
|
||||
MklDnnShape src_mkl_shape;
|
||||
GetMklShape(ctx, src_idx, &src_mkl_shape);
|
||||
@ -355,25 +339,25 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
: TFShapeToMklDnnDims(src_tensor.shape());
|
||||
auto output_dims = src_dims;
|
||||
// Set the dst layout to be the best mkl layout based on dims and type.
|
||||
MEMORY_FORMAT dst_layout_type;
|
||||
memory::format_tag dst_layout_type;
|
||||
switch (src_tf_shape.dims()) {
|
||||
case 0:
|
||||
ComputeScalar(ctx, min_range, max_range);
|
||||
return;
|
||||
case 1:
|
||||
dst_layout_type = MEMORY_FORMAT::x;
|
||||
dst_layout_type = memory::format_tag::x;
|
||||
break;
|
||||
case 2:
|
||||
dst_layout_type = MEMORY_FORMAT::nc;
|
||||
dst_layout_type = memory::format_tag::nc;
|
||||
break;
|
||||
case 3:
|
||||
dst_layout_type = MEMORY_FORMAT::tnc;
|
||||
dst_layout_type = memory::format_tag::tnc;
|
||||
break;
|
||||
case 4:
|
||||
dst_layout_type = MEMORY_FORMAT::nhwc;
|
||||
dst_layout_type = memory::format_tag::nhwc;
|
||||
break;
|
||||
case 5:
|
||||
dst_layout_type = MEMORY_FORMAT::ndhwc;
|
||||
dst_layout_type = memory::format_tag::ndhwc;
|
||||
break;
|
||||
default:
|
||||
OP_REQUIRES_OK(ctx,
|
||||
@ -417,9 +401,7 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
|
||||
memory::desc dst_md =
|
||||
memory::desc(src_dims, MklDnnType<T>(), dst_layout_type);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto dst_pd = memory::primitive_desc(dst_md, cpu_engine);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Standard shape assignments for layout pass
|
||||
MklDnnShape output_mkl_shape;
|
||||
TensorShape output_tf_shape;
|
||||
|
Loading…
x
Reference in New Issue
Block a user