Quantized Concat with added #ifdef INTEL_MKL

This commit is contained in:
Mahmoud Abuzaina 2019-02-05 16:51:28 -08:00
parent cf549e5f36
commit cbcbfe0267
9 changed files with 479 additions and 34 deletions

View File

@ -1144,6 +1144,13 @@ tf_gen_op_libs(
deps = [":protos_all_cc"],
)
tf_gen_op_libs(
op_lib_names = [
"mkl_array_ops",
],
deps = [":protos_all_cc"],
)
tf_gen_op_libs(
op_lib_names = [
"audio_ops",
@ -1283,7 +1290,10 @@ cc_library(
":training_ops_op_lib",
":user_ops_op_lib",
":word2vec_ops",
] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(),
] + if_mkl([
":mkl_array_ops_op_lib",
":mkl_nn_ops_op_lib",
]) + tf_additional_cloud_op_deps(),
alwayslink = 1,
)
@ -4523,7 +4533,7 @@ tf_cc_test(
"//tensorflow/cc:scope",
"//tensorflow/core/kernels:cwise_op",
"//third_party/eigen3",
],
] + if_mkl([":mkl_array_ops_op_lib"]),
)
tf_cc_test(

View File

@ -24,9 +24,9 @@ const std::unordered_set<std::string>* GetExcludedOps() {
"GcsConfigureBlockCache", "GcsConfigureCredentials",
#ifdef INTEL_MKL
// QuantizedFusedOps for Intel CPU
"QuantizedConv2DAndRequantize", "QuantizedConv2DWithBias",
"QuantizedConv2DWithBiasAndRequantize", "QuantizedConv2DAndRelu",
"QuantizedConv2DAndReluAndRequantize",
"QuantizedConcatV2", "QuantizedConv2DAndRequantize",
"QuantizedConv2DWithBias", "QuantizedConv2DWithBiasAndRequantize",
"QuantizedConv2DAndRelu", "QuantizedConv2DAndReluAndRequantize",
"QuantizedConv2DWithBiasAndRelu",
"QuantizedConv2DWithBiasAndReluAndRequantize",
"QuantizedConv2DWithBiasSumAndRelu",

View File

@ -57,9 +57,8 @@ Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
if (*output_size < 0) {
return errors::InvalidArgument(
"Computed output size would be negative: ", *output_size,
" [input_size: ", input_size,
", effective_filter_size: ", effective_filter_size,
", stride: ", stride, "]");
" [input_size: ", input_size, ", effective_filter_size: ",
effective_filter_size, ", stride: ", stride, "]");
}
return Status::OK();
}
@ -1303,6 +1302,12 @@ Status ConcatV2Shape(InferenceContext* c) {
c->num_inputs() - 1 /* dim_index */);
}
Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
return ConcatShapeHelper(c, 0 /* start_value_index */,
num_inputs_to_concat /* end_value_index */,
num_inputs_to_concat /* dim_index */);
}
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
ShapeHandle shape_x,
ShapeHandle shape_y,
@ -1566,11 +1571,10 @@ Status ScatterNdUpdateShape(InferenceContext* c) {
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
"The outer ", num_outer_dims,
" dimensions of indices.shape=", c->DebugString(indices_shape),
" must match the outer ", num_outer_dims,
" dimensions of updates.shape=", c->DebugString(updates_shape),
": ", s.error_message());
"The outer ", num_outer_dims, " dimensions of indices.shape=",
c->DebugString(indices_shape), " must match the outer ",
num_outer_dims, " dimensions of updates.shape=",
c->DebugString(updates_shape), ": ", s.error_message());
}
ShapeHandle input_suffix;

View File

@ -279,6 +279,8 @@ Status ConcatShape(shape_inference::InferenceContext* c,
// Shape function for concat operations.
Status ConcatV2Shape(shape_inference::InferenceContext* c);
Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
// Shape function for binary operators that broadcast their inputs
// and with output to output_index.
// Note: out cannot be NULL.

View File

@ -6561,6 +6561,30 @@ tf_cc_test(
],
)
tf_cc_test_mkl(
name = "mkl_quantized_concat_op_test",
size = "small",
srcs = ["mkl_quantized_concat_op_test.cc"],
deps = [
":mkl_concat_op",
":ops_testutil",
":ops_util",
":quantization_utils",
":quantized_ops",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:mkl_array_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test(
name = "quantized_batch_norm_op_test",
size = "small",

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <vector>
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@ -25,12 +24,15 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/concat_lib_cpu.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
using mkldnn::concat;
using mkldnn::stream;
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -78,9 +80,8 @@ class EigenConcatBaseOp : public OpKernel {
const TensorShape& input_shape = input_shapes[0];
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
OP_REQUIRES(c,
(0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
OP_REQUIRES(c, (0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range "
"[",
@ -102,13 +103,12 @@ class EigenConcatBaseOp : public OpKernel {
const auto in = values[i];
const bool in_is_scalar = IsLegacyScalar(input_shapes[i]);
OP_REQUIRES(
c,
(input_shapes[i].dims() == input_dims) ||
(input_is_scalar && in_is_scalar),
c, (input_shapes[i].dims() == input_dims) ||
(input_is_scalar && in_is_scalar),
errors::InvalidArgument(
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
input_shape.DebugString(), " vs. shape[", i,
"] = ", input_shapes[i].DebugString()));
input_shape.DebugString(), " vs. shape[", i, "] = ",
input_shapes[i].DebugString()));
if (in.NumElements() > 0) {
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
@ -226,8 +226,49 @@ class MklConcatOp : public OpKernel {
// format and avoid calling eigen version.
if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
OpInputList input_mins, input_maxes;
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
// MKL-DNN concat does not support input tensors that have different
// ranges. Check if the ranges of the all input tensors are the same.
// If not, forward it to Eigen implementation.
OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins));
OP_REQUIRES(context, (input_mins.size() == N),
errors::InvalidArgument(
"QuantizedConcatOp : Expected mins input list length ",
input_mins.size(), " to equal values length ", N));
OP_REQUIRES_OK(context,
context->input_list("input_maxes", &input_maxes));
OP_REQUIRES(context, (input_maxes.size() == N),
errors::InvalidArgument(
"QuantizedConcatOp : Expected maxes input list length ",
input_maxes.size(), " to equal values length ", N));
float input_min = input_mins[0].flat<float>()(0);
float input_max = input_maxes[0].flat<float>()(0);
const float eps = 1.0e-6;
for (int i = 1; i < N; ++i) {
float min = input_mins[i].flat<float>()(0);
float max = input_maxes[i].flat<float>()(0);
if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) {
invoke_eigen = true;
break;
}
}
}
// Call Eigen library
if (invoke_eigen) {
// MKL-DNN quantized concat does not support input tensors with
// different ranges.
// TODO (mabuzain): Add quantized version of CallEigen() to support
// this case.
OP_REQUIRES(context, (!std::is_same<T, qint8>::value &&
!std::is_same<T, quint8>::value),
errors::Unimplemented("MKL DNN quantized concat does not "
"support input tensors that have "
"different ranges"));
CallEigenVersion(context, input_tensors, mkl_input_shapes);
return;
}
@ -374,10 +415,27 @@ class MklConcatOp : public OpKernel {
std::vector<primitive> net;
net.push_back(concat_op);
stream(stream::kind::eager).submit(net).wait();
// For quantized concat, min and max outputs are also computed.
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
Tensor* output_min = nullptr;
Tensor* output_max = nullptr;
MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
output_min_mkl_shape.SetMklTensor(false);
output_max_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(context, 1, &output_min, {},
output_min_mkl_shape);
AllocateOutputSetMklShape(context, 2, &output_max, {},
output_max_mkl_shape);
// All input tensors should have the same range, just use the
// first one
output_min->flat<float>()(0) = input_mins[0].flat<float>()(0);
output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0);
}
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@ -490,6 +548,20 @@ class MklConcatOp : public OpKernel {
TF_CALL_float(REGISTER_MKL_CPU);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
.Device(DEVICE_CPU)
.TypeConstraint<quint8>("T")
.HostMemory("axis")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("T")
.HostMemory("axis")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
#undef REGISTER_CONCAT_MKL
} // namespace tensorflow

View File

@ -0,0 +1,230 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
using test::graph::Constant;
static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
static const TensorShape dummy_shape({8});
// Helper class for converting MKL tensors to TF tensors and comparing to
// expected values
class ConvMklToTF : public OpsTestBase {
public:
template <typename T>
void ConvertMKL2TF(DataType dtype, const Tensor& first, const Tensor& second,
Tensor& output) {
// Create an MKL to TF conversion node and execute it
TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
.Input(FakeInput(dtype)) // Input
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Attr("T", dtype)
.Attr("_kernel", "MklOp")
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(first.shape(), first.flat<T>());
AddInputFromArray<uint8>(second.shape(), second.flat<uint8>());
TF_ASSERT_OK(RunOpKernel());
output = *GetOutput(0);
}
void TestBody(){};
};
class QuantizedConcatTest : public OpsTestBase {
protected:
QuantizedConcatTest() {}
void TestSmall8Bit(float first_min, float first_max, float second_min,
float second_max);
void TestSecondDim8Bit(float first_min, float first_max, float second_min,
float second_max);
};
TEST_F(QuantizedConcatTest, Small8BitSameRange) {
// Range for both is the same, so impl can use memcpy.
TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f);
}
void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max,
float second_min, float second_max) {
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
.Input(FakeInput(2, DT_QUINT8))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Attr("N", 2)
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("Tidx", DT_INT32)
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
const int first_batch = 2;
const int first_height = 2;
const int first_width = 3;
const int first_depth = 1;
Tensor first_float(DT_FLOAT,
{first_batch, first_height, first_width, first_depth});
test::FillValues<float>(&first_float,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor first_quantized =
FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
const int second_batch = 2;
const int second_height = 2;
const int second_width = 3;
const int second_depth = 1;
Tensor second_float(
DT_FLOAT, {second_batch, second_height, second_width, second_depth});
test::FillValues<float>(&second_float,
{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
Tensor second_quantized =
FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
const int expected_batch = first_batch + second_batch;
Tensor expected_float(
DT_FLOAT, {expected_batch, first_height, first_width, first_depth});
test::FillValues<float>(&expected_float,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
AddInputFromArray<quint8>(first_quantized.shape(),
first_quantized.flat<quint8>());
AddInputFromArray<quint8>(second_quantized.shape(),
second_quantized.flat<quint8>());
AddInputFromArray<int32>(TensorShape({}), {0});
AddInputFromArray<float>(TensorShape({}), {first_min});
AddInputFromArray<float>(TensorShape({}), {second_min});
AddInputFromArray<float>(TensorShape({}), {first_max});
AddInputFromArray<float>(TensorShape({}), {second_max});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
Tensor output_float =
QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
test::ExpectTensorNear<float>(expected_float, output_float, 0.2);
}
TEST_F(QuantizedConcatTest, SecondDim8BitSameRange) {
TestSecondDim8Bit(-10.0f, 150.0f, -10.0f, 150.0f);
}
void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max,
float second_min,
float second_max) {
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
.Input(FakeInput(2, DT_QUINT8))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_FLOAT))
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
.Attr("N", 2)
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("Tidx", DT_INT32)
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
const int first_batch = 2;
const int first_height = 2;
const int first_width = 3;
const int first_depth = 1;
Tensor first_float(DT_FLOAT,
{first_batch, first_height, first_width, first_depth});
test::FillValues<float>(&first_float,
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor first_quantized =
FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
const int second_batch = 2;
const int second_height = 2;
const int second_width = 3;
const int second_depth = 1;
Tensor second_float(
DT_FLOAT, {second_batch, second_height, second_width, second_depth});
test::FillValues<float>(&second_float,
{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
Tensor second_quantized =
FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
const int expected_height = first_height + second_height;
Tensor expected_float(
DT_FLOAT, {first_batch, expected_height, first_width, first_depth});
test::FillValues<float>(&expected_float,
{1, 2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18,
7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24});
AddInputFromArray<quint8>(first_quantized.shape(),
first_quantized.flat<quint8>());
AddInputFromArray<quint8>(second_quantized.shape(),
second_quantized.flat<quint8>());
AddInputFromArray<int32>(TensorShape({}), {1});
AddInputFromArray<float>(TensorShape({}), {first_min});
AddInputFromArray<float>(TensorShape({}), {second_min});
AddInputFromArray<float>(TensorShape({}), {first_max});
AddInputFromArray<float>(TensorShape({}), {second_max});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
const Tensor& output_quantized = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
Tensor output_float =
QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
// Using the same error tolerance as in Eigen QuantizedConcat test
test::ExpectTensorNear<float>(expected_float, output_float, 1.0);
}
} // namespace tensorflow

View File

@ -17,13 +17,13 @@ limitations under the License.
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/concat_lib_cpu.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
@ -135,8 +135,8 @@ class QuantizedConcatOp : public OpKernel {
context, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
errors::InvalidArgument(
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
input_shape.DebugString(), " vs. shape[", i,
"] = ", in.shape().DebugString()));
input_shape.DebugString(), " vs. shape[", i, "] = ",
in.shape().DebugString()));
for (int j = 0; j < input_dims; ++j) {
if (j == concat_dim) {
continue;
@ -145,8 +145,8 @@ class QuantizedConcatOp : public OpKernel {
context, in.dim_size(j) == input_shape.dim_size(j),
errors::InvalidArgument(
"ConcatOp : Dimensions of inputs should match: shape[0] = ",
input_shape.DebugString(), " vs. shape[", i,
"] = ", in.shape().DebugString()));
input_shape.DebugString(), " vs. shape[", i, "] = ",
in.shape().DebugString()));
}
if (in.NumElements() > 0) {
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
@ -184,9 +184,8 @@ class QuantizedConcatOp : public OpKernel {
const int input_dims = values[0].dims();
const TensorShape& input_shape = values[0].shape();
OP_REQUIRES(
context,
(0 <= concat_dim && concat_dim < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
context, (0 <= concat_dim && concat_dim < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range [", 0,
", ", input_dims, "), but got ", concat_dim));
@ -246,4 +245,16 @@ REGISTER_QUANTIZED_CONCAT(qint32);
#undef REGISTER_QUANTIZED_CONCAT
#ifdef INTEL_MKL
#define REGISTER_QUANTIZED_CONCATV2(type) \
REGISTER_KERNEL_BUILDER(Name("QuantizedConcatV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("axis"), \
QuantizedConcatOp<type>)
REGISTER_QUANTIZED_CONCATV2(quint8);
REGISTER_QUANTIZED_CONCATV2(qint32);
#endif
} // namespace tensorflow

View File

@ -0,0 +1,92 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
// This file contains the registration of MKL-DNN array ops.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/strided_slice_op.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using shape_inference::UnchangedShape;
// Adding QuantizedConcatV2 op to be able to replace it by
// _MklQuantizedConcatV2 in the graph rewrite.
REGISTER_OP("QuantizedConcatV2")
.Input("values: N * T")
.Input("axis: Tidx")
.Input("input_mins: N * float32")
.Input("input_maxes: N * float32")
.Output("output: T")
.Output("output_min: float")
.Output("output_max: float")
.Attr("N: int >= 2")
.Attr("T: type")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
const int n = (c->num_inputs() - 1) / 3;
TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
ShapeHandle unused;
for (int i = n + 1; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();
});
REGISTER_OP("_MklQuantizedConcatV2")
.Input("values: N * T")
.Input("axis: Tidx")
.Input("input_mins: N * float32")
.Input("input_maxes: N * float32")
.Input("mkl_values: N * uint8")
.Input("mkl_axis: uint8")
.Input("mkl_input_mins: N * uint8")
.Input("mkl_input_maxes: N * uint8")
.Output("output: T")
.Output("output_min: float")
.Output("output_max: float")
.Output("mkl_output: uint8")
.Output("mkl_output_min: uint8")
.Output("mkl_output_max: uint8")
.Attr("N: int >= 2")
.Attr("T: type")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
const int n = (c->num_inputs() / 2 - 1) / 3;
TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
ShapeHandle unused;
for (int i = n + 1; i < c->num_inputs() / 2; ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
}
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();
});
}
#endif