From 271f6bb49d2140b4c1bca88391caedd1791561cf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Jan 2020 15:45:56 -0800 Subject: [PATCH] Clean up OpKernel 1. 1. Clean up old code related to legacy scalars and vectors in TensorFlow. This CL does not bump the GrapDef version. Legacy behavior is allowed for a small number of frequently used kernels and a TODO to fix this in a future CL added. 2. Move OpKernel::MakeShape into tensor_util.{h,cc}. Ideally it should go in TensorShapeUtils, but that would create a cyclical dependence. The oldest branch of TensorFlow on github is r0.7, which is at GraphDef version 8 already, so this should have no impact on GraphDefs in the wild. We also changed the scalar test to be strict in open source on 2017-04-03. PiperOrigin-RevId: 292234389 Change-Id: Iee372607d9b9139d33ba7be5be8b792d9471e0f6 --- tensorflow/core/framework/op_kernel.cc | 17 ------ tensorflow/core/framework/op_kernel.h | 25 --------- tensorflow/core/framework/tensor_util.cc | 18 ++++++ tensorflow/core/framework/tensor_util.h | 4 ++ tensorflow/core/kernels/broadcast_to_op.cc | 7 ++- tensorflow/core/kernels/concat_op.cc | 27 ++++----- tensorflow/core/kernels/constant_op.cc | 22 ++++++-- tensorflow/core/kernels/conv_grad_ops_3d.cc | 15 +++-- .../kernels/conv_ops_fused_image_transform.cc | 8 +-- tensorflow/core/kernels/logging_ops.cc | 2 +- tensorflow/core/kernels/mkl_concat_op.cc | 12 ++-- .../core/kernels/mkl_conv_grad_input_ops.cc | 4 +- tensorflow/core/kernels/mkl_reshape_op.cc | 2 +- tensorflow/core/kernels/mkl_slice_op.cc | 9 +-- tensorflow/core/kernels/pad_op.cc | 20 +++---- .../core/kernels/quantized_concat_op.cc | 10 ++-- tensorflow/core/kernels/random_op.cc | 3 +- tensorflow/core/kernels/random_poisson_op.cc | 3 +- tensorflow/core/kernels/reshape_op.h | 10 +++- tensorflow/core/kernels/save_op.cc | 4 +- tensorflow/core/kernels/scatter_op.cc | 2 +- .../core/kernels/segment_reduction_ops_impl.h | 3 +- .../kernels/segment_reduction_ops_impl_1.cc | 2 +- tensorflow/core/kernels/sequence_ops.cc | 16 +++++- tensorflow/core/kernels/slice_op.cc | 4 +- tensorflow/core/kernels/sparse_to_dense_op.cc | 2 +- .../core/kernels/stateful_random_ops.cc | 5 +- .../core/kernels/stateless_random_ops.cc | 3 +- tensorflow/core/kernels/summary_audio_op.cc | 2 +- tensorflow/core/kernels/summary_image_op.cc | 2 +- tensorflow/core/kernels/summary_op.cc | 6 +- tensorflow/core/kernels/summary_tensor_op.cc | 2 +- tensorflow/core/kernels/tensor_array.cc | 1 + tensorflow/core/kernels/tensor_array_ops.cc | 4 +- tensorflow/core/kernels/tile_ops.cc | 4 +- tensorflow/core/kernels/training_ops.cc | 24 ++++---- tensorflow/python/kernel_tests/scalar_test.py | 56 +++++++++++++------ 37 files changed, 191 insertions(+), 169 deletions(-) diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 9426c75b882..1fe8c19608d 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -162,23 +162,6 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start, } } -Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const { - if (!IsLegacyVector(shape.shape())) { - return errors::InvalidArgument( - "shape must be a vector of {int32,int64}, got shape ", - shape.shape().DebugString()); - } - if (shape.dtype() == DataType::DT_INT32) { - auto vec = shape.flat(); - return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); - } else if (shape.dtype() == DataType::DT_INT64) { - auto vec = shape.flat(); - return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); - } else { - return errors::InvalidArgument("shape must be a vector of {int32,int64}."); - } -} - string OpKernel::TraceString(OpKernelContext* ctx, bool verbose) { string trace_string = strings::StrCat(name_view(), ":", type_string_view()); if (!verbose) return trace_string; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index bec65ada4ca..594a3c5142b 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -195,31 +195,6 @@ class OpKernel { Status InputRange(StringPiece input_name, int* start, int* stop) const; Status OutputRange(StringPiece output_name, int* start, int* stop) const; - // We allow legacy scalars within Google up until GraphDef version 6. - // TODO(irving): Remove when we can drop support for GraphDef version 5. - bool allow_legacy_scalars() const { -#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) - return graph_def_version_ < 6; -#else - return false; -#endif - } - - // Allow either scalars or (if allowing legacy scalars) shape (1,). - bool IsLegacyScalar(const TensorShape& shape) const { - return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 && - shape.dim_size(0) == 1); - } - - // Allow rank 1 or (if allowing legacy scalars) rank 0. - bool IsLegacyVector(const TensorShape& shape) const { - return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0); - } - - // Turn a shape Tensor into a TensorShape - // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars - Status MakeShape(const Tensor& shape, TensorShape* out) const; - // Returns `true` if and only if this kernel uses deferred execution. bool is_deferred() const { return is_deferred_; } diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc index 896d83ffa2c..e6b2bd50b8a 100644 --- a/tensorflow/core/framework/tensor_util.cc +++ b/tensorflow/core/framework/tensor_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -366,5 +367,22 @@ bool CompressTensorProtoInPlace(int64 min_num_elements, #undef HANDLE_COMPRESS_CASE +Status MakeShape(const Tensor& shape, TensorShape* out) { + if (!TensorShapeUtils::IsVector(shape.shape())) { + return errors::InvalidArgument( + "shape must be a vector of {int32,int64}, got shape ", + shape.shape().DebugString()); + } + if (shape.dtype() == DataType::DT_INT32) { + auto vec = shape.flat(); + return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); + } else if (shape.dtype() == DataType::DT_INT64) { + auto vec = shape.flat(); + return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); + } else { + return errors::InvalidArgument("shape must be a vector of {int32,int64}."); + } +} + } // namespace tensor } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index fb8216d2859..50ecbb1ecd3 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -325,6 +325,10 @@ inline bool CompressTensorProtoInPlace(TensorProto* tensor) { kDefaultMinCompressionRatio, tensor); } +// Make a TensorShape from the contents of shape_t. Shape_t must be a +// 1-dimensional tensor of type int32 or int64. +Status MakeShape(const Tensor& shape_t, TensorShape* out); + } // namespace tensor } // namespace tensorflow diff --git a/tensorflow/core/kernels/broadcast_to_op.cc b/tensorflow/core/kernels/broadcast_to_op.cc index 51caca50ebd..a3844b8b769 100644 --- a/tensorflow/core/kernels/broadcast_to_op.cc +++ b/tensorflow/core/kernels/broadcast_to_op.cc @@ -21,10 +21,12 @@ limitations under the License. #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/broadcast_to_op.h" + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" @@ -45,8 +47,7 @@ class BroadcastToOp : public OpKernel { const Tensor& shape_tensor = ctx->input(1); TensorShape output_shape; - OP_REQUIRES_OK(ctx, - ctx->op_kernel().MakeShape(shape_tensor, &output_shape)); + OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &output_shape)); // Handle copy. if (output_shape == input_shape) { @@ -91,7 +92,7 @@ class BroadcastToOp : public OpKernel { } }; -// As MakeShape is able to handle both DT_INT32 and DT_INT64, +// As tensor::MakeShape is able to handle both DT_INT32 and DT_INT64, // no need to have TypeConstraint for `Tidx` #define REGISTER_KERNEL(type) \ REGISTER_KERNEL_BUILDER( \ diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index 350f5e71725..9d7f37d2be6 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -53,11 +53,15 @@ class ConcatBaseOp : public OpKernel { void Compute(OpKernelContext* c) override { const Tensor* concat_dim_tensor; const char* axis_attribute_name = - AxisArgName == NAME_IS_AXIS ? "axis" : AxisArgName == NAME_IS_CONCAT_DIM - ? "concat_dim" - : ""; + AxisArgName == NAME_IS_AXIS + ? "axis" + : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : ""; OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); - OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), + // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars. + OP_REQUIRES(c, + (TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) || + (TensorShapeUtils::IsVector(concat_dim_tensor->shape()) && + concat_dim_tensor->shape().dim_size(0) == 1)), errors::InvalidArgument( axis_attribute_name, " tensor should be a scalar integer, but got shape ", @@ -93,9 +97,8 @@ class ConcatBaseOp : public OpKernel { const TensorShape& input_shape = values[0].shape(); int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(c, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + // concat_dim==0 allows concatenating a list of scalars into a vector. + OP_REQUIRES(c, (0 <= axis && axis < input_dims) || concat_dim == 0, errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range " "[", @@ -112,12 +115,10 @@ class ConcatBaseOp : public OpKernel { inputs_flat_dim0 *= input_shape.dim_size(d); } int64 output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { const auto& in = values[i]; - const bool in_is_scalar = IsLegacyScalar(in.shape()); OP_REQUIRES( - c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), + c, in.dims() == input_dims, errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", input_shape.DebugString(), " vs. shape[", i, @@ -138,12 +139,12 @@ class ConcatBaseOp : public OpKernel { inputs_flat.emplace_back(new typename TTypes::ConstMatrix( in.shaped({inputs_flat_dim0, inputs_flat_dim1}))); } - // TODO(irving): Remove check once !allow_legacy_scalars(). + // TODO(rmlarsen): Remove check once !allow_legacy_scalars()? output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; } TensorShape output_shape(input_shape); - // TODO(irving): Remove rank 0 case once !allow_legacy_scalars(). + // TODO(rmlarsen): Remove rank 0 case once !allow_legacy_scalars()? if (output_shape.dims() == 0) { output_shape.AddDim(output_concat_dim); } else { @@ -282,7 +283,7 @@ class ConcatOffsetOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& concat_dim = ctx->input(0); OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim.shape()), + ctx, TensorShapeUtils::IsScalar(concat_dim.shape()), errors::InvalidArgument( "Concat dim tensor should be a scalar integer, but got shape ", concat_dim.shape().DebugString())); diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 5c62e51fdbb..5931599c6e2 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -159,13 +159,23 @@ class FillOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& Tdims = context->input(0); - OP_REQUIRES(context, IsLegacyVector(Tdims.shape()), - errors::InvalidArgument("dims must be a vector, got shape ", - Tdims.shape().DebugString())); + OP_REQUIRES( + context, + // TODO(rmlarsen): Disallow legacy use of scalars to represent shape. + (TensorShapeUtils::IsVector(Tdims.shape()) || + TensorShapeUtils::IsScalar(Tdims.shape())), + errors::InvalidArgument("dims must represent a vector, got shape ", + Tdims.shape().DebugString())); const Tensor& Tvalue = context->input(1); - OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()), - errors::InvalidArgument("value must be a scalar, got shape ", - Tvalue.shape().DebugString())); + OP_REQUIRES( + context, + // TODO(rmlarsen): Disallow legacy use of length-1 vector to represent + // scalar. + TensorShapeUtils::IsScalar(Tvalue.shape()) || + (TensorShapeUtils::IsVector(Tvalue.shape()) && + Tvalue.shape().dim_size(0) == 1), + errors::InvalidArgument("value must represent a scalar, got shape ", + Tvalue.shape().DebugString())); auto dims = Tdims.flat(); TensorShape shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index f4d447fbd0e..0314da7c4cc 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/conv_3d.h" #include "tensorflow/core/kernels/conv_grad_ops.h" @@ -230,8 +231,9 @@ class Conv3DBackpropInputOp : public OpKernel { TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); - // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes. - OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); + // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for + // input_sizes. + OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape)); } else { input_shape = context->input(0).shape(); } @@ -336,8 +338,9 @@ class Conv3DCustomBackpropInputOp : public OpKernel { TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); - // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes. - OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); + // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for + // input_sizes. + OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape)); } else { input_shape = context->input(0).shape(); } @@ -1153,7 +1156,7 @@ class Conv3DBackpropInputOp : public OpKernel { TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); - OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); + OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape)); } else { input_shape = context->input(0).shape(); } @@ -1621,7 +1624,7 @@ class Conv3DBackpropFilterOp : public OpKernel { TensorShape filter_shape; if (takes_shape_) { const Tensor& filter_sizes = context->input(1); - OP_REQUIRES_OK(context, MakeShape(filter_sizes, &filter_shape)); + OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape)); } else { filter_shape = context->input(1).shape(); } diff --git a/tensorflow/core/kernels/conv_ops_fused_image_transform.cc b/tensorflow/core/kernels/conv_ops_fused_image_transform.cc index c1c3b555d64..21c151d3b67 100644 --- a/tensorflow/core/kernels/conv_ops_fused_image_transform.cc +++ b/tensorflow/core/kernels/conv_ops_fused_image_transform.cc @@ -685,15 +685,11 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { paddings.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", paddings.shape().DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - context, fixed_dims == paddings.dim_size(0), + context, dims == paddings.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs: ", - fixed_dims, " ", paddings.shape().DebugString(), " ", + dims, " ", paddings.shape().DebugString(), " ", resized_shape.DebugString())); OP_REQUIRES( context, dims == paddings.dim_size(0), diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index ead59fb0309..c01bd0aa93b 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -56,7 +56,7 @@ AssertOp::AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) { void AssertOp::Compute(OpKernelContext* ctx) { const Tensor& cond = ctx->input(0); - OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cond.shape()), errors::InvalidArgument("In[0] should be a scalar: ", cond.shape().DebugString())); diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 435d158422f..8470a7e2728 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -139,7 +139,7 @@ class EigenConcatBaseOp : public OpKernel { ? "axis" : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : ""; OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); - OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()), + OP_REQUIRES(c, TensorShapeUtils::IsScalar(concat_dim_tensor->shape()), errors::InvalidArgument( axis_attribute_name, " tensor should be a scalar integer, but got shape ", @@ -153,9 +153,7 @@ class EigenConcatBaseOp : public OpKernel { int32 axis = (concat_dim < 0) ? (concat_dim + input_dims) : concat_dim; OP_REQUIRES( - c, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + c, (0 <= axis && axis < input_dims), errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range [", -input_dims, ", ", input_dims, "), but got ", concat_dim)); @@ -180,10 +178,10 @@ class EigenConcatBaseOp : public OpKernel { inputs_flat_dim0 *= input_shape.dim_size(d); } int64 output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); + const bool input_is_scalar = TensorShapeUtils::IsScalar(input_shape); for (int i = 0; i < N; ++i) { const auto in = values[i]; - const bool in_is_scalar = IsLegacyScalar(input_shapes[i]); + const bool in_is_scalar = TensorShapeUtils::IsScalar(input_shapes[i]); OP_REQUIRES( c, (input_shapes[i].dims() == input_dims) || @@ -471,7 +469,7 @@ class MklConcatOp : public OpKernel { : MklGetInput(context, N); // Sanity checks OP_REQUIRES( - context, IsLegacyScalar(concat_dim_tensor.shape()), + context, TensorShapeUtils::IsScalar(concat_dim_tensor.shape()), errors::InvalidArgument( "Concat dim tensor should be a scalar integer, but got shape ", concat_dim_tensor.shape().DebugString())); diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 943f4989f54..a262a409858 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -521,8 +521,8 @@ class MklConvCustomBackpropInputOp TensorShape input_tf_shape; CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true); // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64 - // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for - // input_tensor. + // output_shape tensor::MakeShape is able to handle both DT_INT32 and + // DT_INT64 for input_tensor. CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true); return input_tf_shape; } diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index d2bc2394f9b..3c95a37ecfd 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -74,7 +74,7 @@ class MklReshapeOp : public OpKernel { : input_tensor.NumElements(); // Preliminary validation of sizes. - OP_REQUIRES(context, IsLegacyVector(sizes.shape()), + OP_REQUIRES(context, TensorShapeUtils::IsVector(sizes.shape()), errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes.shape().DebugString())); diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc index 35b8d87ec38..a4b1009af76 100644 --- a/tensorflow/core/kernels/mkl_slice_op.cc +++ b/tensorflow/core/kernels/mkl_slice_op.cc @@ -86,10 +86,11 @@ static void ValidateMklInputs(OpKernelContext* context, bool* is_identity, const int input_dims = input_tf_shape.dims(); OP_REQUIRES( - context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) && - context->op_kernel().IsLegacyVector(size_tensor.shape()) && - begin_tensor.NumElements() == input_dims && - size_tensor.NumElements() == input_dims, + context, + TensorShapeUtils::IsVector(begin_tensor.shape()) && + TensorShapeUtils::IsVector(size_tensor.shape()) && + begin_tensor.NumElements() == input_dims && + size_tensor.NumElements() == input_dims, errors::InvalidArgument( "Expected begin and size arguments to be 1-D tensors of size ", input_dims, ", but got shapes ", begin_tensor.shape().DebugString(), diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc index a9d8e591e14..0b404238a14 100644 --- a/tensorflow/core/kernels/pad_op.cc +++ b/tensorflow/core/kernels/pad_op.cc @@ -61,11 +61,8 @@ class PadOp : public OpKernel { TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", in1.shape().DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1 - : dims; OP_REQUIRES( - context, fixed_dims == in1.dim_size(0), + context, dims == in1.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", in1.shape().DebugString(), " ", in0.shape().DebugString())); @@ -83,15 +80,14 @@ class PadOp : public OpKernel { // Compute the shape of the output tensor, and allocate it. TensorShape output_shape; typename TTypes::ConstMatrix paddings = in1.matrix(); - for (int d = 0; d < fixed_dims; ++d) { + for (int d = 0; d < dims; ++d) { const Tpadding before_d = paddings(d, 0); // Pad before existing elements. const Tpadding after_d = paddings(d, 1); // Pad after existing elements. OP_REQUIRES(context, before_d >= 0 && after_d >= 0, errors::InvalidArgument("Paddings must be non-negative: ", before_d, " ", after_d)); - const int64 size_d = - (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d); + const int64 size_d = in0.dim_size(d); output_shape.AddDim(before_d + size_d + after_d); } @@ -107,10 +103,9 @@ class PadOp : public OpKernel { TensorShape collapsed_input_shape; TensorShape collapsed_output_shape; Tensor collapsed_paddings; - if (fixed_dims > 1 && - CollapseAdjacentNonPaddedDimensions( - in0.shape(), in1, output_shape, &collapsed_input_shape, - &collapsed_paddings, &collapsed_output_shape)) { + if (dims > 1 && CollapseAdjacentNonPaddedDimensions( + in0.shape(), in1, output_shape, &collapsed_input_shape, + &collapsed_paddings, &collapsed_output_shape)) { Tensor collapsed_input; CHECK(collapsed_input.CopyFrom(in0, collapsed_input_shape)); Tensor collapsed_output; @@ -135,8 +130,7 @@ class PadOp : public OpKernel { Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - OperateWithVariableRank(context, fixed_dims, in0, paddings, pad_value, - output); + OperateWithVariableRank(context, dims, in0, paddings, pad_value, output); } } diff --git a/tensorflow/core/kernels/quantized_concat_op.cc b/tensorflow/core/kernels/quantized_concat_op.cc index ff4e7be1622..965da273213 100644 --- a/tensorflow/core/kernels/quantized_concat_op.cc +++ b/tensorflow/core/kernels/quantized_concat_op.cc @@ -127,10 +127,10 @@ class QuantizedConcatOp : public OpKernel { // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). inputs_flat->reserve(N); *output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); + const bool input_is_scalar = TensorShapeUtils::IsScalar(input_shape); for (int i = 0; i < N; ++i) { const auto in = values[i]; - const bool in_is_scalar = IsLegacyScalar(in.shape()); + const bool in_is_scalar = TensorShapeUtils::IsScalar(in.shape()); OP_REQUIRES( context, in.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( @@ -161,7 +161,7 @@ class QuantizedConcatOp : public OpKernel { const Tensor* concat_dim_tensor = nullptr; OP_REQUIRES_OK(context, context->input("concat_dim", &concat_dim_tensor)); OP_REQUIRES( - context, IsLegacyScalar(concat_dim_tensor->shape()), + context, TensorShapeUtils::IsScalar(concat_dim_tensor->shape()), errors::InvalidArgument( "Concat dim tensor should be a scalar integer, but got shape ", concat_dim_tensor->shape().DebugString())); @@ -184,9 +184,7 @@ class QuantizedConcatOp : public OpKernel { const int input_dims = values[0].dims(); const TensorShape& input_shape = values[0].shape(); OP_REQUIRES( - context, - (0 <= concat_dim && concat_dim < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + context, (0 <= concat_dim && concat_dim < input_dims), errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range [", 0, ", ", input_dims, "), but got ", concat_dim)); diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 2fa93fb529c..2fe3a15a3cf 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/random_op_cpu.h" #include "tensorflow/core/lib/hash/crc32c.h" #include "tensorflow/core/lib/random/random_distributions.h" @@ -56,7 +57,7 @@ namespace { static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, int index, Tensor** output) { TensorShape tensor_shape; - TF_RETURN_IF_ERROR(ctx->op_kernel().MakeShape(shape, &tensor_shape)); + TF_RETURN_IF_ERROR(tensor::MakeShape(shape, &tensor_shape)); return ctx->allocate_output(index, tensor_shape, output); } diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc index 64fb4a5c228..7069f896f07 100644 --- a/tensorflow/core/kernels/random_poisson_op.cc +++ b/tensorflow/core/kernels/random_poisson_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/util/guarded_philox_random.h" @@ -290,7 +291,7 @@ class RandomPoissonOp : public OpKernel { const Tensor& rate_t = ctx->input(1); TensorShape samples_shape; - OP_REQUIRES_OK(ctx, MakeShape(shape_t, &samples_shape)); + OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &samples_shape)); const int64 num_samples = samples_shape.num_elements(); samples_shape.AppendShape(rate_t.shape()); diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 47cd219d8cf..155f8dafc9c 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -36,9 +36,13 @@ class ReshapeOp : public OpKernel { const Tensor& input = context->input(0); const Tensor& sizes = context->input(1); // Preliminary validation of sizes. - OP_REQUIRES(context, IsLegacyVector(sizes.shape()), - errors::InvalidArgument("sizes input must be 1-D, not ", - sizes.shape().DebugString())); + OP_REQUIRES( + context, + (TensorShapeUtils::IsVector(sizes.shape()) || + // TODO(rmlarsen): Disallow legacy use of scalars to represent shape. + TensorShapeUtils::IsScalar(sizes.shape())), + errors::InvalidArgument("sizes input must be 1-D, not ", + sizes.shape().DebugString())); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one. diff --git a/tensorflow/core/kernels/save_op.cc b/tensorflow/core/kernels/save_op.cc index f53976cae28..0f6da91abd6 100644 --- a/tensorflow/core/kernels/save_op.cc +++ b/tensorflow/core/kernels/save_op.cc @@ -55,7 +55,7 @@ class ShardedFilenameOp : public OpKernel { void Compute(OpKernelContext* ctx) override { static const char* input_names[3] = {"basename", "shard", "num_shards"}; for (int i = 0; i < ctx->num_inputs(); ++i) { - OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(i).shape()), errors::InvalidArgument(input_names[i], " must be a scalar, got shape ", ctx->input(i).shape().DebugString())); @@ -78,7 +78,7 @@ class ShardedFilespecOp : public OpKernel { void Compute(OpKernelContext* ctx) override { static const char* input_names[2] = {"basename", "num_shards"}; for (int i = 0; i < ctx->num_inputs(); ++i) { - OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(i).shape()), errors::InvalidArgument(input_names[i], " must be a scalar, got shape ", ctx->input(i).shape().DebugString())); diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 81deaad5c95..6eae1b7e217 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -125,7 +125,7 @@ class ScatterUpdateOp : public OpKernel { auto params_flat = params.flat_outer_dims(); if (TensorShapeUtils::IsScalar(updates.shape()) || - IsLegacyScalar(updates.shape())) { + TensorShapeUtils::IsScalar(updates.shape())) { const auto update = updates.scalar(); functor::ScatterScalarFunctor functor; const Index bad_i = functor(c, c->template eigen_device(), diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h index a472655d3e0..ba75150c517 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/segment_reduction_ops.h" #include "tensorflow/core/lib/core/status.h" @@ -780,7 +781,7 @@ class SparseSegmentGradOpBase : public OpKernel { errors::InvalidArgument("indices should be a vector.")); OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), errors::InvalidArgument("segment_ids should be a vector.")); - OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()), + OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()), errors::InvalidArgument("output_dim0 should be a scalar.")); const int64 N = indices.NumElements(); diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc index 494983bff78..ae71ac31f2c 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc @@ -44,7 +44,7 @@ void UnsortedSegmentReductionValidation(OpKernel* op_kernel, const Tensor& segment_ids, const Tensor& num_segments) { OP_REQUIRES( - context, op_kernel->IsLegacyScalar(num_segments.shape()), + context, TensorShapeUtils::IsScalar(num_segments.shape()), errors::InvalidArgument("num_segments should be a scalar, not shape ", num_segments.shape().DebugString())); OP_REQUIRES( diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index 02dcc1e4dec..7ce2016a2f7 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -36,13 +36,23 @@ class RangeOp : public OpKernel { const Tensor& start_in = context->input(0); const Tensor& limit_in = context->input(1); const Tensor& delta_in = context->input(2); - OP_REQUIRES(context, IsLegacyScalar(start_in.shape()), + // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars. + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(start_in.shape()) || + (TensorShapeUtils::IsVector(start_in.shape()) && + start_in.shape().dim_size(0) == 1), errors::InvalidArgument("start must be a scalar, not shape ", start_in.shape().DebugString())); - OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()), + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(limit_in.shape()) || + (TensorShapeUtils::IsVector(limit_in.shape()) && + limit_in.shape().dim_size(0) == 1), errors::InvalidArgument("limit must be a scalar, not shape ", limit_in.shape().DebugString())); - OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()), + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(delta_in.shape()) || + (TensorShapeUtils::IsVector(delta_in.shape()) && + delta_in.shape().dim_size(0) == 1), errors::InvalidArgument("delta must be a scalar, not shape ", delta_in.shape().DebugString())); const T start = start_in.scalar()(); diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 15f7157db07..110440c28c8 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -73,8 +73,8 @@ static void SharedValidation(OpKernelContext* context, OP_REQUIRES( context, - context->op_kernel().IsLegacyVector(begin_tensor.shape()) && - context->op_kernel().IsLegacyVector(size_tensor.shape()) && + TensorShapeUtils::IsVector(begin_tensor.shape()) && + TensorShapeUtils::IsVector(size_tensor.shape()) && begin_tensor.NumElements() == input.dims() && size_tensor.NumElements() == input.dims(), errors::InvalidArgument( diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc index f0cddc88fbf..d9626052b0c 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op.cc @@ -62,7 +62,7 @@ class SparseToDense : public OpKernel { // output_shape const Tensor& output_shape = c->input(1); OP_REQUIRES( - c, IsLegacyVector(output_shape.shape()), + c, TensorShapeUtils::IsVector(output_shape.shape()), errors::InvalidArgument("output_shape should be a vector, got shape ", output_shape.shape().DebugString())); OP_REQUIRES(c, output_shape.NumElements() == num_dims, diff --git a/tensorflow/core/kernels/stateful_random_ops.cc b/tensorflow/core/kernels/stateful_random_ops.cc index cbbce249a66..041b28b734e 100644 --- a/tensorflow/core/kernels/stateful_random_ops.cc +++ b/tensorflow/core/kernels/stateful_random_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/random_op_cpu.h" #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" #include "tensorflow/core/kernels/training_op_helpers.h" @@ -113,7 +114,7 @@ void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist, using T = typename Distribution::ResultElementType; const Tensor& shape_t = ctx->input(shape_input_idx); TensorShape shape; - OP_REQUIRES_OK(ctx, ctx->op_kernel().MakeShape(shape_t, &shape)); + OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape)); Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); auto output_flat = output->flat(); @@ -265,7 +266,7 @@ class NonDeterministicIntsOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& shape_t = ctx->input(0); TensorShape shape; - OP_REQUIRES_OK(ctx, ctx->op_kernel().MakeShape(shape_t, &shape)); + OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape)); Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output)); if (shape.num_elements() == 0) return; diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index 3c4dd60433e..50efee57588 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/random_op.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/platform/logging.h" @@ -74,7 +75,7 @@ class StatelessRandomOpBase : public OpKernel { const Tensor& shape_t = context->input(0); const Tensor& seed_t = context->input(1); TensorShape shape; - OP_REQUIRES_OK(context, MakeShape(shape_t, &shape)); + OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape)); OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, errors::InvalidArgument("seed must have shape [2], not ", seed_t.shape().DebugString())); diff --git a/tensorflow/core/kernels/summary_audio_op.cc b/tensorflow/core/kernels/summary_audio_op.cc index 26be2680b4a..8de2f9248c5 100644 --- a/tensorflow/core/kernels/summary_audio_op.cc +++ b/tensorflow/core/kernels/summary_audio_op.cc @@ -39,7 +39,7 @@ class SummaryAudioOp : public OpKernel { void Compute(OpKernelContext* c) override { const Tensor& tag = c->input(0); const Tensor& tensor = c->input(1); - OP_REQUIRES(c, IsLegacyScalar(tag.shape()), + OP_REQUIRES(c, TensorShapeUtils::IsScalar(tag.shape()), errors::InvalidArgument("Tag must be a scalar")); OP_REQUIRES(c, tensor.dims() >= 2 && tensor.dims() <= 3, errors::InvalidArgument("Tensor must be 3-D or 2-D, got: ", diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc index 025e22c958d..7c91768c6af 100644 --- a/tensorflow/core/kernels/summary_image_op.cc +++ b/tensorflow/core/kernels/summary_image_op.cc @@ -52,7 +52,7 @@ class SummaryImageOp : public OpKernel { void Compute(OpKernelContext* c) override { const Tensor& tags = c->input(0); const Tensor& tensor = c->input(1); - OP_REQUIRES(c, IsLegacyScalar(tags.shape()), + OP_REQUIRES(c, TensorShapeUtils::IsScalar(tags.shape()), errors::InvalidArgument("Tags must be a scalar")); OP_REQUIRES(c, tensor.dims() == 4 && diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc index 07ebb5e0000..386a8964dba 100644 --- a/tensorflow/core/kernels/summary_op.cc +++ b/tensorflow/core/kernels/summary_op.cc @@ -42,8 +42,8 @@ class SummaryScalarOp : public OpKernel { OP_REQUIRES( c, - tags.IsSameSize(values) || - (IsLegacyScalar(tags.shape()) && IsLegacyScalar(values.shape())), + tags.IsSameSize(values) || (TensorShapeUtils::IsScalar(tags.shape()) && + TensorShapeUtils::IsScalar(values.shape())), errors::InvalidArgument( "tags and values not the same shape: ", tags.shape().DebugString(), " != ", values.shape().DebugString(), SingleTag(tags))); @@ -82,7 +82,7 @@ class SummaryHistoOp : public OpKernel { const Tensor& tags = c->input(0); const Tensor& values = c->input(1); const auto flat = values.flat(); - OP_REQUIRES(c, IsLegacyScalar(tags.shape()), + OP_REQUIRES(c, TensorShapeUtils::IsScalar(tags.shape()), errors::InvalidArgument("tags must be scalar")); // Build histogram of values in "values" tensor histogram::Histogram histo; diff --git a/tensorflow/core/kernels/summary_tensor_op.cc b/tensorflow/core/kernels/summary_tensor_op.cc index 9cbc812ffa9..4141c4238d3 100644 --- a/tensorflow/core/kernels/summary_tensor_op.cc +++ b/tensorflow/core/kernels/summary_tensor_op.cc @@ -32,7 +32,7 @@ class SummaryTensorOpV2 : public OpKernel { void Compute(OpKernelContext* c) override { const Tensor& tag = c->input(0); - OP_REQUIRES(c, IsLegacyScalar(tag.shape()), + OP_REQUIRES(c, TensorShapeUtils::IsScalar(tag.shape()), errors::InvalidArgument("tag must be scalar")); const Tensor& tensor = c->input(1); const Tensor& serialized_summary_metadata_tensor = c->input(2); diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc index 2bd6ac0b08d..69efc016a1f 100644 --- a/tensorflow/core/kernels/tensor_array.cc +++ b/tensorflow/core/kernels/tensor_array.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/kernels/aggregate_ops_cpu.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 62d03f9fb7f..ea8e04a33f4 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/concat_lib.h" #include "tensorflow/core/kernels/split_lib.h" @@ -331,8 +332,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { TensorShape shape_to_prepend; auto element_shape = PartialTensorShape(); if (ctx->num_inputs() > 2) { - TF_RETURN_IF_ERROR( - ctx->op_kernel().MakeShape(ctx->input(2), &shape_to_prepend)); + TF_RETURN_IF_ERROR(tensor::MakeShape(ctx->input(2), &shape_to_prepend)); auto ta_element_shape = tensor_array->ElemShape(); if (!ta_element_shape.unknown_rank()) { std::vector dims; diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index e1080acb700..cd047ed9d4a 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -187,7 +187,7 @@ class TileOp : public OpKernel { const Tensor& multiples = context->input(1); OP_REQUIRES( - context, IsLegacyVector(multiples.shape()), + context, TensorShapeUtils::IsVector(multiples.shape()), errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", multiples.shape().DebugString())); OP_REQUIRES(context, input.dims() == multiples.NumElements(), @@ -361,7 +361,7 @@ class TileGradientOp : public OpKernel { const Tensor& input = context->input(0); const Tensor& multiples = context->input(1); OP_REQUIRES( - context, IsLegacyVector(multiples.shape()), + context, TensorShapeUtils::IsVector(multiples.shape()), errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", multiples.shape().DebugString())); OP_REQUIRES(context, input.dims() == multiples.NumElements(), diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 467087b7864..52266e273fe 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -568,7 +568,7 @@ class ApplyGradientDescentOp : public OpKernel { errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(0))); const Tensor& alpha = ctx->input(1); - OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), errors::InvalidArgument("alpha is not a scalar: ", alpha.shape().DebugString())); const Tensor& delta = ctx->input(2); @@ -610,7 +610,7 @@ class ApplyGradientDescentOp : public OpKernel { errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(0))); const Tensor& alpha_dev = ctx->input(1); - OP_REQUIRES(ctx, IsLegacyScalar(alpha_dev.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_dev.shape()), errors::InvalidArgument("alpha is not a scalar: ", alpha_dev.shape().DebugString())); const Tensor& delta = ctx->input(2); @@ -1064,7 +1064,7 @@ class ApplyProximalGradientDescentOp : public OpKernel { errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(0))); const Tensor& alpha = ctx->input(1); - OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()), errors::InvalidArgument("alpha is not a scalar: ", alpha.shape().DebugString())); const Tensor& l1 = ctx->input(2); @@ -1132,7 +1132,7 @@ class SparseApplyProximalGradientDescentOp : public OpKernel { errors::InvalidArgument("var must be at least 1 dimensional")); const Tensor& lr = ctx->input(1); - OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& l1 = ctx->input(2); @@ -1286,7 +1286,7 @@ class ApplyAdagradOp : public OpKernel { errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(1))); const Tensor& lr = ctx->input(2); - OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& grad = ctx->input(3); @@ -1401,7 +1401,7 @@ class ApplyAdagradV2Op : public OpKernel { errors::FailedPrecondition( "Attempting to use uninitialized variables: ", requested_input(1))); const Tensor& lr = ctx->input(2); - OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& epsilon = ctx->input(3); @@ -1631,7 +1631,7 @@ class SparseApplyAdagradOp : public OpKernel { errors::InvalidArgument("var must be at least 1 dimensional")); const Tensor& lr = ctx->input(2); - OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& grad = ctx->input(3); @@ -1800,7 +1800,7 @@ class SparseApplyAdagradV2Op : public OpKernel { errors::InvalidArgument("var must be at least 1 dimensional")); const Tensor& lr = ctx->input(2); - OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& epsilon = ctx->input(3); @@ -2169,7 +2169,7 @@ class ApplyAdagradDAOp : public OpKernel { grad.shape().DebugString())); const Tensor& lr = ctx->input(4); - OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); const Tensor& l1 = ctx->input(5); @@ -2183,7 +2183,7 @@ class ApplyAdagradDAOp : public OpKernel { errors::InvalidArgument("l2 regularization strength is not a scalar: ", l2.shape().DebugString())); const Tensor& global_step = ctx->input(7); - OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()), errors::InvalidArgument("global_step is not a scalar: ", global_step.shape().DebugString())); @@ -2272,7 +2272,7 @@ class SparseApplyAdagradDAOp : public OpKernel { errors::InvalidArgument("indices must be one-dimensional")); const Tensor& lr = ctx->input(5); - OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), errors::InvalidArgument("lr is not a scalar: ", lr.shape().DebugString())); @@ -2289,7 +2289,7 @@ class SparseApplyAdagradDAOp : public OpKernel { l2.shape().DebugString())); const Tensor& global_step = ctx->input(8); - OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()), errors::InvalidArgument("global_step is not a scalar: ", global_step.shape().DebugString())); diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py index d15f2c7b500..6a7ddd79c4d 100644 --- a/tensorflow/python/kernel_tests/scalar_test.py +++ b/tensorflow/python/kernel_tests/scalar_test.py @@ -31,15 +31,16 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test +# TODO(rmlarsen) : Remove this test completely after we stop supporting GraphDef +# version 5 and remove support of legacy scalars from Concat, Fill, Range, +# and Reshape. class ScalarTest(test.TestCase): - def check(self, op, args, error, correct=None): - # Within Google, the switch to scalar strict occurred at version 6. - lenient = [] - strict = [5, 6] - + def check(self, op, args, error, correct=None, lenient=None, strict=[5, 6]): + if lenient is None: + lenient = [] # Use placeholders to bypass shape inference, since only the C++ - # GraphDef level is ever scalar lenient. + # G raphDef level is ever scalar lenient. def placeholders(args, feed): if isinstance(args, tuple): return [placeholders(x, feed) for x in args] @@ -66,18 +67,21 @@ class ScalarTest(test.TestCase): self.assertAllEqual(r, correct) def testConcat(self): - self.check(array_ops.concat, (([2], [3], [7]), [0]), - 'axis tensor should be a scalar integer', [2, 3, 7]) - for data in (2, 3, 7), (2, [3], 7), (2, 3, [7]): - self.check(array_ops.concat, (data, 0), - r'Expected \w+ dimensions in the range \[0, 0\)', [2, 3, 7]) - for data in ([2], 3, 7), ([2], [3], 7): + for data in (2, [3], 7), ([2], 3, 7), ([2], [3], 7): self.check(array_ops.concat, (data, 0), r'Ranks of all input tensors should match', [2, 3, 7]) def testFill(self): - self.check(array_ops.fill, (2, 3), 'dims must be a vector', [3, 3]) - self.check(array_ops.fill, ([2], [3]), 'value must be a scalar', [3, 3]) + self.check( + array_ops.fill, (2, 3), + 'dims must be a vector', [3, 3], + lenient=[5, 6], + strict=[]) + self.check( + array_ops.fill, ([2], [3]), + 'value must be a scalar', [3, 3], + lenient=[5, 6], + strict=[]) def testPad(self): self.check(array_ops.pad, (7, [[1, 2]]), @@ -88,7 +92,11 @@ class ScalarTest(test.TestCase): self.check(random_ops.random_uniform, (3,), 'shape must be a vector') def testReshape(self): - self.check(array_ops.reshape, (7, 1), 'sizes input must be 1-D', [7]) + self.check( + array_ops.reshape, (7, 1), + 'sizes input must be 1-D', [7], + lenient=[5, 6], + strict=[]) def testShardedFilename(self): self.check(gen_io_ops.sharded_filename, ('foo', 4, [100]), @@ -103,9 +111,21 @@ class ScalarTest(test.TestCase): 'num_segments should be a scalar', [0, 7, 0, 0]) def testRange(self): - self.check(math_ops.range, ([0], 3, 2), 'start must be a scalar', [0, 2]) - self.check(math_ops.range, (0, [3], 2), 'limit must be a scalar', [0, 2]) - self.check(math_ops.range, (0, 3, [2]), 'delta must be a scalar', [0, 2]) + self.check( + math_ops.range, ([0], 3, 2), + 'start must be a scalar', [0, 2], + lenient=[5, 6], + strict=[]) + self.check( + math_ops.range, (0, [3], 2), + 'limit must be a scalar', [0, 2], + lenient=[5, 6], + strict=[]) + self.check( + math_ops.range, (0, 3, [2]), + 'delta must be a scalar', [0, 2], + lenient=[5, 6], + strict=[]) def testSlice(self): data = np.arange(10)