diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index f223a5c0128..9b2f77e40ad 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -2161,18 +2161,20 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); const int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node); - if (runtime_inputs != 1) { + if (runtime_inputs > 2) { return absl::InternalError( - absl::StrCat("Expected 1 runtime input tensor, but node has ", + absl::StrCat("Expected 1 or 2 input tensor(s), but node has ", runtime_inputs, " runtime inputs.")); } const int runtime_outputs = NumOutputs(tflite_node); if (runtime_outputs != 1) { return absl::InternalError( - absl::StrCat("Expected 1 runtime output tensor, but node has ", + absl::StrCat("Expected 1 output tensor(s), but node has ", runtime_outputs, " runtime outputs.")); } - RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + if (runtime_inputs == 1) { + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + } const TfLiteTransposeConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR( @@ -2200,7 +2202,15 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser { attr.stride = tf_options ? HW(tf_options->stride_height, tf_options->stride_width) : HW(1, 1); - RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + const int runtime_inputs = reader->GetNumberOfRuntimeInputs(); + if (runtime_inputs == 2) { + RETURN_IF_ERROR(reader->AddInput(node, 1)); + auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape; + attr.weights.shape = OHWI(weights_shape.b, weights_shape.h, + weights_shape.w, weights_shape.c); + } else { // runtime_inputs == 1; + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + } reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional UpdatePadding(tf_options->padding, diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc index 4b848435df6..91603bd5b3f 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.cc @@ -37,6 +37,10 @@ class ConvolutionTransposedBuffers : public NodeShader { public: absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { + if (ctx.input_shapes.size() != 1) { + return absl::UnimplementedError( + "Convolution Transposed does not support more than 1 runtime tensor"); + } const auto& attr = absl::any_cast(ctx.op_attr); auto weights = attr.weights.shape; diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 561b0828cd3..8062fb16d7c 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -260,6 +260,11 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, break; } case OperationType::CONVOLUTION_TRANSPOSED: + if (graph.FindInputs(node->id).size() != 1) { + return absl::UnimplementedError( + "Convolution Transposed does not support more than 1 runtime " + "tensor"); + } *tasks = SelectConvolutionTransposed( node_id, inputs[0], outputs[0], absl::any_cast(