Merge pull request #41688 from Intel-tensorflow:yang/fix_concat
PiperOrigin-RevId: 323612153 Change-Id: I72c92a6af1c5a908e599185ce6a03b6f42d04533
This commit is contained in:
commit
482d273416
@ -376,19 +376,20 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
|||||||
context_.data_mem_shdptr.push_back(src_mem);
|
context_.data_mem_shdptr.push_back(src_mem);
|
||||||
context_.data_mem.push_back(*context_.data_mem_shdptr[i]);
|
context_.data_mem.push_back(*context_.data_mem_shdptr[i]);
|
||||||
}
|
}
|
||||||
|
// Store the expected memory format
|
||||||
|
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
|
||||||
|
MklDnnType<T>(),
|
||||||
|
concat_fwd_dims.mkl_common_format));
|
||||||
// Create a concat primitive descriptor
|
// Create a concat primitive descriptor
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
context_.fwd_pd.reset(new concat::primitive_desc(
|
context_.fwd_pd.reset(new concat::primitive_desc(
|
||||||
concat_fwd_dims.concat_dims, context_.src_md, cpu_engine_));
|
*context_.dst_md, concat_fwd_dims.concat_dims, context_.src_md,
|
||||||
|
cpu_engine_));
|
||||||
#else
|
#else
|
||||||
context_.fwd_pd.reset(new concat::primitive_desc(
|
context_.fwd_pd.reset(new concat::primitive_desc(
|
||||||
concat_fwd_dims.concat_dims, context_.src_pd));
|
concat_fwd_dims.concat_dims, context_.src_pd));
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
// Store the expected memory format
|
|
||||||
context_.dst_md.reset(new memory::desc({concat_fwd_dims.dst_dims},
|
|
||||||
MklDnnType<T>(),
|
|
||||||
concat_fwd_dims.mkl_common_format));
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
// Create memory primitive based on dummy data
|
// Create memory primitive based on dummy data
|
||||||
context_.dst_mem.reset(
|
context_.dst_mem.reset(
|
||||||
@ -404,8 +405,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
|||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
context_.concat_fwd.reset(new concat(*context_.fwd_pd));
|
context_.concat_fwd.reset(new concat(*context_.fwd_pd));
|
||||||
std::unordered_map<int, memory> net_args = {
|
std::unordered_map<int, memory> net_args = {
|
||||||
{ MKLDNN_ARG_DST,
|
{MKLDNN_ARG_DST, *context_.dst_mem}};
|
||||||
*context_.dst_mem }};
|
|
||||||
for (int i = 0; i < concat_fwd_dims.num_inputs; ++i) {
|
for (int i = 0; i < concat_fwd_dims.num_inputs; ++i) {
|
||||||
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, context_.data_mem[i]});
|
net_args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, context_.data_mem[i]});
|
||||||
}
|
}
|
||||||
|
@ -68,6 +68,22 @@ class ConcatOpTest(test.TestCase):
|
|||||||
self.assertAllEqual(result[:, :4], params[p1])
|
self.assertAllEqual(result[:, :4], params[p1])
|
||||||
self.assertAllEqual(result[:, 4:], params[p2])
|
self.assertAllEqual(result[:, 4:], params[p2])
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def test4DStack(self):
|
||||||
|
with self.session(use_gpu=True):
|
||||||
|
p1 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 1, 1])
|
||||||
|
p2 = array_ops.placeholder(dtypes.float32, shape=[2, 3, 4, 1])
|
||||||
|
c = array_ops.concat([p1, p2], 2)
|
||||||
|
params = {
|
||||||
|
p1: np.random.rand(2, 3, 1, 1).astype("f"),
|
||||||
|
p2: np.random.rand(2, 3, 4, 1).astype("f")
|
||||||
|
}
|
||||||
|
result = c.eval(feed_dict=params)
|
||||||
|
|
||||||
|
self.assertEqual(result.shape, c.get_shape())
|
||||||
|
self.assertAllEqual(result[:, :, :1, :], params[p1])
|
||||||
|
self.assertAllEqual(result[:, :, 1:, :], params[p2])
|
||||||
|
|
||||||
def testInt32GPU(self):
|
def testInt32GPU(self):
|
||||||
with test_util.use_gpu():
|
with test_util.use_gpu():
|
||||||
p1 = np.random.rand(2, 3).astype("i")
|
p1 = np.random.rand(2, 3).astype("i")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user