[INTEL MKL] Enabled MIN_FIRST support and primitive caching for MKL-DNN Quantize OP

This commit is contained in:
R Gomathi 2019-09-13 10:55:53 +05:30
parent 5f55cc82a9
commit c5a0825115
4 changed files with 376 additions and 37 deletions

View File

@ -1580,6 +1580,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
TryGetNodeAttr(n->def(), "axis", &axis);
TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
if (narrow_range) {
VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
<< "This case is not optimized by Intel MKL, "
@ -1593,8 +1594,9 @@ rinfo_.push_back({csinfo_.tanh_grad,
<< "thus using Eigen op for Quantize op ";
return false;
}
if (mode_string != "SCALED" || round_mode_string != "HALF_TO_EVEN") {
VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED and/or"
if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
(mode_string == "MIN_FIRST"))) {
VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
<< "rounding mode is not HALF_TO_EVEN. "
<< "This case is not optimized by Intel MKL, thus using Eigen op"
<< "for Quantize op ";
@ -1849,6 +1851,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
// NOTE: names are alphabetically sorted.
static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsConv2DDepthwiseCheckConstFilter(

View File

@ -59,6 +59,196 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklReorderWithScaleFwdParams {
memory::dims src_dims;
memory::desc src_md;
memory::desc dst_md;
string dtypes = string("");
struct PostOpParam {
string name;
std::vector<float> param;
};
PostOpParam post_op_params;
MklReorderWithScaleFwdParams(memory::dims src_dims, memory::desc src_md,
memory::desc dst_md)
: src_dims(src_dims), src_md(src_md), dst_md(dst_md) {}
};
class MklReorderWithScalePrimitive : public MklPrimitive {
public:
explicit MklReorderWithScalePrimitive(
const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
// create reorder primitive
Setup(from, to, fwdParams);
}
~MklReorderWithScalePrimitive() {}
std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
// set data handles
void SetMemory(const memory* from, const memory* to) {
context_.src_mem->set_data_handle(from->get_data_handle());
context_.dst_mem->set_data_handle(to->get_data_handle());
}
private:
// Primitive reuse context for reorder
struct ReorderContext {
// MKLDNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
// Memory desc
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> dst_md;
// Memory primitive desc
std::shared_ptr<mkldnn::memory::primitive_desc> src_mpd;
std::shared_ptr<mkldnn::memory::primitive_desc> dst_mpd;
// Reorder primitive descriptor and primitive
std::shared_ptr<reorder::primitive_desc> reorder_pd;
std::shared_ptr<primitive> reorder_prim;
ReorderContext()
: src_mem(nullptr),
dst_mem(nullptr),
src_md(nullptr),
dst_md(nullptr),
src_mpd(nullptr),
dst_mpd(nullptr),
reorder_pd(nullptr),
reorder_prim(nullptr) {}
} context_;
engine cpu_engine_ = engine(engine::cpu, 0);
// Reorder primitive setup
void Setup(const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
// Create memory descriptors for reorder data with specified format
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
context_.dst_md.reset(new memory::desc(fwdParams.dst_md.data));
context_.src_mpd.reset(
new memory::primitive_desc(*context_.src_md, cpu_engine_));
context_.dst_mpd.reset(
new memory::primitive_desc(*context_.dst_md, cpu_engine_));
// Check if there is any fusion as post-ops
auto const& post_op_params = fwdParams.post_op_params;
mkldnn::primitive_attr post_ops_attr;
if (post_op_params.name == "scale") {
DCHECK_EQ(post_op_params.param.size(), 1);
std::vector<float> scales;
scales.push_back(post_op_params.param[0]);
post_ops_attr.set_output_scales(0, scales);
} else {
DCHECK(post_op_params.name == "scale");
}
// Create a reorder
context_.reorder_pd =
std::make_shared<reorder::primitive_desc>(reorder::primitive_desc(
*context_.src_mpd, *context_.dst_mpd, post_ops_attr));
// Create memory primitive based on dummy data
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
context_.dst_mem.reset(new memory(*context_.dst_mpd, DummyData));
// Create reorder primitive
context_.reorder_prim = std::make_shared<reorder>(
reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem));
}
};
template <typename T>
class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklReorderWithScalePrimitive* Get(
const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
// Try to find a suitable primitive from the cached pool
auto reorderPrim = static_cast<MklReorderWithScalePrimitive*>(
MklReorderWithScalePrimitiveFactory<T>::GetInstance().GetReorder(
from, to, fwdParams));
if (reorderPrim == nullptr) {
reorderPrim = new MklReorderWithScalePrimitive(from, to, fwdParams);
MklReorderWithScalePrimitiveFactory<T>::GetInstance().SetReorder(
from, to, reorderPrim, fwdParams);
}
reorderPrim->SetMemory(from, to);
return reorderPrim;
}
static MklReorderWithScalePrimitiveFactory& GetInstance() {
static MklReorderWithScalePrimitiveFactory instance_;
return instance_;
}
private:
MklReorderWithScalePrimitiveFactory() {}
~MklReorderWithScalePrimitiveFactory() {}
static string CreateKey(const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
string dtypes = string("");
string prefix = "reorder";
FactoryKeyCreator key_creator;
auto const& from_desc = from->get_primitive_desc().desc().data;
auto const& to_desc = to->get_primitive_desc().desc().data;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(static_cast<int>(from_desc.format));
key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
key_creator.AddAsKey(fwdParams.src_dims);
key_creator.AddAsKey(static_cast<int>(to_desc.format));
key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
key_creator.AddAsKey(fwdParams.dtypes);
// Generate key for post-op scale
if (fwdParams.post_op_params.name == "scale") {
DCHECK_EQ(fwdParams.post_op_params.param.size(), 1);
key_creator.AddAsKey(fwdParams.post_op_params.name);
key_creator.AddAsKey(fwdParams.post_op_params.param[0]);
} else {
return string("not_a_key");
}
return key_creator.GetKey();
}
MklPrimitive* GetReorder(const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
string key = CreateKey(from, to, fwdParams);
return this->GetOp(key);
}
void SetReorder(const memory* from, const memory* to, MklPrimitive* op,
const MklReorderWithScaleFwdParams& fwdParams) {
string key = CreateKey(from, to, fwdParams);
this->SetOp(key, op);
}
};
/// Fuction to find(or create) a reorder from memory pointed by
/// from to memory pointed by to, it will create primitive or
/// get primitive from pool if it is cached.
/// Returns the primitive.
template <typename T>
inline primitive FindOrCreateReorder(
const memory* from, const memory* to,
const MklReorderWithScaleFwdParams& fwdParams) {
CHECK_NOTNULL(from);
CHECK_NOTNULL(to);
MklReorderWithScalePrimitive* reorder_prim =
MklReorderWithScalePrimitiveFactory<T>::Get(from, to, fwdParams);
return *reorder_prim->GetPrimitive();
}
// Quantizes a tensor from float to T, with user-specified min_range and
// max_range.
template <typename Device, typename T>
@ -67,6 +257,7 @@ class MklQuantizeV2Op : public OpKernel {
explicit MklQuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
string mode_string;
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
<<<<<<< HEAD
OP_REQUIRES(ctx,
(mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" ||
mode_string == "SCALED"),
@ -93,6 +284,7 @@ class MklQuantizeV2Op : public OpKernel {
if (round_mode_string == "HALF_AWAY_FROM_ZERO") {
round_mode_ = ROUND_HALF_AWAY_FROM_ZERO;
} else if (round_mode_string == "HALF_TO_EVEN") {
if (round_mode_string == "HALF_TO_EVEN") {
OP_REQUIRES(ctx, mode_string == "SCALED",
errors::InvalidArgument("Round mode 'HALF_TO_EVEN' "
"only supported for mode 'SCALED', "
@ -106,9 +298,22 @@ class MklQuantizeV2Op : public OpKernel {
ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_));
}
~MklQuantizeV2Op() {}
~MklQuantizeV2Op() {
if (this->minfirst_input_ != nullptr) {
delete this->minfirst_input_;
minfirst_input_ = nullptr;
}
}
float* GetMinfirstInputBuf(int size) {
if (!minfirst_input_) {
minfirst_input_ = new float[size];
}
return minfirst_input_;
}
void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(0);
const float input_min_range = ctx->input(1).flat<float>()(0);
const float input_max_range = ctx->input(2).flat<float>()(0);
float min_range = std::min(0.0f, input_min_range);
@ -174,7 +379,20 @@ class MklQuantizeV2Op : public OpKernel {
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<float>(), dst_layout_type);
src.SetUsrMem(src_md, &src_tensor);
// If the mode is min_first, input data has to be subtracted from
// min_range, before being scaled
auto flat_input = input.flat<float>().data();
if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
float* minfirst_input = GetMinfirstInputBuf(input.NumElements());
#pragma omp parallel for schedule(static)
for (int i = 0; i < input.NumElements(); i++) {
minfirst_input[i] = flat_input[i] - min_range;
}
src.SetUsrMem(src_md, minfirst_input);
} else {
src.SetUsrMem(src_md, &src_tensor);
}
memory::desc dst_md =
memory::desc(src_dims, MklDnnType<T>(), dst_layout_type);
@ -212,39 +430,65 @@ class MklQuantizeV2Op : public OpKernel {
max_mkl_shape);
dst.SetUsrMem(dst_md, output_tensor);
// Estimating scales for quantization.
const int num_bits = sizeof(T) * 8;
const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
const bool is_signed = std::is_signed<T>::value;
float target_range;
if (is_signed) {
max_range = max_abs;
min_range = -max_abs;
// If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
// example, if it is 8 bits, we have the range [-127, 127]. So for input
// range of [-x, x], the scale should be 254/(2*x).
target_range = static_cast<float>((uint64_t{1} << (num_bits - 1)) - 1);
} else {
max_range = max_abs;
min_range = 0.0;
// If it is unsigned and num_bits == 8, the range with 8 bits is [0,
// 255]. If the input range is [0, x], then the scale is 255/x instead
// of 254 as in the case above.
target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
float scale_factor = 0;
if (mode_ == QUANTIZE_MODE_SCALED) {
// Estimating scales for quantization.
const int num_bits = sizeof(T) * 8;
const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
const bool is_signed = std::is_signed<T>::value;
float target_range;
if (is_signed) {
max_range = max_abs;
min_range = -max_abs;
// If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
// example, if it is 8 bits, we have the range [-127, 127]. So for input
// range of [-x, x], the scale should be 254/(2*x).
target_range = static_cast<float>((uint64_t{1} << (num_bits - 1)) - 1);
} else {
max_range = max_abs;
min_range = 0.0;
// If it is unsigned and num_bits == 8, the range with 8 bits is [0,
// 255]. If the input range is [0, x], then the scale is 255/x instead
// of 254 as in the case above.
target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
}
scale_factor = target_range / max_abs;
output_min_tensor->flat<float>()(0) = min_range;
output_max_tensor->flat<float>()(0) = max_range;
// Primitive creation and stream submit
std::vector<float> scales{scale_factor};
mkldnn::primitive_attr attr;
attr.set_output_scales(0, scales);
auto reorder_desc = reorder::primitive_desc(
src.GetUsrMemPrimDesc(), dst.GetUsrMemPrimDesc(), attr);
reorder my_reorder = reorder(
reorder_desc, primitive::at(*src.GetUsrMem()), *dst.GetUsrMem());
std::vector<primitive> net{my_reorder};
stream(stream::kind::eager).submit(net).wait();
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
// Estimate scale for qunatization
const int number_of_bits = sizeof(T) * 8;
const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
scale_factor = (number_of_steps - 1.0) / (max_range - min_range);
output_min_tensor->flat<float>()(0) = min_range;
output_max_tensor->flat<float>()(0) = max_range;
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md);
fwdParams.dtypes.append(typeid(T).name());
fwdParams.post_op_params.name = "scale";
fwdParams.post_op_params.param.push_back(scale_factor);
// Get primitive from pool or create one and submit
std::vector<primitive> net;
net.push_back(
FindOrCreateReorder<T>(src.GetUsrMem(), dst.GetUsrMem(), fwdParams));
stream(stream::kind::eager).submit(net).wait();
}
output_min_tensor->flat<float>()(0) = min_range;
output_max_tensor->flat<float>()(0) = max_range;
const float scale_factor = target_range / max_abs;
// Primitive creation and stream submit
std::vector<float> scales{scale_factor};
mkldnn::primitive_attr attr;
attr.set_output_scales(0, scales);
auto reorder_desc = reorder::primitive_desc(src.GetUsrMemPrimDesc(),
dst.GetUsrMemPrimDesc(), attr);
reorder my_reorder = reorder(reorder_desc, primitive::at(*src.GetUsrMem()),
*dst.GetUsrMem());
std::vector<primitive> net{my_reorder};
stream(stream::kind::eager).submit(net).wait();
}
private:
@ -253,6 +497,7 @@ class MklQuantizeV2Op : public OpKernel {
int round_mode_;
int axis_;
bool narrow_range_;
float* minfirst_input_ = nullptr;
};
REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2")

View File

@ -95,4 +95,95 @@ TEST_F(MklQuantizeV2OpTest, small_int8) {
test::ExpectTensorEqual<float>(expected_min, *GetOutput(1));
test::ExpectTensorEqual<float>(expected_max, *GetOutput(2));
}
TEST_F(MklQuantizeV2OpTest, small_minfirst) {
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("mode", "MIN_FIRST")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
{1.0, 1.25, 1.75, 2, 3.15, 127.0, 255.0, 500.0});
AddInputFromArray<float>(TensorShape({1}), {0});
AddInputFromArray<float>(TensorShape({1}), {255.0f});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
test::FillValues<quint8>(&expected, {1, 1, 2, 2, 3, 127, 255, 255});
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
EXPECT_NEAR(0.0f, output_min, 1e-5f);
EXPECT_NEAR(255.0f, output_max, 1e-5f);
}
TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) {
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("mode", "MIN_FIRST")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
AddInputFromArray<float>(TensorShape({1}), {0.1});
AddInputFromArray<float>(TensorShape({1}), {0.8});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
test::FillValues<quint8>(&expected, {32, 64, 96, 128, 159, 191, 223, 255});
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
EXPECT_NEAR(0.0f, output_min, 1e-5f);
EXPECT_NEAR(0.8f, output_max, 1e-5f);
}
TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("mode", "MIN_FIRST")
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<float>(TensorShape({8}),
{-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});
AddInputFromArray<float>(TensorShape({1}), {-0.8});
AddInputFromArray<float>(TensorShape({1}), {-0.1});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
test::FillValues<quint8>(&expected, {223, 191, 159, 128, 96, 64, 32, 0});
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
EXPECT_NEAR(-0.8f, output_min, 1e-5f);
EXPECT_NEAR(0.0f, output_max, 1e-5f);
}
} // end namespace tensorflow

View File

@ -105,7 +105,7 @@ REGISTER_OP("_MklQuantizeV2")
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'")
.Attr(
"round_mode: {'HALF_AWAY_FROM_ZERO', 'HALF_TO_EVEN'} = "
"'HALF_TO_EVEN'")
"'HALF_AWAY_FROM_ZERO'")
.Attr("narrow_range: bool = false")
.Attr("axis: int = -1")
.Attr("ensure_minimum_range: float = 0.01")