diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index 9d7f37d2be6..19ed267b441 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -48,53 +48,60 @@ class ConcatBaseOp : public OpKernel { typedef std::vector::ConstMatrix>> ConstMatrixVector; - explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} + explicit ConcatBaseOp(OpKernelConstruction* c) + : OpKernel(c), + axis_attribute_name_(AxisArgName == NAME_IS_AXIS + ? "axis" + : AxisArgName == NAME_IS_CONCAT_DIM + ? "concat_dim" + : "") { + int unused; + OP_REQUIRES_OK( + c, InputRange(axis_attribute_name_, &axis_input_index_, &unused)); + OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_, + &values_input_end_index_)); + } void Compute(OpKernelContext* c) override { - const Tensor* concat_dim_tensor; - const char* axis_attribute_name = - AxisArgName == NAME_IS_AXIS - ? "axis" - : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : ""; - OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor)); + const Tensor& concat_dim_tensor = c->input(axis_input_index_); + // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars. OP_REQUIRES(c, - (TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) || - (TensorShapeUtils::IsVector(concat_dim_tensor->shape()) && - concat_dim_tensor->shape().dim_size(0) == 1)), + (TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) || + (TensorShapeUtils::IsVector(concat_dim_tensor.shape()) && + concat_dim_tensor.shape().dim_size(0) == 1)), errors::InvalidArgument( - axis_attribute_name, + axis_attribute_name_, " tensor should be a scalar integer, but got shape ", - concat_dim_tensor->shape().DebugString())); + concat_dim_tensor.shape().DebugString())); int64 concat_dim; // In case of ConcatV2, "axis" could be int32 or int64 if (AxisArgName == NAME_IS_AXIS) { OP_REQUIRES( c, - (concat_dim_tensor->dtype() == DT_INT32 || - concat_dim_tensor->dtype() == DT_INT64), - errors::InvalidArgument(axis_attribute_name, + (concat_dim_tensor.dtype() == DT_INT32 || + concat_dim_tensor.dtype() == DT_INT64), + errors::InvalidArgument(axis_attribute_name_, " tensor should be int32 or int64, but got ", - DataTypeString(concat_dim_tensor->dtype()))); + DataTypeString(concat_dim_tensor.dtype()))); } else { - OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32), + OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32), errors::InvalidArgument( - axis_attribute_name, " tensor should be int32, but got ", - DataTypeString(concat_dim_tensor->dtype()))); + axis_attribute_name_, " tensor should be int32, but got ", + DataTypeString(concat_dim_tensor.dtype()))); } - if (concat_dim_tensor->dtype() == DT_INT32) { + if (concat_dim_tensor.dtype() == DT_INT32) { concat_dim = - internal::SubtleMustCopy(concat_dim_tensor->scalar()()); + internal::SubtleMustCopy(concat_dim_tensor.scalar()()); } else { concat_dim = - internal::SubtleMustCopy(concat_dim_tensor->scalar()()); + internal::SubtleMustCopy(concat_dim_tensor.scalar()()); } - OpInputList values; - OP_REQUIRES_OK(c, c->input_list("values", &values)); - const int N = values.size(); - const int input_dims = values[0].dims(); - const TensorShape& input_shape = values[0].shape(); + const int N = values_input_end_index_ - values_input_start_index_; + const Tensor& first_input = c->input(values_input_start_index_); + const int input_dims = first_input.dims(); + const TensorShape& input_shape = first_input.shape(); int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; // concat_dim==0 allows concatenating a list of scalars into a vector. @@ -116,7 +123,7 @@ class ConcatBaseOp : public OpKernel { } int64 output_concat_dim = 0; for (int i = 0; i < N; ++i) { - const auto& in = values[i]; + const auto& in = c->input(values_input_start_index_ + i); OP_REQUIRES( c, in.dims() == input_dims, errors::InvalidArgument( @@ -137,7 +144,7 @@ class ConcatBaseOp : public OpKernel { if (in.NumElements() > 0) { int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; inputs_flat.emplace_back(new typename TTypes::ConstMatrix( - in.shaped({inputs_flat_dim0, inputs_flat_dim1}))); + in.template shaped({inputs_flat_dim0, inputs_flat_dim1}))); } // TODO(rmlarsen): Remove check once !allow_legacy_scalars()? output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; @@ -170,6 +177,12 @@ class ConcatBaseOp : public OpKernel { ConcatCPU(c->device(), inputs_flat, &output_flat); } } + + private: + const char* const axis_attribute_name_; + int axis_input_index_; + int values_input_start_index_; + int values_input_end_index_; }; template diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 911462c8eff..54bde28ad62 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -44,12 +44,9 @@ class SelectOp : public OpKernel { explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { - const Tensor* cond; - const Tensor* then; - const Tensor* else_; - OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); - OP_REQUIRES_OK(ctx, ctx->input("t", &then)); - OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + const Tensor* cond = &ctx->input(0); + const Tensor* then = &ctx->input(1); + const Tensor* else_ = &ctx->input(2); if (TensorShapeUtils::IsScalar(cond->shape())) { ComputeScalar(ctx, cond, then, else_); @@ -149,12 +146,9 @@ class SelectV2Op : public OpKernel { explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { - const Tensor* cond; - const Tensor* then; - const Tensor* else_; - OP_REQUIRES_OK(ctx, ctx->input("condition", &cond)); - OP_REQUIRES_OK(ctx, ctx->input("t", &then)); - OP_REQUIRES_OK(ctx, ctx->input("e", &else_)); + const Tensor* cond = &ctx->input(0); + const Tensor* then = &ctx->input(1); + const Tensor* else_ = &ctx->input(2); // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), // This matches the behavior of numpy. @@ -260,7 +254,6 @@ class SelectV2Op : public OpKernel { ctx->input(1).shape().DebugString(), " is not supported yet.")); break; } - return; } private: diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index 4b4705150a6..cf2b6bb1100 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -50,20 +50,10 @@ class PackOp : public OpKernel { } void Compute(OpKernelContext* c) override { - OpInputList values; - OP_REQUIRES_OK(c, c->input_list("values", &values)); - const int num = values.size(); + const int num = num_inputs(); + const Tensor& first_input = c->input(0); - // Verify that all input shapes match - for (int i = 1; i < num; i++) { - OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()), - errors::InvalidArgument( - "Shapes of all inputs must match: values[0].shape = ", - values[0].shape().DebugString(), " != values[", i, - "].shape = ", values[i].shape().DebugString())); - } - - int expanded_num_dims = values[0].dims() + 1; + int expanded_num_dims = first_input.dims() + 1; int axis = axis_; if (axis < 0) axis += expanded_num_dims; @@ -72,13 +62,13 @@ class PackOp : public OpKernel { -expanded_num_dims, ", ", expanded_num_dims, ")")); - TensorShape output_shape(values[0].shape()); + TensorShape output_shape(first_input.shape()); output_shape.InsertDim(axis, num); // In the num = 1 case, just reshape the input if (num == 1) { Tensor output; - CHECK(output.CopyFrom(values[0], output_shape)); + CHECK(output.CopyFrom(first_input, output_shape)); c->set_output(0, output); return; } @@ -109,8 +99,15 @@ class PackOp : public OpKernel { ConstMatrixVector inputs_flat; inputs_flat.reserve(num); for (int i = 0; i < num; ++i) { + const Tensor& input = c->input(i); + OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()), + errors::InvalidArgument( + "Shapes of all inputs must match: values[0].shape = ", + first_input.shape().DebugString(), " != values[", i, + "].shape = ", input.shape().DebugString())); + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( - values[i].shaped({before_dim, after_dim}))); + input.shaped({before_dim, after_dim}))); } #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (std::is_same::value) {