parent
7109ac9ac6
commit
0e4117f671
@ -716,48 +716,6 @@ class FromSessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(([1] == output_details[0]['shape']).all())
|
self.assertTrue(([1] == output_details[0]['shape']).all())
|
||||||
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
self.assertEqual((0., 0.), output_details[0]['quantization'])
|
||||||
|
|
||||||
def testRankAndShapeWithResizeInput(self):
|
|
||||||
# This is a regression test to ensure `Rank` and `Shape` ops work well with
|
|
||||||
# the reize input tensor API.
|
|
||||||
input_tensor = array_ops.placeholder(
|
|
||||||
shape=[None, 4, 4, 3], dtype=dtypes.float32)
|
|
||||||
reshaped_tensor = array_ops.reshape(input_tensor, [1, -1])
|
|
||||||
output_rank = array_ops.rank(reshaped_tensor, name='output_rank')
|
|
||||||
output_shape = array_ops.shape(reshaped_tensor, name='output_shape')
|
|
||||||
|
|
||||||
sess = session.Session()
|
|
||||||
converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
|
|
||||||
[output_rank, output_shape])
|
|
||||||
tflite_model = converter.convert()
|
|
||||||
self.assertTrue(tflite_model)
|
|
||||||
|
|
||||||
interpreter = Interpreter(model_content=tflite_model)
|
|
||||||
|
|
||||||
def verify_rank_and_shape(expected_shape):
|
|
||||||
expected_rank = len(expected_shape)
|
|
||||||
output_details = interpreter.get_output_details()
|
|
||||||
self.assertEqual(2, len(output_details))
|
|
||||||
self.assertEqual('output_rank', output_details[0]['name'])
|
|
||||||
self.assertEqual(np.int32, output_details[0]['dtype'])
|
|
||||||
self.assertEqual('output_shape', output_details[1]['name'])
|
|
||||||
self.assertEqual(np.int32, output_details[1]['dtype'])
|
|
||||||
self.assertEqual(
|
|
||||||
expected_shape,
|
|
||||||
interpreter.get_tensor(output_details[1]['index']).tolist())
|
|
||||||
self.assertEqual(
|
|
||||||
expected_rank,
|
|
||||||
interpreter.get_tensor(output_details[0]['index']).tolist())
|
|
||||||
|
|
||||||
interpreter.allocate_tensors()
|
|
||||||
interpreter.invoke()
|
|
||||||
verify_rank_and_shape([1, 48])
|
|
||||||
|
|
||||||
input_details = interpreter.get_input_details()
|
|
||||||
interpreter.resize_tensor_input(input_details[0]['index'], [1, 4, 4, 4])
|
|
||||||
interpreter.allocate_tensors()
|
|
||||||
interpreter.invoke()
|
|
||||||
verify_rank_and_shape([1, 64])
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
|
||||||
|
@ -199,7 +199,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
|
|||||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
|
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
|
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
|
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRankOnlyForConstantInput)
|
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
|
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
|
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
|
||||||
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
|
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
|
||||||
|
@ -1867,9 +1867,7 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
|
|||||||
CHECK_EQ(op->inputs.size(), 4);
|
CHECK_EQ(op->inputs.size(), 4);
|
||||||
|
|
||||||
const Array& output_shape_array = model->GetArray(op->inputs[1]);
|
const Array& output_shape_array = model->GetArray(op->inputs[1]);
|
||||||
// Return if the shape array has unknown shape or has no data.
|
|
||||||
if (!output_shape_array.has_shape()) return;
|
if (!output_shape_array.has_shape()) return;
|
||||||
if (!output_shape_array.buffer) return;
|
|
||||||
CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
|
CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
|
||||||
|
|
||||||
// Output should not go over four dimensions.
|
// Output should not go over four dimensions.
|
||||||
|
@ -19,11 +19,9 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace toco {
|
namespace toco {
|
||||||
|
|
||||||
namespace {
|
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
|
||||||
|
std::size_t op_index,
|
||||||
::tensorflow::Status ResolveConstantShapeOrRankImpl(
|
bool* modified) {
|
||||||
Model* model, std::size_t op_index, bool only_for_constant_input,
|
|
||||||
bool* modified) {
|
|
||||||
*modified = false;
|
*modified = false;
|
||||||
const auto it = model->operators.begin() + op_index;
|
const auto it = model->operators.begin() + op_index;
|
||||||
const auto* op = it->get();
|
const auto* op = it->get();
|
||||||
@ -44,10 +42,6 @@ namespace {
|
|||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (only_for_constant_input && !input_array.buffer) {
|
|
||||||
return ::tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!output_array.has_shape()) {
|
if (!output_array.has_shape()) {
|
||||||
// Yield until the output shape has been resolved.
|
// Yield until the output shape has been resolved.
|
||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
@ -78,23 +72,4 @@ namespace {
|
|||||||
return ::tensorflow::Status::OK();
|
return ::tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
|
|
||||||
std::size_t op_index,
|
|
||||||
bool* modified) {
|
|
||||||
return ResolveConstantShapeOrRankImpl(model, op_index,
|
|
||||||
|
|
||||||
/*only_for_constant_input=*/false,
|
|
||||||
modified
|
|
||||||
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
::tensorflow::Status ResolveConstantShapeOrRankOnlyForConstantInput::Run(
|
|
||||||
Model* model, std::size_t op_index, bool* modified) {
|
|
||||||
return ResolveConstantShapeOrRankImpl(
|
|
||||||
model, op_index, /*only_for_constant_input=*/true, modified);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace toco
|
} // namespace toco
|
||||||
|
@ -115,6 +115,7 @@ void MakeGeneralGraphTransformationsSet(
|
|||||||
transformations->Add(new ResolveStridedSliceAttributes);
|
transformations->Add(new ResolveStridedSliceAttributes);
|
||||||
transformations->Add(new ResolveSliceAttributes);
|
transformations->Add(new ResolveSliceAttributes);
|
||||||
transformations->Add(new ResolveReduceAttributes);
|
transformations->Add(new ResolveReduceAttributes);
|
||||||
|
transformations->Add(new ResolveConstantShapeOrRank);
|
||||||
transformations->Add(new MakeInitialDequantizeOperator);
|
transformations->Add(new MakeInitialDequantizeOperator);
|
||||||
transformations->Add(new UnpartitionEmbeddingLookup);
|
transformations->Add(new UnpartitionEmbeddingLookup);
|
||||||
transformations->Add(new ResolveGatherAttributes);
|
transformations->Add(new ResolveGatherAttributes);
|
||||||
@ -263,12 +264,6 @@ tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
|
|||||||
|
|
||||||
GraphTransformationsSet transformations;
|
GraphTransformationsSet transformations;
|
||||||
MakeGeneralGraphTransformationsSet(&transformations);
|
MakeGeneralGraphTransformationsSet(&transformations);
|
||||||
|
|
||||||
if (output_format == TFLITE) {
|
|
||||||
transformations.Add(new ResolveConstantShapeOrRankOnlyForConstantInput);
|
|
||||||
} else {
|
|
||||||
transformations.Add(new ResolveConstantShapeOrRank);
|
|
||||||
}
|
|
||||||
auto* remove_trivial_reshape = new RemoveTrivialReshape;
|
auto* remove_trivial_reshape = new RemoveTrivialReshape;
|
||||||
transformations.Add(remove_trivial_reshape);
|
transformations.Add(remove_trivial_reshape);
|
||||||
auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
|
auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
|
||||||
|
Loading…
Reference in New Issue
Block a user