Merge pull request #36761 from Intel-tensorflow:amin/quantize-dequantize
PiperOrigin-RevId: 295787750 Change-Id: If87db944c3a67e2803f04b2278c85a12e394e47c
This commit is contained in:
commit
f079f59af2
@ -17,18 +17,18 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/kernels/meta_support.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
using mkldnn::primitive_attr;
|
||||
using mkldnn::stream;
|
||||
|
||||
@ -51,7 +51,7 @@ class MklDequantizeOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
try {
|
||||
// Using CPU device
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
|
||||
// Get the inputs
|
||||
const Tensor& src_tensor = MklGetInput(ctx, kSrcIndex);
|
||||
@ -82,33 +82,28 @@ class MklDequantizeOp : public OpKernel {
|
||||
auto src_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(), memory::format::nhwc);
|
||||
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
|
||||
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
MklDnnShape output_mkl_shape;
|
||||
TensorShape output_tf_shape;
|
||||
|
||||
memory::primitive_desc src_pd =
|
||||
memory::primitive_desc(src_md, cpu_engine);
|
||||
memory::desc dst_md = src_mkl_shape.IsMklTensor()
|
||||
? src_md
|
||||
: memory::desc(src_dims, MklDnnType<float>(),
|
||||
memory::format::nhwc);
|
||||
memory::primitive_desc dst_pd =
|
||||
memory::primitive_desc(dst_md, cpu_engine);
|
||||
|
||||
MEMORY_FORMAT::nhwc);
|
||||
// If input is MKL shape, output is also MKL shape.
|
||||
// If input is TF shape, output is also TF shape.
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&dst_pd);
|
||||
output_mkl_shape.SetMklLayout(&dst_md);
|
||||
output_mkl_shape.SetElemType(MklDnnType<float>());
|
||||
output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(),
|
||||
src_mkl_shape.GetSizesAsMklDnnDims(),
|
||||
src_mkl_shape.GetTfDataFormat());
|
||||
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(float)));
|
||||
output_tf_shape.AddDim(GET_MEMORY_SIZE_FROM_MD(dst_md, cpu_engine) /
|
||||
sizeof(float));
|
||||
} else {
|
||||
output_mkl_shape.SetMklTensor(false);
|
||||
output_tf_shape = MklDnnDimsToTFShape(output_dims);
|
||||
@ -135,20 +130,35 @@ class MklDequantizeOp : public OpKernel {
|
||||
const float target_range =
|
||||
static_cast<float>((uint64_t{1} << target_bits) - 1);
|
||||
const float scale_factor = max_abs / target_range;
|
||||
|
||||
std::vector<float> scales;
|
||||
scales.push_back(scale_factor);
|
||||
primitive_attr attr;
|
||||
attr.set_output_scales(0, scales);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// MKL-DNN 1.0 does not provide set_int_output_round_mode() API.
|
||||
// Also it does not define round_nearest (enum).
|
||||
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
||||
mkldnn::reorder::primitive_desc reorder_pd =
|
||||
mkldnn::reorder::primitive_desc(src_pd, dst_pd, attr);
|
||||
|
||||
// Execute MKL-DNN primitive
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
stream reorder_stream = CPU_STREAM(cpu_engine);
|
||||
std::vector<primitive> net;
|
||||
net.push_back(
|
||||
mkldnn::reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem()));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
|
||||
// Create reorder primitive and then execute.
|
||||
auto reorder_pd = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(src.GetUsrMem()),
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(dst.GetUsrMem()), cpu_engine,
|
||||
attr);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
net.push_back(reorder(reorder_pd));
|
||||
std::vector<std::unordered_map<int, memory>> reorder_net_args;
|
||||
reorder_net_args.push_back({{MKLDNN_ARG_FROM, *src.GetUsrMem()},
|
||||
{ MKLDNN_ARG_TO,
|
||||
*dst.GetUsrMem() }});
|
||||
execute_primitives(net, std::make_shared<stream>(reorder_stream),
|
||||
reorder_net_args);
|
||||
#else
|
||||
net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem()));
|
||||
reorder_stream.submit(net);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
@ -17,9 +17,7 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "mkldnn.h"
|
||||
#include "mkldnn.hpp"
|
||||
#include "mkldnn_types.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/type_traits.h"
|
||||
@ -27,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::primitive_attr;
|
||||
@ -56,7 +55,6 @@ enum {
|
||||
} // namespace
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
struct MklReorderWithScaleFwdParams {
|
||||
@ -78,20 +76,28 @@ struct MklReorderWithScaleFwdParams {
|
||||
class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklReorderWithScalePrimitive(
|
||||
const memory* from, const memory* to,
|
||||
const MklReorderWithScaleFwdParams& fwdParams) {
|
||||
const MklReorderWithScaleFwdParams& fwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
// Create reorder primitive
|
||||
Setup(from, to, fwdParams);
|
||||
Setup(fwdParams);
|
||||
}
|
||||
|
||||
~MklReorderWithScalePrimitive() {}
|
||||
|
||||
std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
|
||||
|
||||
// set data handles
|
||||
void SetMemory(const memory* from, const memory* to) {
|
||||
context_.src_mem->set_data_handle(from->get_data_handle());
|
||||
context_.dst_mem->set_data_handle(to->get_data_handle());
|
||||
void Execute(void* src_data, void* dst_data) {
|
||||
context_.src_mem->set_data_handle(src_data);
|
||||
context_.dst_mem->set_data_handle(dst_data);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.reorder_stream->submit(context_.net);
|
||||
#else
|
||||
context_.reorder_prim->execute(*context_.reorder_stream,
|
||||
context_.prim_args);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
// After execution, set data handle back.
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -101,41 +107,36 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
// Memory desc
|
||||
std::shared_ptr<mkldnn::memory::desc> src_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> dst_md;
|
||||
|
||||
// Memory primitive desc
|
||||
std::shared_ptr<mkldnn::memory::primitive_desc> src_mpd;
|
||||
std::shared_ptr<mkldnn::memory::primitive_desc> dst_mpd;
|
||||
|
||||
// Reorder primitive descriptor and primitive
|
||||
std::shared_ptr<reorder::primitive_desc> reorder_pd;
|
||||
std::shared_ptr<primitive> reorder_prim;
|
||||
|
||||
// Stream and primitive vector
|
||||
std::shared_ptr<mkldnn::stream> reorder_stream;
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
std::vector<mkldnn::primitive> net;
|
||||
#else
|
||||
std::unordered_map<int, mkldnn::memory> prim_args;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
ReorderContext()
|
||||
: src_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
src_md(nullptr),
|
||||
dst_md(nullptr),
|
||||
src_mpd(nullptr),
|
||||
dst_mpd(nullptr),
|
||||
reorder_pd(nullptr),
|
||||
reorder_prim(nullptr) {}
|
||||
reorder_prim(nullptr),
|
||||
reorder_stream(nullptr) {}
|
||||
} context_;
|
||||
|
||||
engine cpu_engine_ = engine(engine::cpu, 0);
|
||||
engine cpu_engine_;
|
||||
|
||||
// Reorder primitive setup
|
||||
void Setup(const memory* from, const memory* to,
|
||||
const MklReorderWithScaleFwdParams& fwdParams) {
|
||||
void Setup(const MklReorderWithScaleFwdParams& fwdParams) {
|
||||
// Create memory descriptors for reorder data with specified format
|
||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||
context_.dst_md.reset(new memory::desc(fwdParams.dst_md.data));
|
||||
context_.src_mpd.reset(
|
||||
new memory::primitive_desc(*context_.src_md, cpu_engine_));
|
||||
context_.dst_mpd.reset(
|
||||
new memory::primitive_desc(*context_.dst_md, cpu_engine_));
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD(
|
||||
fwdParams.src_md, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD(
|
||||
fwdParams.dst_md, cpu_engine_, DummyData));
|
||||
|
||||
// Check if there is any fusion as post-ops
|
||||
auto const& post_op_params = fwdParams.post_op_params;
|
||||
@ -147,18 +148,22 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
scales.push_back(post_op_params.param[0]);
|
||||
post_ops_attr.set_output_scales(0, scales);
|
||||
|
||||
// Create a reorder
|
||||
context_.reorder_pd =
|
||||
std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
|
||||
*context_.src_mpd, *context_.dst_mpd, post_ops_attr));
|
||||
context_.reorder_pd.reset(new REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.src_mem),
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.dst_mem), cpu_engine_,
|
||||
post_ops_attr));
|
||||
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
|
||||
context_.dst_mem.reset(new memory(*context_.dst_mpd, DummyData));
|
||||
|
||||
// Create reorder primitive
|
||||
context_.reorder_prim = std::make_shared<reorder>(
|
||||
reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
|
||||
// Create reorder primitive
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.reorder_prim.reset(new reorder(
|
||||
*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
|
||||
context_.net.push_back(*context_.reorder_prim);
|
||||
#else
|
||||
context_.reorder_prim.reset(new reorder(*context_.reorder_pd));
|
||||
context_.prim_args.insert({MKLDNN_ARG_FROM, *context_.src_mem});
|
||||
context_.prim_args.insert({MKLDNN_ARG_TO, *context_.dst_mem});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
context_.reorder_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
}
|
||||
};
|
||||
|
||||
@ -173,11 +178,10 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
MklReorderWithScalePrimitiveFactory<T>::GetInstance().GetReorder(
|
||||
from, to, fwdParams));
|
||||
if (reorderPrim == nullptr) {
|
||||
reorderPrim = new MklReorderWithScalePrimitive(from, to, fwdParams);
|
||||
reorderPrim = new MklReorderWithScalePrimitive(fwdParams);
|
||||
MklReorderWithScalePrimitiveFactory<T>::GetInstance().SetReorder(
|
||||
from, to, reorderPrim, fwdParams);
|
||||
}
|
||||
reorderPrim->SetMemory(from, to);
|
||||
return reorderPrim;
|
||||
}
|
||||
|
||||
@ -192,20 +196,8 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
|
||||
static string CreateKey(const memory* from, const memory* to,
|
||||
const MklReorderWithScaleFwdParams& fwdParams) {
|
||||
string dtypes = string("");
|
||||
string prefix = "reorder";
|
||||
FactoryKeyCreator key_creator;
|
||||
auto const& from_desc = from->get_primitive_desc().desc().data;
|
||||
auto const& to_desc = to->get_primitive_desc().desc().data;
|
||||
|
||||
key_creator.AddAsKey(prefix);
|
||||
key_creator.AddAsKey(static_cast<int>(from_desc.format));
|
||||
key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
|
||||
key_creator.AddAsKey(fwdParams.src_dims);
|
||||
key_creator.AddAsKey(static_cast<int>(to_desc.format));
|
||||
key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
|
||||
key_creator.AddAsKey(fwdParams.dtypes);
|
||||
|
||||
key_creator.AddAsKey(MklReorderPrimitiveFactory<T>::CreateKey(from, to));
|
||||
// Generate key for post-op scale
|
||||
if (fwdParams.post_op_params.name == "scale") {
|
||||
DCHECK_EQ(fwdParams.post_op_params.param.size(), 1);
|
||||
@ -231,21 +223,6 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
}
|
||||
};
|
||||
|
||||
// Fuction to find (or create) a reorder from memory pointed by
|
||||
// 'from' to memory pointed by 'to', it will create primitive or
|
||||
// get primitive from pool if it is cached.
|
||||
// Returns the primitive.
|
||||
template <typename T>
|
||||
inline primitive FindOrCreateReorder(
|
||||
const memory* from, const memory* to,
|
||||
const MklReorderWithScaleFwdParams& fwdParams) {
|
||||
DCHECK(from);
|
||||
DCHECK(to);
|
||||
MklReorderWithScalePrimitive* reorder_prim =
|
||||
MklReorderWithScalePrimitiveFactory<T>::Get(from, to, fwdParams);
|
||||
return *reorder_prim->GetPrimitive();
|
||||
}
|
||||
|
||||
// Quantizes a tensor from float to T, with user-specified min_range and
|
||||
// max_range.
|
||||
template <typename Device, typename T>
|
||||
@ -300,7 +277,7 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
"Scalar calculation in MKL is supported only for"
|
||||
"MIN_FIRST mode for now."));
|
||||
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
const Tensor& input = ctx->input(0);
|
||||
const unsigned int src_idx = 0;
|
||||
const Tensor& src_tensor = MklGetInput(ctx, src_idx);
|
||||
@ -366,7 +343,7 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
max_range = std::max(input_max_range, min_range + epsilon);
|
||||
// Clamping the max_range to zero since max_range can also be negative.
|
||||
max_range = std::max(0.0f, max_range);
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
const Tensor& src_tensor = MklGetInput(ctx, src_idx);
|
||||
MklDnnShape src_mkl_shape;
|
||||
GetMklShape(ctx, src_idx, &src_mkl_shape);
|
||||
@ -377,25 +354,25 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
: TFShapeToMklDnnDims(src_tensor.shape());
|
||||
auto output_dims = src_dims;
|
||||
// Set the dst layout to be the best mkl layout based on dims and type.
|
||||
memory::format dst_layout_type;
|
||||
MEMORY_FORMAT dst_layout_type;
|
||||
switch (src_tf_shape.dims()) {
|
||||
case 0:
|
||||
ComputeScalar(ctx, min_range, max_range);
|
||||
return;
|
||||
case 1:
|
||||
dst_layout_type = memory::format::x;
|
||||
dst_layout_type = MEMORY_FORMAT::x;
|
||||
break;
|
||||
case 2:
|
||||
dst_layout_type = memory::format::nc;
|
||||
dst_layout_type = MEMORY_FORMAT::nc;
|
||||
break;
|
||||
case 3:
|
||||
dst_layout_type = memory::format::tnc;
|
||||
dst_layout_type = MEMORY_FORMAT::tnc;
|
||||
break;
|
||||
case 4:
|
||||
dst_layout_type = memory::format::nhwc;
|
||||
dst_layout_type = MEMORY_FORMAT::nhwc;
|
||||
break;
|
||||
case 5:
|
||||
dst_layout_type = memory::format::ndhwc;
|
||||
dst_layout_type = MEMORY_FORMAT::ndhwc;
|
||||
break;
|
||||
default:
|
||||
OP_REQUIRES_OK(ctx,
|
||||
@ -414,11 +391,11 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
// If the mode is min_first, input data has to be subtracted from
|
||||
// min_range, before being scaled
|
||||
auto flat_input = input.flat<float>().data();
|
||||
Tensor minfirst_tmpinput;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(DT_FLOAT, input.shape(), &minfirst_tmpinput));
|
||||
Tensor min_shifted_input_tensor;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, input.shape(),
|
||||
&min_shifted_input_tensor));
|
||||
if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
|
||||
auto minfirst_input = minfirst_tmpinput.flat<float>().data();
|
||||
auto minfirst_input = min_shifted_input_tensor.flat<float>().data();
|
||||
const Eigen::TensorOpCost cost(
|
||||
sizeof(float), /*load bytes*/
|
||||
sizeof(float), /*saved bytes*/
|
||||
@ -432,25 +409,27 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
};
|
||||
d.parallelFor(input.NumElements(), cost, ParallelSub);
|
||||
|
||||
src.SetUsrMem(src_md, minfirst_input);
|
||||
src.SetUsrMem(src_md, &min_shifted_input_tensor);
|
||||
} else {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
}
|
||||
|
||||
memory::desc dst_md =
|
||||
memory::desc(src_dims, MklDnnType<T>(), dst_layout_type);
|
||||
auto dst_pd = src.GetUsrMemPrimDesc();
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto dst_pd = memory::primitive_desc(dst_md, cpu_engine);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
// Standard shape assignments for layout pass
|
||||
MklDnnShape output_mkl_shape;
|
||||
TensorShape output_tf_shape;
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&dst_md);
|
||||
output_mkl_shape.SetMklLayout(&DST_MD);
|
||||
output_mkl_shape.SetElemType(MklDnnType<T>());
|
||||
output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(),
|
||||
src_mkl_shape.GetSizesAsMklDnnDims(),
|
||||
src_mkl_shape.GetTfDataFormat());
|
||||
output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T));
|
||||
output_tf_shape.AddDim(DST_MD.get_size() / sizeof(T));
|
||||
} else {
|
||||
output_mkl_shape.SetMklTensor(false);
|
||||
output_tf_shape = MklDnnDimsToTFShape(output_dims);
|
||||
@ -459,6 +438,8 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
Tensor* output_tensor = nullptr;
|
||||
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
|
||||
output_mkl_shape);
|
||||
dst.SetUsrMem(dst_md, output_tensor);
|
||||
|
||||
TensorShape min_tf_shape = {};
|
||||
MklDnnShape min_mkl_shape;
|
||||
min_mkl_shape.SetMklTensor(false);
|
||||
@ -472,8 +453,6 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
AllocateOutputSetMklShape(ctx, 2, &output_max_tensor, max_tf_shape,
|
||||
max_mkl_shape);
|
||||
|
||||
dst.SetUsrMem(dst_md, output_tensor);
|
||||
|
||||
float scale_factor = 0;
|
||||
if (mode_ == QUANTIZE_MODE_SCALED) {
|
||||
// Estimating scales for quantization.
|
||||
@ -497,41 +476,25 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
|
||||
}
|
||||
scale_factor = target_range / max_abs;
|
||||
|
||||
output_min_tensor->flat<float>()(0) = min_range;
|
||||
output_max_tensor->flat<float>()(0) = max_range;
|
||||
|
||||
// Primitive creation and stream submit
|
||||
std::vector<float> scales{scale_factor};
|
||||
mkldnn::primitive_attr attr;
|
||||
attr.set_output_scales(0, scales);
|
||||
auto reorder_desc = reorder::primitive_desc(
|
||||
src.GetUsrMemPrimDesc(), dst.GetUsrMemPrimDesc(), attr);
|
||||
reorder my_reorder = reorder(
|
||||
reorder_desc, primitive::at(*src.GetUsrMem()), *dst.GetUsrMem());
|
||||
std::vector<primitive> net{my_reorder};
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
|
||||
// Estimate scale for qunatization
|
||||
const int number_of_bits = sizeof(T) * 8;
|
||||
const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
|
||||
scale_factor = (number_of_steps - 1.0) / (max_range - min_range);
|
||||
|
||||
output_min_tensor->flat<float>()(0) = min_range;
|
||||
output_max_tensor->flat<float>()(0) = max_range;
|
||||
|
||||
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md);
|
||||
fwdParams.dtypes.append(typeid(T).name());
|
||||
|
||||
fwdParams.post_op_params.name = "scale";
|
||||
fwdParams.post_op_params.param.push_back(scale_factor);
|
||||
|
||||
// Get primitive from pool or create one and submit
|
||||
std::vector<primitive> net;
|
||||
net.push_back(
|
||||
FindOrCreateReorder<T>(src.GetUsrMem(), dst.GetUsrMem(), fwdParams));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
}
|
||||
|
||||
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md);
|
||||
fwdParams.dtypes.append(typeid(T).name());
|
||||
fwdParams.post_op_params.name = "scale";
|
||||
fwdParams.post_op_params.param.push_back(scale_factor);
|
||||
|
||||
MklReorderWithScalePrimitive* reorder_prim =
|
||||
MklReorderWithScalePrimitiveFactory<T>::Get(src.GetUsrMem(),
|
||||
dst.GetUsrMem(), fwdParams);
|
||||
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle());
|
||||
|
||||
output_min_tensor->flat<float>()(0) = min_range;
|
||||
output_max_tensor->flat<float>()(0) = max_range;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -39,6 +39,7 @@ namespace tensorflow {
|
||||
#define GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr) mem_ptr->get_desc()
|
||||
#define GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(mem_ptr) \
|
||||
GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr)
|
||||
#define GET_MEMORY_SIZE_FROM_MD(md, engine) md.get_size()
|
||||
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc()
|
||||
#define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) op_pd->diff_dst_desc()
|
||||
#define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc()
|
||||
@ -131,6 +132,8 @@ namespace tensorflow {
|
||||
#define GET_BLOCK_STRIDES(strides, idx) strides[(idx)]
|
||||
#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
|
||||
{ {dims}, MklDnnType<type>(), fm }
|
||||
#define GET_MEMORY_SIZE_FROM_MD(md, engine) \
|
||||
memory::primitive_desc(md, engine).get_size()
|
||||
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc()
|
||||
#define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) \
|
||||
op_pd.get()->diff_dst_primitive_desc()
|
||||
|
@ -2078,10 +2078,6 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
return instance_;
|
||||
}
|
||||
|
||||
private:
|
||||
MklReorderPrimitiveFactory() {}
|
||||
~MklReorderPrimitiveFactory() {}
|
||||
|
||||
static string CreateKey(const memory* from, const memory* to) {
|
||||
string prefix = "reorder";
|
||||
FactoryKeyCreator key_creator;
|
||||
@ -2117,6 +2113,10 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
private:
|
||||
MklReorderPrimitiveFactory() {}
|
||||
~MklReorderPrimitiveFactory() {}
|
||||
|
||||
MklPrimitive* GetReorder(const memory* from, const memory* to) {
|
||||
string key = CreateKey(from, to);
|
||||
return this->GetOp(key);
|
||||
|
Loading…
Reference in New Issue
Block a user