Merge pull request #31777 from Intel-tensorflow:lesliefang/fix_concat_failed_create_memory_descriptor
PiperOrigin-RevId: 264461803
This commit is contained in:
commit
15bd4863bd
tensorflow
@ -461,8 +461,14 @@ class MklConcatOp : public OpKernel {
|
||||
dst_dims, MklDnnDataFormatToTFDataFormat(orig_tf_format));
|
||||
// Set the output format same as the most common format of inputs
|
||||
// to avoid layout conversions.
|
||||
dst_md = memory::desc(dst_dims_in_nchw, MklDnnType<T>(),
|
||||
mkl_common_format);
|
||||
if (mkl_common_format == memory::format::blocked) {
|
||||
VLOG(1) << "mkl_common_format == memory::format::blocked";
|
||||
dst_md = MklDnnData<T>::CreateBlockedMemDesc(
|
||||
dst_dims_in_nchw, CalculateTFStrides(dst_dims_in_nchw));
|
||||
} else {
|
||||
dst_md = memory::desc(dst_dims_in_nchw, MklDnnType<T>(),
|
||||
mkl_common_format);
|
||||
}
|
||||
} else if (dst_dims.size() == 2 &&
|
||||
mkl_common_format == memory::format::nc) {
|
||||
// When memory::format::nc, dst_dims are already in MKL-DNN order
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -714,6 +715,33 @@ class ConcatOffsetTest(test.TestCase):
|
||||
ans = self.evaluate(off)
|
||||
self.assertAllEqual(ans, [[0, 0, 0], [2, 0, 0], [3, 0, 0]])
|
||||
|
||||
def testCreateMemDecBlockedFormat(self):
|
||||
"""Try to create the mkl concat operation
|
||||
|
||||
when one of the input's memory descriptor is in blocked format
|
||||
"""
|
||||
if test_util.IsMklEnabled():
|
||||
s0 = np.ones((1, 8188, 4092, 1), dtype=np.uint8).astype(np.float32)
|
||||
s1 = array_ops.strided_slice(
|
||||
s0, [0, 1, 1, 0], [0, -1, -1, 0], [1, 1, 1, 1],
|
||||
begin_mask=9,
|
||||
end_mask=9)
|
||||
s2 = array_ops.slice(s1, [0, 0, 0, 0], [-1, -1, -1, 1])
|
||||
s3_1 = array_ops.slice(s2, [0, 4, 4, 0], [-1, 8178, 4082, 1])
|
||||
s3_2 = array_ops.slice(s2, [0, 4, 4, 0], [-1, 8178, 4082, 1])
|
||||
filter4_1 = constant_op.constant([[[[1.18, -0.51]]]])
|
||||
s4_1 = nn_ops.conv2d(
|
||||
s3_1, filter4_1, strides=[1, 1, 1, 1], padding="VALID")
|
||||
filter4_2 = constant_op.constant([[[[1.38, -0.11]]]])
|
||||
s4_2 = nn_ops.conv2d(
|
||||
s3_2, filter4_2, strides=[1, 1, 1, 1], padding="VALID")
|
||||
s5_1 = array_ops.slice(s4_1, [0, 6, 6, 0], [-1, 1, 1, -1])
|
||||
s5_2 = array_ops.slice(s4_2, [0, 6, 6, 0], [-1, 1, 1, -1])
|
||||
x_concat = array_ops.concat([s5_1, s5_2], 3)
|
||||
self.evaluate(
|
||||
x_concat
|
||||
) # This test is only meant to check the creation is not crashed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user