Propagate unknown rank of tensor shape information to MLIR

PiperOrigin-RevId: 342808569
Change-Id: Idbe02af18ba2b8ee7dbc6d65d66cc1f647d62525
This commit is contained in:
Jaesung Chung 2020-11-17 00:34:56 -08:00 committed by TensorFlower Gardener
parent b858de3779
commit 63d973a730
14 changed files with 136 additions and 34 deletions

View File

@ -55,7 +55,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;

View File

@ -128,7 +128,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;

View File

@ -185,7 +185,7 @@ Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs) {
quant_specs->inference_input_type =
@ -210,8 +210,12 @@ Status PopulateQuantizationSpecs(
node_dtypes->push_back(
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
}
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
if (flag.shape().unknown_rank()) {
node_shapes->push_back(llvm::None);
} else {
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
}
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
if (flag.has_mean_value() && flag.has_std_value()) {

View File

@ -41,7 +41,7 @@ Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs);

View File

@ -1,6 +1,8 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck --check-prefix=NONE %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=SOME %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=*:* -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=UNKNOWN %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=?,1,?:1,?,1 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=DYNAMIC %s
node {
name: "Add"
@ -61,3 +63,19 @@ versions {
# NONE-SAME: outputs = "Add"
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# NONE: fetch %[[add]]
# UNKNOWN-LABEL: func @main
# UNKNOWN-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<*xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<*xi32>) -> tensor<*xi32>
# UNKNOWN-SAME: control_outputs = ""
# UNKNOWN-SAME: inputs = "input0,input1"
# UNKNOWN-SAME: outputs = "Add"
# UNKNOWN: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# UNKNOWN: fetch %[[add]]
# DYNAMIC-LABEL: func @main
# DYNAMIC-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<?x1x?xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<1x?x1xi32>) -> tensor<*xi32>
# DYNAMIC-SAME: control_outputs = ""
# DYNAMIC-SAME: inputs = "input0,input1"
# DYNAMIC-SAME: outputs = "Add"
# DYNAMIC: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# DYNAMIC: fetch %[[add]]

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "llvm/ADT/Optional.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
@ -54,17 +55,18 @@ Status ParseInputArrayInfo(absl::string_view array_names,
GraphImportConfig::InputArrays* inputs) {
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
}
Status ParseInputArrayInfo(const std::vector<string>& node_names,
const std::vector<string>& node_dtypes,
const std::vector<std::vector<int>>& node_shapes,
GraphImportConfig::InputArrays* inputs) {
Status ParseInputArrayInfo(
const std::vector<string>& node_names,
const std::vector<string>& node_dtypes,
const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
GraphImportConfig::InputArrays* inputs) {
std::vector<std::string> used_node_dtypes;
if (node_dtypes.empty()) {
// Mark all the node dtypes Invalid, so the importer can handle them by
@ -110,7 +112,11 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
}
if (!node_shapes.empty()) {
for (auto& dim : node_shapes[i]) {
if (!node_shapes[i].hasValue()) {
info.shape.set_unknown_rank(true);
continue;
}
for (auto& dim : node_shapes[i].getValue()) {
info.shape.add_dim()->set_size(dim);
}
}
@ -118,17 +124,26 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
return Status::OK();
}
Status ParseNodeShapes(absl::string_view shapes_str,
std::vector<std::vector<int>>& shapes_vector) {
Status ParseNodeShapes(
absl::string_view shapes_str,
std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
shapes_vector.clear();
if (!shapes_str.empty()) {
std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
for (int i = 0; i < node_shapes_str.size(); i++) {
if (node_shapes_str[i] == "*") {
shapes_vector.push_back(llvm::None);
continue;
}
std::vector<int> dims;
for (const absl::string_view dim_str :
absl::StrSplit(node_shapes_str[i], ',')) {
// Treats empty input shape as scalar
if (dim_str.empty()) continue;
if (dim_str == "?") {
dims.push_back(-1);
continue;
}
int size;
TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
dims.push_back(size);

View File

@ -92,16 +92,18 @@ Status ParseInputArrayInfo(absl::string_view array_names,
absl::string_view shapes,
GraphImportConfig::InputArrays* inputs);
Status ParseInputArrayInfo(const std::vector<string>& node_names,
const std::vector<string>& node_dtypes,
const std::vector<std::vector<int>>& node_shapes,
GraphImportConfig::InputArrays* inputs);
Status ParseInputArrayInfo(
const std::vector<string>& node_names,
const std::vector<string>& node_dtypes,
const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
GraphImportConfig::InputArrays* inputs);
// Parses shapes from the given string into shapes_vector which is a structured
// format.
// NOTE: If shapes_str is empty, shapes_vector will also be empty.
Status ParseNodeShapes(absl::string_view shapes_str,
std::vector<std::vector<int>>& shapes_vector);
Status ParseNodeShapes(
absl::string_view shapes_str,
std::vector<llvm::Optional<std::vector<int>>>& shapes_vector);
// Parses names from the given string into the names_vector.
// NOTE: If names_str is empty, names_vector will also be empty.

View File

@ -46,7 +46,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
llvm::StringRef input, absl::string_view debug_info_file,
const std::vector<std::string>& input_arrays,
const std::vector<std::string>& input_dtypes,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
const std::vector<std::string>& output_arrays,
const std::vector<std::string>& control_output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
@ -104,7 +104,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
const std::vector<std::string>& input_arrays,
const std::vector<std::string>& input_dtypes,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
const std::vector<std::string>& output_arrays,
const std::vector<std::string>& control_output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
@ -130,7 +130,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
bool enable_shape_inference, mlir::MLIRContext* context) {
std::vector<std::string> input_array_vector;
std::vector<std::string> input_dtype_vector;
std::vector<std::vector<int>> input_shapes_vector;
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
std::vector<std::string> output_array_vector;
std::vector<std::string> control_output_array_vector;
TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
@ -219,7 +219,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
const std::vector<std::string>& input_arrays,
const std::vector<std::string>& input_dtypes,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
const std::vector<std::string>& output_arrays,
const std::vector<std::string>& control_output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
@ -275,7 +275,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
bool enable_shape_inference, mlir::MLIRContext* context) {
std::vector<std::string> input_array_vector;
std::vector<std::string> input_dtype_vector;
std::vector<std::vector<int>> input_shapes_vector;
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
std::vector<std::string> output_array_vector;
std::vector<std::string> control_output_array_vector;
TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));

View File

@ -40,7 +40,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file,
const std::vector<std::string>& input_arrays,
const std::vector<std::string>& input_dtypes,
const std::vector<std::vector<int>>& input_shapes,
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
const std::vector<std::string>& output_arrays,
const std::vector<std::string>& control_output_arrays,
bool prune_unused_nodes, bool convert_legacy_fed_inputs,

View File

@ -117,12 +117,18 @@ Status ParseArgumentShapes(
absl::string_view input_shapes_str,
llvm::SmallVectorImpl<TensorOrResourceShape>& arg_shapes) {
arg_shapes.clear();
std::vector<std::vector<int>> input_shapes_vector;
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes_str, input_shapes_vector));
arg_shapes.resize(input_shapes_vector.size());
for (const auto& shape : llvm::enumerate(input_shapes_vector))
for (const auto& shape : llvm::enumerate(input_shapes_vector)) {
if (!shape.value().hasValue()) {
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
static_cast<int*>(nullptr), 0, &arg_shapes[shape.index()].shape));
continue;
}
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
shape.value(), &arg_shapes[shape.index()].shape));
shape.value().getValue(), &arg_shapes[shape.index()].shape));
}
return Status::OK();
}
@ -180,7 +186,7 @@ Status ParseXlaArguments(absl::string_view input_shapes_str,
absl::string_view arg_kinds_str,
llvm::SmallVectorImpl<XlaArgument>& xla_arguments) {
xla_arguments.clear();
std::vector<std::vector<int>> input_shapes_vector;
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
TF_RETURN_IF_ERROR(
tensorflow::ParseNodeShapes(input_shapes_str, input_shapes_vector));
llvm::SmallVector<DataType, 4> dtypes_vector;
@ -209,8 +215,14 @@ Status ParseXlaArguments(absl::string_view input_shapes_str,
arg_kinds_vector)) {
XlaArgument& arg = std::get<0>(arg_components);
TensorShape shape;
TF_RETURN_IF_ERROR(
TensorShapeUtils::MakeShape(std::get<1>(arg_components), &shape));
auto input_shapes = std::get<1>(arg_components);
if (input_shapes.hasValue()) {
TF_RETURN_IF_ERROR(
TensorShapeUtils::MakeShape(input_shapes.getValue(), &shape));
} else {
TF_RETURN_IF_ERROR(
TensorShapeUtils::MakeShape(static_cast<int*>(nullptr), 0, &shape));
}
arg.shape = std::move(shape);
arg.type = std::get<2>(arg_components);
arg.kind = std::get<3>(arg_components);

View File

@ -60,13 +60,17 @@ Status ConvertInputInfo(
GraphImportConfig* specs) {
std::vector<std::string> array_names;
std::vector<std::string> data_types;
std::vector<std::vector<int>> shapes;
std::vector<llvm::Optional<std::vector<int>>> shapes;
for (const tf2xla::Feed& feed : config.feed()) {
std::string place_holder_name =
feed_name_remap.at(TensorIdToString(feed.id()));
array_names.push_back(place_holder_name);
data_types.push_back(
feed.type() == DT_INVALID ? "" : DataType_Name(feed.type()));
if (feed.shape().unknown_rank()) {
shapes.push_back(llvm::None);
continue;
}
std::vector<int> dims;
dims.reserve(feed.shape().dim_size());
absl::c_for_each(feed.shape().dim(), [&](const TensorShapeProto::Dim d) {

View File

@ -502,6 +502,9 @@ def build_toco_convert_protos(input_tensors,
else:
dims.append(int(dim))
input_array.shape.dims.extend(dims)
input_array.shape.unknown_rank = False
else:
input_array.shape.unknown_rank = True
for output_tensor in output_tensors:
if saved_model_dir:

View File

@ -68,7 +68,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
converter.experimental_new_converter = enable_mlir_converter
tflite_model = converter.convert()
# Check values from converted model.
# Check output value from converted model.
expected_value = root.f(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value.numpy(), actual_value)
@ -995,6 +995,44 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
tflite_model = converter.convert()
self.assertTrue(tflite_model)
def _createUnknownInputShapeModel(self):
"""Create a simple SavedModel with unknown input."""
saved_model_dir = os.path.join(self.get_temp_dir(), 'unknown_input_shape')
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
unknown_shape = tf.TensorShape(None)
in_tensor = tf.compat.v1.placeholder(
shape=unknown_shape, dtype=tf.float32, name='input')
out_tensor = in_tensor + in_tensor
inputs = {'input': in_tensor}
outputs = {'output': out_tensor}
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
return saved_model_dir
@test_util.run_v2_only
def testUnknownInputShapeModel(self):
"""Test a SavedModel with an unknown input shape."""
saved_model_dir = self._createUnknownInputShapeModel()
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Check values from converted model.
interpreter = Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_data = np.array([1., 2., 3.], dtype=np.float32)
interpreter.resize_tensor_input(
input_details[0]['index'], [3], strict=False)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
actual_value = interpreter.get_tensor(output_details[0]['index'])
self.assertEqual([2., 4., 6.], list(actual_value))
class FromKerasModelTest(lite_v2_test_util.ModelTest):

View File

@ -18,7 +18,13 @@ package toco;
import "tensorflow/lite/toco/types.proto";
message InputArrayShape {
// Dimensions of the tensor.
repeated int32 dims = 2;
// If true, the number of dimensions in the shape is unknown.
//
// If true, "dims.size()" must be 0.
optional bool unknown_rank = 3;
}
// Next ID to USE: 7.