Switch ConcatOp, PackOp, SelectOp, and SelectV2Op to access inputs by index (not name).

These kernels are heavily used, and the name resolution on each invocation causes non-trivial overhead.

PiperOrigin-RevId: 297267975
Change-Id: Id69de0e2cff3622e992389c16e020a5da3141462
This commit is contained in:
Derek Murray 2020-02-25 20:17:46 -08:00 committed by TensorFlower Gardener
parent 8c97290ba3
commit 26a24de29b
3 changed files with 61 additions and 58 deletions

View File

@ -48,53 +48,60 @@ class ConcatBaseOp : public OpKernel {
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
ConstMatrixVector; ConstMatrixVector;
explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {} explicit ConcatBaseOp(OpKernelConstruction* c)
: OpKernel(c),
axis_attribute_name_(AxisArgName == NAME_IS_AXIS
? "axis"
: AxisArgName == NAME_IS_CONCAT_DIM
? "concat_dim"
: "<invalid>") {
int unused;
OP_REQUIRES_OK(
c, InputRange(axis_attribute_name_, &axis_input_index_, &unused));
OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_,
&values_input_end_index_));
}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
const Tensor* concat_dim_tensor; const Tensor& concat_dim_tensor = c->input(axis_input_index_);
const char* axis_attribute_name =
AxisArgName == NAME_IS_AXIS
? "axis"
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
// TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars. // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
OP_REQUIRES(c, OP_REQUIRES(c,
(TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) || (TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) ||
(TensorShapeUtils::IsVector(concat_dim_tensor->shape()) && (TensorShapeUtils::IsVector(concat_dim_tensor.shape()) &&
concat_dim_tensor->shape().dim_size(0) == 1)), concat_dim_tensor.shape().dim_size(0) == 1)),
errors::InvalidArgument( errors::InvalidArgument(
axis_attribute_name, axis_attribute_name_,
" tensor should be a scalar integer, but got shape ", " tensor should be a scalar integer, but got shape ",
concat_dim_tensor->shape().DebugString())); concat_dim_tensor.shape().DebugString()));
int64 concat_dim; int64 concat_dim;
// In case of ConcatV2, "axis" could be int32 or int64 // In case of ConcatV2, "axis" could be int32 or int64
if (AxisArgName == NAME_IS_AXIS) { if (AxisArgName == NAME_IS_AXIS) {
OP_REQUIRES( OP_REQUIRES(
c, c,
(concat_dim_tensor->dtype() == DT_INT32 || (concat_dim_tensor.dtype() == DT_INT32 ||
concat_dim_tensor->dtype() == DT_INT64), concat_dim_tensor.dtype() == DT_INT64),
errors::InvalidArgument(axis_attribute_name, errors::InvalidArgument(axis_attribute_name_,
" tensor should be int32 or int64, but got ", " tensor should be int32 or int64, but got ",
DataTypeString(concat_dim_tensor->dtype()))); DataTypeString(concat_dim_tensor.dtype())));
} else { } else {
OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32), OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32),
errors::InvalidArgument( errors::InvalidArgument(
axis_attribute_name, " tensor should be int32, but got ", axis_attribute_name_, " tensor should be int32, but got ",
DataTypeString(concat_dim_tensor->dtype()))); DataTypeString(concat_dim_tensor.dtype())));
} }
if (concat_dim_tensor->dtype() == DT_INT32) { if (concat_dim_tensor.dtype() == DT_INT32) {
concat_dim = concat_dim =
internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()()); internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
} else { } else {
concat_dim = concat_dim =
internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()()); internal::SubtleMustCopy(concat_dim_tensor.scalar<int64>()());
} }
OpInputList values; const int N = values_input_end_index_ - values_input_start_index_;
OP_REQUIRES_OK(c, c->input_list("values", &values)); const Tensor& first_input = c->input(values_input_start_index_);
const int N = values.size(); const int input_dims = first_input.dims();
const int input_dims = values[0].dims(); const TensorShape& input_shape = first_input.shape();
const TensorShape& input_shape = values[0].shape();
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
// concat_dim==0 allows concatenating a list of scalars into a vector. // concat_dim==0 allows concatenating a list of scalars into a vector.
@ -116,7 +123,7 @@ class ConcatBaseOp : public OpKernel {
} }
int64 output_concat_dim = 0; int64 output_concat_dim = 0;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
const auto& in = values[i]; const auto& in = c->input(values_input_start_index_ + i);
OP_REQUIRES( OP_REQUIRES(
c, in.dims() == input_dims, c, in.dims() == input_dims,
errors::InvalidArgument( errors::InvalidArgument(
@ -137,7 +144,7 @@ class ConcatBaseOp : public OpKernel {
if (in.NumElements() > 0) { if (in.NumElements() > 0) {
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
} }
// TODO(rmlarsen): Remove check once !allow_legacy_scalars()? // TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
@ -170,6 +177,12 @@ class ConcatBaseOp : public OpKernel {
ConcatCPU<T>(c->device(), inputs_flat, &output_flat); ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
} }
} }
private:
const char* const axis_attribute_name_;
int axis_input_index_;
int values_input_start_index_;
int values_input_end_index_;
}; };
template <typename Device, typename T> template <typename Device, typename T>

View File

@ -44,12 +44,9 @@ class SelectOp : public OpKernel {
explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {} explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor* cond; const Tensor* cond = &ctx->input(0);
const Tensor* then; const Tensor* then = &ctx->input(1);
const Tensor* else_; const Tensor* else_ = &ctx->input(2);
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
if (TensorShapeUtils::IsScalar(cond->shape())) { if (TensorShapeUtils::IsScalar(cond->shape())) {
ComputeScalar(ctx, cond, then, else_); ComputeScalar(ctx, cond, then, else_);
@ -149,12 +146,9 @@ class SelectV2Op : public OpKernel {
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {} explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor* cond; const Tensor* cond = &ctx->input(0);
const Tensor* then; const Tensor* then = &ctx->input(1);
const Tensor* else_; const Tensor* else_ = &ctx->input(2);
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
// This matches the behavior of numpy. // This matches the behavior of numpy.
@ -260,7 +254,6 @@ class SelectV2Op : public OpKernel {
ctx->input(1).shape().DebugString(), " is not supported yet.")); ctx->input(1).shape().DebugString(), " is not supported yet."));
break; break;
} }
return;
} }
private: private:

View File

@ -50,20 +50,10 @@ class PackOp : public OpKernel {
} }
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
OpInputList values; const int num = num_inputs();
OP_REQUIRES_OK(c, c->input_list("values", &values)); const Tensor& first_input = c->input(0);
const int num = values.size();
// Verify that all input shapes match int expanded_num_dims = first_input.dims() + 1;
for (int i = 1; i < num; i++) {
OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
errors::InvalidArgument(
"Shapes of all inputs must match: values[0].shape = ",
values[0].shape().DebugString(), " != values[", i,
"].shape = ", values[i].shape().DebugString()));
}
int expanded_num_dims = values[0].dims() + 1;
int axis = axis_; int axis = axis_;
if (axis < 0) axis += expanded_num_dims; if (axis < 0) axis += expanded_num_dims;
@ -72,13 +62,13 @@ class PackOp : public OpKernel {
-expanded_num_dims, ", ", -expanded_num_dims, ", ",
expanded_num_dims, ")")); expanded_num_dims, ")"));
TensorShape output_shape(values[0].shape()); TensorShape output_shape(first_input.shape());
output_shape.InsertDim(axis, num); output_shape.InsertDim(axis, num);
// In the num = 1 case, just reshape the input // In the num = 1 case, just reshape the input
if (num == 1) { if (num == 1) {
Tensor output; Tensor output;
CHECK(output.CopyFrom(values[0], output_shape)); CHECK(output.CopyFrom(first_input, output_shape));
c->set_output(0, output); c->set_output(0, output);
return; return;
} }
@ -109,8 +99,15 @@ class PackOp : public OpKernel {
ConstMatrixVector inputs_flat; ConstMatrixVector inputs_flat;
inputs_flat.reserve(num); inputs_flat.reserve(num);
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
const Tensor& input = c->input(i);
OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()),
errors::InvalidArgument(
"Shapes of all inputs must match: values[0].shape = ",
first_input.shape().DebugString(), " != values[", i,
"].shape = ", input.shape().DebugString()));
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
values[i].shaped<T, 2>({before_dim, after_dim}))); input.shaped<T, 2>({before_dim, after_dim})));
} }
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value) { if (std::is_same<Device, GPUDevice>::value) {