Update shape checking logic in einsum (#7387)
* Update shape checking logic in einsum * Fix typo * Make modifications on einsum be more structured and simpler * Remove unnecessary parts * Fix indentation
This commit is contained in:
parent
0be5c4a395
commit
807f449bf5
@ -318,44 +318,28 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
|
||||
# into a single axis, and combine multiple summed axes into a
|
||||
# single axis.
|
||||
|
||||
t0_shape = tuple(x.value for x in t0.get_shape())
|
||||
t0_shape = _get_shape(t0)
|
||||
num_broadcast_elements_t0 = _total_size(
|
||||
t0_shape[len(preserved_axes):-len(axes_to_sum)])
|
||||
num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
|
||||
new_shape = t0_shape[:len(preserved_axes)] + (num_broadcast_elements_t0,
|
||||
num_summed_elements)
|
||||
new_shape = (t0_shape[:len(preserved_axes)]
|
||||
+ [num_broadcast_elements_t0, num_summed_elements])
|
||||
t0 = _reshape_if_necessary(t0, new_shape)
|
||||
|
||||
t1_shape = tuple(x.value for x in t1.get_shape())
|
||||
t1_shape = _get_shape(t1)
|
||||
num_broadcast_elements_t1 = _total_size(
|
||||
t1_shape[len(preserved_axes)+len(axes_to_sum):])
|
||||
new_shape = t1_shape[:len(preserved_axes)] + (num_summed_elements,
|
||||
num_broadcast_elements_t1)
|
||||
new_shape = (t1_shape[:len(preserved_axes)]
|
||||
+ [num_summed_elements, num_broadcast_elements_t1])
|
||||
t1 = _reshape_if_necessary(t1, new_shape)
|
||||
|
||||
product = math_ops.matmul(t0, t1)
|
||||
|
||||
# Undo compaction of broadcast axes
|
||||
uncompacted_shape = (
|
||||
t0_shape[:len(preserved_axes)+len(broadcast_axes[0])] +
|
||||
t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
|
||||
t0_shape[:len(preserved_axes)+len(broadcast_axes[0])]
|
||||
+ t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
|
||||
)
|
||||
|
||||
# Check the number of None values and replace them with Tensors containing
|
||||
# corresponding dimensions if there exist two or more None values
|
||||
num_none_dims = sum(1 for d in uncompacted_shape if d is None)
|
||||
if num_none_dims > 1:
|
||||
uncompacted_shape = list(uncompacted_shape)
|
||||
for i in xrange(len(uncompacted_shape)):
|
||||
if uncompacted_shape[i] is None:
|
||||
if i < len(preserved_axes) + len(broadcast_axes[0]):
|
||||
uncompacted_shape[i] = array_ops.shape(inputs[0])[i]
|
||||
else:
|
||||
idx = (i - len(preserved_axes) - len(broadcast_axes[0])
|
||||
+ len(t1_shape) - len(broadcast_axes[1]))
|
||||
uncompacted_shape[i] = array_ops.shape(inputs[1])[idx]
|
||||
uncompacted_shape = tuple(uncompacted_shape)
|
||||
|
||||
product = _reshape_if_necessary(product, uncompacted_shape)
|
||||
|
||||
product_axes = (
|
||||
@ -386,13 +370,27 @@ def _reshape_if_necessary(tensor, new_shape):
|
||||
return array_ops.reshape(tensor, new_shape)
|
||||
|
||||
|
||||
def _get_shape(tensor):
|
||||
"""Like get_shape().as_list(), but explicitly queries the shape of a tensor
|
||||
if necessary to ensure that the returned value contains no unknown value."""
|
||||
|
||||
shape = tensor.get_shape().as_list()
|
||||
none_indices = [i for i, d in enumerate(shape) if d is None]
|
||||
if none_indices:
|
||||
# Query the shape if shape contains None values
|
||||
shape_tensor = array_ops.shape(tensor)
|
||||
for i in none_indices:
|
||||
shape[i] = shape_tensor[i]
|
||||
return shape
|
||||
|
||||
def _total_size(shape_values):
|
||||
"""Given list of tensor shape values, returns total size or -1 if unknown."""
|
||||
"""Given list of tensor shape values, returns total size.
|
||||
If shape_values contains tensor values (which are results of
|
||||
array_ops.shape), then it returns a scalar tensor.
|
||||
If not, it returns an integer."""
|
||||
|
||||
result = 1
|
||||
for val in shape_values:
|
||||
if val is None:
|
||||
return -1
|
||||
assert isinstance(val, int)
|
||||
result *= val
|
||||
return result
|
||||
|
||||
|
@ -318,7 +318,19 @@ class EinsumTest(test.TestCase):
|
||||
m1: [3, 2],
|
||||
}
|
||||
np.testing.assert_almost_equal(
|
||||
[[7]], sess.run(out, feed_dict=feed_dict))
|
||||
[[7]], sess.run(out, feed_dict=feed_dict))
|
||||
|
||||
with ops.Graph().as_default():
|
||||
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2))
|
||||
m1 = array_ops.placeholder(dtypes.int32, shape=(None, 2))
|
||||
out = special_math_ops.einsum('ijkl,ij->ikl', m0, m1)
|
||||
with session.Session() as sess:
|
||||
feed_dict = {
|
||||
m0: [[[[1, 2]], [[2, 1]]]],
|
||||
m1: [[3, 2]],
|
||||
}
|
||||
np.testing.assert_almost_equal(
|
||||
[[[7, 8]]], sess.run(out, feed_dict=feed_dict))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user