Added support of runtime weights for convolution transposed in model_builder.

PiperOrigin-RevId: 342977892
Change-Id: I9f07336c5dc9adc431626d54bdeb64a07cf99761
This commit is contained in:
Raman Sarokin 2020-11-17 17:10:21 -08:00 committed by TensorFlower Gardener
parent 28c5e97db9
commit 8117e74787
3 changed files with 24 additions and 5 deletions

View File

@ -2161,18 +2161,20 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
const int runtime_inputs =
GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs != 1) {
if (runtime_inputs > 2) {
return absl::InternalError(
absl::StrCat("Expected 1 runtime input tensor, but node has ",
absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
runtime_inputs, " runtime inputs."));
}
const int runtime_outputs = NumOutputs(tflite_node);
if (runtime_outputs != 1) {
return absl::InternalError(
absl::StrCat("Expected 1 runtime output tensor, but node has ",
absl::StrCat("Expected 1 output tensor(s), but node has ",
runtime_outputs, " runtime outputs."));
}
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
if (runtime_inputs == 1) {
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
}
const TfLiteTransposeConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(
@ -2200,7 +2202,15 @@ class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
attr.stride = tf_options
? HW(tf_options->stride_height, tf_options->stride_width)
: HW(1, 1);
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
if (runtime_inputs == 2) {
RETURN_IF_ERROR(reader->AddInput(node, 1));
auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
weights_shape.w, weights_shape.c);
} else { // runtime_inputs == 1;
RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
}
reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional
UpdatePadding(tf_options->padding,

View File

@ -37,6 +37,10 @@ class ConvolutionTransposedBuffers : public NodeShader {
public:
absl::Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
if (ctx.input_shapes.size() != 1) {
return absl::UnimplementedError(
"Convolution Transposed does not support more than 1 runtime tensor");
}
const auto& attr =
absl::any_cast<const ConvolutionTransposedAttributes&>(ctx.op_attr);
auto weights = attr.weights.shape;

View File

@ -260,6 +260,11 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
break;
}
case OperationType::CONVOLUTION_TRANSPOSED:
if (graph.FindInputs(node->id).size() != 1) {
return absl::UnimplementedError(
"Convolution Transposed does not support more than 1 runtime "
"tensor");
}
*tasks = SelectConvolutionTransposed(
node_id, inputs[0], outputs[0],
absl::any_cast<ConvolutionTransposedAttributes>(