Propagate unknown rank of tensor shape information to MLIR
PiperOrigin-RevId: 342808569 Change-Id: Idbe02af18ba2b8ee7dbc6d65d66cc1f647d62525
This commit is contained in:
parent
b858de3779
commit
63d973a730
@ -55,7 +55,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
|
@ -128,7 +128,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
|
@ -185,7 +185,7 @@ Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
@ -210,8 +210,12 @@ Status PopulateQuantizationSpecs(
|
||||
node_dtypes->push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
|
||||
}
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
if (flag.shape().unknown_rank()) {
|
||||
node_shapes->push_back(llvm::None);
|
||||
} else {
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
}
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
if (flag.has_mean_value() && flag.has_std_value()) {
|
||||
|
@ -41,7 +41,7 @@ Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs);
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck --check-prefix=NONE %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=SOME %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=*:* -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=UNKNOWN %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=?,1,?:1,?,1 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=DYNAMIC %s
|
||||
|
||||
node {
|
||||
name: "Add"
|
||||
@ -61,3 +63,19 @@ versions {
|
||||
# NONE-SAME: outputs = "Add"
|
||||
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# NONE: fetch %[[add]]
|
||||
|
||||
# UNKNOWN-LABEL: func @main
|
||||
# UNKNOWN-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<*xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<*xi32>) -> tensor<*xi32>
|
||||
# UNKNOWN-SAME: control_outputs = ""
|
||||
# UNKNOWN-SAME: inputs = "input0,input1"
|
||||
# UNKNOWN-SAME: outputs = "Add"
|
||||
# UNKNOWN: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# UNKNOWN: fetch %[[add]]
|
||||
|
||||
# DYNAMIC-LABEL: func @main
|
||||
# DYNAMIC-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<?x1x?xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<1x?x1xi32>) -> tensor<*xi32>
|
||||
# DYNAMIC-SAME: control_outputs = ""
|
||||
# DYNAMIC-SAME: inputs = "input0,input1"
|
||||
# DYNAMIC-SAME: outputs = "Add"
|
||||
# DYNAMIC: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# DYNAMIC: fetch %[[add]]
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -54,17 +55,18 @@ Status ParseInputArrayInfo(absl::string_view array_names,
|
||||
GraphImportConfig::InputArrays* inputs) {
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
|
||||
TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
|
||||
TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
|
||||
TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
|
||||
return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
|
||||
}
|
||||
|
||||
Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
const std::vector<string>& node_dtypes,
|
||||
const std::vector<std::vector<int>>& node_shapes,
|
||||
GraphImportConfig::InputArrays* inputs) {
|
||||
Status ParseInputArrayInfo(
|
||||
const std::vector<string>& node_names,
|
||||
const std::vector<string>& node_dtypes,
|
||||
const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
|
||||
GraphImportConfig::InputArrays* inputs) {
|
||||
std::vector<std::string> used_node_dtypes;
|
||||
if (node_dtypes.empty()) {
|
||||
// Mark all the node dtypes Invalid, so the importer can handle them by
|
||||
@ -110,7 +112,11 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
}
|
||||
|
||||
if (!node_shapes.empty()) {
|
||||
for (auto& dim : node_shapes[i]) {
|
||||
if (!node_shapes[i].hasValue()) {
|
||||
info.shape.set_unknown_rank(true);
|
||||
continue;
|
||||
}
|
||||
for (auto& dim : node_shapes[i].getValue()) {
|
||||
info.shape.add_dim()->set_size(dim);
|
||||
}
|
||||
}
|
||||
@ -118,17 +124,26 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ParseNodeShapes(absl::string_view shapes_str,
|
||||
std::vector<std::vector<int>>& shapes_vector) {
|
||||
Status ParseNodeShapes(
|
||||
absl::string_view shapes_str,
|
||||
std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
|
||||
shapes_vector.clear();
|
||||
if (!shapes_str.empty()) {
|
||||
std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
|
||||
for (int i = 0; i < node_shapes_str.size(); i++) {
|
||||
if (node_shapes_str[i] == "*") {
|
||||
shapes_vector.push_back(llvm::None);
|
||||
continue;
|
||||
}
|
||||
std::vector<int> dims;
|
||||
for (const absl::string_view dim_str :
|
||||
absl::StrSplit(node_shapes_str[i], ',')) {
|
||||
// Treats empty input shape as scalar
|
||||
if (dim_str.empty()) continue;
|
||||
if (dim_str == "?") {
|
||||
dims.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
int size;
|
||||
TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
|
||||
dims.push_back(size);
|
||||
|
@ -92,16 +92,18 @@ Status ParseInputArrayInfo(absl::string_view array_names,
|
||||
absl::string_view shapes,
|
||||
GraphImportConfig::InputArrays* inputs);
|
||||
|
||||
Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
const std::vector<string>& node_dtypes,
|
||||
const std::vector<std::vector<int>>& node_shapes,
|
||||
GraphImportConfig::InputArrays* inputs);
|
||||
Status ParseInputArrayInfo(
|
||||
const std::vector<string>& node_names,
|
||||
const std::vector<string>& node_dtypes,
|
||||
const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
|
||||
GraphImportConfig::InputArrays* inputs);
|
||||
|
||||
// Parses shapes from the given string into shapes_vector which is a structured
|
||||
// format.
|
||||
// NOTE: If shapes_str is empty, shapes_vector will also be empty.
|
||||
Status ParseNodeShapes(absl::string_view shapes_str,
|
||||
std::vector<std::vector<int>>& shapes_vector);
|
||||
Status ParseNodeShapes(
|
||||
absl::string_view shapes_str,
|
||||
std::vector<llvm::Optional<std::vector<int>>>& shapes_vector);
|
||||
|
||||
// Parses names from the given string into the names_vector.
|
||||
// NOTE: If names_str is empty, names_vector will also be empty.
|
||||
|
@ -46,7 +46,7 @@ static StatusOr<mlir::OwningModuleRef> GraphdefToMlirImport(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
const std::vector<std::string>& input_arrays,
|
||||
const std::vector<std::string>& input_dtypes,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
|
||||
const std::vector<std::string>& output_arrays,
|
||||
const std::vector<std::string>& control_output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
@ -104,7 +104,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
const std::vector<std::string>& input_arrays,
|
||||
const std::vector<std::string>& input_dtypes,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
|
||||
const std::vector<std::string>& output_arrays,
|
||||
const std::vector<std::string>& control_output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
@ -130,7 +130,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
|
||||
bool enable_shape_inference, mlir::MLIRContext* context) {
|
||||
std::vector<std::string> input_array_vector;
|
||||
std::vector<std::string> input_dtype_vector;
|
||||
std::vector<std::vector<int>> input_shapes_vector;
|
||||
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
|
||||
std::vector<std::string> output_array_vector;
|
||||
std::vector<std::string> control_output_array_vector;
|
||||
TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
|
||||
@ -219,7 +219,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
const std::vector<std::string>& input_arrays,
|
||||
const std::vector<std::string>& input_dtypes,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
|
||||
const std::vector<std::string>& output_arrays,
|
||||
const std::vector<std::string>& control_output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
@ -275,7 +275,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToSplattedMlirTranslateFunction(
|
||||
bool enable_shape_inference, mlir::MLIRContext* context) {
|
||||
std::vector<std::string> input_array_vector;
|
||||
std::vector<std::string> input_dtype_vector;
|
||||
std::vector<std::vector<int>> input_shapes_vector;
|
||||
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
|
||||
std::vector<std::string> output_array_vector;
|
||||
std::vector<std::string> control_output_array_vector;
|
||||
TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector));
|
||||
|
@ -40,7 +40,7 @@ StatusOr<mlir::OwningModuleRef> GraphdefToMlirTranslateFunction(
|
||||
llvm::StringRef input, absl::string_view debug_info_file,
|
||||
const std::vector<std::string>& input_arrays,
|
||||
const std::vector<std::string>& input_dtypes,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const std::vector<llvm::Optional<std::vector<int>>>& input_shapes,
|
||||
const std::vector<std::string>& output_arrays,
|
||||
const std::vector<std::string>& control_output_arrays,
|
||||
bool prune_unused_nodes, bool convert_legacy_fed_inputs,
|
||||
|
@ -117,12 +117,18 @@ Status ParseArgumentShapes(
|
||||
absl::string_view input_shapes_str,
|
||||
llvm::SmallVectorImpl<TensorOrResourceShape>& arg_shapes) {
|
||||
arg_shapes.clear();
|
||||
std::vector<std::vector<int>> input_shapes_vector;
|
||||
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
|
||||
TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes_str, input_shapes_vector));
|
||||
arg_shapes.resize(input_shapes_vector.size());
|
||||
for (const auto& shape : llvm::enumerate(input_shapes_vector))
|
||||
for (const auto& shape : llvm::enumerate(input_shapes_vector)) {
|
||||
if (!shape.value().hasValue()) {
|
||||
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
|
||||
static_cast<int*>(nullptr), 0, &arg_shapes[shape.index()].shape));
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
|
||||
shape.value(), &arg_shapes[shape.index()].shape));
|
||||
shape.value().getValue(), &arg_shapes[shape.index()].shape));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -180,7 +186,7 @@ Status ParseXlaArguments(absl::string_view input_shapes_str,
|
||||
absl::string_view arg_kinds_str,
|
||||
llvm::SmallVectorImpl<XlaArgument>& xla_arguments) {
|
||||
xla_arguments.clear();
|
||||
std::vector<std::vector<int>> input_shapes_vector;
|
||||
std::vector<llvm::Optional<std::vector<int>>> input_shapes_vector;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ParseNodeShapes(input_shapes_str, input_shapes_vector));
|
||||
llvm::SmallVector<DataType, 4> dtypes_vector;
|
||||
@ -209,8 +215,14 @@ Status ParseXlaArguments(absl::string_view input_shapes_str,
|
||||
arg_kinds_vector)) {
|
||||
XlaArgument& arg = std::get<0>(arg_components);
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeUtils::MakeShape(std::get<1>(arg_components), &shape));
|
||||
auto input_shapes = std::get<1>(arg_components);
|
||||
if (input_shapes.hasValue()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeUtils::MakeShape(input_shapes.getValue(), &shape));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeUtils::MakeShape(static_cast<int*>(nullptr), 0, &shape));
|
||||
}
|
||||
arg.shape = std::move(shape);
|
||||
arg.type = std::get<2>(arg_components);
|
||||
arg.kind = std::get<3>(arg_components);
|
||||
|
@ -60,13 +60,17 @@ Status ConvertInputInfo(
|
||||
GraphImportConfig* specs) {
|
||||
std::vector<std::string> array_names;
|
||||
std::vector<std::string> data_types;
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> shapes;
|
||||
for (const tf2xla::Feed& feed : config.feed()) {
|
||||
std::string place_holder_name =
|
||||
feed_name_remap.at(TensorIdToString(feed.id()));
|
||||
array_names.push_back(place_holder_name);
|
||||
data_types.push_back(
|
||||
feed.type() == DT_INVALID ? "" : DataType_Name(feed.type()));
|
||||
if (feed.shape().unknown_rank()) {
|
||||
shapes.push_back(llvm::None);
|
||||
continue;
|
||||
}
|
||||
std::vector<int> dims;
|
||||
dims.reserve(feed.shape().dim_size());
|
||||
absl::c_for_each(feed.shape().dim(), [&](const TensorShapeProto::Dim d) {
|
||||
|
@ -502,6 +502,9 @@ def build_toco_convert_protos(input_tensors,
|
||||
else:
|
||||
dims.append(int(dim))
|
||||
input_array.shape.dims.extend(dims)
|
||||
input_array.shape.unknown_rank = False
|
||||
else:
|
||||
input_array.shape.unknown_rank = True
|
||||
|
||||
for output_tensor in output_tensors:
|
||||
if saved_model_dir:
|
||||
|
@ -68,7 +68,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
|
||||
converter.experimental_new_converter = enable_mlir_converter
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
# Check output value from converted model.
|
||||
expected_value = root.f(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
self.assertEqual(expected_value.numpy(), actual_value)
|
||||
@ -995,6 +995,44 @@ class FromSavedModelTest(lite_v2_test_util.ModelTest):
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
def _createUnknownInputShapeModel(self):
|
||||
"""Create a simple SavedModel with unknown input."""
|
||||
saved_model_dir = os.path.join(self.get_temp_dir(), 'unknown_input_shape')
|
||||
with tf.Graph().as_default():
|
||||
with tf.compat.v1.Session() as sess:
|
||||
unknown_shape = tf.TensorShape(None)
|
||||
in_tensor = tf.compat.v1.placeholder(
|
||||
shape=unknown_shape, dtype=tf.float32, name='input')
|
||||
out_tensor = in_tensor + in_tensor
|
||||
inputs = {'input': in_tensor}
|
||||
outputs = {'output': out_tensor}
|
||||
saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
|
||||
return saved_model_dir
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testUnknownInputShapeModel(self):
|
||||
"""Test a SavedModel with an unknown input shape."""
|
||||
saved_model_dir = self._createUnknownInputShapeModel()
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
|
||||
tflite_model = converter.convert()
|
||||
self.assertTrue(tflite_model)
|
||||
|
||||
# Check values from converted model.
|
||||
interpreter = Interpreter(model_content=tflite_model)
|
||||
input_details = interpreter.get_input_details()
|
||||
output_details = interpreter.get_output_details()
|
||||
|
||||
input_data = np.array([1., 2., 3.], dtype=np.float32)
|
||||
interpreter.resize_tensor_input(
|
||||
input_details[0]['index'], [3], strict=False)
|
||||
interpreter.allocate_tensors()
|
||||
|
||||
interpreter.set_tensor(input_details[0]['index'], input_data)
|
||||
interpreter.invoke()
|
||||
actual_value = interpreter.get_tensor(output_details[0]['index'])
|
||||
self.assertEqual([2., 4., 6.], list(actual_value))
|
||||
|
||||
|
||||
class FromKerasModelTest(lite_v2_test_util.ModelTest):
|
||||
|
||||
|
@ -18,7 +18,13 @@ package toco;
|
||||
import "tensorflow/lite/toco/types.proto";
|
||||
|
||||
message InputArrayShape {
|
||||
// Dimensions of the tensor.
|
||||
repeated int32 dims = 2;
|
||||
|
||||
// If true, the number of dimensions in the shape is unknown.
|
||||
//
|
||||
// If true, "dims.size()" must be 0.
|
||||
optional bool unknown_rank = 3;
|
||||
}
|
||||
|
||||
// Next ID to USE: 7.
|
||||
|
Loading…
Reference in New Issue
Block a user