parent
7109ac9ac6
commit
0e4117f671
@ -716,48 +716,6 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(([1] == output_details[0]['shape']).all())
|
||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||
|
||||
def testRankAndShapeWithResizeInput(self):
|
||||
# This is a regression test to ensure `Rank` and `Shape` ops work well with
|
||||
# the reize input tensor API.
|
||||
input_tensor = array_ops.placeholder(
|
||||
shape=[None, 4, 4, 3], dtype=dtypes.float32)
|
||||
reshaped_tensor = array_ops.reshape(input_tensor, [1, -1])
|
||||
output_rank = array_ops.rank(reshaped_tensor, name='output_rank')
|
||||
output_shape = array_ops.shape(reshaped_tensor, name='output_shape')
|
||||
|
||||
sess = session.Session()
|
||||
converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
|
||||
[output_rank, output_shape])
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
|
||||
def verify_rank_and_shape(expected_shape):
|
||||
expected_rank = len(expected_shape)
|
||||
output_details = interpreter.get_output_details()
|
||||
self.assertEqual(2, len(output_details))
|
||||
self.assertEqual('output_rank', output_details[0]['name'])
|
||||
self.assertEqual(np.int32, output_details[0]['dtype'])
|
||||
self.assertEqual('output_shape', output_details[1]['name'])
|
||||
self.assertEqual(np.int32, output_details[1]['dtype'])
|
||||
self.assertEqual(
|
||||
expected_shape,
|
||||
interpreter.get_tensor(output_details[1]['index']).tolist())
|
||||
self.assertEqual(
|
||||
expected_rank,
|
||||
interpreter.get_tensor(output_details[0]['index']).tolist())
|
||||
|
||||
interpreter.allocate_tensors()
|
||||
interpreter.invoke()
|
||||
verify_rank_and_shape([1, 48])
|
||||
|
||||
input_details = interpreter.get_input_details()
|
||||
interpreter.resize_tensor_input(input_details[0]['index'], [1, 4, 4, 4])
|
||||
interpreter.allocate_tensors()
|
||||
interpreter.invoke()
|
||||
verify_rank_and_shape([1, 64])
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||
|
@ -199,7 +199,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRankOnlyForConstantInput)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
|
||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
|
||||
|
@ -1867,9 +1867,7 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
|
||||
CHECK_EQ(op->inputs.size(), 4);
|
||||
|
||||
const Array& output_shape_array = model->GetArray(op->inputs[1]);
|
||||
// Return if the shape array has unknown shape or has no data.
|
||||
if (!output_shape_array.has_shape()) return;
|
||||
if (!output_shape_array.buffer) return;
|
||||
CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
|
||||
|
||||
// Output should not go over four dimensions.
|
||||
|
@ -19,11 +19,9 @@ limitations under the License.
|
||||
|
||||
namespace toco {
|
||||
|
||||
namespace {
|
||||
|
||||
::tensorflow::Status ResolveConstantShapeOrRankImpl(
|
||||
Model* model, std::size_t op_index, bool only_for_constant_input,
|
||||
bool* modified) {
|
||||
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
const auto it = model->operators.begin() + op_index;
|
||||
const auto* op = it->get();
|
||||
@ -44,10 +42,6 @@ namespace {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (only_for_constant_input && !input_array.buffer) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (!output_array.has_shape()) {
|
||||
// Yield until the output shape has been resolved.
|
||||
return ::tensorflow::Status::OK();
|
||||
@ -78,23 +72,4 @@ namespace {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
return ResolveConstantShapeOrRankImpl(model, op_index,
|
||||
|
||||
/*only_for_constant_input=*/false,
|
||||
modified
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
::tensorflow::Status ResolveConstantShapeOrRankOnlyForConstantInput::Run(
|
||||
Model* model, std::size_t op_index, bool* modified) {
|
||||
return ResolveConstantShapeOrRankImpl(
|
||||
model, op_index, /*only_for_constant_input=*/true, modified);
|
||||
}
|
||||
|
||||
} // namespace toco
|
||||
|
@ -115,6 +115,7 @@ void MakeGeneralGraphTransformationsSet(
|
||||
transformations->Add(new ResolveStridedSliceAttributes);
|
||||
transformations->Add(new ResolveSliceAttributes);
|
||||
transformations->Add(new ResolveReduceAttributes);
|
||||
transformations->Add(new ResolveConstantShapeOrRank);
|
||||
transformations->Add(new MakeInitialDequantizeOperator);
|
||||
transformations->Add(new UnpartitionEmbeddingLookup);
|
||||
transformations->Add(new ResolveGatherAttributes);
|
||||
@ -263,12 +264,6 @@ tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
|
||||
|
||||
GraphTransformationsSet transformations;
|
||||
MakeGeneralGraphTransformationsSet(&transformations);
|
||||
|
||||
if (output_format == TFLITE) {
|
||||
transformations.Add(new ResolveConstantShapeOrRankOnlyForConstantInput);
|
||||
} else {
|
||||
transformations.Add(new ResolveConstantShapeOrRank);
|
||||
}
|
||||
auto* remove_trivial_reshape = new RemoveTrivialReshape;
|
||||
transformations.Add(remove_trivial_reshape);
|
||||
auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
|
||||
|
Loading…
Reference in New Issue
Block a user