Update LegalizeNodeName to mutate inplace (NFC).

*NameMappers already return copies of strings for GetName. As the legalization of names is via substituting characters with `.`, this can be done inplace instead.
tensorflow name space is removed in preparation of name utils being used for converting locations to strings for XLA OpMetadata tracking.

PiperOrigin-RevId: 328762081
Change-Id: I4d460d2fe6ea1cccd0165b4bb02e5cc0d0994167
This commit is contained in:
Andy Ly 2020-08-27 10:17:11 -07:00 committed by TensorFlower Gardener
parent 7806843d74
commit e5578d2030
4 changed files with 21 additions and 29 deletions

View File

@ -106,14 +106,14 @@ bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
auto name_from_loc = mlir::tensorflow::GetNameFromLoc(op->getLoc());
auto name_from_loc = mlir::GetNameFromLoc(op->getLoc());
if (!name_from_loc.empty()) return name_from_loc;
// If the location is none of the expected types, then simply use name
// generated using the op type.
return std::string(op->getName().getStringRef());
}
auto val = op_or_val.dyn_cast<mlir::Value>();
auto name_from_loc = mlir::tensorflow::GetNameFromLoc(val.getLoc());
auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
if (!name_from_loc.empty()) return name_from_loc;
// If the location is none of the expected types, then simply use name
// generated using the op type. Follow TF convention and append the result

View File

@ -85,8 +85,10 @@ constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
private:
std::string GetName(OpOrVal op_or_val) override {
return mlir::tensorflow::LegalizeNodeName(
OpOrArgLocNameMapper::GetName(op_or_val));
std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
assert(!name.empty() && "expected non-empty name");
mlir::LegalizeNodeName(name);
return name;
}
};
@ -490,13 +492,14 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
if (index >= num_data_results) break;
// TODO(jpienaar): If there is a result index specified, ensure only one
// and that it matches the result index of the op.
std::string orig_name(output_names[index]);
auto tensor_id = ParseTensorName(orig_name);
auto name = mlir::tensorflow::LegalizeNodeName(
llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
std::string name(output_names[index]);
auto tensor_id = ParseTensorName(name);
std::string tensor_id_node(tensor_id.node());
assert(!tensor_id_node.empty() && "expected non-empty name");
mlir::LegalizeNodeName(tensor_id_node);
// Ensure name does not get reused.
(void)exporter.op_to_name_.GetUniqueName(name);
(void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
}
}
@ -504,8 +507,9 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
TF_RET_CHECK(input_names.size() == block.getNumArguments());
for (const auto& it : llvm::enumerate(function.getArguments())) {
// TODO(lyandy): Update when changing feed/fetch import.
std::string orig_name(input_names[it.index()]);
std::string name = mlir::tensorflow::LegalizeNodeName(orig_name);
std::string name(input_names[it.index()]);
assert(!name.empty() && "expected non-empty name");
mlir::LegalizeNodeName(name);
auto tensor_id = ParseTensorName(name);
TF_RET_CHECK(tensor_id.index() == 0)
<< "input port designation not supported";

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "mlir/IR/Identifier.h" // from @llvm-project
namespace mlir {
namespace tensorflow {
namespace {
// Checks if a character is legal for a TensorFlow node name, with special
@ -45,21 +44,13 @@ bool IsLegalChar(char c, bool first_char) {
}
} // anonymous namespace
std::string LegalizeNodeName(llvm::StringRef name) {
assert(!name.empty() && "expected non-empty name");
void LegalizeNodeName(std::string& name) {
if (name.empty()) return;
std::string legalized_name;
bool first = true;
for (auto c : name) {
if (IsLegalChar(c, first)) {
legalized_name += c;
} else {
legalized_name += '.';
}
first = false;
}
if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.';
return legalized_name;
for (char& c : llvm::drop_begin(name, 1))
if (!IsLegalChar(c, /*first_char=*/false)) c = '.';
}
std::string GetNameFromLoc(Location loc) {
@ -105,5 +96,4 @@ std::string GetNameFromLoc(Location loc) {
return "";
}
} // namespace tensorflow
} // namespace mlir

View File

@ -22,16 +22,14 @@ limitations under the License.
#include "mlir/IR/Location.h" // from @llvm-project
namespace mlir {
namespace tensorflow {
// Converts characters in name that are considered illegal in TensorFlow Node
// name to '.'.
std::string LegalizeNodeName(llvm::StringRef name);
void LegalizeNodeName(std::string& name);
// Creates a TensorFlow node name from a location.
std::string GetNameFromLoc(Location loc);
} // namespace tensorflow
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_