Automated rollback of commit a6aa65015e

PiperOrigin-RevId: 243752707
This commit is contained in:
A. Unique TensorFlower 2019-04-15 23:26:25 -07:00 committed by TensorFlower Gardener
parent 7109ac9ac6
commit 0e4117f671
5 changed files with 4 additions and 79 deletions

View File

@ -716,48 +716,6 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
def testRankAndShapeWithResizeInput(self):
# This is a regression test to ensure `Rank` and `Shape` ops work well with
# the reize input tensor API.
input_tensor = array_ops.placeholder(
shape=[None, 4, 4, 3], dtype=dtypes.float32)
reshaped_tensor = array_ops.reshape(input_tensor, [1, -1])
output_rank = array_ops.rank(reshaped_tensor, name='output_rank')
output_shape = array_ops.shape(reshaped_tensor, name='output_shape')
sess = session.Session()
converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
[output_rank, output_shape])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
interpreter = Interpreter(model_content=tflite_model)
def verify_rank_and_shape(expected_shape):
expected_rank = len(expected_shape)
output_details = interpreter.get_output_details()
self.assertEqual(2, len(output_details))
self.assertEqual('output_rank', output_details[0]['name'])
self.assertEqual(np.int32, output_details[0]['dtype'])
self.assertEqual('output_shape', output_details[1]['name'])
self.assertEqual(np.int32, output_details[1]['dtype'])
self.assertEqual(
expected_shape,
interpreter.get_tensor(output_details[1]['index']).tolist())
self.assertEqual(
expected_rank,
interpreter.get_tensor(output_details[0]['index']).tolist())
interpreter.allocate_tensors()
interpreter.invoke()
verify_rank_and_shape([1, 48])
input_details = interpreter.get_input_details()
interpreter.resize_tensor_input(input_details[0]['index'], [1, 4, 4, 4])
interpreter.allocate_tensors()
interpreter.invoke()
verify_rank_and_shape([1, 64])
@test_util.run_v1_only('b/120545219')
class FromFrozenGraphFile(test_util.TensorFlowTestCase):

View File

@ -199,7 +199,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRankOnlyForConstantInput)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)

View File

@ -1867,9 +1867,7 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
CHECK_EQ(op->inputs.size(), 4);
const Array& output_shape_array = model->GetArray(op->inputs[1]);
// Return if the shape array has unknown shape or has no data.
if (!output_shape_array.has_shape()) return;
if (!output_shape_array.buffer) return;
CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
// Output should not go over four dimensions.

View File

@ -19,11 +19,9 @@ limitations under the License.
namespace toco {
namespace {
::tensorflow::Status ResolveConstantShapeOrRankImpl(
Model* model, std::size_t op_index, bool only_for_constant_input,
bool* modified) {
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
std::size_t op_index,
bool* modified) {
*modified = false;
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
@ -44,10 +42,6 @@ namespace {
return ::tensorflow::Status::OK();
}
if (only_for_constant_input && !input_array.buffer) {
return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been resolved.
return ::tensorflow::Status::OK();
@ -78,23 +72,4 @@ namespace {
return ::tensorflow::Status::OK();
}
} // namespace
::tensorflow::Status ResolveConstantShapeOrRank::Run(Model* model,
std::size_t op_index,
bool* modified) {
return ResolveConstantShapeOrRankImpl(model, op_index,
/*only_for_constant_input=*/false,
modified
);
}
::tensorflow::Status ResolveConstantShapeOrRankOnlyForConstantInput::Run(
Model* model, std::size_t op_index, bool* modified) {
return ResolveConstantShapeOrRankImpl(
model, op_index, /*only_for_constant_input=*/true, modified);
}
} // namespace toco

View File

@ -115,6 +115,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveStridedSliceAttributes);
transformations->Add(new ResolveSliceAttributes);
transformations->Add(new ResolveReduceAttributes);
transformations->Add(new ResolveConstantShapeOrRank);
transformations->Add(new MakeInitialDequantizeOperator);
transformations->Add(new UnpartitionEmbeddingLookup);
transformations->Add(new ResolveGatherAttributes);
@ -263,12 +264,6 @@ tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
GraphTransformationsSet transformations;
MakeGeneralGraphTransformationsSet(&transformations);
if (output_format == TFLITE) {
transformations.Add(new ResolveConstantShapeOrRankOnlyForConstantInput);
} else {
transformations.Add(new ResolveConstantShapeOrRank);
}
auto* remove_trivial_reshape = new RemoveTrivialReshape;
transformations.Add(remove_trivial_reshape);
auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;