Switch ConcatOp, PackOp, SelectOp, and SelectV2Op to access inputs by index (not name).
These kernels are heavily used, and the name resolution on each invocation causes non-trivial overhead. PiperOrigin-RevId: 297267975 Change-Id: Id69de0e2cff3622e992389c16e020a5da3141462
This commit is contained in:
parent
8c97290ba3
commit
26a24de29b
@ -48,53 +48,60 @@ class ConcatBaseOp : public OpKernel {
|
|||||||
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
|
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
|
||||||
ConstMatrixVector;
|
ConstMatrixVector;
|
||||||
|
|
||||||
explicit ConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
|
explicit ConcatBaseOp(OpKernelConstruction* c)
|
||||||
|
: OpKernel(c),
|
||||||
|
axis_attribute_name_(AxisArgName == NAME_IS_AXIS
|
||||||
|
? "axis"
|
||||||
|
: AxisArgName == NAME_IS_CONCAT_DIM
|
||||||
|
? "concat_dim"
|
||||||
|
: "<invalid>") {
|
||||||
|
int unused;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, InputRange(axis_attribute_name_, &axis_input_index_, &unused));
|
||||||
|
OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_,
|
||||||
|
&values_input_end_index_));
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
const Tensor* concat_dim_tensor;
|
const Tensor& concat_dim_tensor = c->input(axis_input_index_);
|
||||||
const char* axis_attribute_name =
|
|
||||||
AxisArgName == NAME_IS_AXIS
|
|
||||||
? "axis"
|
|
||||||
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
|
|
||||||
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
|
|
||||||
// TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
|
// TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
|
||||||
OP_REQUIRES(c,
|
OP_REQUIRES(c,
|
||||||
(TensorShapeUtils::IsScalar(concat_dim_tensor->shape()) ||
|
(TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) ||
|
||||||
(TensorShapeUtils::IsVector(concat_dim_tensor->shape()) &&
|
(TensorShapeUtils::IsVector(concat_dim_tensor.shape()) &&
|
||||||
concat_dim_tensor->shape().dim_size(0) == 1)),
|
concat_dim_tensor.shape().dim_size(0) == 1)),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
axis_attribute_name,
|
axis_attribute_name_,
|
||||||
" tensor should be a scalar integer, but got shape ",
|
" tensor should be a scalar integer, but got shape ",
|
||||||
concat_dim_tensor->shape().DebugString()));
|
concat_dim_tensor.shape().DebugString()));
|
||||||
int64 concat_dim;
|
int64 concat_dim;
|
||||||
// In case of ConcatV2, "axis" could be int32 or int64
|
// In case of ConcatV2, "axis" could be int32 or int64
|
||||||
if (AxisArgName == NAME_IS_AXIS) {
|
if (AxisArgName == NAME_IS_AXIS) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c,
|
c,
|
||||||
(concat_dim_tensor->dtype() == DT_INT32 ||
|
(concat_dim_tensor.dtype() == DT_INT32 ||
|
||||||
concat_dim_tensor->dtype() == DT_INT64),
|
concat_dim_tensor.dtype() == DT_INT64),
|
||||||
errors::InvalidArgument(axis_attribute_name,
|
errors::InvalidArgument(axis_attribute_name_,
|
||||||
" tensor should be int32 or int64, but got ",
|
" tensor should be int32 or int64, but got ",
|
||||||
DataTypeString(concat_dim_tensor->dtype())));
|
DataTypeString(concat_dim_tensor.dtype())));
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES(c, (concat_dim_tensor->dtype() == DT_INT32),
|
OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
axis_attribute_name, " tensor should be int32, but got ",
|
axis_attribute_name_, " tensor should be int32, but got ",
|
||||||
DataTypeString(concat_dim_tensor->dtype())));
|
DataTypeString(concat_dim_tensor.dtype())));
|
||||||
}
|
}
|
||||||
if (concat_dim_tensor->dtype() == DT_INT32) {
|
if (concat_dim_tensor.dtype() == DT_INT32) {
|
||||||
concat_dim =
|
concat_dim =
|
||||||
internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
|
internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
|
||||||
} else {
|
} else {
|
||||||
concat_dim =
|
concat_dim =
|
||||||
internal::SubtleMustCopy(concat_dim_tensor->scalar<int64>()());
|
internal::SubtleMustCopy(concat_dim_tensor.scalar<int64>()());
|
||||||
}
|
}
|
||||||
|
|
||||||
OpInputList values;
|
const int N = values_input_end_index_ - values_input_start_index_;
|
||||||
OP_REQUIRES_OK(c, c->input_list("values", &values));
|
const Tensor& first_input = c->input(values_input_start_index_);
|
||||||
const int N = values.size();
|
const int input_dims = first_input.dims();
|
||||||
const int input_dims = values[0].dims();
|
const TensorShape& input_shape = first_input.shape();
|
||||||
const TensorShape& input_shape = values[0].shape();
|
|
||||||
|
|
||||||
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
|
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
|
||||||
// concat_dim==0 allows concatenating a list of scalars into a vector.
|
// concat_dim==0 allows concatenating a list of scalars into a vector.
|
||||||
@ -116,7 +123,7 @@ class ConcatBaseOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
int64 output_concat_dim = 0;
|
int64 output_concat_dim = 0;
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
const auto& in = values[i];
|
const auto& in = c->input(values_input_start_index_ + i);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, in.dims() == input_dims,
|
c, in.dims() == input_dims,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -137,7 +144,7 @@ class ConcatBaseOp : public OpKernel {
|
|||||||
if (in.NumElements() > 0) {
|
if (in.NumElements() > 0) {
|
||||||
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
|
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
|
||||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||||
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
||||||
}
|
}
|
||||||
// TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
|
// TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
|
||||||
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
|
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
|
||||||
@ -170,6 +177,12 @@ class ConcatBaseOp : public OpKernel {
|
|||||||
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
|
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const char* const axis_attribute_name_;
|
||||||
|
int axis_input_index_;
|
||||||
|
int values_input_start_index_;
|
||||||
|
int values_input_end_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
|
@ -44,12 +44,9 @@ class SelectOp : public OpKernel {
|
|||||||
explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
|
explicit SelectOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor* cond;
|
const Tensor* cond = &ctx->input(0);
|
||||||
const Tensor* then;
|
const Tensor* then = &ctx->input(1);
|
||||||
const Tensor* else_;
|
const Tensor* else_ = &ctx->input(2);
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
|
|
||||||
|
|
||||||
if (TensorShapeUtils::IsScalar(cond->shape())) {
|
if (TensorShapeUtils::IsScalar(cond->shape())) {
|
||||||
ComputeScalar(ctx, cond, then, else_);
|
ComputeScalar(ctx, cond, then, else_);
|
||||||
@ -149,12 +146,9 @@ class SelectV2Op : public OpKernel {
|
|||||||
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
|
explicit SelectV2Op(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor* cond;
|
const Tensor* cond = &ctx->input(0);
|
||||||
const Tensor* then;
|
const Tensor* then = &ctx->input(1);
|
||||||
const Tensor* else_;
|
const Tensor* else_ = &ctx->input(2);
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("condition", &cond));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("t", &then));
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("e", &else_));
|
|
||||||
|
|
||||||
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
|
// The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()),
|
||||||
// This matches the behavior of numpy.
|
// This matches the behavior of numpy.
|
||||||
@ -260,7 +254,6 @@ class SelectV2Op : public OpKernel {
|
|||||||
ctx->input(1).shape().DebugString(), " is not supported yet."));
|
ctx->input(1).shape().DebugString(), " is not supported yet."));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -50,20 +50,10 @@ class PackOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
OpInputList values;
|
const int num = num_inputs();
|
||||||
OP_REQUIRES_OK(c, c->input_list("values", &values));
|
const Tensor& first_input = c->input(0);
|
||||||
const int num = values.size();
|
|
||||||
|
|
||||||
// Verify that all input shapes match
|
int expanded_num_dims = first_input.dims() + 1;
|
||||||
for (int i = 1; i < num; i++) {
|
|
||||||
OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Shapes of all inputs must match: values[0].shape = ",
|
|
||||||
values[0].shape().DebugString(), " != values[", i,
|
|
||||||
"].shape = ", values[i].shape().DebugString()));
|
|
||||||
}
|
|
||||||
|
|
||||||
int expanded_num_dims = values[0].dims() + 1;
|
|
||||||
int axis = axis_;
|
int axis = axis_;
|
||||||
if (axis < 0) axis += expanded_num_dims;
|
if (axis < 0) axis += expanded_num_dims;
|
||||||
|
|
||||||
@ -72,13 +62,13 @@ class PackOp : public OpKernel {
|
|||||||
-expanded_num_dims, ", ",
|
-expanded_num_dims, ", ",
|
||||||
expanded_num_dims, ")"));
|
expanded_num_dims, ")"));
|
||||||
|
|
||||||
TensorShape output_shape(values[0].shape());
|
TensorShape output_shape(first_input.shape());
|
||||||
output_shape.InsertDim(axis, num);
|
output_shape.InsertDim(axis, num);
|
||||||
|
|
||||||
// In the num = 1 case, just reshape the input
|
// In the num = 1 case, just reshape the input
|
||||||
if (num == 1) {
|
if (num == 1) {
|
||||||
Tensor output;
|
Tensor output;
|
||||||
CHECK(output.CopyFrom(values[0], output_shape));
|
CHECK(output.CopyFrom(first_input, output_shape));
|
||||||
c->set_output(0, output);
|
c->set_output(0, output);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -109,8 +99,15 @@ class PackOp : public OpKernel {
|
|||||||
ConstMatrixVector inputs_flat;
|
ConstMatrixVector inputs_flat;
|
||||||
inputs_flat.reserve(num);
|
inputs_flat.reserve(num);
|
||||||
for (int i = 0; i < num; ++i) {
|
for (int i = 0; i < num; ++i) {
|
||||||
|
const Tensor& input = c->input(i);
|
||||||
|
OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Shapes of all inputs must match: values[0].shape = ",
|
||||||
|
first_input.shape().DebugString(), " != values[", i,
|
||||||
|
"].shape = ", input.shape().DebugString()));
|
||||||
|
|
||||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||||
values[i].shaped<T, 2>({before_dim, after_dim})));
|
input.shaped<T, 2>({before_dim, after_dim})));
|
||||||
}
|
}
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
if (std::is_same<Device, GPUDevice>::value) {
|
if (std::is_same<Device, GPUDevice>::value) {
|
||||||
|
Loading…
Reference in New Issue
Block a user