Merge pull request from Intel-tensorflow:lesliefang/fix_concat_failed_create_memory_descriptor

PiperOrigin-RevId: 264461803
This commit is contained in:
TensorFlower Gardener 2019-08-20 14:41:26 -07:00
commit 15bd4863bd
2 changed files with 36 additions and 2 deletions
tensorflow
core/kernels
python/kernel_tests

View File

@ -461,8 +461,14 @@ class MklConcatOp : public OpKernel {
dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format));
// Set the output format same as the most common format of inputs
// to avoid layout conversions.
dst_md = memory::desc(dst_dims_in_nchw, MklDnnType<T>(),
mkl_common_format);
if (mkl_common_format == memory::format::blocked) {
VLOG(1) << "mkl_common_format == memory::format::blocked";
dst_md = MklDnnData<T>::CreateBlockedMemDesc(
dst_dims_in_nchw, CalculateTFStrides(dst_dims_in_nchw));
} else {
dst_md = memory::desc(dst_dims_in_nchw, MklDnnType<T>(),
mkl_common_format);
}
} else if (dst_dims.size() == 2 &&
mkl_common_format == memory::format::nc) {
// When memory::format::nc, dst_dims are already in MKL-DNN order

View File

@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -714,6 +715,33 @@ class ConcatOffsetTest(test.TestCase):
ans = self.evaluate(off)
self.assertAllEqual(ans, [[0, 0, 0], [2, 0, 0], [3, 0, 0]])
def testCreateMemDecBlockedFormat(self):
"""Try to create the mkl concat operation
when one of the input's memory descriptor is in blocked format
"""
if test_util.IsMklEnabled():
s0 = np.ones((1, 8188, 4092, 1), dtype=np.uint8).astype(np.float32)
s1 = array_ops.strided_slice(
s0, [0, 1, 1, 0], [0, -1, -1, 0], [1, 1, 1, 1],
begin_mask=9,
end_mask=9)
s2 = array_ops.slice(s1, [0, 0, 0, 0], [-1, -1, -1, 1])
s3_1 = array_ops.slice(s2, [0, 4, 4, 0], [-1, 8178, 4082, 1])
s3_2 = array_ops.slice(s2, [0, 4, 4, 0], [-1, 8178, 4082, 1])
filter4_1 = constant_op.constant([[[[1.18, -0.51]]]])
s4_1 = nn_ops.conv2d(
s3_1, filter4_1, strides=[1, 1, 1, 1], padding="VALID")
filter4_2 = constant_op.constant([[[[1.38, -0.11]]]])
s4_2 = nn_ops.conv2d(
s3_2, filter4_2, strides=[1, 1, 1, 1], padding="VALID")
s5_1 = array_ops.slice(s4_1, [0, 6, 6, 0], [-1, 1, 1, -1])
s5_2 = array_ops.slice(s4_2, [0, 6, 6, 0], [-1, 1, 1, -1])
x_concat = array_ops.concat([s5_1, s5_2], 3)
self.evaluate(
x_concat
) # This test is only meant to check the creation is not crashed
if __name__ == "__main__":
test.main()