Fix SplitV support.

PiperOrigin-RevId: 253155146
This commit is contained in:
Benoit Jacob 2019-06-13 20:00:06 -07:00 committed by TensorFlower Gardener
parent 81e6aec1bc
commit 517381ecd9
3 changed files with 18 additions and 17 deletions

View File

@ -1145,7 +1145,8 @@ void ConvertSplitOperator(const Model& model,
for (const auto& input : src_op.inputs) {
*split_op->add_input() = input;
}
(*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
(*split_op->mutable_attr())["T"].set_type(
GetTensorFlowDataType(model, src_op.outputs[0]));
(*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
CHECK(split_dim_array.buffer);
@ -1168,17 +1169,11 @@ void ConvertSplitVOperator(const Model& model,
*split_v_op->add_input() = input;
}
(*split_v_op->mutable_attr())["T"].set_type(
GetTensorFlowDataType(model, src_op.inputs[0]));
GetTensorFlowDataType(model, src_op.outputs[0]));
(*split_v_op->mutable_attr())["Tlen"].set_type(
GetTensorFlowDataType(model, src_op.inputs[1]));
(*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
const auto& split_dim_array = model.GetArray(src_op.inputs[1]);
CHECK(split_dim_array.buffer);
CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
const auto& split_dim_data =
split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
CHECK_EQ(split_dim_data.size(), 1);
const int split_dim = split_dim_data[0];
CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
tensorflow_graph);
ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph);
}
void ConvertCastOperator(const Model& model, const CastOperator& src_op,

View File

@ -46,12 +46,12 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
type == OperatorType::kResizeBilinear ||
type == OperatorType::kSplit || type == OperatorType::kSub ||
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
type == OperatorType::kTanh || type == OperatorType::kMul ||
type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum ||
type == OperatorType::kSpaceToBatchND ||
type == OperatorType::kSplit || type == OperatorType::kSplitV ||
type == OperatorType::kSub || type == OperatorType::kSqueeze ||
type == OperatorType::kPad || type == OperatorType::kPadV2 ||
type == OperatorType::kReshape || type == OperatorType::kTanh ||
type == OperatorType::kMul || type == OperatorType::kBatchToSpaceND ||
type == OperatorType::kSum || type == OperatorType::kSpaceToBatchND ||
type == OperatorType::kSpaceToDepth ||
type == OperatorType::kStridedSlice ||
type == OperatorType::kDepthToSpace ||

View File

@ -2135,6 +2135,12 @@ void AddExtraOutputs(Model* model) {
// Now add operator outputs so that all arrays that are consumed,
// are produced.
for (const string& consumed_array : consumed_arrays) {
// Test if consumed_array is already the output of some op.
// This has occurred in a model where separate nodes had names of the form
// foo:$i with the same base name foo.
if (GetOpWithOutput(*model, consumed_array)) {
continue;
}
// Split the consumed array name into the form name:output_index.
const std::vector<string>& split = absl::StrSplit(consumed_array, ':');
// If not of the form name:output_index, then this is not an additional