Fix SplitV support.
PiperOrigin-RevId: 253155146
This commit is contained in:
parent
81e6aec1bc
commit
517381ecd9
@ -1145,7 +1145,8 @@ void ConvertSplitOperator(const Model& model,
|
||||
for (const auto& input : src_op.inputs) {
|
||||
*split_op->add_input() = input;
|
||||
}
|
||||
(*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
|
||||
(*split_op->mutable_attr())["T"].set_type(
|
||||
GetTensorFlowDataType(model, src_op.outputs[0]));
|
||||
(*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
|
||||
const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
|
||||
CHECK(split_dim_array.buffer);
|
||||
@ -1168,17 +1169,11 @@ void ConvertSplitVOperator(const Model& model,
|
||||
*split_v_op->add_input() = input;
|
||||
}
|
||||
(*split_v_op->mutable_attr())["T"].set_type(
|
||||
GetTensorFlowDataType(model, src_op.inputs[0]));
|
||||
GetTensorFlowDataType(model, src_op.outputs[0]));
|
||||
(*split_v_op->mutable_attr())["Tlen"].set_type(
|
||||
GetTensorFlowDataType(model, src_op.inputs[1]));
|
||||
(*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
|
||||
const auto& split_dim_array = model.GetArray(src_op.inputs[1]);
|
||||
CHECK(split_dim_array.buffer);
|
||||
CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
|
||||
const auto& split_dim_data =
|
||||
split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
|
||||
CHECK_EQ(split_dim_data.size(), 1);
|
||||
const int split_dim = split_dim_data[0];
|
||||
CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
|
||||
tensorflow_graph);
|
||||
ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph);
|
||||
}
|
||||
|
||||
void ConvertCastOperator(const Model& model, const CastOperator& src_op,
|
||||
|
@ -46,12 +46,12 @@ bool SupportsQuantization(const Operator& op) {
|
||||
type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
|
||||
type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
|
||||
type == OperatorType::kResizeBilinear ||
|
||||
type == OperatorType::kSplit || type == OperatorType::kSub ||
|
||||
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
|
||||
type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
|
||||
type == OperatorType::kTanh || type == OperatorType::kMul ||
|
||||
type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum ||
|
||||
type == OperatorType::kSpaceToBatchND ||
|
||||
type == OperatorType::kSplit || type == OperatorType::kSplitV ||
|
||||
type == OperatorType::kSub || type == OperatorType::kSqueeze ||
|
||||
type == OperatorType::kPad || type == OperatorType::kPadV2 ||
|
||||
type == OperatorType::kReshape || type == OperatorType::kTanh ||
|
||||
type == OperatorType::kMul || type == OperatorType::kBatchToSpaceND ||
|
||||
type == OperatorType::kSum || type == OperatorType::kSpaceToBatchND ||
|
||||
type == OperatorType::kSpaceToDepth ||
|
||||
type == OperatorType::kStridedSlice ||
|
||||
type == OperatorType::kDepthToSpace ||
|
||||
|
@ -2135,6 +2135,12 @@ void AddExtraOutputs(Model* model) {
|
||||
// Now add operator outputs so that all arrays that are consumed,
|
||||
// are produced.
|
||||
for (const string& consumed_array : consumed_arrays) {
|
||||
// Test if consumed_array is already the output of some op.
|
||||
// This has occurred in a model where separate nodes had names of the form
|
||||
// foo:$i with the same base name foo.
|
||||
if (GetOpWithOutput(*model, consumed_array)) {
|
||||
continue;
|
||||
}
|
||||
// Split the consumed array name into the form name:output_index.
|
||||
const std::vector<string>& split = absl::StrSplit(consumed_array, ':');
|
||||
// If not of the form name:output_index, then this is not an additional
|
||||
|
Loading…
Reference in New Issue
Block a user