Switch ConcatOp, PackOp, SelectOp, and SelectV2Op to access inputs by index (not name).

These kernels are heavily used, and the name resolution on each invocation causes non-trivial overhead.

PiperOrigin-RevId: 297267975
Change-Id: Id69de0e2cff3622e992389c16e020a5da3141462
This commit is contained in:
Derek Murray 2020-02-25 20:17:46 -08:00 committed by TensorFlower Gardener
parent 8c97290ba3
commit 26a24de29b
3 changed files with 61 additions and 58 deletions

View File

@ -48,53 +48,60 @@ class ConcatBaseOp : public OpKernel {
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
ConstMatrixVector;
explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
explicit ConcatBaseOp(OpKernelConstruction* c)
: OpKernel(c),
axis_attribute_name_(AxisArgName == NAME_IS_AXIS
? "axis"
: AxisArgName == NAME_IS_CONCAT_DIM
? "concat_dim"
: "<invalid>") {
int unused;
OP_REQUIRES_OK(
c, InputRange(axis_attribute_name_, &axis_input_index_, &unused));
OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_,
&values_input_end_index_));
}
void Compute(OpKernelContext* c) override {
const Tensor* concat_dim_tensor;
const char* axis_attribute_name =
AxisArgName == NAME_IS_AXIS
? "axis"
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
const Tensor& concat_dim_tensor = c->input(axis_input_index_);
// TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
OP_REQUIRES(c,
(TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) ||
(TensorShapeUtils::IsVector(concat_dim_tensor->shape()) &&
concat_dim_tensor->shape().dim_size(0) == 1)),
(TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) ||
(TensorShapeUtils::IsVector(concat_dim_tensor.shape()) &&
concat_dim_tensor.shape().dim_size(0) == 1)),
errors::InvalidArgument(
axis_attribute_name,
axis_attribute_name_,
" tensor should be a scalar integer, but got shape ",
concat_dim_tensor->shape().DebugString()));
concat_dim_tensor.shape().DebugString()));
int64 concat_dim;
// In case of ConcatV2, "axis" could be int32 or int64
if (AxisArgName == NAME_IS_AXIS) {
OP_REQUIRES(
c,
(concat_dim_tensor->dtype() == DT_INT32 ||
concat_dim_tensor->dtype() == DT_INT64),
errors::InvalidArgument(axis_attribute_name,
(concat_dim_tensor.dtype() == DT_INT32 ||
concat_dim_tensor.dtype() == DT_INT64),
errors::InvalidArgument(axis_attribute_name_,
" tensor should be int32 or int64, but got ",
DataTypeString(concat_dim_tensor->dtype())));
DataTypeString(concat_dim_tensor.dtype())));
} else {
OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32),
OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32),
errors::InvalidArgument(
axis_attribute_name, " tensor should be int32, but got ",
DataTypeString(concat_dim_tensor->dtype())));
axis_attribute_name_, " tensor should be int32, but got ",
DataTypeString(concat_dim_tensor.dtype())));
}
if (concat_dim_tensor->dtype() == DT_INT32) {
if (concat_dim_tensor.dtype() == DT_INT32) {
concat_dim =
internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
} else {
concat_dim =
internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()());
internal::SubtleMustCopy(concat_dim_tensor.scalar<int64>()());
}
OpInputList values;
OP_REQUIRES_OK(c, c->input_list("values", &values));
const int N = values.size();
const int input_dims = values[0].dims();
const TensorShape& input_shape = values[0].shape();
const int N = values_input_end_index_ - values_input_start_index_;
const Tensor& first_input = c->input(values_input_start_index_);
const int input_dims = first_input.dims();
const TensorShape& input_shape = first_input.shape();
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
// concat_dim==0 allows concatenating a list of scalars into a vector.
@ -116,7 +123,7 @@ class ConcatBaseOp : public OpKernel {
}
int64 output_concat_dim = 0;
for (int i = 0; i < N; ++i) {
const auto& in = values[i];
const auto& in = c->input(values_input_start_index_ + i);
OP_REQUIRES(
c, in.dims() == input_dims,
errors::InvalidArgument(
@ -137,7 +144,7 @@ class ConcatBaseOp : public OpKernel {
if (in.NumElements() > 0) {
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
}
// TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
@ -170,6 +177,12 @@ class ConcatBaseOp : public OpKernel {
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}
}
private:
const char* const axis_attribute_name_;
int axis_input_index_;
int values_input_start_index_;
int values_input_end_index_;
};
template <typename Device, typename T>

View File

@ -44,12 +44,9 @@ class SelectOp : public OpKernel {
explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
const Tensor* cond;
const Tensor* then;
const Tensor* else_;
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
const Tensor* cond = &ctx->input(0);
const Tensor* then = &ctx->input(1);
const Tensor* else_ = &ctx->input(2);
if (TensorShapeUtils::IsScalar(cond->shape())) {
ComputeScalar(ctx, cond, then, else_);
@ -149,12 +146,9 @@ class SelectV2Op : public OpKernel {
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
const Tensor* cond;
const Tensor* then;
const Tensor* else_;
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
const Tensor* cond = &ctx->input(0);
const Tensor* then = &ctx->input(1);
const Tensor* else_ = &ctx->input(2);
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
// This matches the behavior of numpy.
@ -260,7 +254,6 @@ class SelectV2Op : public OpKernel {
ctx->input(1).shape().DebugString(), " is not supported yet."));
break;
}
return;
}
private:

View File

@ -50,20 +50,10 @@ class PackOp : public OpKernel {
}
void Compute(OpKernelContext* c) override {
OpInputList values;
OP_REQUIRES_OK(c, c->input_list("values", &values));
const int num = values.size();
const int num = num_inputs();
const Tensor& first_input = c->input(0);
// Verify that all input shapes match
for (int i = 1; i < num; i++) {
OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
errors::InvalidArgument(
"Shapes of all inputs must match: values[0].shape = ",
values[0].shape().DebugString(), " != values[", i,
"].shape = ", values[i].shape().DebugString()));
}
int expanded_num_dims = values[0].dims() + 1;
int expanded_num_dims = first_input.dims() + 1;
int axis = axis_;
if (axis < 0) axis += expanded_num_dims;
@ -72,13 +62,13 @@ class PackOp : public OpKernel {
-expanded_num_dims, ", ",
expanded_num_dims, ")"));
TensorShape output_shape(values[0].shape());
TensorShape output_shape(first_input.shape());
output_shape.InsertDim(axis, num);
// In the num = 1 case, just reshape the input
if (num == 1) {
Tensor output;
CHECK(output.CopyFrom(values[0], output_shape));
CHECK(output.CopyFrom(first_input, output_shape));
c->set_output(0, output);
return;
}
@ -109,8 +99,15 @@ class PackOp : public OpKernel {
ConstMatrixVector inputs_flat;
inputs_flat.reserve(num);
for (int i = 0; i < num; ++i) {
const Tensor& input = c->input(i);
OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()),
errors::InvalidArgument(
"Shapes of all inputs must match: values[0].shape = ",
first_input.shape().DebugString(), " != values[", i,
"].shape = ", input.shape().DebugString()));
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
values[i].shaped<T, 2>({before_dim, after_dim})));
input.shaped<T, 2>({before_dim, after_dim})));
}
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value) {