Update shape checking logic in einsum (#7387)

* Update shape checking logic in einsum

* Fix typo

* Make modifications on einsum be more structured and simpler

* Remove unnecessary parts

* Fix indentation
This commit is contained in:
Jihun Choi 2017-02-18 02:39:07 +09:00 committed by Vijay Vasudevan
parent 0be5c4a395
commit 807f449bf5
2 changed files with 39 additions and 29 deletions

View File

@ -318,44 +318,28 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
# into a single axis, and combine multiple summed axes into a
# single axis.
t0_shape = tuple(x.value for x in t0.get_shape())
t0_shape = _get_shape(t0)
num_broadcast_elements_t0 = _total_size(
t0_shape[len(preserved_axes):-len(axes_to_sum)])
num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
new_shape = t0_shape[:len(preserved_axes)] + (num_broadcast_elements_t0,
num_summed_elements)
new_shape = (t0_shape[:len(preserved_axes)]
+ [num_broadcast_elements_t0, num_summed_elements])
t0 = _reshape_if_necessary(t0, new_shape)
t1_shape = tuple(x.value for x in t1.get_shape())
t1_shape = _get_shape(t1)
num_broadcast_elements_t1 = _total_size(
t1_shape[len(preserved_axes)+len(axes_to_sum):])
new_shape = t1_shape[:len(preserved_axes)] + (num_summed_elements,
num_broadcast_elements_t1)
new_shape = (t1_shape[:len(preserved_axes)]
+ [num_summed_elements, num_broadcast_elements_t1])
t1 = _reshape_if_necessary(t1, new_shape)
product = math_ops.matmul(t0, t1)
# Undo compaction of broadcast axes
uncompacted_shape = (
t0_shape[:len(preserved_axes)+len(broadcast_axes[0])] +
t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
t0_shape[:len(preserved_axes)+len(broadcast_axes[0])]
+ t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
)
# Check the number of None values and replace them with Tensors containing
# corresponding dimensions if there exist two or more None values
num_none_dims = sum(1 for d in uncompacted_shape if d is None)
if num_none_dims > 1:
uncompacted_shape = list(uncompacted_shape)
for i in xrange(len(uncompacted_shape)):
if uncompacted_shape[i] is None:
if i < len(preserved_axes) + len(broadcast_axes[0]):
uncompacted_shape[i] = array_ops.shape(inputs[0])[i]
else:
idx = (i - len(preserved_axes) - len(broadcast_axes[0])
+ len(t1_shape) - len(broadcast_axes[1]))
uncompacted_shape[i] = array_ops.shape(inputs[1])[idx]
uncompacted_shape = tuple(uncompacted_shape)
product = _reshape_if_necessary(product, uncompacted_shape)
product_axes = (
@ -386,13 +370,27 @@ def _reshape_if_necessary(tensor, new_shape):
return array_ops.reshape(tensor, new_shape)
def _get_shape(tensor):
"""Like get_shape().as_list(), but explicitly queries the shape of a tensor
if necessary to ensure that the returned value contains no unknown value."""
shape = tensor.get_shape().as_list()
none_indices = [i for i, d in enumerate(shape) if d is None]
if none_indices:
# Query the shape if shape contains None values
shape_tensor = array_ops.shape(tensor)
for i in none_indices:
shape[i] = shape_tensor[i]
return shape
def _total_size(shape_values):
"""Given list of tensor shape values, returns total size or -1 if unknown."""
"""Given list of tensor shape values, returns total size.
If shape_values contains tensor values (which are results of
array_ops.shape), then it returns a scalar tensor.
If not, it returns an integer."""
result = 1
for val in shape_values:
if val is None:
return -1
assert isinstance(val, int)
result *= val
return result

View File

@ -318,7 +318,19 @@ class EinsumTest(test.TestCase):
m1: [3, 2],
}
np.testing.assert_almost_equal(
[[7]], sess.run(out, feed_dict=feed_dict))
[[7]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2))
m1 = array_ops.placeholder(dtypes.int32, shape=(None, 2))
out = special_math_ops.einsum('ijkl,ij->ikl', m0, m1)
with session.Session() as sess:
feed_dict = {
m0: [[[[1, 2]], [[2, 1]]]],
m1: [[3, 2]],
}
np.testing.assert_almost_equal(
[[[7, 8]]], sess.run(out, feed_dict=feed_dict))
if __name__ == '__main__':