Merge pull request #25533 from Intel-tensorflow:quantized_concat_part_1_new

PiperOrigin-RevId: 232718138
This commit is contained in:
TensorFlower Gardener 2019-02-06 12:10:36 -08:00
commit 86950c2c44
9 changed files with 461 additions and 6 deletions

View File

@ -1144,6 +1144,13 @@ tf_gen_op_libs(
deps = [":protos_all_cc"],
)
tf_gen_op_libs(
op_lib_names = [
"mkl_array_ops",
],
deps = [":protos_all_cc"],
)
tf_gen_op_libs(
op_lib_names = [
"audio_ops",
@ -1283,7 +1290,10 @@ cc_library(
":training_ops_op_lib",
":user_ops_op_lib",
":word2vec_ops",
] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(),
] + if_mkl([
":mkl_array_ops_op_lib",
":mkl_nn_ops_op_lib",
]) + tf_additional_cloud_op_deps(),
alwayslink = 1,
)
@ -4527,7 +4537,7 @@ tf_cc_test(
"//tensorflow/cc:scope",
"//tensorflow/core/kernels:cwise_op",
"//third_party/eigen3",
],
] + if_mkl([":mkl_array_ops_op_lib"]),
)
tf_cc_test(

View File

@ -24,9 +24,9 @@ const std::unordered_set<std::string>* GetExcludedOps() {
"GcsConfigureBlockCache", "GcsConfigureCredentials",
#ifdef INTEL_MKL
// QuantizedFusedOps for Intel CPU
"QuantizedConv2DAndRequantize", "QuantizedConv2DWithBias",
"QuantizedConv2DWithBiasAndRequantize", "QuantizedConv2DAndRelu",
"QuantizedConv2DAndReluAndRequantize",
"QuantizedConcatV2", "QuantizedConv2DAndRequantize",
"QuantizedConv2DWithBias", "QuantizedConv2DWithBiasAndRequantize",
"QuantizedConv2DAndRelu", "QuantizedConv2DAndReluAndRequantize",
"QuantizedConv2DWithBiasAndRelu",
"QuantizedConv2DWithBiasAndReluAndRequantize",
"QuantizedConv2DWithBiasSumAndRelu",

View File

@ -1303,6 +1303,12 @@ Status ConcatV2Shape(InferenceContext* c) {
c->num_inputs() - 1 /* dim_index */);
}
Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
return ConcatShapeHelper(c, 0 /* start_value_index */,
num_inputs_to_concat /* end_value_index */,
num_inputs_to_concat /* dim_index */);
}
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
ShapeHandle shape_x,
ShapeHandle shape_y,

View File

@ -279,6 +279,8 @@ Status ConcatShape(shape_inference::InferenceContext* c,
// Shape function for concat operations.
Status ConcatV2Shape(shape_inference::InferenceContext* c);
Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
// Shape function for binary operators that broadcast their inputs
// and with output to output_index.
// Note: out cannot be NULL.

View File

@ -6554,6 +6554,30 @@ tf_cc_test(
],
)
tf_cc_test_mkl(
name = "mkl_quantized_concat_op_test",
size = "small",
srcs = ["mkl_quantized_concat_op_test.cc"],
deps = [
":mkl_concat_op",
":ops_testutil",
":ops_util",
":quantization_utils",
":quantized_ops",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:mkl_array_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "quantized_batch_norm_op_test",
size = "small",

View File

@ -25,12 +25,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/concat_lib_cpu.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/mkl_util.h"
using mkldnn::concat;
using mkldnn::stream;
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -226,8 +228,50 @@ class MklConcatOp : public OpKernel {
// format and avoid calling eigen version.
if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
OpInputList input_mins, input_maxes;
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
// MKL-DNN concat does not support input tensors that have different
// ranges. Check if the ranges of the all input tensors are the same.
// If not, forward it to Eigen implementation.
OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins));
OP_REQUIRES(context, (input_mins.size() == N),
errors::InvalidArgument(
"QuantizedConcatOp : Expected mins input list length ",
input_mins.size(), " to equal values length ", N));
OP_REQUIRES_OK(context,
context->input_list("input_maxes", &input_maxes));
OP_REQUIRES(context, (input_maxes.size() == N),
errors::InvalidArgument(
"QuantizedConcatOp : Expected maxes input list length ",
input_maxes.size(), " to equal values length ", N));
float input_min = input_mins[0].flat<float>()(0);
float input_max = input_maxes[0].flat<float>()(0);
const float eps = 1.0e-6;
for (int i = 1; i < N; ++i) {
float min = input_mins[i].flat<float>()(0);
float max = input_maxes[i].flat<float>()(0);
if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) {
invoke_eigen = true;
break;
}
}
}
// Call Eigen library
if (invoke_eigen) {
// MKL-DNN quantized concat does not support input tensors with
// different ranges.
// TODO (mabuzain): Add quantized version of CallEigen() to support
// this case.
OP_REQUIRES(
context,
(!std::is_same<T, qint8>::value && !std::is_same<T, quint8>::value),
errors::Unimplemented("MKL DNN quantized concat does not "
"support input tensors that have "
"different ranges"));
CallEigenVersion(context, input_tensors, mkl_input_shapes);
return;
}
@ -374,6 +418,23 @@ class MklConcatOp : public OpKernel {
std::vector<primitive> net;
net.push_back(concat_op);
stream(stream::kind::eager).submit(net).wait();
// For quantized concat, min and max outputs are also computed.
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
Tensor* output_min = nullptr;
Tensor* output_max = nullptr;
MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
output_min_mkl_shape.SetMklTensor(false);
output_max_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(context, 1, &output_min, {},
output_min_mkl_shape);
AllocateOutputSetMklShape(context, 2, &output_max, {},
output_max_mkl_shape);
// All input tensors should have the same range, just use the
// first one
output_min->flat<float>()(0) = input_mins[0].flat<float>()(0);
output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0);
}
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@ -490,6 +551,20 @@ class MklConcatOp : public OpKernel {
TF_CALL_float(REGISTER_MKL_CPU);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
.Device(DEVICE_CPU)
.TypeConstraint<quint8>("T")
.HostMemory("axis")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("T")
.HostMemory("axis")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
#undef REGISTER_CONCAT_MKL
} // namespace tensorflow

View File

@ -0,0 +1,234 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define EIGEN_USE_THREADS
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
using test::graph::Constant;
static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
static const TensorShape dummy_shape({8});
// Helper class for converting MKL tensors to TF tensors and comparing to
// expected values
class ConvMklToTF : public OpsTestBase {
public:
template <typename T>
void ConvertMKL2TF(DataType dtype, const Tensor& first, const Tensor& second,
Tensor& output) {
// Create an MKL to TF conversion node and execute it
TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
.Input(FakeInput(dtype)) // Input
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Attr("T", dtype)
.Attr("_kernel", "MklOp")
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(first.shape(), first.flat<T>());
AddInputFromArray<uint8>(second.shape(), second.flat<uint8>());
TF_ASSERT_OK(RunOpKernel());
output = *GetOutput(0);
}
void TestBody(){};
};
class QuantizedConcatTest : public OpsTestBase {
protected:
QuantizedConcatTest() {}
void TestSmall8Bit(float first_min, float first_max, float second_min,
float second_max);
void TestSecondDim8Bit(float first_min, float first_max, float second_min,
float second_max);
};
TEST_F(QuantizedConcatTest, Small8BitSameRange) {
// Range for both is the same, so impl can use memcpy.
TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f);
}
void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max,
float second_min, float second_max) {
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
.Input(FakeInput(2, DT_QUINT8))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Attr("N", 2)
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("Tidx", DT_INT32)
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
const int first_batch = 2;
const int first_height = 2;
const int first_width = 3;
const int first_depth = 1;
Tensor first_float(DT_FLOAT,
{first_batch, first_height, first_width, first_depth});
test::FillValues<float>(&first_float,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor first_quantized =
FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
const int second_batch = 2;
const int second_height = 2;
const int second_width = 3;
const int second_depth = 1;
Tensor second_float(
DT_FLOAT, {second_batch, second_height, second_width, second_depth});
test::FillValues<float>(&second_float,
{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
Tensor second_quantized =
FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
const int expected_batch = first_batch + second_batch;
Tensor expected_float(
DT_FLOAT, {expected_batch, first_height, first_width, first_depth});
test::FillValues<float>(&expected_float,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
AddInputFromArray<quint8>(first_quantized.shape(),
first_quantized.flat<quint8>());
AddInputFromArray<quint8>(second_quantized.shape(),
second_quantized.flat<quint8>());
AddInputFromArray<int32>(TensorShape({}), {0});
AddInputFromArray<float>(TensorShape({}), {first_min});
AddInputFromArray<float>(TensorShape({}), {second_min});
AddInputFromArray<float>(TensorShape({}), {first_max});
AddInputFromArray<float>(TensorShape({}), {second_max});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
Tensor output_float =
QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
test::ExpectTensorNear<float>(expected_float, output_float, 0.2);
}
TEST_F(QuantizedConcatTest, SecondDim8BitSameRange) {
TestSecondDim8Bit(-10.0f, 150.0f, -10.0f, 150.0f);
}
void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max,
float second_min,
float second_max) {
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
.Input(FakeInput(2, DT_QUINT8))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Attr("N", 2)
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("Tidx", DT_INT32)
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
const int first_batch = 2;
const int first_height = 2;
const int first_width = 3;
const int first_depth = 1;
Tensor first_float(DT_FLOAT,
{first_batch, first_height, first_width, first_depth});
test::FillValues<float>(&first_float,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor first_quantized =
FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
const int second_batch = 2;
const int second_height = 2;
const int second_width = 3;
const int second_depth = 1;
Tensor second_float(
DT_FLOAT, {second_batch, second_height, second_width, second_depth});
test::FillValues<float>(&second_float,
{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
Tensor second_quantized =
FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
const int expected_height = first_height + second_height;
Tensor expected_float(
DT_FLOAT, {first_batch, expected_height, first_width, first_depth});
test::FillValues<float>(&expected_float,
{1, 2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18,
7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24});
AddInputFromArray<quint8>(first_quantized.shape(),
first_quantized.flat<quint8>());
AddInputFromArray<quint8>(second_quantized.shape(),
second_quantized.flat<quint8>());
AddInputFromArray<int32>(TensorShape({}), {1});
AddInputFromArray<float>(TensorShape({}), {first_min});
AddInputFromArray<float>(TensorShape({}), {second_min});
AddInputFromArray<float>(TensorShape({}), {first_max});
AddInputFromArray<float>(TensorShape({}), {second_max});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
Tensor output_float =
QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
// Using the same error tolerance as in Eigen QuantizedConcat test
test::ExpectTensorNear<float>(expected_float, output_float, 1.0);
}
} // namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL

View File

@ -246,4 +246,16 @@ REGISTER_QUANTIZED_CONCAT(qint32);
#undef REGISTER_QUANTIZED_CONCAT
#ifdef INTEL_MKL
#define REGISTER_QUANTIZED_CONCATV2(type) \
REGISTER_KERNEL_BUILDER(Name("QuantizedConcatV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("axis"), \
QuantizedConcatOp<type>)
REGISTER_QUANTIZED_CONCATV2(quint8);
REGISTER_QUANTIZED_CONCATV2(qint32);
#endif
} // namespace tensorflow

View File

@ -0,0 +1,92 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
// This file contains the registration of MKL-DNN array ops.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/strided_slice_op.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using shape_inference::UnchangedShape;
// Adding QuantizedConcatV2 op to be able to replace it by
// _MklQuantizedConcatV2 in the graph rewrite.
REGISTER_OP("QuantizedConcatV2")
.Input("values: N * T")
.Input("axis: Tidx")
.Input("input_mins: N * float32")
.Input("input_maxes: N * float32")
.Output("output: T")
.Output("output_min: float")
.Output("output_max: float")
.Attr("N: int >= 2")
.Attr("T: type")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
const int n = (c->num_inputs() - 1) / 3;
TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
ShapeHandle unused;
for (int i = n + 1; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();
});
REGISTER_OP("_MklQuantizedConcatV2")
.Input("values: N * T")
.Input("axis: Tidx")
.Input("input_mins: N * float32")
.Input("input_maxes: N * float32")
.Input("mkl_values: N * uint8")
.Input("mkl_axis: uint8")
.Input("mkl_input_mins: N * uint8")
.Input("mkl_input_maxes: N * uint8")
.Output("output: T")
.Output("output_min: float")
.Output("output_max: float")
.Output("mkl_output: uint8")
.Output("mkl_output_min: uint8")
.Output("mkl_output_max: uint8")
.Attr("N: int >= 2")
.Attr("T: type")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
const int n = (c->num_inputs() / 2 - 1) / 3;
TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
ShapeHandle unused;
for (int i = n + 1; i < c->num_inputs() / 2; ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();
});
} // namespace tensorflow
#endif