parent
94d7e348d8
commit
fc7bce9b4a
@ -0,0 +1,24 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32>
|
||||
attributes {tf.entry_function = {inputs = "foo,bar", outputs = "Add"}} {
|
||||
%0 = "tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32>
|
||||
%1 = "tf.Placeholder.input"(%arg1) {device = "", dtype = "tfdtype$DT_INT32", shape = "tfshape$dim { size: 10 }"} : (tensor<10xi32>) -> tensor<10xi32>
|
||||
// This node would be renamed to bar1
|
||||
%2 = "tf.Identity"(%1) {device = "", dtype = "tfdtype$DT_INT32"} : (tensor<10xi32>) -> tensor<10xi32> loc ("bar")
|
||||
// The following node would be renamed to bar2
|
||||
%3 = "tf.Identity"(%2) {device = "", dtype = "tfdtype$DT_INT32"} : (tensor<10xi32>) -> tensor<10xi32> loc ("bar")
|
||||
%4 = "tf.Add"(%0, %3) {T = "tfdtype$DT_INT32", device = ""} : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32> loc("Add")
|
||||
return %4 : tensor<10xi32>
|
||||
}
|
||||
|
||||
// CHECK: name: "bar1"
|
||||
// CHECK-NEXT: op: "Identity"
|
||||
// CHECK: name: "bar2"
|
||||
// CHECK-NEXT: op: "Identity"
|
||||
// CHECK: name: "Add"
|
||||
// CHECK-NEXT: op: "Add"
|
||||
// CHECK: name: "foo"
|
||||
// CHECK-NEXT: op: "Placeholder"
|
||||
// CHECK: name: "bar"
|
||||
// CHECK-NEXT: op: "Placeholder"
|
@ -0,0 +1,18 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main() {
|
||||
^bb0:
|
||||
// CHECK: name: "foo"
|
||||
%0 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<i32>} : () -> (tensor<i32>) loc("foo")
|
||||
// CHECK: name: "foo1"
|
||||
%1 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<1> : tensor<i32>} : () -> (tensor<i32>) loc("foo")
|
||||
// CHECK: name: "foo11"
|
||||
%2 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : () -> (tensor<i32>) loc("foo1")
|
||||
// CHECK: name: "foo2"
|
||||
%3 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<2> : tensor<i32>} : () -> (tensor<i32>) loc("foo")
|
||||
// CHECK: name: "2"
|
||||
%4 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<3> : tensor<i32>} : () -> (tensor<i32>) loc("2")
|
||||
// CHECK: name: "3"
|
||||
%5 = "tf.Const"() {dtype = "tfdtype$DT_INT32", value = dense<3> : tensor<i32>} : () -> (tensor<i32>) loc("3")
|
||||
return
|
||||
}
|
@ -171,10 +171,16 @@ class Exporter {
|
||||
};
|
||||
|
||||
std::string Exporter::UniqueName(llvm::StringRef prefix) {
|
||||
// Keep incrementing the counter until we find a unique name.
|
||||
std::string name = prefix;
|
||||
auto& val = name_to_count_[name];
|
||||
if (val) name = (prefix + llvm::Twine(val)).str();
|
||||
++val;
|
||||
auto& prefix_count = name_to_count_[name];
|
||||
int64 val = prefix_count;
|
||||
while (val != 0) {
|
||||
name = (prefix + llvm::Twine(prefix_count)).str();
|
||||
++prefix_count;
|
||||
val = name_to_count_[name];
|
||||
}
|
||||
name_to_count_[name] = 1;
|
||||
return name;
|
||||
}
|
||||
|
||||
@ -330,6 +336,9 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) {
|
||||
}
|
||||
for (auto* r : input->getResults()) state.types.push_back(r->getType());
|
||||
auto* inst = builder.createOperation(state);
|
||||
// If it is one of the specified input names, then the new
|
||||
// instruction should have the same name.
|
||||
op_to_name_[inst].assign(op_to_name_[input]);
|
||||
for (int index = 0, e = input->getNumResults(); index != e; ++index) {
|
||||
input->getResult(index)->replaceAllUsesWith(inst->getResult(index));
|
||||
}
|
||||
@ -393,7 +402,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
|
||||
TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(*flib));
|
||||
Exporter exporter(graph.get(), tf_dialect);
|
||||
|
||||
// Set input and output names.
|
||||
// Set input and output names and increment the use counter for them to help
|
||||
// generate unique names.
|
||||
if (!output_names.empty()) {
|
||||
auto term = block.getTerminator();
|
||||
TF_RET_CHECK(output_names.size() == term->getNumOperands())
|
||||
@ -401,12 +411,14 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(const ExporterConfigs& confs,
|
||||
<< ") != terminator operands (" << term->getNumOperands() << ")";
|
||||
int i = 0;
|
||||
for (auto it : term->getOperands()) {
|
||||
exporter.name_to_count_[output_names[i].str()] = 1;
|
||||
exporter.op_to_name_[it->getDefiningOp()] = output_names[i++];
|
||||
}
|
||||
}
|
||||
if (!input_names.empty()) {
|
||||
TF_RET_CHECK(input_names.size() == block.getNumArguments());
|
||||
for (auto it : llvm::enumerate(function.getArguments())) {
|
||||
exporter.name_to_count_[input_names[it.index()].str()] = 1;
|
||||
exporter.op_to_name_[*it.value()->user_begin()] = input_names[it.index()];
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user