Merge pull request #25533 from Intel-tensorflow:quantized_concat_part_1_new
PiperOrigin-RevId: 232718138
This commit is contained in:
commit
86950c2c44
@ -1144,6 +1144,13 @@ tf_gen_op_libs(
|
||||
deps = [":protos_all_cc"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"mkl_array_ops",
|
||||
],
|
||||
deps = [":protos_all_cc"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"audio_ops",
|
||||
@ -1283,7 +1290,10 @@ cc_library(
|
||||
":training_ops_op_lib",
|
||||
":user_ops_op_lib",
|
||||
":word2vec_ops",
|
||||
] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(),
|
||||
] + if_mkl([
|
||||
":mkl_array_ops_op_lib",
|
||||
":mkl_nn_ops_op_lib",
|
||||
]) + tf_additional_cloud_op_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -4527,7 +4537,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
] + if_mkl([":mkl_array_ops_op_lib"]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -24,9 +24,9 @@ const std::unordered_set<std::string>* GetExcludedOps() {
|
||||
"GcsConfigureBlockCache", "GcsConfigureCredentials",
|
||||
#ifdef INTEL_MKL
|
||||
// QuantizedFusedOps for Intel CPU
|
||||
"QuantizedConv2DAndRequantize", "QuantizedConv2DWithBias",
|
||||
"QuantizedConv2DWithBiasAndRequantize", "QuantizedConv2DAndRelu",
|
||||
"QuantizedConv2DAndReluAndRequantize",
|
||||
"QuantizedConcatV2", "QuantizedConv2DAndRequantize",
|
||||
"QuantizedConv2DWithBias", "QuantizedConv2DWithBiasAndRequantize",
|
||||
"QuantizedConv2DAndRelu", "QuantizedConv2DAndReluAndRequantize",
|
||||
"QuantizedConv2DWithBiasAndRelu",
|
||||
"QuantizedConv2DWithBiasAndReluAndRequantize",
|
||||
"QuantizedConv2DWithBiasSumAndRelu",
|
||||
|
@ -1303,6 +1303,12 @@ Status ConcatV2Shape(InferenceContext* c) {
|
||||
c->num_inputs() - 1 /* dim_index */);
|
||||
}
|
||||
|
||||
Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
|
||||
return ConcatShapeHelper(c, 0 /* start_value_index */,
|
||||
num_inputs_to_concat /* end_value_index */,
|
||||
num_inputs_to_concat /* dim_index */);
|
||||
}
|
||||
|
||||
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
ShapeHandle shape_x,
|
||||
ShapeHandle shape_y,
|
||||
|
@ -279,6 +279,8 @@ Status ConcatShape(shape_inference::InferenceContext* c,
|
||||
// Shape function for concat operations.
|
||||
Status ConcatV2Shape(shape_inference::InferenceContext* c);
|
||||
|
||||
Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
|
||||
|
||||
// Shape function for binary operators that broadcast their inputs
|
||||
// and with output to output_index.
|
||||
// Note: out cannot be NULL.
|
||||
|
@ -6554,6 +6554,30 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test_mkl(
|
||||
name = "mkl_quantized_concat_op_test",
|
||||
size = "small",
|
||||
srcs = ["mkl_quantized_concat_op_test.cc"],
|
||||
deps = [
|
||||
":mkl_concat_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
":quantization_utils",
|
||||
":quantized_ops",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:mkl_array_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "quantized_batch_norm_op_test",
|
||||
size = "small",
|
||||
|
@ -25,12 +25,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/kernels/concat_lib_cpu.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::concat;
|
||||
using mkldnn::stream;
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -226,8 +228,50 @@ class MklConcatOp : public OpKernel {
|
||||
// format and avoid calling eigen version.
|
||||
if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
|
||||
|
||||
OpInputList input_mins, input_maxes;
|
||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
|
||||
// MKL-DNN concat does not support input tensors that have different
|
||||
// ranges. Check if the ranges of the all input tensors are the same.
|
||||
// If not, forward it to Eigen implementation.
|
||||
|
||||
OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins));
|
||||
OP_REQUIRES(context, (input_mins.size() == N),
|
||||
errors::InvalidArgument(
|
||||
"QuantizedConcatOp : Expected mins input list length ",
|
||||
input_mins.size(), " to equal values length ", N));
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("input_maxes", &input_maxes));
|
||||
OP_REQUIRES(context, (input_maxes.size() == N),
|
||||
errors::InvalidArgument(
|
||||
"QuantizedConcatOp : Expected maxes input list length ",
|
||||
input_maxes.size(), " to equal values length ", N));
|
||||
float input_min = input_mins[0].flat<float>()(0);
|
||||
float input_max = input_maxes[0].flat<float>()(0);
|
||||
const float eps = 1.0e-6;
|
||||
for (int i = 1; i < N; ++i) {
|
||||
float min = input_mins[i].flat<float>()(0);
|
||||
float max = input_maxes[i].flat<float>()(0);
|
||||
|
||||
if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) {
|
||||
invoke_eigen = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Call Eigen library
|
||||
if (invoke_eigen) {
|
||||
// MKL-DNN quantized concat does not support input tensors with
|
||||
// different ranges.
|
||||
// TODO (mabuzain): Add quantized version of CallEigen() to support
|
||||
// this case.
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
(!std::is_same<T, qint8>::value && !std::is_same<T, quint8>::value),
|
||||
errors::Unimplemented("MKL DNN quantized concat does not "
|
||||
"support input tensors that have "
|
||||
"different ranges"));
|
||||
CallEigenVersion(context, input_tensors, mkl_input_shapes);
|
||||
return;
|
||||
}
|
||||
@ -374,6 +418,23 @@ class MklConcatOp : public OpKernel {
|
||||
std::vector<primitive> net;
|
||||
net.push_back(concat_op);
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
|
||||
// For quantized concat, min and max outputs are also computed.
|
||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
|
||||
Tensor* output_min = nullptr;
|
||||
Tensor* output_max = nullptr;
|
||||
MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
|
||||
output_min_mkl_shape.SetMklTensor(false);
|
||||
output_max_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 1, &output_min, {},
|
||||
output_min_mkl_shape);
|
||||
AllocateOutputSetMklShape(context, 2, &output_max, {},
|
||||
output_max_mkl_shape);
|
||||
// All input tensors should have the same range, just use the
|
||||
// first one
|
||||
output_min->flat<float>()(0) = input_mins[0].flat<float>()(0);
|
||||
output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0);
|
||||
}
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
@ -490,6 +551,20 @@ class MklConcatOp : public OpKernel {
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("T")
|
||||
.HostMemory("axis")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<qint8>("T")
|
||||
.HostMemory("axis")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
|
||||
|
||||
#undef REGISTER_CONCAT_MKL
|
||||
} // namespace tensorflow
|
||||
|
||||
|
234
tensorflow/core/kernels/mkl_quantized_concat_op_test.cc
Normal file
234
tensorflow/core/kernels/mkl_quantized_concat_op_test.cc
Normal file
@ -0,0 +1,234 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using test::graph::Constant;
|
||||
|
||||
static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
static const TensorShape dummy_shape({8});
|
||||
|
||||
// Helper class for converting MKL tensors to TF tensors and comparing to
|
||||
// expected values
|
||||
|
||||
class ConvMklToTF : public OpsTestBase {
|
||||
public:
|
||||
template <typename T>
|
||||
void ConvertMKL2TF(DataType dtype, const Tensor& first, const Tensor& second,
|
||||
Tensor& output) {
|
||||
// Create an MKL to TF conversion node and execute it
|
||||
TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
|
||||
.Input(FakeInput(dtype)) // Input
|
||||
.Input(FakeInput(DT_UINT8)) // MKL second tensor
|
||||
.Attr("T", dtype)
|
||||
.Attr("_kernel", "MklOp")
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
AddInputFromArray<T>(first.shape(), first.flat<T>());
|
||||
AddInputFromArray<uint8>(second.shape(), second.flat<uint8>());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
output = *GetOutput(0);
|
||||
}
|
||||
void TestBody(){};
|
||||
};
|
||||
|
||||
class QuantizedConcatTest : public OpsTestBase {
|
||||
protected:
|
||||
QuantizedConcatTest() {}
|
||||
|
||||
void TestSmall8Bit(float first_min, float first_max, float second_min,
|
||||
float second_max);
|
||||
void TestSecondDim8Bit(float first_min, float first_max, float second_min,
|
||||
float second_max);
|
||||
};
|
||||
|
||||
TEST_F(QuantizedConcatTest, Small8BitSameRange) {
|
||||
// Range for both is the same, so impl can use memcpy.
|
||||
TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f);
|
||||
}
|
||||
|
||||
void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max,
|
||||
float second_min, float second_max) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
|
||||
.Input(FakeInput(2, DT_QUINT8))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(2, DT_FLOAT))
|
||||
.Input(FakeInput(2, DT_FLOAT))
|
||||
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
|
||||
.Input(FakeInput(DT_UINT8)) // MKL second tensor
|
||||
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
|
||||
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
|
||||
.Attr("N", 2)
|
||||
.Attr("T", DataTypeToEnum<quint8>::v())
|
||||
.Attr("Tidx", DT_INT32)
|
||||
.Attr("_kernel", "QuantizedMklOp")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
const int first_batch = 2;
|
||||
const int first_height = 2;
|
||||
const int first_width = 3;
|
||||
const int first_depth = 1;
|
||||
Tensor first_float(DT_FLOAT,
|
||||
{first_batch, first_height, first_width, first_depth});
|
||||
test::FillValues<float>(&first_float,
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
Tensor first_quantized =
|
||||
FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
|
||||
|
||||
const int second_batch = 2;
|
||||
const int second_height = 2;
|
||||
const int second_width = 3;
|
||||
const int second_depth = 1;
|
||||
Tensor second_float(
|
||||
DT_FLOAT, {second_batch, second_height, second_width, second_depth});
|
||||
test::FillValues<float>(&second_float,
|
||||
{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
Tensor second_quantized =
|
||||
FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
|
||||
|
||||
const int expected_batch = first_batch + second_batch;
|
||||
Tensor expected_float(
|
||||
DT_FLOAT, {expected_batch, first_height, first_width, first_depth});
|
||||
test::FillValues<float>(&expected_float,
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
|
||||
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
|
||||
AddInputFromArray<quint8>(first_quantized.shape(),
|
||||
first_quantized.flat<quint8>());
|
||||
AddInputFromArray<quint8>(second_quantized.shape(),
|
||||
second_quantized.flat<quint8>());
|
||||
AddInputFromArray<int32>(TensorShape({}), {0});
|
||||
AddInputFromArray<float>(TensorShape({}), {first_min});
|
||||
AddInputFromArray<float>(TensorShape({}), {second_min});
|
||||
AddInputFromArray<float>(TensorShape({}), {first_max});
|
||||
AddInputFromArray<float>(TensorShape({}), {second_max});
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
const Tensor& output_quantized = *GetOutput(0);
|
||||
const float output_min = GetOutput(1)->flat<float>()(0);
|
||||
const float output_max = GetOutput(2)->flat<float>()(0);
|
||||
Tensor output_float =
|
||||
QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
|
||||
test::ExpectTensorNear<float>(expected_float, output_float, 0.2);
|
||||
}
|
||||
|
||||
TEST_F(QuantizedConcatTest, SecondDim8BitSameRange) {
|
||||
TestSecondDim8Bit(-10.0f, 150.0f, -10.0f, 150.0f);
|
||||
}
|
||||
|
||||
void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max,
|
||||
float second_min,
|
||||
float second_max) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
|
||||
.Input(FakeInput(2, DT_QUINT8))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(2, DT_FLOAT))
|
||||
.Input(FakeInput(2, DT_FLOAT))
|
||||
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
|
||||
.Input(FakeInput(DT_UINT8)) // MKL second tensor
|
||||
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
|
||||
.Input(FakeInput(2, DT_UINT8)) // MKL second tensor
|
||||
.Attr("N", 2)
|
||||
.Attr("T", DataTypeToEnum<quint8>::v())
|
||||
.Attr("Tidx", DT_INT32)
|
||||
.Attr("_kernel", "QuantizedMklOp")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
const int first_batch = 2;
|
||||
const int first_height = 2;
|
||||
const int first_width = 3;
|
||||
const int first_depth = 1;
|
||||
Tensor first_float(DT_FLOAT,
|
||||
{first_batch, first_height, first_width, first_depth});
|
||||
test::FillValues<float>(&first_float,
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
Tensor first_quantized =
|
||||
FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
|
||||
|
||||
const int second_batch = 2;
|
||||
const int second_height = 2;
|
||||
const int second_width = 3;
|
||||
const int second_depth = 1;
|
||||
|
||||
Tensor second_float(
|
||||
DT_FLOAT, {second_batch, second_height, second_width, second_depth});
|
||||
test::FillValues<float>(&second_float,
|
||||
{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||
Tensor second_quantized =
|
||||
FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
|
||||
|
||||
const int expected_height = first_height + second_height;
|
||||
Tensor expected_float(
|
||||
DT_FLOAT, {first_batch, expected_height, first_width, first_depth});
|
||||
test::FillValues<float>(&expected_float,
|
||||
{1, 2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18,
|
||||
7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24});
|
||||
|
||||
AddInputFromArray<quint8>(first_quantized.shape(),
|
||||
first_quantized.flat<quint8>());
|
||||
AddInputFromArray<quint8>(second_quantized.shape(),
|
||||
second_quantized.flat<quint8>());
|
||||
AddInputFromArray<int32>(TensorShape({}), {1});
|
||||
AddInputFromArray<float>(TensorShape({}), {first_min});
|
||||
AddInputFromArray<float>(TensorShape({}), {second_min});
|
||||
AddInputFromArray<float>(TensorShape({}), {first_max});
|
||||
AddInputFromArray<float>(TensorShape({}), {second_max});
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
const Tensor& output_quantized = *GetOutput(0);
|
||||
const float output_min = GetOutput(1)->flat<float>()(0);
|
||||
const float output_max = GetOutput(2)->flat<float>()(0);
|
||||
Tensor output_float =
|
||||
QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
|
||||
// Using the same error tolerance as in Eigen QuantizedConcat test
|
||||
test::ExpectTensorNear<float>(expected_float, output_float, 1.0);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL && ENABLE_MKL
|
@ -246,4 +246,16 @@ REGISTER_QUANTIZED_CONCAT(qint32);
|
||||
|
||||
#undef REGISTER_QUANTIZED_CONCAT
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
#define REGISTER_QUANTIZED_CONCATV2(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantizedConcatV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("axis"), \
|
||||
QuantizedConcatOp<type>)
|
||||
|
||||
REGISTER_QUANTIZED_CONCATV2(quint8);
|
||||
REGISTER_QUANTIZED_CONCATV2(qint32);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
92
tensorflow/core/ops/mkl_array_ops.cc
Normal file
92
tensorflow/core/ops/mkl_array_ops.cc
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
// This file contains the registration of MKL-DNN array ops.
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/strided_slice_op.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
using shape_inference::UnchangedShape;
|
||||
|
||||
// Adding QuantizedConcatV2 op to be able to replace it by
|
||||
// _MklQuantizedConcatV2 in the graph rewrite.
|
||||
REGISTER_OP("QuantizedConcatV2")
|
||||
.Input("values: N * T")
|
||||
.Input("axis: Tidx")
|
||||
.Input("input_mins: N * float32")
|
||||
.Input("input_maxes: N * float32")
|
||||
.Output("output: T")
|
||||
.Output("output_min: float")
|
||||
.Output("output_max: float")
|
||||
.Attr("N: int >= 2")
|
||||
.Attr("T: type")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
const int n = (c->num_inputs() - 1) / 3;
|
||||
TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
|
||||
ShapeHandle unused;
|
||||
for (int i = n + 1; i < c->num_inputs(); ++i) {
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
|
||||
}
|
||||
c->set_output(1, c->Scalar());
|
||||
c->set_output(2, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("_MklQuantizedConcatV2")
|
||||
.Input("values: N * T")
|
||||
.Input("axis: Tidx")
|
||||
.Input("input_mins: N * float32")
|
||||
.Input("input_maxes: N * float32")
|
||||
.Input("mkl_values: N * uint8")
|
||||
.Input("mkl_axis: uint8")
|
||||
.Input("mkl_input_mins: N * uint8")
|
||||
.Input("mkl_input_maxes: N * uint8")
|
||||
.Output("output: T")
|
||||
.Output("output_min: float")
|
||||
.Output("output_max: float")
|
||||
.Output("mkl_output: uint8")
|
||||
.Output("mkl_output_min: uint8")
|
||||
.Output("mkl_output_max: uint8")
|
||||
.Attr("N: int >= 2")
|
||||
.Attr("T: type")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
const int n = (c->num_inputs() / 2 - 1) / 3;
|
||||
TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
|
||||
ShapeHandle unused;
|
||||
for (int i = n + 1; i < c->num_inputs() / 2; ++i) {
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
|
||||
}
|
||||
c->set_output(1, c->Scalar());
|
||||
c->set_output(2, c->Scalar());
|
||||
return Status::OK();
|
||||
});
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
Loading…
x
Reference in New Issue
Block a user