Update LegalizeNodeName to mutate inplace (NFC).
*NameMappers already return copies of strings for GetName. As the legalization of names is via substituting characters with `.`, this can be done inplace instead. tensorflow name space is removed in preparation of name utils being used for converting locations to strings for XLA OpMetadata tracking. PiperOrigin-RevId: 328762081 Change-Id: I4d460d2fe6ea1cccd0165b4bb02e5cc0d0994167
This commit is contained in:
parent
7806843d74
commit
e5578d2030
@ -106,14 +106,14 @@ bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
|
||||
|
||||
std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
|
||||
if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
|
||||
auto name_from_loc = mlir::tensorflow::GetNameFromLoc(op->getLoc());
|
||||
auto name_from_loc = mlir::GetNameFromLoc(op->getLoc());
|
||||
if (!name_from_loc.empty()) return name_from_loc;
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type.
|
||||
return std::string(op->getName().getStringRef());
|
||||
}
|
||||
auto val = op_or_val.dyn_cast<mlir::Value>();
|
||||
auto name_from_loc = mlir::tensorflow::GetNameFromLoc(val.getLoc());
|
||||
auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
|
||||
if (!name_from_loc.empty()) return name_from_loc;
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type. Follow TF convention and append the result
|
||||
|
@ -85,8 +85,10 @@ constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
|
||||
class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
|
||||
private:
|
||||
std::string GetName(OpOrVal op_or_val) override {
|
||||
return mlir::tensorflow::LegalizeNodeName(
|
||||
OpOrArgLocNameMapper::GetName(op_or_val));
|
||||
std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
|
||||
assert(!name.empty() && "expected non-empty name");
|
||||
mlir::LegalizeNodeName(name);
|
||||
return name;
|
||||
}
|
||||
};
|
||||
|
||||
@ -490,13 +492,14 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
if (index >= num_data_results) break;
|
||||
// TODO(jpienaar): If there is a result index specified, ensure only one
|
||||
// and that it matches the result index of the op.
|
||||
std::string orig_name(output_names[index]);
|
||||
auto tensor_id = ParseTensorName(orig_name);
|
||||
auto name = mlir::tensorflow::LegalizeNodeName(
|
||||
llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
|
||||
std::string name(output_names[index]);
|
||||
auto tensor_id = ParseTensorName(name);
|
||||
std::string tensor_id_node(tensor_id.node());
|
||||
assert(!tensor_id_node.empty() && "expected non-empty name");
|
||||
mlir::LegalizeNodeName(tensor_id_node);
|
||||
|
||||
// Ensure name does not get reused.
|
||||
(void)exporter.op_to_name_.GetUniqueName(name);
|
||||
(void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
|
||||
}
|
||||
}
|
||||
|
||||
@ -504,8 +507,9 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
TF_RET_CHECK(input_names.size() == block.getNumArguments());
|
||||
for (const auto& it : llvm::enumerate(function.getArguments())) {
|
||||
// TODO(lyandy): Update when changing feed/fetch import.
|
||||
std::string orig_name(input_names[it.index()]);
|
||||
std::string name = mlir::tensorflow::LegalizeNodeName(orig_name);
|
||||
std::string name(input_names[it.index()]);
|
||||
assert(!name.empty() && "expected non-empty name");
|
||||
mlir::LegalizeNodeName(name);
|
||||
auto tensor_id = ParseTensorName(name);
|
||||
TF_RET_CHECK(tensor_id.index() == 0)
|
||||
<< "input port designation not supported";
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Checks if a character is legal for a TensorFlow node name, with special
|
||||
@ -45,21 +44,13 @@ bool IsLegalChar(char c, bool first_char) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
std::string LegalizeNodeName(llvm::StringRef name) {
|
||||
assert(!name.empty() && "expected non-empty name");
|
||||
void LegalizeNodeName(std::string& name) {
|
||||
if (name.empty()) return;
|
||||
|
||||
std::string legalized_name;
|
||||
bool first = true;
|
||||
for (auto c : name) {
|
||||
if (IsLegalChar(c, first)) {
|
||||
legalized_name += c;
|
||||
} else {
|
||||
legalized_name += '.';
|
||||
}
|
||||
first = false;
|
||||
}
|
||||
if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.';
|
||||
|
||||
return legalized_name;
|
||||
for (char& c : llvm::drop_begin(name, 1))
|
||||
if (!IsLegalChar(c, /*first_char=*/false)) c = '.';
|
||||
}
|
||||
|
||||
std::string GetNameFromLoc(Location loc) {
|
||||
@ -105,5 +96,4 @@ std::string GetNameFromLoc(Location loc) {
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace mlir
|
||||
|
@ -22,16 +22,14 @@ limitations under the License.
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts characters in name that are considered illegal in TensorFlow Node
|
||||
// name to '.'.
|
||||
std::string LegalizeNodeName(llvm::StringRef name);
|
||||
void LegalizeNodeName(std::string& name);
|
||||
|
||||
// Creates a TensorFlow node name from a location.
|
||||
std::string GetNameFromLoc(Location loc);
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
|
||||
|
Loading…
x
Reference in New Issue
Block a user