Switch ConcatOp, PackOp, SelectOp, and SelectV2Op to access inputs by index (not name).
These kernels are heavily used, and the name resolution on each invocation causes non-trivial overhead. PiperOrigin-RevId: 297267975 Change-Id: Id69de0e2cff3622e992389c16e020a5da3141462
This commit is contained in:
parent
8c97290ba3
commit
26a24de29b
@ -48,53 +48,60 @@ class ConcatBaseOp : public OpKernel {
|
||||
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
|
||||
ConstMatrixVector;
|
||||
|
||||
explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit ConcatBaseOp(OpKernelConstruction* c)
|
||||
: OpKernel(c),
|
||||
axis_attribute_name_(AxisArgName == NAME_IS_AXIS
|
||||
? "axis"
|
||||
: AxisArgName == NAME_IS_CONCAT_DIM
|
||||
? "concat_dim"
|
||||
: "<invalid>") {
|
||||
int unused;
|
||||
OP_REQUIRES_OK(
|
||||
c, InputRange(axis_attribute_name_, &axis_input_index_, &unused));
|
||||
OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_,
|
||||
&values_input_end_index_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const Tensor* concat_dim_tensor;
|
||||
const char* axis_attribute_name =
|
||||
AxisArgName == NAME_IS_AXIS
|
||||
? "axis"
|
||||
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
|
||||
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
|
||||
const Tensor& concat_dim_tensor = c->input(axis_input_index_);
|
||||
|
||||
// TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
|
||||
OP_REQUIRES(c,
|
||||
(TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) ||
|
||||
(TensorShapeUtils::IsVector(concat_dim_tensor->shape()) &&
|
||||
concat_dim_tensor->shape().dim_size(0) == 1)),
|
||||
(TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) ||
|
||||
(TensorShapeUtils::IsVector(concat_dim_tensor.shape()) &&
|
||||
concat_dim_tensor.shape().dim_size(0) == 1)),
|
||||
errors::InvalidArgument(
|
||||
axis_attribute_name,
|
||||
axis_attribute_name_,
|
||||
" tensor should be a scalar integer, but got shape ",
|
||||
concat_dim_tensor->shape().DebugString()));
|
||||
concat_dim_tensor.shape().DebugString()));
|
||||
int64 concat_dim;
|
||||
// In case of ConcatV2, "axis" could be int32 or int64
|
||||
if (AxisArgName == NAME_IS_AXIS) {
|
||||
OP_REQUIRES(
|
||||
c,
|
||||
(concat_dim_tensor->dtype() == DT_INT32 ||
|
||||
concat_dim_tensor->dtype() == DT_INT64),
|
||||
errors::InvalidArgument(axis_attribute_name,
|
||||
(concat_dim_tensor.dtype() == DT_INT32 ||
|
||||
concat_dim_tensor.dtype() == DT_INT64),
|
||||
errors::InvalidArgument(axis_attribute_name_,
|
||||
" tensor should be int32 or int64, but got ",
|
||||
DataTypeString(concat_dim_tensor->dtype())));
|
||||
DataTypeString(concat_dim_tensor.dtype())));
|
||||
} else {
|
||||
OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32),
|
||||
OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32),
|
||||
errors::InvalidArgument(
|
||||
axis_attribute_name, " tensor should be int32, but got ",
|
||||
DataTypeString(concat_dim_tensor->dtype())));
|
||||
axis_attribute_name_, " tensor should be int32, but got ",
|
||||
DataTypeString(concat_dim_tensor.dtype())));
|
||||
}
|
||||
if (concat_dim_tensor->dtype() == DT_INT32) {
|
||||
if (concat_dim_tensor.dtype() == DT_INT32) {
|
||||
concat_dim =
|
||||
internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
|
||||
internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
|
||||
} else {
|
||||
concat_dim =
|
||||
internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()());
|
||||
internal::SubtleMustCopy(concat_dim_tensor.scalar<int64>()());
|
||||
}
|
||||
|
||||
OpInputList values;
|
||||
OP_REQUIRES_OK(c, c->input_list("values", &values));
|
||||
const int N = values.size();
|
||||
const int input_dims = values[0].dims();
|
||||
const TensorShape& input_shape = values[0].shape();
|
||||
const int N = values_input_end_index_ - values_input_start_index_;
|
||||
const Tensor& first_input = c->input(values_input_start_index_);
|
||||
const int input_dims = first_input.dims();
|
||||
const TensorShape& input_shape = first_input.shape();
|
||||
|
||||
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
|
||||
// concat_dim==0 allows concatenating a list of scalars into a vector.
|
||||
@ -116,7 +123,7 @@ class ConcatBaseOp : public OpKernel {
|
||||
}
|
||||
int64 output_concat_dim = 0;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const auto& in = values[i];
|
||||
const auto& in = c->input(values_input_start_index_ + i);
|
||||
OP_REQUIRES(
|
||||
c, in.dims() == input_dims,
|
||||
errors::InvalidArgument(
|
||||
@ -137,7 +144,7 @@ class ConcatBaseOp : public OpKernel {
|
||||
if (in.NumElements() > 0) {
|
||||
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
|
||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
||||
in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
||||
}
|
||||
// TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
|
||||
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
|
||||
@ -170,6 +177,12 @@ class ConcatBaseOp : public OpKernel {
|
||||
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const char* const axis_attribute_name_;
|
||||
int axis_input_index_;
|
||||
int values_input_start_index_;
|
||||
int values_input_end_index_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
|
@ -44,12 +44,9 @@ class SelectOp : public OpKernel {
|
||||
explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* cond;
|
||||
const Tensor* then;
|
||||
const Tensor* else_;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
|
||||
const Tensor* cond = &ctx->input(0);
|
||||
const Tensor* then = &ctx->input(1);
|
||||
const Tensor* else_ = &ctx->input(2);
|
||||
|
||||
if (TensorShapeUtils::IsScalar(cond->shape())) {
|
||||
ComputeScalar(ctx, cond, then, else_);
|
||||
@ -149,12 +146,9 @@ class SelectV2Op : public OpKernel {
|
||||
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor* cond;
|
||||
const Tensor* then;
|
||||
const Tensor* else_;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
|
||||
const Tensor* cond = &ctx->input(0);
|
||||
const Tensor* then = &ctx->input(1);
|
||||
const Tensor* else_ = &ctx->input(2);
|
||||
|
||||
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
|
||||
// This matches the behavior of numpy.
|
||||
@ -260,7 +254,6 @@ class SelectV2Op : public OpKernel {
|
||||
ctx->input(1).shape().DebugString(), " is not supported yet."));
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -50,20 +50,10 @@ class PackOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
OpInputList values;
|
||||
OP_REQUIRES_OK(c, c->input_list("values", &values));
|
||||
const int num = values.size();
|
||||
const int num = num_inputs();
|
||||
const Tensor& first_input = c->input(0);
|
||||
|
||||
// Verify that all input shapes match
|
||||
for (int i = 1; i < num; i++) {
|
||||
OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
|
||||
errors::InvalidArgument(
|
||||
"Shapes of all inputs must match: values[0].shape = ",
|
||||
values[0].shape().DebugString(), " != values[", i,
|
||||
"].shape = ", values[i].shape().DebugString()));
|
||||
}
|
||||
|
||||
int expanded_num_dims = values[0].dims() + 1;
|
||||
int expanded_num_dims = first_input.dims() + 1;
|
||||
int axis = axis_;
|
||||
if (axis < 0) axis += expanded_num_dims;
|
||||
|
||||
@ -72,13 +62,13 @@ class PackOp : public OpKernel {
|
||||
-expanded_num_dims, ", ",
|
||||
expanded_num_dims, ")"));
|
||||
|
||||
TensorShape output_shape(values[0].shape());
|
||||
TensorShape output_shape(first_input.shape());
|
||||
output_shape.InsertDim(axis, num);
|
||||
|
||||
// In the num = 1 case, just reshape the input
|
||||
if (num == 1) {
|
||||
Tensor output;
|
||||
CHECK(output.CopyFrom(values[0], output_shape));
|
||||
CHECK(output.CopyFrom(first_input, output_shape));
|
||||
c->set_output(0, output);
|
||||
return;
|
||||
}
|
||||
@ -109,8 +99,15 @@ class PackOp : public OpKernel {
|
||||
ConstMatrixVector inputs_flat;
|
||||
inputs_flat.reserve(num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const Tensor& input = c->input(i);
|
||||
OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Shapes of all inputs must match: values[0].shape = ",
|
||||
first_input.shape().DebugString(), " != values[", i,
|
||||
"].shape = ", input.shape().DebugString()));
|
||||
|
||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||
values[i].shaped<T, 2>({before_dim, after_dim})));
|
||||
input.shaped<T, 2>({before_dim, after_dim})));
|
||||
}
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
if (std::is_same<Device, GPUDevice>::value) {
|
||||
|
Loading…
Reference in New Issue
Block a user