Clean up OpKernel 1.

1. Clean up old code related to legacy scalars and vectors in TensorFlow. This CL does not bump the GrapDef version.
Legacy behavior is allowed for a small number of frequently used kernels and a TODO to fix this in a future CL added.
2. Move OpKernel::MakeShape into tensor_util.{h,cc}. Ideally it should go in TensorShapeUtils, but that would create a cyclical dependence.

The oldest branch of TensorFlow on github is r0.7, which is at GraphDef version 8 already, so this should have no impact on GraphDefs in the wild.

We also changed the scalar test to be strict in open source on 2017-04-03.

PiperOrigin-RevId: 292234389
Change-Id: Iee372607d9b9139d33ba7be5be8b792d9471e0f6
This commit is contained in:
A. Unique TensorFlower 2020-01-29 15:45:56 -08:00 committed by TensorFlower Gardener
parent a33b83cf74
commit 271f6bb49d
37 changed files with 191 additions and 169 deletions

View File

@ -162,23 +162,6 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start,
}
}
Status OpKernel::MakeShape(const Tensor& shape, TensorShape* out) const {
if (!IsLegacyVector(shape.shape())) {
return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ",
shape.shape().DebugString());
}
if (shape.dtype() == DataType::DT_INT32) {
auto vec = shape.flat<int32>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else if (shape.dtype() == DataType::DT_INT64) {
auto vec = shape.flat<int64>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else {
return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
}
}
string OpKernel::TraceString(OpKernelContext* ctx, bool verbose) {
string trace_string = strings::StrCat(name_view(), ":", type_string_view());
if (!verbose) return trace_string;

View File

@ -195,31 +195,6 @@ class OpKernel {
Status InputRange(StringPiece input_name, int* start, int* stop) const;
Status OutputRange(StringPiece output_name, int* start, int* stop) const;
// We allow legacy scalars within Google up until GraphDef version 6.
// TODO(irving): Remove when we can drop support for GraphDef version 5.
bool allow_legacy_scalars() const {
#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
return graph_def_version_ < 6;
#else
return false;
#endif
}
// Allow either scalars or (if allowing legacy scalars) shape (1,).
bool IsLegacyScalar(const TensorShape& shape) const {
return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 &&
shape.dim_size(0) == 1);
}
// Allow rank 1 or (if allowing legacy scalars) rank 0.
bool IsLegacyVector(const TensorShape& shape) const {
return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
}
// Turn a shape Tensor into a TensorShape
// TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
Status MakeShape(const Tensor& shape, TensorShape* out) const;
// Returns `true` if and only if this kernel uses deferred execution.
bool is_deferred() const { return is_deferred_; }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@ -366,5 +367,22 @@ bool CompressTensorProtoInPlace(int64 min_num_elements,
#undef HANDLE_COMPRESS_CASE
Status MakeShape(const Tensor& shape, TensorShape* out) {
if (!TensorShapeUtils::IsVector(shape.shape())) {
return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ",
shape.shape().DebugString());
}
if (shape.dtype() == DataType::DT_INT32) {
auto vec = shape.flat<int32>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else if (shape.dtype() == DataType::DT_INT64) {
auto vec = shape.flat<int64>();
return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
} else {
return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
}
}
} // namespace tensor
} // namespace tensorflow

View File

@ -325,6 +325,10 @@ inline bool CompressTensorProtoInPlace(TensorProto* tensor) {
kDefaultMinCompressionRatio, tensor);
}
// Make a TensorShape from the contents of shape_t. Shape_t must be a
// 1-dimensional tensor of type int32 or int64.
Status MakeShape(const Tensor& shape_t, TensorShape* out);
} // namespace tensor
} // namespace tensorflow

View File

@ -21,10 +21,12 @@ limitations under the License.
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/broadcast_to_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/bcast.h"
@ -45,8 +47,7 @@ class BroadcastToOp : public OpKernel {
const Tensor& shape_tensor = ctx->input(1);
TensorShape output_shape;
OP_REQUIRES_OK(ctx,
ctx->op_kernel().MakeShape(shape_tensor, &output_shape));
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &output_shape));
// Handle copy.
if (output_shape == input_shape) {
@ -91,7 +92,7 @@ class BroadcastToOp : public OpKernel {
}
};
// As MakeShape is able to handle both DT_INT32 and DT_INT64,
// As tensor::MakeShape is able to handle both DT_INT32 and DT_INT64,
// no need to have TypeConstraint for `Tidx`
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \

View File

@ -53,11 +53,15 @@ class ConcatBaseOp : public OpKernel {
void Compute(OpKernelContext* c) override {
const Tensor* concat_dim_tensor;
const char* axis_attribute_name =
AxisArgName == NAME_IS_AXIS ? "axis" : AxisArgName == NAME_IS_CONCAT_DIM
? "concat_dim"
: "<invalid>";
AxisArgName == NAME_IS_AXIS
? "axis"
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
// TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
OP_REQUIRES(c,
(TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) ||
(TensorShapeUtils::IsVector(concat_dim_tensor->shape()) &&
concat_dim_tensor->shape().dim_size(0) == 1)),
errors::InvalidArgument(
axis_attribute_name,
" tensor should be a scalar integer, but got shape ",
@ -93,9 +97,8 @@ class ConcatBaseOp : public OpKernel {
const TensorShape& input_shape = values[0].shape();
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
OP_REQUIRES(c,
(0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
// concat_dim==0 allows concatenating a list of scalars into a vector.
OP_REQUIRES(c, (0 <= axis && axis < input_dims) || concat_dim == 0,
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range "
"[",
@ -112,12 +115,10 @@ class ConcatBaseOp : public OpKernel {
inputs_flat_dim0 *= input_shape.dim_size(d);
}
int64 output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
const auto& in = values[i];
const bool in_is_scalar = IsLegacyScalar(in.shape());
OP_REQUIRES(
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
c, in.dims() == input_dims,
errors::InvalidArgument(
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
input_shape.DebugString(), " vs. shape[", i,
@ -138,12 +139,12 @@ class ConcatBaseOp : public OpKernel {
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
}
// TODO(irving): Remove check once !allow_legacy_scalars().
// TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
}
TensorShape output_shape(input_shape);
// TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
// TODO(rmlarsen): Remove rank 0 case once !allow_legacy_scalars()?
if (output_shape.dims() == 0) {
output_shape.AddDim(output_concat_dim);
} else {
@ -282,7 +283,7 @@ class ConcatOffsetOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& concat_dim = ctx->input(0);
OP_REQUIRES(
ctx, IsLegacyScalar(concat_dim.shape()),
ctx, TensorShapeUtils::IsScalar(concat_dim.shape()),
errors::InvalidArgument(
"Concat dim tensor should be a scalar integer, but got shape ",
concat_dim.shape().DebugString()));

View File

@ -159,13 +159,23 @@ class FillOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& Tdims = context->input(0);
OP_REQUIRES(context, IsLegacyVector(Tdims.shape()),
errors::InvalidArgument("dims must be a vector, got shape ",
Tdims.shape().DebugString()));
OP_REQUIRES(
context,
// TODO(rmlarsen): Disallow legacy use of scalars to represent shape.
(TensorShapeUtils::IsVector(Tdims.shape()) ||
TensorShapeUtils::IsScalar(Tdims.shape())),
errors::InvalidArgument("dims must represent a vector, got shape ",
Tdims.shape().DebugString()));
const Tensor& Tvalue = context->input(1);
OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()),
errors::InvalidArgument("value must be a scalar, got shape ",
Tvalue.shape().DebugString()));
OP_REQUIRES(
context,
// TODO(rmlarsen): Disallow legacy use of length-1 vector to represent
// scalar.
TensorShapeUtils::IsScalar(Tvalue.shape()) ||
(TensorShapeUtils::IsVector(Tvalue.shape()) &&
Tvalue.shape().dim_size(0) == 1),
errors::InvalidArgument("value must represent a scalar, got shape ",
Tvalue.shape().DebugString()));
auto dims = Tdims.flat<Index>();
TensorShape shape;
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_3d.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
@ -230,8 +231,9 @@ class Conv3DBackpropInputOp : public OpKernel {
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
// MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
// tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
// input_sizes.
OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
} else {
input_shape = context->input(0).shape();
}
@ -336,8 +338,9 @@ class Conv3DCustomBackpropInputOp : public OpKernel {
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
// MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
// tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for
// input_sizes.
OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
} else {
input_shape = context->input(0).shape();
}
@ -1153,7 +1156,7 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape));
} else {
input_shape = context->input(0).shape();
}
@ -1621,7 +1624,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, MakeShape(filter_sizes, &filter_shape));
OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape));
} else {
filter_shape = context->input(1).shape();
}

View File

@ -685,15 +685,11 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel {
paddings.dim_size(1) == 2,
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
paddings.shape().DebugString()));
const int fixed_dims =
(allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
? 1
: dims;
OP_REQUIRES(
context, fixed_dims == paddings.dim_size(0),
context, dims == paddings.dim_size(0),
errors::InvalidArgument(
"The first dimension of paddings must be the rank of inputs: ",
fixed_dims, " ", paddings.shape().DebugString(), " ",
dims, " ", paddings.shape().DebugString(), " ",
resized_shape.DebugString()));
OP_REQUIRES(
context, dims == paddings.dim_size(0),

View File

@ -56,7 +56,7 @@ AssertOp::AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
void AssertOp::Compute(OpKernelContext* ctx) {
const Tensor& cond = ctx->input(0);
OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cond.shape()),
errors::InvalidArgument("In[0] should be a scalar: ",
cond.shape().DebugString()));

View File

@ -139,7 +139,7 @@ class EigenConcatBaseOp : public OpKernel {
? "axis"
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
OP_REQUIRES(c, TensorShapeUtils::IsScalar(concat_dim_tensor->shape()),
errors::InvalidArgument(
axis_attribute_name,
" tensor should be a scalar integer, but got shape ",
@ -153,9 +153,7 @@ class EigenConcatBaseOp : public OpKernel {
int32 axis = (concat_dim < 0) ? (concat_dim + input_dims) : concat_dim;
OP_REQUIRES(
c,
(0 <= axis && axis < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
c, (0 <= axis && axis < input_dims),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range [",
-input_dims, ", ", input_dims, "), but got ", concat_dim));
@ -180,10 +178,10 @@ class EigenConcatBaseOp : public OpKernel {
inputs_flat_dim0 *= input_shape.dim_size(d);
}
int64 output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
const bool input_is_scalar = TensorShapeUtils::IsScalar(input_shape);
for (int i = 0; i < N; ++i) {
const auto in = values[i];
const bool in_is_scalar = IsLegacyScalar(input_shapes[i]);
const bool in_is_scalar = TensorShapeUtils::IsScalar(input_shapes[i]);
OP_REQUIRES(
c,
(input_shapes[i].dims() == input_dims) ||
@ -471,7 +469,7 @@ class MklConcatOp : public OpKernel {
: MklGetInput(context, N);
// Sanity checks
OP_REQUIRES(
context, IsLegacyScalar(concat_dim_tensor.shape()),
context, TensorShapeUtils::IsScalar(concat_dim_tensor.shape()),
errors::InvalidArgument(
"Concat dim tensor should be a scalar integer, but got shape ",
concat_dim_tensor.shape().DebugString()));

View File

@ -521,8 +521,8 @@ class MklConvCustomBackpropInputOp
TensorShape input_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
// Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64
// output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for
// input_tensor.
// output_shape tensor::MakeShape is able to handle both DT_INT32 and
// DT_INT64 for input_tensor.
CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true);
return input_tf_shape;
}

View File

@ -74,7 +74,7 @@ class MklReshapeOp : public OpKernel {
: input_tensor.NumElements();
// Preliminary validation of sizes.
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
OP_REQUIRES(context, TensorShapeUtils::IsVector(sizes.shape()),
errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes.shape().DebugString()));

View File

@ -86,10 +86,11 @@ static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
const int input_dims = input_tf_shape.dims();
OP_REQUIRES(
context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
begin_tensor.NumElements() == input_dims &&
size_tensor.NumElements() == input_dims,
context,
TensorShapeUtils::IsVector(begin_tensor.shape()) &&
TensorShapeUtils::IsVector(size_tensor.shape()) &&
begin_tensor.NumElements() == input_dims &&
size_tensor.NumElements() == input_dims,
errors::InvalidArgument(
"Expected begin and size arguments to be 1-D tensors of size ",
input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),

View File

@ -61,11 +61,8 @@ class PadOp : public OpKernel {
TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
in1.shape().DebugString()));
const int fixed_dims =
(allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
: dims;
OP_REQUIRES(
context, fixed_dims == in1.dim_size(0),
context, dims == in1.dim_size(0),
errors::InvalidArgument(
"The first dimension of paddings must be the rank of inputs",
in1.shape().DebugString(), " ", in0.shape().DebugString()));
@ -83,15 +80,14 @@ class PadOp : public OpKernel {
// Compute the shape of the output tensor, and allocate it.
TensorShape output_shape;
typename TTypes<Tpadding>::ConstMatrix paddings = in1.matrix<Tpadding>();
for (int d = 0; d < fixed_dims; ++d) {
for (int d = 0; d < dims; ++d) {
const Tpadding before_d =
paddings(d, 0); // Pad before existing elements.
const Tpadding after_d = paddings(d, 1); // Pad after existing elements.
OP_REQUIRES(context, before_d >= 0 && after_d >= 0,
errors::InvalidArgument("Paddings must be non-negative: ",
before_d, " ", after_d));
const int64 size_d =
(allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
const int64 size_d = in0.dim_size(d);
output_shape.AddDim(before_d + size_d + after_d);
}
@ -107,10 +103,9 @@ class PadOp : public OpKernel {
TensorShape collapsed_input_shape;
TensorShape collapsed_output_shape;
Tensor collapsed_paddings;
if (fixed_dims > 1 &&
CollapseAdjacentNonPaddedDimensions(
in0.shape(), in1, output_shape, &collapsed_input_shape,
&collapsed_paddings, &collapsed_output_shape)) {
if (dims > 1 && CollapseAdjacentNonPaddedDimensions(
in0.shape(), in1, output_shape, &collapsed_input_shape,
&collapsed_paddings, &collapsed_output_shape)) {
Tensor collapsed_input;
CHECK(collapsed_input.CopyFrom(in0, collapsed_input_shape));
Tensor collapsed_output;
@ -135,8 +130,7 @@ class PadOp : public OpKernel {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output));
OperateWithVariableRank(context, fixed_dims, in0, paddings, pad_value,
output);
OperateWithVariableRank(context, dims, in0, paddings, pad_value, output);
}
}

View File

@ -127,10 +127,10 @@ class QuantizedConcatOp : public OpKernel {
// Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
inputs_flat->reserve(N);
*output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
const bool input_is_scalar = TensorShapeUtils::IsScalar(input_shape);
for (int i = 0; i < N; ++i) {
const auto in = values[i];
const bool in_is_scalar = IsLegacyScalar(in.shape());
const bool in_is_scalar = TensorShapeUtils::IsScalar(in.shape());
OP_REQUIRES(
context, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
errors::InvalidArgument(
@ -161,7 +161,7 @@ class QuantizedConcatOp : public OpKernel {
const Tensor* concat_dim_tensor = nullptr;
OP_REQUIRES_OK(context, context->input("concat_dim", &concat_dim_tensor));
OP_REQUIRES(
context, IsLegacyScalar(concat_dim_tensor->shape()),
context, TensorShapeUtils::IsScalar(concat_dim_tensor->shape()),
errors::InvalidArgument(
"Concat dim tensor should be a scalar integer, but got shape ",
concat_dim_tensor->shape().DebugString()));
@ -184,9 +184,7 @@ class QuantizedConcatOp : public OpKernel {
const int input_dims = values[0].dims();
const TensorShape& input_shape = values[0].shape();
OP_REQUIRES(
context,
(0 <= concat_dim && concat_dim < input_dims) ||
(allow_legacy_scalars() && concat_dim == 0),
context, (0 <= concat_dim && concat_dim < input_dims),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range [", 0,
", ", input_dims, "), but got ", concat_dim));

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/random_op_cpu.h"
#include "tensorflow/core/lib/hash/crc32c.h"
#include "tensorflow/core/lib/random/random_distributions.h"
@ -56,7 +57,7 @@ namespace {
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
int index, Tensor** output) {
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(ctx->op_kernel().MakeShape(shape, &tensor_shape));
TF_RETURN_IF_ERROR(tensor::MakeShape(shape, &tensor_shape));
return ctx->allocate_output(index, tensor_shape, output);
}

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/util/guarded_philox_random.h"
@ -290,7 +291,7 @@ class RandomPoissonOp : public OpKernel {
const Tensor& rate_t = ctx->input(1);
TensorShape samples_shape;
OP_REQUIRES_OK(ctx, MakeShape(shape_t, &samples_shape));
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &samples_shape));
const int64 num_samples = samples_shape.num_elements();
samples_shape.AppendShape(rate_t.shape());

View File

@ -36,9 +36,13 @@ class ReshapeOp : public OpKernel {
const Tensor& input = context->input(0);
const Tensor& sizes = context->input(1);
// Preliminary validation of sizes.
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
errors::InvalidArgument("sizes input must be 1-D, not ",
sizes.shape().DebugString()));
OP_REQUIRES(
context,
(TensorShapeUtils::IsVector(sizes.shape()) ||
// TODO(rmlarsen): Disallow legacy use of scalars to represent shape.
TensorShapeUtils::IsScalar(sizes.shape())),
errors::InvalidArgument("sizes input must be 1-D, not ",
sizes.shape().DebugString()));
// Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one.

View File

@ -55,7 +55,7 @@ class ShardedFilenameOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
static const char* input_names[3] = {"basename", "shard", "num_shards"};
for (int i = 0; i < ctx->num_inputs(); ++i) {
OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(i).shape()),
errors::InvalidArgument(input_names[i],
" must be a scalar, got shape ",
ctx->input(i).shape().DebugString()));
@ -78,7 +78,7 @@ class ShardedFilespecOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
static const char* input_names[2] = {"basename", "num_shards"};
for (int i = 0; i < ctx->num_inputs(); ++i) {
OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(i).shape()),
errors::InvalidArgument(input_names[i],
" must be a scalar, got shape ",
ctx->input(i).shape().DebugString()));

View File

@ -125,7 +125,7 @@ class ScatterUpdateOp : public OpKernel {
auto params_flat = params.flat_outer_dims<T>();
if (TensorShapeUtils::IsScalar(updates.shape()) ||
IsLegacyScalar(updates.shape())) {
TensorShapeUtils::IsScalar(updates.shape())) {
const auto update = updates.scalar<T>();
functor::ScatterScalarFunctor<Device, T, Index, op> functor;
const Index bad_i = functor(c, c->template eigen_device<Device>(),

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/segment_reduction_ops.h"
#include "tensorflow/core/lib/core/status.h"
@ -780,7 +781,7 @@ class SparseSegmentGradOpBase : public OpKernel {
errors::InvalidArgument("indices should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
errors::InvalidArgument("segment_ids should be a vector."));
OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()),
errors::InvalidArgument("output_dim0 should be a scalar."));
const int64 N = indices.NumElements();

View File

@ -44,7 +44,7 @@ void UnsortedSegmentReductionValidation(OpKernel* op_kernel,
const Tensor& segment_ids,
const Tensor& num_segments) {
OP_REQUIRES(
context, op_kernel->IsLegacyScalar(num_segments.shape()),
context, TensorShapeUtils::IsScalar(num_segments.shape()),
errors::InvalidArgument("num_segments should be a scalar, not shape ",
num_segments.shape().DebugString()));
OP_REQUIRES(

View File

@ -36,13 +36,23 @@ class RangeOp : public OpKernel {
const Tensor& start_in = context->input(0);
const Tensor& limit_in = context->input(1);
const Tensor& delta_in = context->input(2);
OP_REQUIRES(context, IsLegacyScalar(start_in.shape()),
// TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(start_in.shape()) ||
(TensorShapeUtils::IsVector(start_in.shape()) &&
start_in.shape().dim_size(0) == 1),
errors::InvalidArgument("start must be a scalar, not shape ",
start_in.shape().DebugString()));
OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()),
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(limit_in.shape()) ||
(TensorShapeUtils::IsVector(limit_in.shape()) &&
limit_in.shape().dim_size(0) == 1),
errors::InvalidArgument("limit must be a scalar, not shape ",
limit_in.shape().DebugString()));
OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
OP_REQUIRES(context,
TensorShapeUtils::IsScalar(delta_in.shape()) ||
(TensorShapeUtils::IsVector(delta_in.shape()) &&
delta_in.shape().dim_size(0) == 1),
errors::InvalidArgument("delta must be a scalar, not shape ",
delta_in.shape().DebugString()));
const T start = start_in.scalar<T>()();

View File

@ -73,8 +73,8 @@ static void SharedValidation(OpKernelContext* context,
OP_REQUIRES(
context,
context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
TensorShapeUtils::IsVector(begin_tensor.shape()) &&
TensorShapeUtils::IsVector(size_tensor.shape()) &&
begin_tensor.NumElements() == input.dims() &&
size_tensor.NumElements() == input.dims(),
errors::InvalidArgument(

View File

@ -62,7 +62,7 @@ class SparseToDense : public OpKernel {
// output_shape
const Tensor& output_shape = c->input(1);
OP_REQUIRES(
c, IsLegacyVector(output_shape.shape()),
c, TensorShapeUtils::IsVector(output_shape.shape()),
errors::InvalidArgument("output_shape should be a vector, got shape ",
output_shape.shape().DebugString()));
OP_REQUIRES(c, output_shape.NumElements() == num_dims,

View File

@ -15,6 +15,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/random_op_cpu.h"
#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
@ -113,7 +114,7 @@ void StatefulRandomCompute(OpKernelContext* ctx, Distribution dist,
using T = typename Distribution::ResultElementType;
const Tensor& shape_t = ctx->input(shape_input_idx);
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->op_kernel().MakeShape(shape_t, &shape));
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
auto output_flat = output->flat<T>();
@ -265,7 +266,7 @@ class NonDeterministicIntsOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& shape_t = ctx->input(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, ctx->op_kernel().MakeShape(shape_t, &shape));
OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &shape));
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape, &output));
if (shape.num_elements() == 0) return;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/random_op.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/platform/logging.h"
@ -74,7 +75,7 @@ class StatelessRandomOpBase : public OpKernel {
const Tensor& shape_t = context->input(0);
const Tensor& seed_t = context->input(1);
TensorShape shape;
OP_REQUIRES_OK(context, MakeShape(shape_t, &shape));
OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape));
OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
errors::InvalidArgument("seed must have shape [2], not ",
seed_t.shape().DebugString()));

View File

@ -39,7 +39,7 @@ class SummaryAudioOp : public OpKernel {
void Compute(OpKernelContext* c) override {
const Tensor& tag = c->input(0);
const Tensor& tensor = c->input(1);
OP_REQUIRES(c, IsLegacyScalar(tag.shape()),
OP_REQUIRES(c, TensorShapeUtils::IsScalar(tag.shape()),
errors::InvalidArgument("Tag must be a scalar"));
OP_REQUIRES(c, tensor.dims() >= 2 && tensor.dims() <= 3,
errors::InvalidArgument("Tensor must be 3-D or 2-D, got: ",

View File

@ -52,7 +52,7 @@ class SummaryImageOp : public OpKernel {
void Compute(OpKernelContext* c) override {
const Tensor& tags = c->input(0);
const Tensor& tensor = c->input(1);
OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
OP_REQUIRES(c, TensorShapeUtils::IsScalar(tags.shape()),
errors::InvalidArgument("Tags must be a scalar"));
OP_REQUIRES(c,
tensor.dims() == 4 &&

View File

@ -42,8 +42,8 @@ class SummaryScalarOp : public OpKernel {
OP_REQUIRES(
c,
tags.IsSameSize(values) ||
(IsLegacyScalar(tags.shape()) && IsLegacyScalar(values.shape())),
tags.IsSameSize(values) || (TensorShapeUtils::IsScalar(tags.shape()) &&
TensorShapeUtils::IsScalar(values.shape())),
errors::InvalidArgument(
"tags and values not the same shape: ", tags.shape().DebugString(),
" != ", values.shape().DebugString(), SingleTag(tags)));
@ -82,7 +82,7 @@ class SummaryHistoOp : public OpKernel {
const Tensor& tags = c->input(0);
const Tensor& values = c->input(1);
const auto flat = values.flat<T>();
OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
OP_REQUIRES(c, TensorShapeUtils::IsScalar(tags.shape()),
errors::InvalidArgument("tags must be scalar"));
// Build histogram of values in "values" tensor
histogram::Histogram histo;

View File

@ -32,7 +32,7 @@ class SummaryTensorOpV2 : public OpKernel {
void Compute(OpKernelContext* c) override {
const Tensor& tag = c->input(0);
OP_REQUIRES(c, IsLegacyScalar(tag.shape()),
OP_REQUIRES(c, TensorShapeUtils::IsScalar(tag.shape()),
errors::InvalidArgument("tag must be scalar"));
const Tensor& tensor = c->input(1);
const Tensor& serialized_summary_metadata_tensor = c->input(2);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/kernels/aggregate_ops_cpu.h"
namespace tensorflow {

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/split_lib.h"
@ -331,8 +332,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
TensorShape shape_to_prepend;
auto element_shape = PartialTensorShape();
if (ctx->num_inputs() > 2) {
TF_RETURN_IF_ERROR(
ctx->op_kernel().MakeShape(ctx->input(2), &shape_to_prepend));
TF_RETURN_IF_ERROR(tensor::MakeShape(ctx->input(2), &shape_to_prepend));
auto ta_element_shape = tensor_array->ElemShape();
if (!ta_element_shape.unknown_rank()) {
std::vector<int64> dims;

View File

@ -187,7 +187,7 @@ class TileOp : public OpKernel {
const Tensor& multiples = context->input(1);
OP_REQUIRES(
context, IsLegacyVector(multiples.shape()),
context, TensorShapeUtils::IsVector(multiples.shape()),
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
multiples.shape().DebugString()));
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
@ -361,7 +361,7 @@ class TileGradientOp : public OpKernel {
const Tensor& input = context->input(0);
const Tensor& multiples = context->input(1);
OP_REQUIRES(
context, IsLegacyVector(multiples.shape()),
context, TensorShapeUtils::IsVector(multiples.shape()),
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
multiples.shape().DebugString()));
OP_REQUIRES(context, input.dims() == multiples.NumElements(),

View File

@ -568,7 +568,7 @@ class ApplyGradientDescentOp : public OpKernel {
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", requested_input(0)));
const Tensor& alpha = ctx->input(1);
OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
errors::InvalidArgument("alpha is not a scalar: ",
alpha.shape().DebugString()));
const Tensor& delta = ctx->input(2);
@ -610,7 +610,7 @@ class ApplyGradientDescentOp<SYCLDevice, T> : public OpKernel {
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", requested_input(0)));
const Tensor& alpha_dev = ctx->input(1);
OP_REQUIRES(ctx, IsLegacyScalar(alpha_dev.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_dev.shape()),
errors::InvalidArgument("alpha is not a scalar: ",
alpha_dev.shape().DebugString()));
const Tensor& delta = ctx->input(2);
@ -1064,7 +1064,7 @@ class ApplyProximalGradientDescentOp : public OpKernel {
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", requested_input(0)));
const Tensor& alpha = ctx->input(1);
OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha.shape()),
errors::InvalidArgument("alpha is not a scalar: ",
alpha.shape().DebugString()));
const Tensor& l1 = ctx->input(2);
@ -1132,7 +1132,7 @@ class SparseApplyProximalGradientDescentOp : public OpKernel {
errors::InvalidArgument("var must be at least 1 dimensional"));
const Tensor& lr = ctx->input(1);
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& l1 = ctx->input(2);
@ -1286,7 +1286,7 @@ class ApplyAdagradOp : public OpKernel {
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", requested_input(1)));
const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& grad = ctx->input(3);
@ -1401,7 +1401,7 @@ class ApplyAdagradV2Op : public OpKernel {
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", requested_input(1)));
const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& epsilon = ctx->input(3);
@ -1631,7 +1631,7 @@ class SparseApplyAdagradOp : public OpKernel {
errors::InvalidArgument("var must be at least 1 dimensional"));
const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& grad = ctx->input(3);
@ -1800,7 +1800,7 @@ class SparseApplyAdagradV2Op : public OpKernel {
errors::InvalidArgument("var must be at least 1 dimensional"));
const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& epsilon = ctx->input(3);
@ -2169,7 +2169,7 @@ class ApplyAdagradDAOp : public OpKernel {
grad.shape().DebugString()));
const Tensor& lr = ctx->input(4);
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& l1 = ctx->input(5);
@ -2183,7 +2183,7 @@ class ApplyAdagradDAOp : public OpKernel {
errors::InvalidArgument("l2 regularization strength is not a scalar: ",
l2.shape().DebugString()));
const Tensor& global_step = ctx->input(7);
OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()),
errors::InvalidArgument("global_step is not a scalar: ",
global_step.shape().DebugString()));
@ -2272,7 +2272,7 @@ class SparseApplyAdagradDAOp : public OpKernel {
errors::InvalidArgument("indices must be one-dimensional"));
const Tensor& lr = ctx->input(5);
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
@ -2289,7 +2289,7 @@ class SparseApplyAdagradDAOp : public OpKernel {
l2.shape().DebugString()));
const Tensor& global_step = ctx->input(8);
OP_REQUIRES(ctx, IsLegacyScalar(global_step.shape()),
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step.shape()),
errors::InvalidArgument("global_step is not a scalar: ",
global_step.shape().DebugString()));

View File

@ -31,15 +31,16 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
# TODO(rmlarsen) : Remove this test completely after we stop supporting GraphDef
# version 5 and remove support of legacy scalars from Concat, Fill, Range,
# and Reshape.
class ScalarTest(test.TestCase):
def check(self, op, args, error, correct=None):
# Within Google, the switch to scalar strict occurred at version 6.
lenient = []
strict = [5, 6]
def check(self, op, args, error, correct=None, lenient=None, strict=[5, 6]):
if lenient is None:
lenient = []
# Use placeholders to bypass shape inference, since only the C++
# GraphDef level is ever scalar lenient.
# G raphDef level is ever scalar lenient.
def placeholders(args, feed):
if isinstance(args, tuple):
return [placeholders(x, feed) for x in args]
@ -66,18 +67,21 @@ class ScalarTest(test.TestCase):
self.assertAllEqual(r, correct)
def testConcat(self):
self.check(array_ops.concat, (([2], [3], [7]), [0]),
'axis tensor should be a scalar integer', [2, 3, 7])
for data in (2, 3, 7), (2, [3], 7), (2, 3, [7]):
self.check(array_ops.concat, (data, 0),
r'Expected \w+ dimensions in the range \[0, 0\)', [2, 3, 7])
for data in ([2], 3, 7), ([2], [3], 7):
for data in (2, [3], 7), ([2], 3, 7), ([2], [3], 7):
self.check(array_ops.concat, (data, 0),
r'Ranks of all input tensors should match', [2, 3, 7])
def testFill(self):
self.check(array_ops.fill, (2, 3), 'dims must be a vector', [3, 3])
self.check(array_ops.fill, ([2], [3]), 'value must be a scalar', [3, 3])
self.check(
array_ops.fill, (2, 3),
'dims must be a vector', [3, 3],
lenient=[5, 6],
strict=[])
self.check(
array_ops.fill, ([2], [3]),
'value must be a scalar', [3, 3],
lenient=[5, 6],
strict=[])
def testPad(self):
self.check(array_ops.pad, (7, [[1, 2]]),
@ -88,7 +92,11 @@ class ScalarTest(test.TestCase):
self.check(random_ops.random_uniform, (3,), 'shape must be a vector')
def testReshape(self):
self.check(array_ops.reshape, (7, 1), 'sizes input must be 1-D', [7])
self.check(
array_ops.reshape, (7, 1),
'sizes input must be 1-D', [7],
lenient=[5, 6],
strict=[])
def testShardedFilename(self):
self.check(gen_io_ops.sharded_filename, ('foo', 4, [100]),
@ -103,9 +111,21 @@ class ScalarTest(test.TestCase):
'num_segments should be a scalar', [0, 7, 0, 0])
def testRange(self):
self.check(math_ops.range, ([0], 3, 2), 'start must be a scalar', [0, 2])
self.check(math_ops.range, (0, [3], 2), 'limit must be a scalar', [0, 2])
self.check(math_ops.range, (0, 3, [2]), 'delta must be a scalar', [0, 2])
self.check(
math_ops.range, ([0], 3, 2),
'start must be a scalar', [0, 2],
lenient=[5, 6],
strict=[])
self.check(
math_ops.range, (0, [3], 2),
'limit must be a scalar', [0, 2],
lenient=[5, 6],
strict=[])
self.check(
math_ops.range, (0, 3, [2]),
'delta must be a scalar', [0, 2],
lenient=[5, 6],
strict=[])
def testSlice(self):
data = np.arange(10)