diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4cf37a67585..e22eca36962 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1144,6 +1144,13 @@ tf_gen_op_libs( deps = [":protos_all_cc"], ) +tf_gen_op_libs( + op_lib_names = [ + "mkl_array_ops", + ], + deps = [":protos_all_cc"], +) + tf_gen_op_libs( op_lib_names = [ "audio_ops", @@ -1283,7 +1290,10 @@ cc_library( ":training_ops_op_lib", ":user_ops_op_lib", ":word2vec_ops", - ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(), + ] + if_mkl([ + ":mkl_array_ops_op_lib", + ":mkl_nn_ops_op_lib", + ]) + tf_additional_cloud_op_deps(), alwayslink = 1, ) @@ -4527,7 +4537,7 @@ tf_cc_test( "//tensorflow/cc:scope", "//tensorflow/core/kernels:cwise_op", "//third_party/eigen3", - ], + ] + if_mkl([":mkl_array_ops_op_lib"]), ) tf_cc_test( diff --git a/tensorflow/core/api_def/excluded_ops.cc b/tensorflow/core/api_def/excluded_ops.cc index 02026e94abc..65d2102ac80 100644 --- a/tensorflow/core/api_def/excluded_ops.cc +++ b/tensorflow/core/api_def/excluded_ops.cc @@ -24,9 +24,9 @@ const std::unordered_set* GetExcludedOps() { "GcsConfigureBlockCache", "GcsConfigureCredentials", #ifdef INTEL_MKL // QuantizedFusedOps for Intel CPU - "QuantizedConv2DAndRequantize", "QuantizedConv2DWithBias", - "QuantizedConv2DWithBiasAndRequantize", "QuantizedConv2DAndRelu", - "QuantizedConv2DAndReluAndRequantize", + "QuantizedConcatV2", "QuantizedConv2DAndRequantize", + "QuantizedConv2DWithBias", "QuantizedConv2DWithBiasAndRequantize", + "QuantizedConv2DAndRelu", "QuantizedConv2DAndReluAndRequantize", "QuantizedConv2DWithBiasAndRelu", "QuantizedConv2DWithBiasAndReluAndRequantize", "QuantizedConv2DWithBiasSumAndRelu", diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 83bc95065e0..5c974a76aca 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1303,6 +1303,12 @@ Status ConcatV2Shape(InferenceContext* c) { c->num_inputs() - 1 /* dim_index */); } +Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) { + return ConcatShapeHelper(c, 0 /* start_value_index */, + num_inputs_to_concat /* end_value_index */, + num_inputs_to_concat /* dim_index */); +} + Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, ShapeHandle shape_x, ShapeHandle shape_y, diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 14b9688bdc5..d421844ee60 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -279,6 +279,8 @@ Status ConcatShape(shape_inference::InferenceContext* c, // Shape function for concat operations. Status ConcatV2Shape(shape_inference::InferenceContext* c); +Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat); + // Shape function for binary operators that broadcast their inputs // and with output to output_index. // Note: out cannot be NULL. diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ae3bb94e9e1..be4f29bf2e4 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6554,6 +6554,30 @@ tf_cc_test( ], ) +tf_cc_test_mkl( + name = "mkl_quantized_concat_op_test", + size = "small", + srcs = ["mkl_quantized_concat_op_test.cc"], + deps = [ + ":mkl_concat_op", + ":ops_testutil", + ":ops_util", + ":quantization_utils", + ":quantized_ops", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:mkl_array_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "quantized_batch_norm_op_test", size = "small", diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 3a5c87485cc..b95bbca6b59 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -25,12 +25,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/concat_lib_cpu.h" +#include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/mkl_util.h" using mkldnn::concat; using mkldnn::stream; -#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -226,8 +228,50 @@ class MklConcatOp : public OpKernel { // format and avoid calling eigen version. if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true; + OpInputList input_mins, input_maxes; + if (std::is_same::value || std::is_same::value) { + // MKL-DNN concat does not support input tensors that have different + // ranges. Check if the ranges of the all input tensors are the same. + // If not, forward it to Eigen implementation. + + OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins)); + OP_REQUIRES(context, (input_mins.size() == N), + errors::InvalidArgument( + "QuantizedConcatOp : Expected mins input list length ", + input_mins.size(), " to equal values length ", N)); + + OP_REQUIRES_OK(context, + context->input_list("input_maxes", &input_maxes)); + OP_REQUIRES(context, (input_maxes.size() == N), + errors::InvalidArgument( + "QuantizedConcatOp : Expected maxes input list length ", + input_maxes.size(), " to equal values length ", N)); + float input_min = input_mins[0].flat()(0); + float input_max = input_maxes[0].flat()(0); + const float eps = 1.0e-6; + for (int i = 1; i < N; ++i) { + float min = input_mins[i].flat()(0); + float max = input_maxes[i].flat()(0); + + if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) { + invoke_eigen = true; + break; + } + } + } + // Call Eigen library if (invoke_eigen) { + // MKL-DNN quantized concat does not support input tensors with + // different ranges. + // TODO (mabuzain): Add quantized version of CallEigen() to support + // this case. + OP_REQUIRES( + context, + (!std::is_same::value && !std::is_same::value), + errors::Unimplemented("MKL DNN quantized concat does not " + "support input tensors that have " + "different ranges")); CallEigenVersion(context, input_tensors, mkl_input_shapes); return; } @@ -374,6 +418,23 @@ class MklConcatOp : public OpKernel { std::vector net; net.push_back(concat_op); stream(stream::kind::eager).submit(net).wait(); + + // For quantized concat, min and max outputs are also computed. + if (std::is_same::value || std::is_same::value) { + Tensor* output_min = nullptr; + Tensor* output_max = nullptr; + MklDnnShape output_min_mkl_shape, output_max_mkl_shape; + output_min_mkl_shape.SetMklTensor(false); + output_max_mkl_shape.SetMklTensor(false); + AllocateOutputSetMklShape(context, 1, &output_min, {}, + output_min_mkl_shape); + AllocateOutputSetMklShape(context, 2, &output_max, {}, + output_max_mkl_shape); + // All input tensors should have the same range, just use the + // first one + output_min->flat()(0) = input_mins[0].flat()(0); + output_max->flat()(0) = input_maxes[0].flat()(0); + } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -490,6 +551,20 @@ class MklConcatOp : public OpKernel { TF_CALL_float(REGISTER_MKL_CPU); +REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .HostMemory("axis") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklConcatOp) + +REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .HostMemory("axis") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklConcatOp) + #undef REGISTER_CONCAT_MKL } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc b/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc new file mode 100644 index 00000000000..fc68480bbe8 --- /dev/null +++ b/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc @@ -0,0 +1,234 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(INTEL_MKL) && defined(ENABLE_MKL) + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +using test::graph::Constant; + +static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0}; +static const TensorShape dummy_shape({8}); + +// Helper class for converting MKL tensors to TF tensors and comparing to +// expected values + +class ConvMklToTF : public OpsTestBase { + public: + template + void ConvertMKL2TF(DataType dtype, const Tensor& first, const Tensor& second, + Tensor& output) { + // Create an MKL to TF conversion node and execute it + TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("T", dtype) + .Attr("_kernel", "MklOp") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(first.shape(), first.flat()); + AddInputFromArray(second.shape(), second.flat()); + TF_ASSERT_OK(RunOpKernel()); + + output = *GetOutput(0); + } + void TestBody(){}; +}; + +class QuantizedConcatTest : public OpsTestBase { + protected: + QuantizedConcatTest() {} + + void TestSmall8Bit(float first_min, float first_max, float second_min, + float second_max); + void TestSecondDim8Bit(float first_min, float first_max, float second_min, + float second_max); +}; + +TEST_F(QuantizedConcatTest, Small8BitSameRange) { + // Range for both is the same, so impl can use memcpy. + TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f); +} + +void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max, + float second_min, float second_max) { + TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2") + .Input(FakeInput(2, DT_QUINT8)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(2, DT_FLOAT)) + .Input(FakeInput(2, DT_FLOAT)) + .Input(FakeInput(2, DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(2, DT_UINT8)) // MKL second tensor + .Input(FakeInput(2, DT_UINT8)) // MKL second tensor + .Attr("N", 2) + .Attr("T", DataTypeToEnum::v()) + .Attr("Tidx", DT_INT32) + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + const int first_batch = 2; + const int first_height = 2; + const int first_width = 3; + const int first_depth = 1; + Tensor first_float(DT_FLOAT, + {first_batch, first_height, first_width, first_depth}); + test::FillValues(&first_float, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor first_quantized = + FloatTensorToQuantized(first_float, first_min, first_max); + + const int second_batch = 2; + const int second_height = 2; + const int second_width = 3; + const int second_depth = 1; + Tensor second_float( + DT_FLOAT, {second_batch, second_height, second_width, second_depth}); + test::FillValues(&second_float, + {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + Tensor second_quantized = + FloatTensorToQuantized(second_float, second_min, second_max); + + const int expected_batch = first_batch + second_batch; + Tensor expected_float( + DT_FLOAT, {expected_batch, first_height, first_width, first_depth}); + test::FillValues(&expected_float, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + + AddInputFromArray(first_quantized.shape(), + first_quantized.flat()); + AddInputFromArray(second_quantized.shape(), + second_quantized.flat()); + AddInputFromArray(TensorShape({}), {0}); + AddInputFromArray(TensorShape({}), {first_min}); + AddInputFromArray(TensorShape({}), {second_min}); + AddInputFromArray(TensorShape({}), {first_max}); + AddInputFromArray(TensorShape({}), {second_max}); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + const Tensor& output_quantized = *GetOutput(0); + const float output_min = GetOutput(1)->flat()(0); + const float output_max = GetOutput(2)->flat()(0); + Tensor output_float = + QuantizedTensorToFloat(output_quantized, output_min, output_max); + test::ExpectTensorNear(expected_float, output_float, 0.2); +} + +TEST_F(QuantizedConcatTest, SecondDim8BitSameRange) { + TestSecondDim8Bit(-10.0f, 150.0f, -10.0f, 150.0f); +} + +void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max, + float second_min, + float second_max) { + TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2") + .Input(FakeInput(2, DT_QUINT8)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(2, DT_FLOAT)) + .Input(FakeInput(2, DT_FLOAT)) + .Input(FakeInput(2, DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(2, DT_UINT8)) // MKL second tensor + .Input(FakeInput(2, DT_UINT8)) // MKL second tensor + .Attr("N", 2) + .Attr("T", DataTypeToEnum::v()) + .Attr("Tidx", DT_INT32) + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + const int first_batch = 2; + const int first_height = 2; + const int first_width = 3; + const int first_depth = 1; + Tensor first_float(DT_FLOAT, + {first_batch, first_height, first_width, first_depth}); + test::FillValues(&first_float, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor first_quantized = + FloatTensorToQuantized(first_float, first_min, first_max); + + const int second_batch = 2; + const int second_height = 2; + const int second_width = 3; + const int second_depth = 1; + + Tensor second_float( + DT_FLOAT, {second_batch, second_height, second_width, second_depth}); + test::FillValues(&second_float, + {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); + Tensor second_quantized = + FloatTensorToQuantized(second_float, second_min, second_max); + + const int expected_height = first_height + second_height; + Tensor expected_float( + DT_FLOAT, {first_batch, expected_height, first_width, first_depth}); + test::FillValues(&expected_float, + {1, 2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18, + 7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24}); + + AddInputFromArray(first_quantized.shape(), + first_quantized.flat()); + AddInputFromArray(second_quantized.shape(), + second_quantized.flat()); + AddInputFromArray(TensorShape({}), {1}); + AddInputFromArray(TensorShape({}), {first_min}); + AddInputFromArray(TensorShape({}), {second_min}); + AddInputFromArray(TensorShape({}), {first_max}); + AddInputFromArray(TensorShape({}), {second_max}); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + const Tensor& output_quantized = *GetOutput(0); + const float output_min = GetOutput(1)->flat()(0); + const float output_max = GetOutput(2)->flat()(0); + Tensor output_float = + QuantizedTensorToFloat(output_quantized, output_min, output_max); + // Using the same error tolerance as in Eigen QuantizedConcat test + test::ExpectTensorNear(expected_float, output_float, 1.0); +} + +} // namespace tensorflow + +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/kernels/quantized_concat_op.cc b/tensorflow/core/kernels/quantized_concat_op.cc index b03ac8e87da..ff4e7be1622 100644 --- a/tensorflow/core/kernels/quantized_concat_op.cc +++ b/tensorflow/core/kernels/quantized_concat_op.cc @@ -246,4 +246,16 @@ REGISTER_QUANTIZED_CONCAT(qint32); #undef REGISTER_QUANTIZED_CONCAT +#ifdef INTEL_MKL +#define REGISTER_QUANTIZED_CONCATV2(type) \ + REGISTER_KERNEL_BUILDER(Name("QuantizedConcatV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("axis"), \ + QuantizedConcatOp) + +REGISTER_QUANTIZED_CONCATV2(quint8); +REGISTER_QUANTIZED_CONCATV2(qint32); +#endif + } // namespace tensorflow diff --git a/tensorflow/core/ops/mkl_array_ops.cc b/tensorflow/core/ops/mkl_array_ops.cc new file mode 100644 index 00000000000..e7ad3be6112 --- /dev/null +++ b/tensorflow/core/ops/mkl_array_ops.cc @@ -0,0 +1,92 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +// This file contains the registration of MKL-DNN array ops. + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/util/mirror_pad_mode.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/strided_slice_op.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; +using shape_inference::UnchangedShape; + +// Adding QuantizedConcatV2 op to be able to replace it by +// _MklQuantizedConcatV2 in the graph rewrite. +REGISTER_OP("QuantizedConcatV2") + .Input("values: N * T") + .Input("axis: Tidx") + .Input("input_mins: N * float32") + .Input("input_maxes: N * float32") + .Output("output: T") + .Output("output_min: float") + .Output("output_max: float") + .Attr("N: int >= 2") + .Attr("T: type") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + const int n = (c->num_inputs() - 1) / 3; + TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n)); + ShapeHandle unused; + for (int i = n + 1; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); + } + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("_MklQuantizedConcatV2") + .Input("values: N * T") + .Input("axis: Tidx") + .Input("input_mins: N * float32") + .Input("input_maxes: N * float32") + .Input("mkl_values: N * uint8") + .Input("mkl_axis: uint8") + .Input("mkl_input_mins: N * uint8") + .Input("mkl_input_maxes: N * uint8") + .Output("output: T") + .Output("output_min: float") + .Output("output_max: float") + .Output("mkl_output: uint8") + .Output("mkl_output_min: uint8") + .Output("mkl_output_max: uint8") + .Attr("N: int >= 2") + .Attr("T: type") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + const int n = (c->num_inputs() / 2 - 1) / 3; + TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n)); + ShapeHandle unused; + for (int i = n + 1; i < c->num_inputs() / 2; ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); + } + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); +} // namespace tensorflow + +#endif