diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index bf4d1982091..5a8eb432d18 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -318,44 +318,28 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): # into a single axis, and combine multiple summed axes into a # single axis. - t0_shape = tuple(x.value for x in t0.get_shape()) + t0_shape = _get_shape(t0) num_broadcast_elements_t0 = _total_size( t0_shape[len(preserved_axes):-len(axes_to_sum)]) num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):]) - new_shape = t0_shape[:len(preserved_axes)] + (num_broadcast_elements_t0, - num_summed_elements) + new_shape = (t0_shape[:len(preserved_axes)] + + [num_broadcast_elements_t0, num_summed_elements]) t0 = _reshape_if_necessary(t0, new_shape) - t1_shape = tuple(x.value for x in t1.get_shape()) + t1_shape = _get_shape(t1) num_broadcast_elements_t1 = _total_size( t1_shape[len(preserved_axes)+len(axes_to_sum):]) - new_shape = t1_shape[:len(preserved_axes)] + (num_summed_elements, - num_broadcast_elements_t1) + new_shape = (t1_shape[:len(preserved_axes)] + + [num_summed_elements, num_broadcast_elements_t1]) t1 = _reshape_if_necessary(t1, new_shape) product = math_ops.matmul(t0, t1) # Undo compaction of broadcast axes uncompacted_shape = ( - t0_shape[:len(preserved_axes)+len(broadcast_axes[0])] + - t1_shape[len(t1_shape)-len(broadcast_axes[1]):] + t0_shape[:len(preserved_axes)+len(broadcast_axes[0])] + + t1_shape[len(t1_shape)-len(broadcast_axes[1]):] ) - - # Check the number of None values and replace them with Tensors containing - # corresponding dimensions if there exist two or more None values - num_none_dims = sum(1 for d in uncompacted_shape if d is None) - if num_none_dims > 1: - uncompacted_shape = list(uncompacted_shape) - for i in xrange(len(uncompacted_shape)): - if uncompacted_shape[i] is None: - if i < len(preserved_axes) + len(broadcast_axes[0]): - uncompacted_shape[i] = array_ops.shape(inputs[0])[i] - else: - idx = (i - len(preserved_axes) - len(broadcast_axes[0]) - + len(t1_shape) - len(broadcast_axes[1])) - uncompacted_shape[i] = array_ops.shape(inputs[1])[idx] - uncompacted_shape = tuple(uncompacted_shape) - product = _reshape_if_necessary(product, uncompacted_shape) product_axes = ( @@ -386,13 +370,27 @@ def _reshape_if_necessary(tensor, new_shape): return array_ops.reshape(tensor, new_shape) +def _get_shape(tensor): + """Like get_shape().as_list(), but explicitly queries the shape of a tensor + if necessary to ensure that the returned value contains no unknown value.""" + + shape = tensor.get_shape().as_list() + none_indices = [i for i, d in enumerate(shape) if d is None] + if none_indices: + # Query the shape if shape contains None values + shape_tensor = array_ops.shape(tensor) + for i in none_indices: + shape[i] = shape_tensor[i] + return shape + def _total_size(shape_values): - """Given list of tensor shape values, returns total size or -1 if unknown.""" + """Given list of tensor shape values, returns total size. + If shape_values contains tensor values (which are results of + array_ops.shape), then it returns a scalar tensor. + If not, it returns an integer.""" + result = 1 for val in shape_values: - if val is None: - return -1 - assert isinstance(val, int) result *= val return result diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py index 3d289bcc9a5..c792d322770 100644 --- a/tensorflow/python/ops/special_math_ops_test.py +++ b/tensorflow/python/ops/special_math_ops_test.py @@ -318,7 +318,19 @@ class EinsumTest(test.TestCase): m1: [3, 2], } np.testing.assert_almost_equal( - [[7]], sess.run(out, feed_dict=feed_dict)) + [[7]], sess.run(out, feed_dict=feed_dict)) + + with ops.Graph().as_default(): + m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2)) + m1 = array_ops.placeholder(dtypes.int32, shape=(None, 2)) + out = special_math_ops.einsum('ijkl,ij->ikl', m0, m1) + with session.Session() as sess: + feed_dict = { + m0: [[[[1, 2]], [[2, 1]]]], + m1: [[3, 2]], + } + np.testing.assert_almost_equal( + [[[7, 8]]], sess.run(out, feed_dict=feed_dict)) if __name__ == '__main__':