Fix the mean the description of the output shape, which should be [batch_size, 1, max_time] (#16642)

This commit is contained in:
ImSheridan 2018-02-06 08:15:33 +08:00 committed by Jonathan Hseu
parent fe7bbd8565
commit 283f03c825

View File

@ -331,7 +331,7 @@ def _luong_score(query, keys, scale):
# batched matmul on: # batched matmul on:
# [batch_size, 1, depth] . [batch_size, depth, max_time] # [batch_size, 1, depth] . [batch_size, depth, max_time]
# resulting in an output shape of: # resulting in an output shape of:
# [batch_time, 1, max_time]. # [batch_size, 1, max_time].
# we then squeeze out the center singleton dimension. # we then squeeze out the center singleton dimension.
score = math_ops.matmul(query, keys, transpose_b=True) score = math_ops.matmul(query, keys, transpose_b=True)
score = array_ops.squeeze(score, [1]) score = array_ops.squeeze(score, [1])