STT-tensorflow/tensorflow/core/grappler/utils.cc
TensorFlower Gardener 18db4c71cd Merge pull request #40496 from tg-at-google:sign-compare-warning-fixes-batch-2
PiperOrigin-RevId: 316926815
Change-Id: I3e3788e3472fe7d3a9c0d2f2c44a335b3300210c
2020-06-17 11:22:40 -07:00

522 lines
17 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils.h"
#include <iterator>
#include <memory>
#include <queue>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace grappler {
namespace {
template <typename T>
bool SafeSetDoubleScalarTensorValue(double value, Tensor* tensor) {
using RealType = typename Eigen::NumTraits<T>::Real;
if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
return false;
}
tensor->flat<T>()(0) = static_cast<T>(value);
return true;
}
template <typename T>
bool SafeSetIntScalarTensorValue(int value, Tensor* tensor) {
using RealType = typename Eigen::NumTraits<T>::Real;
if (value > static_cast<int>(Eigen::NumTraits<RealType>::highest()) ||
value < static_cast<int>(Eigen::NumTraits<RealType>::lowest())) {
return false;
}
tensor->flat<T>()(0) = static_cast<T>(value);
return true;
}
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
// TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
bool IsShapeConsumer(const NodeDef& node) {
const string& op = node.op();
return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
}
} // namespace
NodeMap::NodeMap(GraphDef* graph) {
CHECK(graph != nullptr);
nodes_.reserve(graph->node_size());
outputs_.reserve(graph->node_size());
for (int i = 0; i < graph->node_size(); i++) {
NodeDef* node = graph->mutable_node(i);
const string& node_name = node->name();
auto rslt = nodes_.emplace(node_name, node);
// Check that the graph doesn't contain multiple nodes with the same name.
if (!rslt.second) {
// The first node found with a given name becomes the canonical.
LOG(WARNING) << "Duplicated node in the graph: " << node_name;
}
NodeDef* canonical = rslt.second ? node : rslt.first->second;
for (const auto& input : node->input()) {
outputs_[NodeName(input)].insert(canonical);
}
}
}
string TensorIdToString(const TensorId& tensor_id) {
return tensor_id.index() == 0 ? string(tensor_id.node())
: tensor_id.ToString();
}
string SafeTensorIdToString(const SafeTensorId& tensor_id) {
return tensor_id.index() == 0 ? tensor_id.node() : tensor_id.ToString();
}
bool IsSameInput(const string& name1, const string& name2) {
if (name1 == name2) return true;
TensorId tensor1 = ParseTensorName(name1);
TensorId tensor2 = ParseTensorName(name2);
return tensor1 == tensor2;
}
bool IsControlInput(const string& name) {
return !name.empty() && name[0] == '^';
}
bool IsControlInput(const TensorId& tensor_id) { return tensor_id.index() < 0; }
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
if (name[0] == '^') {
return absl::StrCat("^", prefix, delimiter, name.substr(1));
}
}
return absl::StrCat(prefix, delimiter, name);
}
string AddPrefixToNodeName(const string& name, const string& prefix) {
return AddPrefixToNodeName(name, prefix, "/");
}
bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
thread::ThreadPool* const thread_pool) {
if (timeout_in_ms <= 0) {
fn();
return true;
}
auto done = std::make_shared<Notification>();
thread_pool->Schedule([done, fn]() {
fn();
done->Notify();
});
const bool notified =
WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
return notified;
}
string AsControlDependency(const NodeDef& node) {
return absl::StrCat("^", node.name());
}
string AsControlDependency(const string& node_name) {
CHECK(!node_name.empty());
return (!node_name.empty() && node_name[0] == '^')
? node_name
: absl::StrCat("^", node_name);
}
bool NodeIsOnCpu(const NodeDef* node) {
string task, device;
return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
absl::StartsWith(device, DEVICE_CPU);
}
bool NodeIsOnGpu(const NodeDef* node) {
string task, device;
return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
absl::StartsWith(device, DEVICE_GPU);
}
int NumOutputs(const NodeDef& node, GraphDef* graph) {
int num_outputs = 0;
const OpDef* op_def = nullptr;
auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
if (status.ok()) {
for (const auto& output : op_def->output_arg()) {
if (!output.type_list_attr().empty()) {
num_outputs +=
node.attr().at(output.type_list_attr()).list().type_size();
} else if (!output.number_attr().empty()) {
num_outputs += node.attr().at(output.number_attr()).i();
} else {
num_outputs++;
}
}
} else {
FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
auto status = fdef.LookUpOpDef(node.op(), &op_def);
if (status.ok()) {
num_outputs = op_def->output_arg_size();
}
}
return num_outputs;
}
bool HasControlInputs(const NodeDef& node) {
const int num_inputs = node.input_size();
if (num_inputs > 0 && IsControlInput(node.input(num_inputs - 1))) {
return true;
}
return false;
}
bool HasRegularInputs(const NodeDef& node) {
const int num_inputs = node.input_size();
if (num_inputs > 0 && !IsControlInput(node.input(0))) {
return true;
}
return false;
}
int NumNonControlInputs(const NodeDef& node) {
int num_inputs = 0;
for (; num_inputs < node.input_size(); ++num_inputs) {
const string& input = node.input(num_inputs);
if (IsControlInput(input)) {
return num_inputs;
}
}
return num_inputs;
}
int NumControlInputs(const NodeDef& node) {
int num_inputs = 0;
for (; num_inputs < node.input_size(); ++num_inputs) {
const string& input = node.input(node.input_size() - num_inputs - 1);
if (!IsControlInput(input)) {
return num_inputs;
}
}
return num_inputs;
}
bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& node_as_input : output->input()) {
if (IsControlInput(node_as_input)) break;
TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {
return true;
}
}
}
return false;
}
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
const string& node_as_input = output->input(idx);
if (!IsControlInput(node_as_input)) break;
TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {
return true;
}
}
}
return false;
}
int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
const string& node_as_input = output->input(idx);
if (!IsControlInput(node_as_input)) break;
TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {
++num_outputs;
}
}
}
return num_outputs;
}
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& node_as_input : output->input()) {
if (IsControlInput(node_as_input)) {
break;
}
if (node_as_input == node.name()) {
++num_outputs;
} else {
const TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {
++num_outputs;
}
}
}
}
return num_outputs;
}
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_data_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
if (IsShapeConsumer(*output)) continue;
for (int i = 0; i < output->input_size(); ++i) {
const string& input = output->input(i);
if (!IsControlInput(input) && NodeName(input) == node.name()) {
++num_data_outputs;
break;
}
}
}
return num_data_outputs;
}
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr) {
if (!node.attr().count(type_attr)) {
return DT_INVALID;
}
const auto& attr = node.attr().at(type_attr);
if (attr.value_case() != AttrValue::kType) {
return DT_INVALID;
}
return attr.type();
}
NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
bool follow_control_input,
const std::function<bool(const NodeDef&)>& pred_fn) {
const NodeDef* current = &source;
const NodeDef* next = current;
while (next == &source || (next != nullptr && pred_fn(*next))) {
current = next;
if (current->input_size() == 0 ||
(!follow_control_input && IsControlInput(current->input(0)))) {
break;
}
next = node_map.GetNode(current->input(0));
if (next == nullptr) {
LOG(ERROR) << "Node not found: " << current->input(0);
}
}
return const_cast<NodeDef*>(current);
}
// Every permutation is a product of one or more cycles. Iterate over the cycles
// in the permutation, and convert each of those into a product of
// transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
bool invert_permutation) {
CHECK_EQ(graph->node_size(), permutation->size());
std::vector<int> inv_perm(permutation->size(), 0);
if (invert_permutation) {
for (size_t n = 0; n < permutation->size(); ++n) {
inv_perm[(*permutation)[n]] = n;
}
permutation->swap(inv_perm);
}
for (int n = 0, permutation_size = permutation->size();
n + 1 < permutation_size; ++n) {
while (n != (*permutation)[n]) {
std::size_t r = (*permutation)[n];
graph->mutable_node()->SwapElements(n, r);
std::swap((*permutation)[n], (*permutation)[r]);
}
}
}
void DedupControlInputs(NodeDef* node) {
absl::flat_hash_set<string> inputs;
int pos = 0;
while (pos < node->input_size()) {
const string& input = node->input(pos);
if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
node->mutable_input()->SwapElements(pos, node->input_size() - 1);
node->mutable_input()->RemoveLast();
} else {
++pos;
}
}
}
namespace {
template <typename UniqueContainer>
void EraseNodesFromGraphImpl(const UniqueContainer& nodes_to_delete,
GraphDef* graph) {
static_assert(std::is_same<typename UniqueContainer::value_type, int>::value,
"Need to pass container of ints");
int last = graph->node_size() - 1;
for (auto it = nodes_to_delete.rbegin(); it != nodes_to_delete.rend(); ++it) {
const int index = *it;
graph->mutable_node()->SwapElements(index, last);
last--;
}
graph->mutable_node()->DeleteSubrange(last + 1, nodes_to_delete.size());
}
template <typename T>
inline void STLSortAndRemoveDuplicates(T* v) {
std::sort(v->begin(), v->end());
v->erase(std::unique(v->begin(), v->end()), v->end());
}
} // namespace
void EraseNodesFromGraph(const std::set<int>& nodes_to_delete,
GraphDef* graph) {
EraseNodesFromGraphImpl(nodes_to_delete, graph);
}
void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph) {
STLSortAndRemoveDuplicates(&nodes_to_delete);
EraseNodesFromGraphImpl(nodes_to_delete, graph);
}
void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
GraphDef* graph) {
std::vector<int> nodes_idx_to_delete;
nodes_idx_to_delete.reserve(nodes_to_delete.size());
for (int i = 0; i < graph->node_size(); ++i) {
if (nodes_to_delete.count(graph->node(i).name()))
nodes_idx_to_delete.push_back(i);
}
EraseNodesFromGraphImpl(nodes_idx_to_delete, graph);
}
#define HANDLE_DOUBLE_CASE(DTYPE) \
case DTYPE: \
if (!SafeSetDoubleScalarTensorValue<EnumToDataType<DTYPE>::Type>( \
static_cast<double>(value), tensor)) { \
return errors::InvalidArgument("Cannot store value ", value, \
" in tensor of type " #DTYPE); \
} \
break
#define HANDLE_INT_CASE(DTYPE) \
case DTYPE: \
if (!SafeSetIntScalarTensorValue<EnumToDataType<DTYPE>::Type>(value, \
tensor)) { \
return errors::InvalidArgument("Cannot store value ", value, \
" in tensor of type " #DTYPE); \
} \
break
Status SetTensorValue(DataType dtype, int value, Tensor* tensor) {
// TODO(rmlarsen): Support more general shapes.
// TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported.
if (tensor->NumElements() != 1) {
return errors::InvalidArgument(
"Expected scalar tensor, got num_elements = ", tensor->NumElements());
}
switch (dtype) {
HANDLE_DOUBLE_CASE(DT_HALF);
HANDLE_DOUBLE_CASE(DT_BFLOAT16);
HANDLE_DOUBLE_CASE(DT_BOOL);
HANDLE_DOUBLE_CASE(DT_FLOAT);
HANDLE_DOUBLE_CASE(DT_DOUBLE);
HANDLE_DOUBLE_CASE(DT_UINT8);
HANDLE_DOUBLE_CASE(DT_INT8);
HANDLE_DOUBLE_CASE(DT_UINT16);
HANDLE_DOUBLE_CASE(DT_INT16);
HANDLE_DOUBLE_CASE(DT_INT32);
HANDLE_DOUBLE_CASE(DT_INT64);
HANDLE_DOUBLE_CASE(DT_COMPLEX64);
HANDLE_DOUBLE_CASE(DT_COMPLEX128);
HANDLE_INT_CASE(DT_QINT8);
HANDLE_INT_CASE(DT_QUINT8);
HANDLE_INT_CASE(DT_QINT16);
HANDLE_INT_CASE(DT_QUINT16);
HANDLE_INT_CASE(DT_QINT32);
default:
return errors::InvalidArgument("Unsupported type ",
DataTypeString(dtype));
}
return Status::OK();
}
#undef HANDLE_CASE
Status CheckAttrExists(const NodeDef& node, const string& key) {
if (!HasNodeAttr(node, key)) {
return errors::InvalidArgument("Node '", node.name(), "' lacks '", key,
"' attr: ", node.ShortDebugString());
}
return Status::OK();
}
Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
for (const string& key : keys) {
TF_RETURN_IF_ERROR(CheckAttrExists(node, key));
}
return Status::OK();
}
Status IsKernelRegisteredForNode(
absl::string_view node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
absl::string_view node_op, absl::string_view node_device,
AttrSlice node_attrs) {
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
return errors::InvalidArgument("Could not parse device name: ",
node_device);
}
return FindKernelDef(DeviceType(parsed_name.type), node_name,
has_experimental_debug_info, experimental_debug_info,
node_op, node_device, node_attrs, nullptr, nullptr);
}
Status IsKernelRegisteredForNode(const NodeDef& node) {
return IsKernelRegisteredForNode(node.name(),
node.has_experimental_debug_info(),
node.experimental_debug_info(), node.op(),
node.device(), AttrSlice(&node.attr()));
}
} // end namespace grappler
} // end namespace tensorflow