diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 4e1a99acd68..9f61d3d47e6 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -33,14 +33,13 @@ static void AppendTo(const TensorShape& s, gtl::InlinedVector* vals) { } void TensorShape::CheckDimsEqual(int NDIMS) const { - CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << "dimensions" - << " from a tensor of " << dims() << " dimensions"; + CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS + << " for a tensor of " << dims() << " dimensions"; } void TensorShape::CheckDimsAtLeast(int NDIMS) const { CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS - << " dimensions from a tensor of " << dims() - << " dimensions"; + << " for a tensor of " << dims() << " dimensions"; } bool TensorShape::IsValid(const TensorShapeProto& proto) { diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 3dd8ffe9628..4332a0facca 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -273,10 +273,12 @@ class StridedSliceChecker(object): self.x_np = np.array(x) def __getitem__(self, spec): - op = self.x.__getitem__(spec) + # TODO(aselle): When NewSliceHelper is installed, we can switch this back + # op = self.x[spec] + op = array_ops._NewSliceHelper(self.x, spec) tensor = op.eval() - self.test.assertAllEqual(self.x_np.__getitem__(spec), tensor) + self.test.assertAllEqual(self.x_np[spec], tensor) self.test.assertAllEqual(tensor.shape, op.get_shape()) return tensor @@ -393,7 +395,9 @@ class StridedSliceShapeChecker(object): self.x = x def __getitem__(self, spec): - op = self.x.__getitem__(spec) + # TODO(aselle): When NewSliceHelper is installed, we can switch this back + # op = self.x[spec] + op = array_ops._NewSliceHelper(self.x, spec) return op.get_shape() @@ -447,28 +451,22 @@ class GradSliceChecker(object): self.varnp = varnp def __getitem__(self, spec): - slice_var = self.var[spec] - slice_val = self.val[spec] - - # compute analytic 2nd derivative - analytic_grad2 = 2 * slice_val - - dy = tf.Variable(tf.ones(shape=slice_var.get_shape(), dtype=tf.int32)) - assign = dy.assign(slice_var) - slice_val_grad, = tf.gradients(slice_val, self.var, grad_ys=dy) - slice_val_grad2, = tf.gradients(slice_val_grad, dy, grad_ys=self.var) - self.sess.run(assign) - slice_val_grad_evaled, slice_val_grad2_evaled = ( - self.sess.run([slice_val_grad, slice_val_grad2])) - analytic_grad2_evaled = analytic_grad2.eval() - self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled) - - # compute analytic gradient for slice - np_val_grad = (2 * self.varnp * self.varnp) + val_grad_op = tf.gradients(self.val, self.var) + sliceval_grad_op = tf.gradients( + array_ops._NewSliceHelper(self.val, spec), self.var) + slice1_op = array_ops._NewSliceHelper(val_grad_op, spec) + slice2_op = array_ops._NewSliceHelper(sliceval_grad_op, spec) + val_grad, sliceval_grad, slice1, slice2 = self.sess.run( + [val_grad_op, sliceval_grad_op, slice1_op, slice2_op]) + np_val_grad = (2 * self.varnp) np_sliceval_grad = np.zeros(self.var.get_shape()) - np_sliceval_grad[spec] = np_val_grad[spec] - # verify gradient - self.test.assertAllEqual(slice_val_grad_evaled, np_sliceval_grad) + np_sliceval_grad[spec] = np.array(val_grad[0])[spec] + # make sure np val grad is correct + self.test.assertAllEqual(np_val_grad, val_grad[0]) + # make sure slice gradient is correct + self.test.assertAllEqual(np_sliceval_grad, sliceval_grad[0]) + # make sure val grad and sliceval grad are the same in sliced area + self.test.assertAllEqual(slice1, slice2) class StridedSliceGradTest(test_util.TensorFlowTestCase): @@ -495,7 +493,7 @@ class BenchmarkSlice(object): self.tensor = tensor def __getitem__(self, x): - return self.tensor[x] + return array_ops._NewSliceHelper(self.tensor, x) class StridedSliceBenchmark(tf.test.Benchmark): diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 43a62aae1a9..b663c56e56a 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -151,7 +151,7 @@ def _SliceGrad(op, grad): @ops.RegisterGradient("StridedSlice") def _StridedSliceGrad(op, grad): - """Gradient for StridedSlice op.""" + """Gradient for unpack op.""" x = array_ops.shape(op.inputs[0]) begin = op.inputs[1] end = op.inputs[2] @@ -170,25 +170,6 @@ def _StridedSliceGrad(op, grad): shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None -@ops.RegisterGradient("StridedSliceGrad") -def _StridedSliceGradGrad(op, grad): - """Gradient for StridedSliceGrad op.""" - begin = op.inputs[1] - end = op.inputs[2] - strides = op.inputs[3] - - return None, None, None, None, array_ops.strided_slice( - grad, - begin, - end, - strides, - begin_mask=op.get_attr("begin_mask"), - end_mask=op.get_attr("end_mask"), - ellipsis_mask=op.get_attr("ellipsis_mask"), - new_axis_mask=op.get_attr("new_axis_mask"), - shrink_axis_mask=op.get_attr("shrink_axis_mask")) - - @ops.RegisterGradient("Split") def _SplitGrad(op, *grads): return None, array_ops.concat(op.inputs[0], list(grads)) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 66aa3b0679e..1435279f549 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -196,7 +196,7 @@ def zeros_initializer(shape, dtype=dtypes.float32): return zeros(shape, dtype) -def _SliceHelper(tensor, slice_spec): +def _NewSliceHelper(tensor, slice_spec): """Overload for Tensor.__getitem__. This operation extracts the specified region from the tensor. @@ -275,6 +275,73 @@ def _SliceHelper(tensor, slice_spec): # pylint: disable=undefined-variable,protected-access +def _SliceHelper(tensor, slice_spec): + """Overload for Tensor.__getitem__. + + Currently the size of the slice must be statically known in each dimension, + i.e. the "stop" of the slice must not be omitted. + + TODO(mrry): Support slices where the sizes are not specified. + TODO(mrry): Support negative indices in slices with numpy/Python semantics. + + Args: + tensor: An ops.Tensor object. + slice_spec: The arguments to Tensor.__getitem__. + + Returns: + The appropriate slice of "tensor", based on "slice_spec". + + Raises: + ValueError: If a slice range is negative size. + TypeError: If the slice indices aren't int, slice, or Ellipsis. + """ + if not isinstance(slice_spec, (list, tuple)): + slice_spec = [slice_spec] + indices = [] + sizes = [] + squeeze_dims = [] + for dim, s in enumerate(slice_spec): + if isinstance(s, _baseslice): + if s.step not in (None, 1): + raise NotImplementedError( + "Steps other than 1 are not currently supported") + start = s.start if s.start is not None else 0 + if start < 0: + raise NotImplementedError( + "Negative start indices are not currently supported") + indices.append(start) + if s.stop is not None and s.stop < 0: + raise NotImplementedError( + "Negative stop indices are not currently supported") + # NOTE(mrry): If the stop is not specified, Python substitutes + # sys.maxsize, which is typically (2 ** 63) - 1. Since Slice currently + # supports signed DT_INT32 arguments, we use -1 to specify that all + # elements should be captured. + if s.stop is None or s.stop == sys.maxsize: + sizes.append(-1) + else: + if start > s.stop: + raise ValueError("Stop must be at least start") + sizes.append(s.stop - start) + elif s is Ellipsis: + raise NotImplementedError("Ellipsis is not currently supported") + else: + try: + s = int(s) + except TypeError: + raise TypeError("Bad slice index %s of type %s" % (s, type(s))) + if s < 0: + raise NotImplementedError("Negative indices are currently unsupported") + indices.append(s) + sizes.append(1) + squeeze_dims.append(dim) + sliced = slice(tensor, indices, sizes) + if squeeze_dims: + return squeeze(sliced, squeeze_dims=squeeze_dims) + else: + return sliced + + def slice(input_, begin, size, name=None): """Extracts a slice from a tensor. @@ -423,6 +490,8 @@ def strided_slice(input_, new_axis_mask=new_axis_mask, shrink_axis_mask=shrink_axis_mask) +# TODO(aselle): When gradient is added and performance verified switch +# ops.Tensor._override_operator("__getitem__", _NewSliceHelper) ops.Tensor._override_operator("__getitem__", _SliceHelper) @@ -1526,9 +1595,8 @@ def _StridedSliceShape(op): sparse_dims = begin_shape.merge_with(end_shape).merge_with(strides_shape)[ 0].value - if (sparse_dims is None or begin_value is None or end_value is None or - strides_value is None): - return [tensor_shape.unknown_shape()] + if sparse_dims is None: + return [input_shape.unknown_shape()] begin_mask = op.get_attr("begin_mask") end_mask = op.get_attr("end_mask")