Refactored Quantized Concat
This commit is contained in:
parent
b5d67b7c69
commit
5a30ce41e4
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -30,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
using mkldnn::concat;
|
||||
using mkldnn::stream;
|
||||
@ -47,6 +47,45 @@ enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
|
||||
// --------------------------------------------------------------------------
|
||||
// Eigen Concat Op
|
||||
// --------------------------------------------------------------------------
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct RequantizeCopier {
|
||||
RequantizeCopier(
|
||||
const std::vector<std::pair<float, float>>* input_min_and_max,
|
||||
float output_min, float output_max)
|
||||
: output_min(output_min), output_max(output_max) {
|
||||
DCHECK(input_min_and_max);
|
||||
this->input_min_and_max = input_min_and_max;
|
||||
}
|
||||
|
||||
inline void Copy(T* dst, const T* src, int input_index, size_t n) {
|
||||
const float input_min = (*input_min_and_max)[input_index].first;
|
||||
const float input_max = (*input_min_and_max)[input_index].second;
|
||||
if (input_min == output_min && input_max == output_max) {
|
||||
DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v()));
|
||||
memcpy(dst, src, n * sizeof(T));
|
||||
} else {
|
||||
Eigen::array<Eigen::DenseIndex, 1> dims;
|
||||
dims[0] = n;
|
||||
typename TTypes<T, 1>::UnalignedConstTensor input_array(src, dims);
|
||||
typename TTypes<T, 1>::UnalignedTensor output_array(dst, dims);
|
||||
|
||||
QuantizedToFloatStruct<T> q2f(input_min, input_max);
|
||||
auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
|
||||
FloatToQuantizedStruct<T> f2q(output_min, output_max);
|
||||
// RequantizeCopier::Copy is called from within a shard of computation, so
|
||||
// don't use the threadpool device here, simply assign with default CPU
|
||||
// device.
|
||||
output_array = QUANTIZE_WITH_EIGEN(input_float, f2q, T);
|
||||
}
|
||||
}
|
||||
|
||||
float output_min;
|
||||
float output_max;
|
||||
const std::vector<std::pair<float, float>>* input_min_and_max;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T, AxisArgumentName AxisArgName>
|
||||
class EigenConcatBaseOp : public OpKernel {
|
||||
public:
|
||||
@ -55,12 +94,45 @@ class EigenConcatBaseOp : public OpKernel {
|
||||
|
||||
explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
void CalculateInputAndOutputRange(
|
||||
const OpInputList& input_mins, const OpInputList& input_maxes,
|
||||
const size_t N,
|
||||
std::vector<std::pair<float, float>>* input_mins_and_maxes,
|
||||
float* output_min, float* output_max) {
|
||||
input_mins_and_maxes->reserve(N);
|
||||
float overall_min = std::numeric_limits<float>::max();
|
||||
float overall_max = std::numeric_limits<float>::lowest();
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const float input_min = input_mins[i].flat<float>()(0);
|
||||
const float input_max = input_maxes[i].flat<float>()(0);
|
||||
input_mins_and_maxes->emplace_back(input_min, input_max);
|
||||
overall_min = std::min(overall_min, input_min);
|
||||
overall_max = std::max(overall_max, input_max);
|
||||
}
|
||||
if (std::is_signed<T>::value) {
|
||||
// For signed, we want a symmetrical distribution including zero for the
|
||||
// output, so pick a range that meets that need.
|
||||
const float largest_value =
|
||||
std::max(std::abs(overall_min), std::abs(overall_max));
|
||||
*output_min = -largest_value;
|
||||
*output_max = largest_value;
|
||||
} else {
|
||||
// For MKL quantization, we only support scaled mode, so the range is
|
||||
// [0,m] for unsigned data
|
||||
overall_min = std::min(0.0f, overall_min);
|
||||
*output_min = overall_min;
|
||||
*output_max = overall_max;
|
||||
}
|
||||
}
|
||||
|
||||
// Although, we modify Compute for this call to accept one extra param,
|
||||
// we need to have empty Compute because Compute is pure virtual function.
|
||||
void Compute(OpKernelContext* c) {}
|
||||
|
||||
void Compute(OpKernelContext* c, const std::vector<Tensor>& values,
|
||||
const TensorShapeList& input_shapes) {
|
||||
const TensorShapeList& input_shapes,
|
||||
const OpInputList& input_mins, const OpInputList& input_maxes,
|
||||
bool quantized_input) {
|
||||
const Tensor* concat_dim_tensor;
|
||||
const char* axis_attribute_name =
|
||||
AxisArgName == NAME_IS_AXIS
|
||||
@ -80,13 +152,21 @@ class EigenConcatBaseOp : public OpKernel {
|
||||
const TensorShape& input_shape = input_shapes[0];
|
||||
|
||||
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
|
||||
OP_REQUIRES(c,
|
||||
(0 <= axis && axis < input_dims) ||
|
||||
(allow_legacy_scalars() && concat_dim == 0),
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Expected concatenating dimensions in the range "
|
||||
"[",
|
||||
-input_dims, ", ", input_dims, "), but got ", concat_dim));
|
||||
OP_REQUIRES(
|
||||
c, (0 <= axis && axis < input_dims) ||
|
||||
(allow_legacy_scalars() && concat_dim == 0),
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Expected concatenating dimensions in the range [",
|
||||
-input_dims, ", ", input_dims, "), but got ", concat_dim));
|
||||
|
||||
float output_min = std::numeric_limits<float>::max();
|
||||
float output_max = std::numeric_limits<float>::lowest();
|
||||
std::vector<std::pair<float, float>> input_mins_and_maxes;
|
||||
if (quantized_input) {
|
||||
CalculateInputAndOutputRange(input_mins, input_maxes, N,
|
||||
&input_mins_and_maxes, &output_min,
|
||||
&output_max);
|
||||
}
|
||||
// Note that we reduce the concat of n-dimensional tensors into a two
|
||||
// dimensional concat. Assuming the dimensions of any input/output
|
||||
// tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
|
||||
@ -104,13 +184,12 @@ class EigenConcatBaseOp : public OpKernel {
|
||||
const auto in = values[i];
|
||||
const bool in_is_scalar = IsLegacyScalar(input_shapes[i]);
|
||||
OP_REQUIRES(
|
||||
c,
|
||||
(input_shapes[i].dims() == input_dims) ||
|
||||
(input_is_scalar && in_is_scalar),
|
||||
c, (input_shapes[i].dims() == input_dims) ||
|
||||
(input_is_scalar && in_is_scalar),
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
|
||||
input_shape.DebugString(), " vs. shape[", i,
|
||||
"] = ", input_shapes[i].DebugString()));
|
||||
input_shape.DebugString(), " vs. shape[", i, "] = ",
|
||||
input_shapes[i].DebugString()));
|
||||
if (in.NumElements() > 0) {
|
||||
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
|
||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||
@ -131,7 +210,24 @@ class EigenConcatBaseOp : public OpKernel {
|
||||
if (output->NumElements() > 0) {
|
||||
int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
|
||||
auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
|
||||
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
|
||||
if (!quantized_input) {
|
||||
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
|
||||
} else {
|
||||
ConcatCPUImpl<T>(
|
||||
c->device(), inputs_flat, sizeof(T) /* cost_per_unit */,
|
||||
RequantizeCopier<T>(&input_mins_and_maxes, output_min, output_max),
|
||||
&output_flat);
|
||||
}
|
||||
}
|
||||
|
||||
if (quantized_input) {
|
||||
Tensor* output_min_tensor = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(1, {}, &output_min_tensor));
|
||||
output_min_tensor->flat<float>()(0) = output_min;
|
||||
|
||||
Tensor* output_max_tensor = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(2, {}, &output_max_tensor));
|
||||
output_max_tensor->flat<float>()(0) = output_max;
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -229,7 +325,9 @@ class MklConcatOp : public OpKernel {
|
||||
if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
|
||||
|
||||
OpInputList input_mins, input_maxes;
|
||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
|
||||
bool quantized_input =
|
||||
std::is_same<T, qint8>::value || std::is_same<T, quint8>::value;
|
||||
if (quantized_input) {
|
||||
// MKL-DNN concat does not support input tensors that have different
|
||||
// ranges. Check if the ranges of the all input tensors are the same.
|
||||
// If not, forward it to Eigen implementation.
|
||||
@ -262,17 +360,8 @@ class MklConcatOp : public OpKernel {
|
||||
|
||||
// Call Eigen library
|
||||
if (invoke_eigen) {
|
||||
// MKL-DNN quantized concat does not support input tensors with
|
||||
// different ranges.
|
||||
// TODO (mabuzain): Add quantized version of CallEigen() to support
|
||||
// this case.
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
(!std::is_same<T, qint8>::value && !std::is_same<T, quint8>::value),
|
||||
errors::Unimplemented("MKL DNN quantized concat does not "
|
||||
"support input tensors that have "
|
||||
"different ranges"));
|
||||
CallEigenVersion(context, input_tensors, mkl_input_shapes);
|
||||
CallEigenVersion(context, input_tensors, input_mins, input_maxes,
|
||||
mkl_input_shapes, quantized_input);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -421,7 +510,7 @@ class MklConcatOp : public OpKernel {
|
||||
}
|
||||
AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
|
||||
dnn_shape_dst);
|
||||
DCHECK(dst_tensor == nullptr) << "Output tensor pointer is NULL";
|
||||
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
|
||||
|
||||
if (dnn_shape_dst.IsMklTensor()) dst_md = dnn_shape_dst.GetMklLayout();
|
||||
dst.SetUsrMem(dst_md, dst_tensor);
|
||||
@ -432,7 +521,7 @@ class MklConcatOp : public OpKernel {
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
|
||||
// For quantized concat, min and max outputs are also computed.
|
||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
|
||||
if (quantized_input) {
|
||||
Tensor* output_min = nullptr;
|
||||
Tensor* output_max = nullptr;
|
||||
MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
|
||||
@ -456,13 +545,13 @@ class MklConcatOp : public OpKernel {
|
||||
|
||||
AllocateOutputSetMklShape(context, 0, &dst_tensor, tf_shape_dst,
|
||||
dnn_shape_dst);
|
||||
DCHECK(dst_tensor == nullptr) << "Output tensor pointer is NULL";
|
||||
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
|
||||
}
|
||||
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||
string(e.message) + ", in file " + string(__FILE__) +
|
||||
":" + std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
@ -470,42 +559,45 @@ class MklConcatOp : public OpKernel {
|
||||
}
|
||||
|
||||
void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
|
||||
const MklDnnShapeList& mkl_input_shapes) {
|
||||
CHECK_EQ(values.size(), mkl_input_shapes.size());
|
||||
|
||||
std::vector<Tensor> converted_values;
|
||||
const OpInputList& input_mins,
|
||||
const OpInputList& input_maxes,
|
||||
const MklDnnShapeList& mkl_input_shapes,
|
||||
bool quantized_input) {
|
||||
size_t num_mkl_input_shapes = mkl_input_shapes.size();
|
||||
CHECK_EQ(values.size(), num_mkl_input_shapes);
|
||||
std::vector<Tensor> converted_values(num_mkl_input_shapes);
|
||||
TensorShapeList tf_input_shapes;
|
||||
for (int i = 0; i < mkl_input_shapes.size(); i++) {
|
||||
for (size_t i = 0; i < num_mkl_input_shapes; ++i) {
|
||||
if (mkl_input_shapes[i].IsMklTensor()) {
|
||||
// do conversion from MKL to TF
|
||||
Tensor tmp_tensor =
|
||||
ConvertMklToTF<T>(context, values[i], mkl_input_shapes[i]);
|
||||
converted_values.push_back(tmp_tensor);
|
||||
converted_values[i] = tmp_tensor;
|
||||
tf_input_shapes.push_back(mkl_input_shapes[i].GetTfShape());
|
||||
} else {
|
||||
// no conversion since it is TF tensor already
|
||||
converted_values.push_back(values[i]);
|
||||
converted_values[i] = values[i];
|
||||
tf_input_shapes.push_back(values[i].shape());
|
||||
}
|
||||
}
|
||||
|
||||
// Call Eigen concat.
|
||||
eigen_concat_op_.Compute(context, converted_values, tf_input_shapes);
|
||||
eigen_concat_op_.Compute(context, converted_values, tf_input_shapes,
|
||||
input_mins, input_maxes, quantized_input);
|
||||
|
||||
// Set output Mkl tensor for this op.
|
||||
MklDnnShape dnn_shape_output;
|
||||
dnn_shape_output.SetMklTensor(false);
|
||||
dnn_shape_output.SetDimensions(4);
|
||||
Tensor* output_tensor = nullptr;
|
||||
TensorShape tf_shape_output;
|
||||
tf_shape_output.AddDim(dnn_shape_output.GetSerializeBufferSize());
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(
|
||||
GetTensorMetaDataIndex(0, context->num_outputs()),
|
||||
tf_shape_output, &output_tensor));
|
||||
dnn_shape_output.SerializeMklDnnShape(
|
||||
output_tensor->flat<uint8>().data(),
|
||||
output_tensor->flat<uint8>().size() * sizeof(uint8));
|
||||
// Get the number of dims from first input since all input tensors
|
||||
// should have same rank.
|
||||
size_t dims = values[0].shape().dims();
|
||||
MklDnnShape output_data_mkl_shape;
|
||||
output_data_mkl_shape.SetMklTensor(false);
|
||||
output_data_mkl_shape.SetDimensions(dims);
|
||||
AllocateOutputSetMklShape(context, 0, output_data_mkl_shape);
|
||||
if (quantized_input) {
|
||||
MklDnnShape output_min_max_mkl_shape;
|
||||
output_min_max_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 1, output_min_max_mkl_shape);
|
||||
AllocateOutputSetMklShape(context, 2, output_min_max_mkl_shape);
|
||||
}
|
||||
}
|
||||
|
||||
// This method finds the most common format across all MKL inputs
|
||||
|
@ -86,6 +86,10 @@ TEST_F(QuantizedConcatTest, Small8BitSameRange) {
|
||||
TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f);
|
||||
}
|
||||
|
||||
TEST_F(QuantizedConcatTest, Small8BitDifferentRange) {
|
||||
TestSmall8Bit(0.0f, 255.0f, 0.0f, 25.0f);
|
||||
}
|
||||
|
||||
void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max,
|
||||
float second_min, float second_max) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
|
||||
|
Loading…
Reference in New Issue
Block a user