Expose FindKernelDef with NodeDef components (name, op, device, etc.). Update Grappler util's IsKernelRegisteredForNode to use lower FindKernelDef.

PiperOrigin-RevId: 247104392
This commit is contained in:
Andy Ly 2019-05-07 15:38:40 -07:00 committed by TensorFlower Gardener
parent 7413b165c0
commit 1a40c07e83
8 changed files with 108 additions and 28 deletions

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
@ -50,7 +49,7 @@ AttrSlice::AttrSlice(const NodeDef& node_def)
AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
string ret;
// We sort the attrs so the output is deterministic.
@ -120,6 +119,13 @@ string FormatNodeDefForError(const NodeDef& node_def) {
return FormatNodeForError(NodeDebugInfo(node_def));
}
string FormatNodeDefForError(
StringPiece node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
return FormatNodeForError(NodeDebugInfo(
node_name, has_experimental_debug_info, experimental_debug_info));
}
void GetMergedOriginalNodeNames(const NodeDebugInfo& from,
const NodeDebugInfo& to,
std::set<string>* names) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@ -34,6 +35,7 @@ struct NodeDebugInfo;
// We forward declare protos so that kernels don't need to depend on them
class NodeDef;
class OpDef;
class AttrSlice;
// Name of the attribute used to encode node colocation constraints.
//
@ -50,12 +52,16 @@ extern const char* const kColocationGroupPrefix;
string SummarizeNode(const Node& node);
string SummarizeNodeDef(const NodeDef& node_def);
string SummarizeAttrs(const NodeDef& node_def);
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
// Produces a formatted string pattern from the node which can uniquely identify
// this node upstream to produce an informative error message. The pattern
// followed is: {{node <node_name>}}
string FormatNodeForError(const Node& node);
string FormatNodeDefForError(const NodeDef& node_def);
string FormatNodeDefForError(
StringPiece node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
// Merges the original node names from the debug information of 'from' to the
// debug information of 'to'.

View File

@ -1132,28 +1132,31 @@ namespace {
static const StringPiece kKernelAttr("_kernel");
// TODO(irving): Replace with const Node& version below.
Status FindKernelRegistration(const DeviceType& device_type,
const NodeDef& node_def,
const KernelRegistration** reg,
bool* was_attr_mismatch) {
Status FindKernelRegistration(
const DeviceType& device_type, StringPiece node_name,
bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
bool* was_attr_mismatch) {
*reg = nullptr;
*was_attr_mismatch = false;
// Label defaults to empty if not found in NodeDef.
const string& label = GetNodeAttrString(node_def, kKernelAttr);
const string& label = GetNodeAttrString(node_attrs, kKernelAttr);
const string key = Key(node_def.op(), device_type, label);
const string key = Key(node_op, device_type, label);
auto regs = GlobalKernelRegistryTyped()->equal_range(key);
for (auto iter = regs.first; iter != regs.second; ++iter) {
// If there is a kernel registered for the op and device_type,
// check that the attrs match.
bool match;
TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
if (match) {
if (*reg != nullptr) {
return errors::InvalidArgument(
"Multiple OpKernel registrations match NodeDef '",
FormatNodeDefForError(node_def), "': '",
ProtoShortDebugString((*reg)->def), "' and '",
FormatNodeDefForError(node_name, has_experimental_debug_info,
experimental_debug_info),
"': '", ProtoShortDebugString((*reg)->def), "' and '",
ProtoShortDebugString(iter->second.def), "'");
}
*reg = &iter->second;
@ -1164,6 +1167,16 @@ Status FindKernelRegistration(const DeviceType& device_type,
return Status::OK();
}
Status FindKernelRegistration(const DeviceType& device_type,
const NodeDef& node_def,
const KernelRegistration** reg,
bool* was_attr_mismatch) {
return FindKernelRegistration(
device_type, node_def.name(), node_def.has_experimental_debug_info(),
node_def.experimental_debug_info(), node_def.op(),
AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
}
} // namespace
bool KernelDefAvailable(const DeviceType& device_type,
@ -1176,24 +1189,31 @@ bool KernelDefAvailable(const DeviceType& device_type,
}
// TODO(irving): Change const NodeDef& to const Node&
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
const KernelDef** def, string* kernel_class_name) {
Status FindKernelDef(
const DeviceType& device_type, StringPiece node_name,
bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
const KernelDef** def, string* kernel_class_name) {
const KernelRegistration* reg = nullptr;
bool was_attr_mismatch;
TF_RETURN_IF_ERROR(
FindKernelRegistration(device_type, node_def, &reg, &was_attr_mismatch));
TF_RETURN_IF_ERROR(FindKernelRegistration(
device_type, node_name, has_experimental_debug_info,
experimental_debug_info, node_op, node_attrs, &reg, &was_attr_mismatch));
if (reg == nullptr) {
Status s = errors::NotFound(
"No registered '", node_def.op(), "' OpKernel for ",
"No registered '", node_op, "' OpKernel for ",
DeviceTypeString(device_type), " devices compatible with node ",
FormatNodeDefForError(node_def));
FormatNodeDefForError(node_name, has_experimental_debug_info,
experimental_debug_info));
if (was_attr_mismatch) {
errors::AppendToMessage(
&s, " (OpKernel was found, but attributes didn't match) ",
"Requested Attributes: ", SummarizeAttrs(node_def));
"Requested Attributes: ",
SummarizeAttrsHelper(node_attrs, node_device));
}
errors::AppendToMessage(
&s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
errors::AppendToMessage(&s,
". Registered:", KernelsRegisteredForOp(node_op));
return s;
}
if (def != nullptr) *def = &reg->def;
@ -1201,6 +1221,14 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
return Status::OK();
}
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
const KernelDef** def, string* kernel_class_name) {
return FindKernelDef(
device_type, node_def.name(), node_def.has_experimental_debug_info(),
node_def.experimental_debug_info(), node_def.op(), node_def.device(),
AttrSlice(&node_def.attr()), def, kernel_class_name);
}
Status SupportedDeviceTypesForNode(
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
PrioritizedDeviceTypeVector* prioritized_device_types) {

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <atomic>
#include <functional>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h"
@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/rendezvous.h"
@ -1436,6 +1437,17 @@ class Name : public KernelDefBuilder {
// Checks whether a given kernel is registered on device_type.
bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
// If node of node_name, experimental_debug_info, node_op, node_device and
// node_attrs has a corresponding kernel registered on device_type, returns OK
// and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(
const DeviceType& device_type, StringPiece node_name,
bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
const KernelDef** def, string* kernel_class_name);
// If node_def has a corresponding kernel registered on device_type,
// returns OK and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.

View File

@ -315,9 +315,15 @@ Status Node::input_tensor(int idx, OutputTensor* t) const {
// NodeDebugInfo
NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) : name(ndef.name()) {
if (ndef.has_experimental_debug_info()) {
const auto& names = ndef.experimental_debug_info().original_node_names();
NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
: NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
ndef.experimental_debug_info()) {}
NodeDebugInfo::NodeDebugInfo(
StringPiece node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
: name(node_name) {
if (has_experimental_debug_info) {
const auto& names = experimental_debug_info.original_node_names();
original_node_names.assign(names.begin(), names.end());
}
}

View File

@ -40,7 +40,9 @@ limitations under the License.
#include <functional>
#include <string>
#include <vector>
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/edgeset.h"
@ -309,6 +311,8 @@ struct NodeDebugInfo {
NodeDebugInfo(const Node& n);
NodeDebugInfo(const NodeDef& ndef);
NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
};
// Represents an input of a node, i.e., the `index`-th input to `node`.

View File

@ -493,13 +493,26 @@ Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
return Status::OK();
}
Status IsKernelRegisteredForNode(const NodeDef& node) {
Status IsKernelRegisteredForNode(
absl::string_view node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
absl::string_view node_op, absl::string_view node_device,
AttrSlice node_attrs) {
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
return errors::InvalidArgument("Could not parse device name: ",
node.device());
node_device);
}
return FindKernelDef(DeviceType(parsed_name.type), node, nullptr, nullptr);
return FindKernelDef(DeviceType(parsed_name.type), node_name,
has_experimental_debug_info, experimental_debug_info,
node_op, node_device, node_attrs, nullptr, nullptr);
}
Status IsKernelRegisteredForNode(const NodeDef& node) {
return IsKernelRegisteredForNode(node.name(),
node.has_experimental_debug_info(),
node.experimental_debug_info(), node.op(),
node.device(), AttrSlice(&node.attr()));
}
} // end namespace grappler

View File

@ -298,6 +298,11 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
// Returns Status::OK() if a kernel is registered for node.op() on the device
// type corresponding to node.device().
Status IsKernelRegisteredForNode(
absl::string_view node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
absl::string_view node_op, absl::string_view node_device,
AttrSlice node_attrs);
Status IsKernelRegisteredForNode(const NodeDef& node);
Status SetTensorValue(DataType dtype, int value, Tensor* tensor);