Merge pull request #41688 from Intel-tensorflow:yang/fix_concat
PiperOrigin-RevId: 323612153 Change-Id: I72c92a6af1c5a908e599185ce6a03b6f42d04533
This commit is contained in:
commit
482d273416
@ -376,19 +376,20 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
context_.data_mem_shdptr.push_back(src_mem);
|
||||
context_.data_mem.push_back(*context_.data_mem_shdptr[i]);
|
||||
}
|
||||
// Store the expected memory format
|
||||
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
|
||||
MklDnnType<T>(),
|
||||
concat_fwd_dims.mkl_common_format));
|
||||
// Create a concat primitive descriptor
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.fwd_pd.reset(new concat::primitive_desc(
|
||||
concat_fwd_dims.concat_dims, context_.src_md, cpu_engine_));
|
||||
*context_.dst_md, concat_fwd_dims.concat_dims, context_.src_md,
|
||||
cpu_engine_));
|
||||
#else
|
||||
context_.fwd_pd.reset(new concat::primitive_desc(
|
||||
concat_fwd_dims.concat_dims, context_.src_pd));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Store the expected memory format
|
||||
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
|
||||
MklDnnType<T>(),
|
||||
concat_fwd_dims.mkl_common_format));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Create memory primitive based on dummy data
|
||||
context_.dst_mem.reset(
|
||||
@ -404,8 +405,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.concat_fwd.reset(new concat(*context_.fwd_pd));
|
||||
std::unordered_map<int, memory> net_args = {
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }};
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem}};
|
||||
for (int i = 0; i < concat_fwd_dims.num_inputs; ++i) {
|
||||
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, context_.data_mem[i]});
|
||||
}
|
||||
|
@ -68,6 +68,22 @@ class ConcatOpTest(test.TestCase):
|
||||
self.assertAllEqual(result[:, :4], params[p1])
|
||||
self.assertAllEqual(result[:, 4:], params[p2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test4DStack(self):
|
||||
with self.session(use_gpu=True):
|
||||
p1 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 1, 1])
|
||||
p2 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 4, 1])
|
||||
c = array_ops.concat([p1, p2], 2)
|
||||
params = {
|
||||
p1: np.random.rand(2, 3, 1, 1).astype("f"),
|
||||
p2: np.random.rand(2, 3, 4, 1).astype("f")
|
||||
}
|
||||
result = c.eval(feed_dict=params)
|
||||
|
||||
self.assertEqual(result.shape, c.get_shape())
|
||||
self.assertAllEqual(result[:, :, :1, :], params[p1])
|
||||
self.assertAllEqual(result[:, :, 1:, :], params[p2])
|
||||
|
||||
def testInt32GPU(self):
|
||||
with test_util.use_gpu():
|
||||
p1 = np.random.rand(2, 3).astype("i")
|
||||
|
Loading…
x
Reference in New Issue
Block a user