[Intel MKL] Fix dequantize accuracy issue and re-enable this OP
This commit is contained in:
parent
16d4b320f0
commit
60da3fbda7
@ -359,9 +359,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.mul = "Mul";
|
csinfo_.mul = "Mul";
|
||||||
csinfo_.squared_difference = "SquaredDifference";
|
csinfo_.squared_difference = "SquaredDifference";
|
||||||
csinfo_.sub = "Sub";
|
csinfo_.sub = "Sub";
|
||||||
// End - element-wise ops. See note above.
|
// End - element-wise ops. See note above.
|
||||||
|
|
||||||
// NOTE: names are alphabetically sorted.
|
// NOTE: names are alphabetically sorted.
|
||||||
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
|
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
|
||||||
CopyAttrsAll, AlwaysRewrite,
|
CopyAttrsAll, AlwaysRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
@ -671,18 +671,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back(
|
rinfo_.push_back(
|
||||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||||
// Disable these two MKL operators for now due to some test failures caused
|
// Disable these two MKL operators for now due to some test failures caused
|
||||||
// by these two ops
|
// by these two ops
|
||||||
/*
|
/*
|
||||||
rinfo_.push_back({csinfo_.tanh,
|
rinfo_.push_back({csinfo_.tanh,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
mkl_op_registry::GetMklOpName(csinfo_.tanh),
|
||||||
CopyAttrsAll, AlwaysRewrite,
|
CopyAttrsAll, AlwaysRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
rinfo_.push_back({csinfo_.tanh_grad,
|
rinfo_.push_back({csinfo_.tanh_grad,
|
||||||
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
|
||||||
CopyAttrsAll, AlwaysRewrite,
|
CopyAttrsAll, AlwaysRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
*/
|
*/
|
||||||
rinfo_.push_back(
|
rinfo_.push_back(
|
||||||
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||||
@ -1478,9 +1478,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
|
|||||||
"Eigen op for Dequantize op.";
|
"Eigen op for Dequantize op.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// TODO(sriniva2/mabuzain) Enable the op after verifying support for
|
return true;
|
||||||
// object detection models
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rewrite rule for _FusedMatMul.
|
// Rewrite rule for _FusedMatMul.
|
||||||
|
@ -7976,6 +7976,7 @@ tf_cc_test_mkl(
|
|||||||
srcs = ["mkl_dequantize_op_test.cc"],
|
srcs = ["mkl_dequantize_op_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":mkl_dequantize_op",
|
":mkl_dequantize_op",
|
||||||
|
":mkl_tfconv_op",
|
||||||
":ops_testutil",
|
":ops_testutil",
|
||||||
":ops_util",
|
":ops_util",
|
||||||
"//tensorflow/core:array_ops_op_lib",
|
"//tensorflow/core:array_ops_op_lib",
|
||||||
@ -7984,6 +7985,7 @@ tf_cc_test_mkl(
|
|||||||
"//tensorflow/core:mkl_array_ops_op_lib",
|
"//tensorflow/core:mkl_array_ops_op_lib",
|
||||||
"//tensorflow/core:nn_ops_op_lib",
|
"//tensorflow/core:nn_ops_op_lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
@ -92,10 +92,12 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
|
|
||||||
memory::primitive_desc src_pd =
|
memory::primitive_desc src_pd =
|
||||||
memory::primitive_desc(src_md, cpu_engine);
|
memory::primitive_desc(src_md, cpu_engine);
|
||||||
memory::desc dst_md = src_mkl_shape.IsMklTensor()
|
memory::desc dst_md =
|
||||||
? src_md
|
src_mkl_shape.IsMklTensor()
|
||||||
: memory::desc(src_dims, MklDnnType<float>(),
|
? memory::desc(src_dims, MklDnnType<float>(),
|
||||||
memory::format::nhwc);
|
static_cast<memory::format>(src_md.data.format))
|
||||||
|
: memory::desc(src_dims, MklDnnType<float>(),
|
||||||
|
memory::format::nhwc);
|
||||||
memory::primitive_desc dst_pd =
|
memory::primitive_desc dst_pd =
|
||||||
memory::primitive_desc(dst_md, cpu_engine);
|
memory::primitive_desc(dst_md, cpu_engine);
|
||||||
|
|
||||||
@ -150,9 +152,9 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
mkldnn::reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem()));
|
mkldnn::reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem()));
|
||||||
stream(stream::kind::eager).submit(net).wait();
|
stream(stream::kind::eager).submit(net).wait();
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status: " + std::to_string(e.status) +
|
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||||
", message: " + string(e.message) + ", in file " +
|
string(e.message) + ", in file " + string(__FILE__) +
|
||||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
":" + std::to_string(__LINE__);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class MklDequantizeOpTest : public OpsTestBase {};
|
class MklDequantizeOpTest : public OpsTestBase {};
|
||||||
@ -59,4 +61,83 @@ TEST_F(MklDequantizeOpTest, small) {
|
|||||||
test::ExpectTensorNear<float>(expected, output, 0.1);
|
test::ExpectTensorNear<float>(expected, output, 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor CreateMklInput() {
|
||||||
|
MklDnnShape mkl_shape;
|
||||||
|
memory::desc md =
|
||||||
|
memory::desc({1, 2, 2, 2}, MklDnnType<uint8>(), memory::format::nhwc);
|
||||||
|
mkl_shape.SetMklTensor(true);
|
||||||
|
mkl_shape.SetMklLayout(&md);
|
||||||
|
mkl_shape.SetElemType(MklDnnType<uint8>());
|
||||||
|
mkl_shape.SetTfLayout(4, {1, 2, 2, 2}, memory::format::nhwc);
|
||||||
|
|
||||||
|
DataType dtype = DataTypeToEnum<uint8>::v();
|
||||||
|
Tensor mkl_tensor(dtype, {mkl_shape.GetSerializeBufferSize()});
|
||||||
|
mkl_shape.SerializeMklDnnShape(
|
||||||
|
mkl_tensor.flat<uint8>().data(),
|
||||||
|
mkl_tensor.flat<uint8>().size() * sizeof(uint8));
|
||||||
|
return mkl_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class CommonTestUtilities : public OpsTestBase {
|
||||||
|
public:
|
||||||
|
void MklToTF(const Tensor& tensor, const Tensor& mkl_meta_tensor,
|
||||||
|
Tensor* output) {
|
||||||
|
// Create an MKL to TF conversion node and execute it
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
|
||||||
|
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||||
|
.Input(FakeInput(DT_UINT8)) // MKL second tensor
|
||||||
|
.Attr("T", DataTypeToEnum<T>::v())
|
||||||
|
.Attr("_kernel", "MklLayoutDependentOp")
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
AddInputFromArray<T>(tensor.shape(), tensor.flat<T>());
|
||||||
|
AddInputFromArray<uint8>(mkl_meta_tensor.shape(),
|
||||||
|
mkl_meta_tensor.flat<uint8>());
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
*output = *GetOutput(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ConvertAndCompare(const Tensor& tensor, const Tensor& mkl_meta_tensor,
|
||||||
|
const Tensor& expected) {
|
||||||
|
Tensor output;
|
||||||
|
MklToTF(tensor, mkl_meta_tensor, &output);
|
||||||
|
test::ExpectTensorNear<T>(expected, output, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestBody() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MklDequantizeOpTest, MKLInput) {
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("dequantize_op", "_MklDequantize")
|
||||||
|
.Input(FakeInput(DT_QUINT8))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_UINT8)) // MKL second tensor
|
||||||
|
.Input(FakeInput(DT_UINT8)) // MKL second tensor
|
||||||
|
.Input(FakeInput(DT_UINT8)) // MKL second tensor
|
||||||
|
.Attr("T", DataTypeToEnum<quint8>::v())
|
||||||
|
.Attr("mode", "SCALED")
|
||||||
|
.Attr("_kernel", "QuantizedMklOp")
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
AddInputFromArray<quint8>(TensorShape({1, 2, 2, 2}),
|
||||||
|
{0, 10, 50, 40, 25, 115, 190, 255});
|
||||||
|
// min_range = 0
|
||||||
|
AddInputFromArray<float>(TensorShape({1}), {0});
|
||||||
|
// max_range = 200
|
||||||
|
AddInputFromArray<float>(TensorShape({1}), {200.0f});
|
||||||
|
auto mkl_tensor = CreateMklInput();
|
||||||
|
AddInputFromArray<uint8>(mkl_tensor.shape(), mkl_tensor.flat<uint8>());
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 2}));
|
||||||
|
test::FillValues<float>(&expected,
|
||||||
|
{0.0, 7.84, 39.21, 31.37, 19.6, 90.2, 149.0, 200});
|
||||||
|
CommonTestUtilities<float> test_util;
|
||||||
|
test_util.ConvertAndCompare(*GetOutput(0), *GetOutput(1), expected);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -132,7 +132,7 @@ class MklReshapeOp : public OpKernel {
|
|||||||
" values, but the requested shape has ",
|
" values, but the requested shape has ",
|
||||||
shape.num_elements()));
|
shape.num_elements()));
|
||||||
|
|
||||||
if (input_in_mkl_format) {
|
if (input_in_mkl_format && !SkipReorder(mkl_shape_input, shape)) {
|
||||||
TensorShape& shape_to = shape;
|
TensorShape& shape_to = shape;
|
||||||
TensorShape shape_from = mkl_shape_input.GetTfShape();
|
TensorShape shape_from = mkl_shape_input.GetTfShape();
|
||||||
if (shape_from == shape_to) {
|
if (shape_from == shape_to) {
|
||||||
@ -152,65 +152,36 @@ class MklReshapeOp : public OpKernel {
|
|||||||
// Tensorflow, we don't need to reorder tensor contents, we just
|
// Tensorflow, we don't need to reorder tensor contents, we just
|
||||||
// need to update MklDnnShape object associated with the input
|
// need to update MklDnnShape object associated with the input
|
||||||
// tensor to reflect the shape change expected by reshape.
|
// tensor to reflect the shape change expected by reshape.
|
||||||
if (!SkipReorder(mkl_shape_input, shape_to)) {
|
// If dimensions that are being expanded or collapsed are not
|
||||||
// If dimensions that are being expanded or collapsed are not
|
// maintained contiguously by MKLDNN, then we use reorder.
|
||||||
// maintained contiguously by MKLDNN, then we use reorder.
|
|
||||||
|
|
||||||
// Get Mkl layout of input tensor.
|
// Get Mkl layout of input tensor.
|
||||||
auto input_mkl_md = mkl_shape_input.GetMklLayout();
|
auto input_mkl_md = mkl_shape_input.GetMklLayout();
|
||||||
// Set input Mkl layout as the user layout.
|
// Set input Mkl layout as the user layout.
|
||||||
dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor);
|
dnn_data_input.SetUsrMem(input_mkl_md, &input_tensor);
|
||||||
// Get expected Tensorflow layout of input tensor.
|
// Get expected Tensorflow layout of input tensor.
|
||||||
auto output_tf_md = mkl_shape_input.GetTfLayout();
|
auto output_tf_md = mkl_shape_input.GetTfLayout();
|
||||||
auto output_tf_pd =
|
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
|
||||||
memory::primitive_desc(output_tf_md, cpu_engine);
|
|
||||||
|
|
||||||
Tensor* output_tensor = nullptr;
|
Tensor* output_tensor = nullptr;
|
||||||
MklDnnShape mkl_shape_output;
|
MklDnnShape mkl_shape_output;
|
||||||
mkl_shape_output.SetMklTensor(false);
|
mkl_shape_output.SetMklTensor(false);
|
||||||
// We allocate output tensor in the shape expected by Reshape.
|
// We allocate output tensor in the shape expected by Reshape.
|
||||||
AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor,
|
AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor,
|
||||||
shape_to, mkl_shape_output);
|
shape_to, mkl_shape_output);
|
||||||
|
|
||||||
// Insert reorder between Mkl layout and TensorFlow layout if
|
// Insert reorder between Mkl layout and TensorFlow layout if
|
||||||
// needed. If reorder is not needed but reshape is needed (since
|
// needed. If reorder is not needed but reshape is needed (since
|
||||||
// shape_from != shape_to), then we just copy input tensor to
|
// shape_from != shape_to), then we just copy input tensor to
|
||||||
// output tensor with target shape (we cannot forward Mkl layout
|
// output tensor with target shape (we cannot forward Mkl layout
|
||||||
// in such case because shape has changed.)
|
// in such case because shape has changed.)
|
||||||
if (dnn_data_input.CheckReorderToOpMem(output_tf_pd,
|
if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, output_tensor)) {
|
||||||
output_tensor)) {
|
|
||||||
} else {
|
|
||||||
OP_REQUIRES(
|
|
||||||
context, output_tensor->CopyFrom(input_tensor, shape_to),
|
|
||||||
errors::InvalidArgument("invalid input tensor shape"));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
} else {
|
} else {
|
||||||
// If dimensions that are being expanded or collapsed are
|
OP_REQUIRES(context,
|
||||||
// maintained contiguously by MKLDNN, then we skip reorder, just
|
output_tensor->CopyFrom(input_tensor, shape_to),
|
||||||
// update MklDnnShape object for the tensorflow tensor, and forward
|
errors::InvalidArgument("invalid input tensor shape"));
|
||||||
// Tensorflow tensor as it is to the output.
|
|
||||||
auto output_dims = TFShapeToMklDnnDims(shape_to);
|
|
||||||
auto output_strides = CalculateTFStrides(output_dims);
|
|
||||||
auto output_tf_md = MklDnnData<T>::CreateBlockedMemDesc(
|
|
||||||
output_dims, output_strides);
|
|
||||||
auto output_tf_pd =
|
|
||||||
memory::primitive_desc(output_tf_md, cpu_engine);
|
|
||||||
|
|
||||||
// Set MklDnnShape
|
|
||||||
MklDnnShape mkl_shape_output;
|
|
||||||
mkl_shape_output.SetMklTensor(true);
|
|
||||||
mkl_shape_output.SetMklLayout(&output_tf_pd);
|
|
||||||
mkl_shape_output.SetElemType(MklDnnType<T>());
|
|
||||||
mkl_shape_output.SetTfLayout(output_dims.size(), output_dims,
|
|
||||||
memory::format::blocked);
|
|
||||||
|
|
||||||
// We now simply forward input Mkl tensor to output and change its
|
|
||||||
// output MklDnnShape object.
|
|
||||||
ForwardMklTensorInToOutWithMklShape(
|
|
||||||
context, kInputSlotIdx, kOutputSlotIdx, mkl_shape_output);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status: " + std::to_string(e.status) +
|
string error_msg = "Status: " + std::to_string(e.status) +
|
||||||
", message: " + string(e.message) + ", in file " +
|
", message: " + string(e.message) + ", in file " +
|
||||||
|
@ -142,7 +142,10 @@ REGISTER_OP("_MklDequantize")
|
|||||||
.Output("output: float")
|
.Output("output: float")
|
||||||
.Output("mkl_output: uint8")
|
.Output("mkl_output: uint8")
|
||||||
.Attr("T: quantizedtype")
|
.Attr("T: quantizedtype")
|
||||||
|
.Attr("narrow_range: bool = false")
|
||||||
|
.Attr("axis: int = -1")
|
||||||
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'")
|
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST', 'SCALED'} = 'SCALED'")
|
||||||
|
.Attr("dtype: {bfloat16, float} = DT_FLOAT")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
|
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
|
||||||
ShapeHandle unused;
|
ShapeHandle unused;
|
||||||
|
@ -728,9 +728,9 @@ inline Status ConvertMklToTF(OpKernelContext* context,
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status: " + std::to_string(e.status) +
|
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||||
", message: " + string(e.message) + ", in file " +
|
string(e.message) + ", in file " + string(__FILE__) +
|
||||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
":" + std::to_string(__LINE__);
|
||||||
LOG(FATAL) << "Operation received an exception: " << error_msg;
|
LOG(FATAL) << "Operation received an exception: " << error_msg;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1011,6 +1011,11 @@ memory::data_type MklDnnType<quint8>() {
|
|||||||
return memory::data_type::u8;
|
return memory::data_type::u8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
memory::data_type MklDnnType<uint8>() {
|
||||||
|
return memory::data_type::u8;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
memory::data_type MklDnnType<qint8>() {
|
memory::data_type MklDnnType<qint8>() {
|
||||||
return memory::data_type::s8;
|
return memory::data_type::s8;
|
||||||
@ -1250,8 +1255,8 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
|
|||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
return Status(error::Code::INTERNAL,
|
return Status(error::Code::INTERNAL,
|
||||||
tensorflow::strings::StrCat(
|
tensorflow::strings::StrCat(
|
||||||
"Failed to create blocked memory descriptor.",
|
"Failed to create blocked memory descriptor.", "Status: ",
|
||||||
"Status: ", e.status, ", message: ", e.message));
|
e.status, ", message: ", e.message));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// We have to construct memory descriptor in a C style. This is not at all
|
// We have to construct memory descriptor in a C style. This is not at all
|
||||||
|
Loading…
Reference in New Issue
Block a user