Added support of runtime weights for convolution transposed in model_builder.
PiperOrigin-RevId: 342977892 Change-Id: I9f07336c5dc9adc431626d54bdeb64a07cf99761
This commit is contained in:
parent
28c5e97db9
commit
8117e74787
@ -2161,18 +2161,20 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
|
||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
|
||||
const int runtime_inputs =
|
||||
GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
||||
if (runtime_inputs != 1) {
|
||||
if (runtime_inputs > 2) {
|
||||
return absl::InternalError(
|
||||
absl::StrCat("Expected 1 runtime input tensor, but node has ",
|
||||
absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
|
||||
runtime_inputs, " runtime inputs."));
|
||||
}
|
||||
const int runtime_outputs = NumOutputs(tflite_node);
|
||||
if (runtime_outputs != 1) {
|
||||
return absl::InternalError(
|
||||
absl::StrCat("Expected 1 runtime output tensor, but node has ",
|
||||
absl::StrCat("Expected 1 output tensor(s), but node has ",
|
||||
runtime_outputs, " runtime outputs."));
|
||||
}
|
||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||
if (runtime_inputs == 1) {
|
||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||
}
|
||||
const TfLiteTransposeConvParams* tf_options;
|
||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||
RETURN_IF_ERROR(
|
||||
@ -2200,7 +2202,15 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
|
||||
attr.stride = tf_options
|
||||
? HW(tf_options->stride_height, tf_options->stride_width)
|
||||
: HW(1, 1);
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
||||
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
|
||||
if (runtime_inputs == 2) {
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 1));
|
||||
auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
|
||||
attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
|
||||
weights_shape.w, weights_shape.c);
|
||||
} else { // runtime_inputs == 1;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
|
||||
}
|
||||
reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional
|
||||
|
||||
UpdatePadding(tf_options->padding,
|
||||
|
||||
@ -37,6 +37,10 @@ class ConvolutionTransposedBuffers : public NodeShader {
|
||||
public:
|
||||
absl::Status GenerateCode(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const final {
|
||||
if (ctx.input_shapes.size() != 1) {
|
||||
return absl::UnimplementedError(
|
||||
"Convolution Transposed does not support more than 1 runtime tensor");
|
||||
}
|
||||
const auto& attr =
|
||||
absl::any_cast<const ConvolutionTransposedAttributes&>(ctx.op_attr);
|
||||
auto weights = attr.weights.shape;
|
||||
|
||||
@ -260,6 +260,11 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
break;
|
||||
}
|
||||
case OperationType::CONVOLUTION_TRANSPOSED:
|
||||
if (graph.FindInputs(node->id).size() != 1) {
|
||||
return absl::UnimplementedError(
|
||||
"Convolution Transposed does not support more than 1 runtime "
|
||||
"tensor");
|
||||
}
|
||||
*tasks = SelectConvolutionTransposed(
|
||||
node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<ConvolutionTransposedAttributes>(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user