Fix SplitV support.

PiperOrigin-RevId: 253155146
This commit is contained in:
Benoit Jacob 2019-06-13 20:00:06 -07:00 committed by TensorFlower Gardener
parent 81e6aec1bc
commit 517381ecd9
3 changed files with 18 additions and 17 deletions

View File

@ -1145,7 +1145,8 @@ void ConvertSplitOperator(const Model& model,
for (const auto& input : src_op.inputs) { for (const auto& input : src_op.inputs) {
*split_op->add_input() = input; *split_op->add_input() = input;
} }
(*split_op->mutable_attr())["T"].set_type(DT_FLOAT); (*split_op->mutable_attr())["T"].set_type(
GetTensorFlowDataType(model, src_op.outputs[0]));
(*split_op->mutable_attr())["num_split"].set_i(src_op.num_split); (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
const auto& split_dim_array = model.GetArray(src_op.inputs[0]); const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
CHECK(split_dim_array.buffer); CHECK(split_dim_array.buffer);
@ -1168,17 +1169,11 @@ void ConvertSplitVOperator(const Model& model,
*split_v_op->add_input() = input; *split_v_op->add_input() = input;
} }
(*split_v_op->mutable_attr())["T"].set_type( (*split_v_op->mutable_attr())["T"].set_type(
GetTensorFlowDataType(model, src_op.inputs[0])); GetTensorFlowDataType(model, src_op.outputs[0]));
(*split_v_op->mutable_attr())["Tlen"].set_type(
GetTensorFlowDataType(model, src_op.inputs[1]));
(*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split); (*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
const auto& split_dim_array = model.GetArray(src_op.inputs[1]); ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph);
CHECK(split_dim_array.buffer);
CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
const auto& split_dim_data =
split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
CHECK_EQ(split_dim_data.size(), 1);
const int split_dim = split_dim_data[0];
CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
tensorflow_graph);
} }
void ConvertCastOperator(const Model& model, const CastOperator& src_op, void ConvertCastOperator(const Model& model, const CastOperator& src_op,

View File

@ -46,12 +46,12 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kLogistic || type == OperatorType::kSoftmax || type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
type == OperatorType::kResizeBilinear || type == OperatorType::kResizeBilinear ||
type == OperatorType::kSplit || type == OperatorType::kSub || type == OperatorType::kSplit || type == OperatorType::kSplitV ||
type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kSub || type == OperatorType::kSqueeze ||
type == OperatorType::kPadV2 || type == OperatorType::kReshape || type == OperatorType::kPad || type == OperatorType::kPadV2 ||
type == OperatorType::kTanh || type == OperatorType::kMul || type == OperatorType::kReshape || type == OperatorType::kTanh ||
type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum || type == OperatorType::kMul || type == OperatorType::kBatchToSpaceND ||
type == OperatorType::kSpaceToBatchND || type == OperatorType::kSum || type == OperatorType::kSpaceToBatchND ||
type == OperatorType::kSpaceToDepth || type == OperatorType::kSpaceToDepth ||
type == OperatorType::kStridedSlice || type == OperatorType::kStridedSlice ||
type == OperatorType::kDepthToSpace || type == OperatorType::kDepthToSpace ||

View File

@ -2135,6 +2135,12 @@ void AddExtraOutputs(Model* model) {
// Now add operator outputs so that all arrays that are consumed, // Now add operator outputs so that all arrays that are consumed,
// are produced. // are produced.
for (const string& consumed_array : consumed_arrays) { for (const string& consumed_array : consumed_arrays) {
// Test if consumed_array is already the output of some op.
// This has occurred in a model where separate nodes had names of the form
// foo:$i with the same base name foo.
if (GetOpWithOutput(*model, consumed_array)) {
continue;
}
// Split the consumed array name into the form name:output_index. // Split the consumed array name into the form name:output_index.
const std::vector<string>& split = absl::StrSplit(consumed_array, ':'); const std::vector<string>& split = absl::StrSplit(consumed_array, ':');
// If not of the form name:output_index, then this is not an additional // If not of the form name:output_index, then this is not an additional