Fix to tensor forwarding.

This commit is contained in:
mdfaijul 2020-09-25 00:48:40 -07:00
parent 7893e4bcc1
commit f6490af3f1
3 changed files with 197 additions and 17 deletions

View File

@ -24,8 +24,8 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "mkldnn.hpp"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "mkldnn.hpp"
#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -942,10 +942,11 @@ class MklConvOp : public OpKernel {
GetMklShape(context, kInputIndex_Add, &add_mkl_shape); GetMklShape(context, kInputIndex_Add, &add_mkl_shape);
// Check if reorder is needed // Check if reorder is needed
if (add_mkl_shape == *output_mkl_shape) { if (add_mkl_shape == *output_mkl_shape &&
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
kOutputIndex_Dst, add_mkl_shape); kOutputIndex_Dst, output_tensor,
*output_tensor = context->mutable_output(kOutputIndex_Dst); add_mkl_shape, false)) {
return;
} else { } else {
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
output_tf_shape, *output_mkl_shape, output_tf_shape, *output_mkl_shape,
@ -1870,7 +1871,7 @@ class MklQuantizedConv2DSumReluOp
GetMklShape(context, summand_idx, &summand_mkl_shape); GetMklShape(context, summand_idx, &summand_mkl_shape);
auto dst_md = summand_mkl_shape.GetMklLayout(); auto dst_md = summand_mkl_shape.GetMklLayout();
// TODO(mdfaijul): handle both non-MKL and MKL tensors // TODO(intel-tf): Handle both non-MKL and MKL tensors
if (summand_type == DT_QINT8) { if (summand_type == DT_QINT8) {
OP_REQUIRES_OK( OP_REQUIRES_OK(
context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape())); context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape()));
@ -1879,9 +1880,13 @@ class MklQuantizedConv2DSumReluOp
summand_mkl_shape.SetMklLayout(&dst_md); summand_mkl_shape.SetMklLayout(&dst_md);
summand_mkl_shape.SetElemType(MklDnnType<Toutput>()); summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
} }
ForwardMklTensorInToOutWithMklShape(context, summand_idx, 0, // TODO(intel-tf): Support cases when summand cannot be forwarded.
summand_mkl_shape); OP_REQUIRES(
*output_tensor = const_cast<Tensor*>(&summand); context,
ForwardMklTensorInToOutWithMklShape(
context, summand_idx, 0, output_tensor, summand_mkl_shape, false),
errors::InvalidArgument(
"Summand cannot be forwarded in the current fusion."));
return; return;
} }
MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,

View File

@ -893,20 +893,36 @@ inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output); AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
} }
inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, // If the input tensor has ref count as 1, it is forwarded to the desired
// output port and the function reutrns true. In that case, it also allocates
// the serialized MklDnnShape object. Otherwise, the function returns false.
inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
int idx_in, int idx_out, int idx_in, int idx_out,
const MklDnnShape& mkl_shape) { Tensor** output,
const MklDnnShape& mkl_shape,
bool always_forward = true) {
int num_inputs = context->num_inputs(); int num_inputs = context->num_inputs();
int num_outputs = context->num_outputs(); int num_outputs = context->num_outputs();
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
bool is_forwarded = false;
AllocateOutputSetMklShape(context, idx_out, mkl_shape); const Tensor& input_tensor = context->input(idx_data_in);
const auto output_shape = input_tensor.shape();
if (IsRefType(context->input_dtype(idx_data_in))) { if (always_forward) {
context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); if (IsRefType(context->input_dtype(idx_data_in))) {
context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
} else {
context->set_output(idx_data_out, input_tensor);
}
} else { } else {
context->set_output(idx_data_out, context->input(idx_data_in)); is_forwarded = context->forward_input_to_output_with_shape(
idx_data_in, idx_data_out, output_shape, output);
}
if (is_forwarded || always_forward) {
AllocateOutputSetMklShape(context, idx_out, mkl_shape);
return true;
} else {
return false;
} }
} }

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
@ -3266,6 +3267,164 @@ def GetInceptionBackFilterTest(input_size, filter_size, output_size, strides,
return Test return Test
class FusedConv2DTest(test.TestCase):
def _CreateNumpyTensor(self, shape):
total_size = 1
for s in shape:
total_size *= s
return np.arange(1, total_size + 1, dtype=np.float32).reshape(shape)
def _CreateConv2D(self, input_values, filters,
strides=[1, 1], padding="SAME"):
return nn_ops.convolution(
input_values,
filters,
strides=strides,
padding=padding)
@test_util.deprecated_graph_mode_only
def testAddWithRefCountOne(self):
expected_output = [
113377, 125570, 77305, 86738, 19433, 22226, 60681,
70722, 36291, 43718, 7143, 9206, 9785, 12098,
4783, 6366, 779, 1134]
tensor_in_sizes = [1, 3, 3, 2]
filter_in_sizes = [2, 2, 2, 2]
bias_in_sizes = [2]
x = self._CreateNumpyTensor(tensor_in_sizes)
filter_in = self._CreateNumpyTensor(filter_in_sizes)
bias_in = self._CreateNumpyTensor(bias_in_sizes)
# To get different weights for filter
ofs = 1
conv1 = self._CreateConv2D(x, filter_in)
conv2 = self._CreateConv2D(conv1, filter_in + ofs)
conv = self._CreateConv2D(conv1, filter_in - ofs)
bias_add = nn_ops.bias_add(conv, bias_in)
add = math_ops.add_n([bias_add, conv2])
self.assertAllEqual(
np.rint(expected_output),
self.evaluate(add).reshape(-1))
@test_util.deprecated_graph_mode_only
def testAddWithRefCountTwoAndRunAddLast(self):
expected_output = [
1.907175e+06, 2.253505e+06, 7.809210e+05, 9.537180e+05,
1.184170e+05, 1.523070e+05, 5.367010e+05, 6.803700e+05,
1.867090e+05, 2.529460e+05, 2.362300e+04, 3.522600e+04,
5.121700e+04, 7.168300e+04, 1.494300e+04, 2.347400e+04,
1.558000e+03, 2.903000e+03]
tensor_in_sizes = [1, 3, 3, 2]
filter_in_sizes = [2, 2, 2, 2]
bias_in_sizes = [2]
x = self._CreateNumpyTensor(tensor_in_sizes)
filter_in = self._CreateNumpyTensor(filter_in_sizes)
bias_in = self._CreateNumpyTensor(bias_in_sizes)
# To get different weights for filter
ofs = 1
conv1 = self._CreateConv2D(x, filter_in)
conv2 = self._CreateConv2D(conv1, filter_in + ofs)
conv = self._CreateConv2D(conv2, filter_in - ofs)
bias_add = nn_ops.bias_add(conv, bias_in)
add = math_ops.add_n([bias_add, conv1])
self.assertAllEqual(
np.rint(expected_output),
self.evaluate(add).reshape(-1))
@test_util.deprecated_graph_mode_only
def testAddWithRefCountTwoAndRunAddFirst(self):
expected_output = [
176161, 194450, 120673, 134822, 30545, 34734, 96041,
111102, 58149, 69289, 11745, 14839, 15833, 19302,
7965, 10339, 1345, 1877]
tensor_in_sizes = [1, 3, 3, 2]
filter_in_sizes = [2, 2, 2, 2]
bias_in_sizes = [2]
x = self._CreateNumpyTensor(tensor_in_sizes)
filter_in = self._CreateNumpyTensor(filter_in_sizes)
bias_in = self._CreateNumpyTensor(bias_in_sizes)
# To get different weights for filter
ofs = 1
conv1 = self._CreateConv2D(x, filter_in)
conv2 = self._CreateConv2D(conv1, filter_in + ofs)
conv = self._CreateConv2D(conv1, filter_in - ofs)
bias_add = nn_ops.bias_add(conv, bias_in)
add = math_ops.add_n([bias_add, conv2])
relu = nn_ops.relu(add)
output = math_ops.add_n([relu, conv2])
self.assertAllEqual(
np.rint(expected_output),
self.evaluate(output).reshape(-1))
@test_util.deprecated_graph_mode_only
def testAddWithRefCountTwoAndNoDependence(self):
expected_output = [
176161, 194450, 120673, 134822, 30545, 34734, 96041,
111102, 58149, 69289, 11745, 14839, 15833, 19302,
7965, 10339, 1345, 1877]
tensor_in_sizes = [1, 3, 3, 2]
filter_in_sizes = [2, 2, 2, 2]
bias_in_sizes = [2]
x = self._CreateNumpyTensor(tensor_in_sizes)
filter_in = self._CreateNumpyTensor(filter_in_sizes)
bias_in = self._CreateNumpyTensor(bias_in_sizes)
# To get different weights for filter
ofs = 1
conv1 = self._CreateConv2D(x, filter_in)
conv2 = self._CreateConv2D(conv1, filter_in + ofs)
conv = self._CreateConv2D(conv1, filter_in - ofs)
bias_add = nn_ops.bias_add(conv, bias_in)
add = math_ops.add_n([bias_add, conv2])
relu1 = nn_ops.relu(add)
relu2 = nn_ops.relu(conv2)
output = math_ops.add_n([relu1, relu2])
self.assertAllEqual(
np.rint(expected_output),
self.evaluate(output).reshape(-1))
@test_util.deprecated_graph_mode_only
def testAddWithSameSrcAndAddTensorBuffer(self):
expected_output = [
57157, 63298, 39249, 44026, 9971, 11402, 31193, 36306,
19126, 22948, 3970, 5060, 5135, 6350, 2666, 3524,
461, 674]
tensor_in_sizes = [1, 3, 3, 2]
filter_in_sizes = [2, 2, 2, 2]
bias_in_sizes = [2]
x = self._CreateNumpyTensor(tensor_in_sizes)
filter_in = self._CreateNumpyTensor(filter_in_sizes)
bias_in = self._CreateNumpyTensor(bias_in_sizes)
conv1 = self._CreateConv2D(x, filter_in)
conv = self._CreateConv2D(conv1, filter_in)
bias_add = nn_ops.bias_add(conv, bias_in)
add = math_ops.add_n([bias_add, conv1])
self.assertAllEqual(
np.rint(expected_output),
self.evaluate(add).reshape(-1))
if __name__ == "__main__": if __name__ == "__main__":
for index, (input_size_, filter_size_, output_size_, stride_, for index, (input_size_, filter_size_, output_size_, stride_,
padding_) in enumerate(GetShrunkInceptionShapes()): padding_) in enumerate(GetShrunkInceptionShapes()):