Merge pull request #43577 from Intel-tensorflow:amin/inplace-fix
PiperOrigin-RevId: 336362945 Change-Id: I6d466209337df92b6146bfb90c91f88faef50f5d
This commit is contained in:
commit
ead4d024f8
@ -957,10 +957,11 @@ class MklConvOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
// Check if reorder is needed
|
||||
if (add_mkl_shape == *output_mkl_shape) {
|
||||
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
|
||||
kOutputIndex_Dst, add_mkl_shape);
|
||||
*output_tensor = context->mutable_output(kOutputIndex_Dst);
|
||||
if (add_mkl_shape == *output_mkl_shape &&
|
||||
ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
|
||||
kOutputIndex_Dst, output_tensor,
|
||||
add_mkl_shape, false)) {
|
||||
return;
|
||||
} else {
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||
output_tf_shape, *output_mkl_shape,
|
||||
@ -1888,7 +1889,7 @@ class MklQuantizedConv2DSumReluOp
|
||||
GetMklShape(context, summand_idx, &summand_mkl_shape);
|
||||
auto dst_md = summand_mkl_shape.GetMklLayout();
|
||||
|
||||
// TODO(mdfaijul): handle both non-MKL and MKL tensors
|
||||
// TODO(intel-tf): Handle both non-MKL and MKL tensors
|
||||
if (summand_type == DT_QINT8) {
|
||||
OP_REQUIRES_OK(
|
||||
context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape()));
|
||||
@ -1897,9 +1898,13 @@ class MklQuantizedConv2DSumReluOp
|
||||
summand_mkl_shape.SetMklLayout(&dst_md);
|
||||
summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
||||
}
|
||||
ForwardMklTensorInToOutWithMklShape(context, summand_idx, 0,
|
||||
summand_mkl_shape);
|
||||
*output_tensor = const_cast<Tensor*>(&summand);
|
||||
// TODO(intel-tf): Support cases when summand cannot be forwarded.
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
ForwardMklTensorInToOutWithMklShape(
|
||||
context, summand_idx, 0, output_tensor, summand_mkl_shape, false),
|
||||
errors::InvalidArgument(
|
||||
"Summand cannot be forwarded in the current fusion."));
|
||||
return;
|
||||
}
|
||||
MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||
|
@ -893,21 +893,36 @@ inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
|
||||
AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
|
||||
}
|
||||
|
||||
inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
|
||||
// If the input tensor has ref count as 1, it is forwarded to the desired
|
||||
// output port and the function returns true. In that case, it also allocates
|
||||
// the serialized MklDnnShape object. Otherwise, the function returns false.
|
||||
inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
|
||||
int idx_in, int idx_out,
|
||||
const MklDnnShape& mkl_shape) {
|
||||
Tensor** output,
|
||||
const MklDnnShape& mkl_shape,
|
||||
bool always_forward = true) {
|
||||
int num_inputs = context->num_inputs();
|
||||
int num_outputs = context->num_outputs();
|
||||
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
||||
int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
|
||||
|
||||
AllocateOutputSetMklShape(context, idx_out, mkl_shape);
|
||||
|
||||
if (IsRefType(context->input_dtype(idx_data_in))) {
|
||||
context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
|
||||
bool is_forwarded = false;
|
||||
const Tensor& input_tensor = context->input(idx_data_in);
|
||||
const auto output_shape = input_tensor.shape();
|
||||
if (always_forward) {
|
||||
if (IsRefType(context->input_dtype(idx_data_in))) {
|
||||
context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
|
||||
} else {
|
||||
context->set_output(idx_data_out, input_tensor);
|
||||
}
|
||||
} else {
|
||||
context->set_output(idx_data_out, context->input(idx_data_in));
|
||||
is_forwarded = context->forward_input_to_output_with_shape(
|
||||
idx_data_in, idx_data_out, output_shape, output);
|
||||
}
|
||||
if (is_forwarded || always_forward) {
|
||||
AllocateOutputSetMklShape(context, idx_out, mkl_shape);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Forward the MKL shape ONLY (used in elementwise and other ops where
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
@ -3266,6 +3267,173 @@ def GetInceptionBackFilterTest(input_size, filter_size, output_size, strides,
|
||||
return Test
|
||||
|
||||
|
||||
class FusedConv2DTest(test.TestCase):
|
||||
|
||||
def _CreateNumpyTensor(self, shape):
|
||||
total_size = np.prod(shape)
|
||||
return np.arange(1, total_size + 1, dtype=np.float32).reshape(shape)
|
||||
|
||||
def _CreateConv2D(self,
|
||||
input_values,
|
||||
filters,
|
||||
strides=[1, 1],
|
||||
padding="SAME"):
|
||||
return nn_ops.convolution(
|
||||
input_values, filters, strides=strides, padding=padding)
|
||||
|
||||
# Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to
|
||||
# Add has refcount 1.
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
|
||||
def testAddWithRefCountOne(self):
|
||||
expected_output = [
|
||||
113377, 125570, 77305, 86738, 19433, 22226, 60681, 70722, 36291, 43718,
|
||||
7143, 9206, 9785, 12098, 4783, 6366, 779, 1134
|
||||
]
|
||||
tensor_in_sizes = [1, 3, 3, 2]
|
||||
filter_in_sizes = [2, 2, 2, 2]
|
||||
bias_in_sizes = [2]
|
||||
|
||||
x = self._CreateNumpyTensor(tensor_in_sizes)
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
bias_in = self._CreateNumpyTensor(bias_in_sizes)
|
||||
# To get different weights for filter
|
||||
offset = 1
|
||||
|
||||
conv1 = self._CreateConv2D(x, filter_in)
|
||||
conv2 = self._CreateConv2D(conv1, filter_in + offset)
|
||||
|
||||
conv = self._CreateConv2D(conv1, filter_in - offset)
|
||||
bias_add = nn_ops.bias_add(conv, bias_in)
|
||||
add = math_ops.add_n([bias_add, conv2])
|
||||
|
||||
self.assertAllEqual(
|
||||
np.rint(expected_output),
|
||||
self.evaluate(add).reshape(-1))
|
||||
|
||||
# Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to
|
||||
# Add has a total refcount of 2, and Add is its last consumer.
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
|
||||
def testAddWithRefCountTwoAndRunAddLast(self):
|
||||
expected_output = [
|
||||
1.907175e+06, 2.253505e+06, 7.809210e+05, 9.537180e+05, 1.184170e+05,
|
||||
1.523070e+05, 5.367010e+05, 6.803700e+05, 1.867090e+05, 2.529460e+05,
|
||||
2.362300e+04, 3.522600e+04, 5.121700e+04, 7.168300e+04, 1.494300e+04,
|
||||
2.347400e+04, 1.558000e+03, 2.903000e+03
|
||||
]
|
||||
tensor_in_sizes = [1, 3, 3, 2]
|
||||
filter_in_sizes = [2, 2, 2, 2]
|
||||
bias_in_sizes = [2]
|
||||
|
||||
x = self._CreateNumpyTensor(tensor_in_sizes)
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
bias_in = self._CreateNumpyTensor(bias_in_sizes)
|
||||
# To get different weights for filter
|
||||
offset = 1
|
||||
|
||||
conv1 = self._CreateConv2D(x, filter_in)
|
||||
conv2 = self._CreateConv2D(conv1, filter_in + offset)
|
||||
|
||||
conv = self._CreateConv2D(conv2, filter_in - offset)
|
||||
bias_add = nn_ops.bias_add(conv, bias_in)
|
||||
add = math_ops.add_n([bias_add, conv1])
|
||||
|
||||
self.assertAllEqual(
|
||||
np.rint(expected_output),
|
||||
self.evaluate(add).reshape(-1))
|
||||
|
||||
# Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to
|
||||
# Add has refcount 2 and Add (in the fused Conv2D op) is its first consumer.
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
|
||||
def testAddWithRefCountTwoAndRunAddFirst(self):
|
||||
expected_output = [
|
||||
176161, 194450, 120673, 134822, 30545, 34734, 96041, 111102, 58149,
|
||||
69289, 11745, 14839, 15833, 19302, 7965, 10339, 1345, 1877
|
||||
]
|
||||
tensor_in_sizes = [1, 3, 3, 2]
|
||||
filter_in_sizes = [2, 2, 2, 2]
|
||||
bias_in_sizes = [2]
|
||||
|
||||
x = self._CreateNumpyTensor(tensor_in_sizes)
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
bias_in = self._CreateNumpyTensor(bias_in_sizes)
|
||||
# To get different weights for filter
|
||||
offset = 1
|
||||
|
||||
conv1 = self._CreateConv2D(x, filter_in)
|
||||
conv2 = self._CreateConv2D(conv1, filter_in + offset)
|
||||
|
||||
conv = self._CreateConv2D(conv1, filter_in - offset)
|
||||
bias_add = nn_ops.bias_add(conv, bias_in)
|
||||
add = math_ops.add_n([bias_add, conv2])
|
||||
|
||||
relu = nn_ops.relu(add)
|
||||
output = math_ops.add_n([relu, conv2])
|
||||
|
||||
self.assertAllEqual(
|
||||
np.rint(expected_output),
|
||||
self.evaluate(output).reshape(-1))
|
||||
|
||||
# Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to
|
||||
# Add has refcount 2, and there is no dependency between its two consumers.
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
|
||||
def testAddWithRefCountTwoAndNoDependence(self):
|
||||
expected_output = [
|
||||
176161, 194450, 120673, 134822, 30545, 34734, 96041, 111102, 58149,
|
||||
69289, 11745, 14839, 15833, 19302, 7965, 10339, 1345, 1877
|
||||
]
|
||||
tensor_in_sizes = [1, 3, 3, 2]
|
||||
filter_in_sizes = [2, 2, 2, 2]
|
||||
bias_in_sizes = [2]
|
||||
|
||||
x = self._CreateNumpyTensor(tensor_in_sizes)
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
bias_in = self._CreateNumpyTensor(bias_in_sizes)
|
||||
# To get different weights for filter
|
||||
offset = 1
|
||||
|
||||
conv1 = self._CreateConv2D(x, filter_in)
|
||||
conv2 = self._CreateConv2D(conv1, filter_in + offset)
|
||||
|
||||
conv = self._CreateConv2D(conv1, filter_in - offset)
|
||||
bias_add = nn_ops.bias_add(conv, bias_in)
|
||||
add = math_ops.add_n([bias_add, conv2])
|
||||
|
||||
relu1 = nn_ops.relu(add)
|
||||
relu2 = nn_ops.relu(conv2)
|
||||
output = math_ops.add_n([relu1, relu2])
|
||||
|
||||
self.assertAllEqual(
|
||||
np.rint(expected_output),
|
||||
self.evaluate(output).reshape(-1))
|
||||
|
||||
# Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to
|
||||
# Add is the same as the input to the fused Conv2D op and needs a tensor
|
||||
# buffer.
|
||||
@test_util.run_in_graph_and_eager_modes(use_gpu=False)
|
||||
def testAddWithSameSrcAndAddTensorBuffer(self):
|
||||
expected_output = [
|
||||
57157, 63298, 39249, 44026, 9971, 11402, 31193, 36306, 19126, 22948,
|
||||
3970, 5060, 5135, 6350, 2666, 3524, 461, 674
|
||||
]
|
||||
tensor_in_sizes = [1, 3, 3, 2]
|
||||
filter_in_sizes = [2, 2, 2, 2]
|
||||
bias_in_sizes = [2]
|
||||
|
||||
x = self._CreateNumpyTensor(tensor_in_sizes)
|
||||
filter_in = self._CreateNumpyTensor(filter_in_sizes)
|
||||
bias_in = self._CreateNumpyTensor(bias_in_sizes)
|
||||
|
||||
conv1 = self._CreateConv2D(x, filter_in)
|
||||
|
||||
conv = self._CreateConv2D(conv1, filter_in)
|
||||
bias_add = nn_ops.bias_add(conv, bias_in)
|
||||
add = math_ops.add_n([bias_add, conv1])
|
||||
|
||||
self.assertAllEqual(
|
||||
np.rint(expected_output),
|
||||
self.evaluate(add).reshape(-1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for index, (input_size_, filter_size_, output_size_, stride_,
|
||||
padding_) in enumerate(GetShrunkInceptionShapes()):
|
||||
|
Loading…
Reference in New Issue
Block a user