Merge pull request #41688 from Intel-tensorflow:yang/fix_concat

PiperOrigin-RevId: 323612153
Change-Id: I72c92a6af1c5a908e599185ce6a03b6f42d04533
This commit is contained in:
TensorFlower Gardener 2020-07-28 11:20:11 -07:00
commit 482d273416
2 changed files with 23 additions and 7 deletions

View File

@ -376,19 +376,20 @@ class MklConcatFwdPrimitive : public MklPrimitive {
context_.data_mem_shdptr.push_back(src_mem);
context_.data_mem.push_back(*context_.data_mem_shdptr[i]);
}
// Store the expected memory format
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
MklDnnType<T>(),
concat_fwd_dims.mkl_common_format));
// Create a concat primitive descriptor
#ifdef ENABLE_MKLDNN_V1
context_.fwd_pd.reset(new concat::primitive_desc(
concat_fwd_dims.concat_dims, context_.src_md, cpu_engine_));
*context_.dst_md, concat_fwd_dims.concat_dims, context_.src_md,
cpu_engine_));
#else
context_.fwd_pd.reset(new concat::primitive_desc(
concat_fwd_dims.concat_dims, context_.src_pd));
#endif // ENABLE_MKLDNN_V1
// Store the expected memory format
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
MklDnnType<T>(),
concat_fwd_dims.mkl_common_format));
#ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.dst_mem.reset(
@ -404,8 +405,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
#ifdef ENABLE_MKLDNN_V1
context_.concat_fwd.reset(new concat(*context_.fwd_pd));
std::unordered_map<int, memory> net_args = {
{ MKLDNN_ARG_DST,
*context_.dst_mem }};
{MKLDNN_ARG_DST, *context_.dst_mem}};
for (int i = 0; i < concat_fwd_dims.num_inputs; ++i) {
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, context_.data_mem[i]});
}

View File

@ -68,6 +68,22 @@ class ConcatOpTest(test.TestCase):
self.assertAllEqual(result[:, :4], params[p1])
self.assertAllEqual(result[:, 4:], params[p2])
@test_util.run_deprecated_v1
def test4DStack(self):
with self.session(use_gpu=True):
p1 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 1, 1])
p2 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 4, 1])
c = array_ops.concat([p1, p2], 2)
params = {
p1: np.random.rand(2, 3, 1, 1).astype("f"),
p2: np.random.rand(2, 3, 4, 1).astype("f")
}
result = c.eval(feed_dict=params)
self.assertEqual(result.shape, c.get_shape())
self.assertAllEqual(result[:, :, :1, :], params[p1])
self.assertAllEqual(result[:, :, 1:, :], params[p2])
def testInt32GPU(self):
with test_util.use_gpu():
p1 = np.random.rand(2, 3).astype("i")