Expose FindKernelDef with NodeDef components (name, op, device, etc.). Update Grappler util's IsKernelRegisteredForNode to use lower FindKernelDef.
PiperOrigin-RevId: 247104392
This commit is contained in:
parent
7413b165c0
commit
1a40c07e83
@ -21,7 +21,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
@ -50,7 +49,7 @@ AttrSlice::AttrSlice(const NodeDef& node_def)
|
||||
|
||||
AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
|
||||
|
||||
static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
|
||||
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
|
||||
string ret;
|
||||
|
||||
// We sort the attrs so the output is deterministic.
|
||||
@ -120,6 +119,13 @@ string FormatNodeDefForError(const NodeDef& node_def) {
|
||||
return FormatNodeForError(NodeDebugInfo(node_def));
|
||||
}
|
||||
|
||||
string FormatNodeDefForError(
|
||||
StringPiece node_name, bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
|
||||
return FormatNodeForError(NodeDebugInfo(
|
||||
node_name, has_experimental_debug_info, experimental_debug_info));
|
||||
}
|
||||
|
||||
void GetMergedOriginalNodeNames(const NodeDebugInfo& from,
|
||||
const NodeDebugInfo& to,
|
||||
std::set<string>* names) {
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
@ -34,6 +35,7 @@ struct NodeDebugInfo;
|
||||
// We forward declare protos so that kernels don't need to depend on them
|
||||
class NodeDef;
|
||||
class OpDef;
|
||||
class AttrSlice;
|
||||
|
||||
// Name of the attribute used to encode node colocation constraints.
|
||||
//
|
||||
@ -50,12 +52,16 @@ extern const char* const kColocationGroupPrefix;
|
||||
string SummarizeNode(const Node& node);
|
||||
string SummarizeNodeDef(const NodeDef& node_def);
|
||||
string SummarizeAttrs(const NodeDef& node_def);
|
||||
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
|
||||
|
||||
// Produces a formatted string pattern from the node which can uniquely identify
|
||||
// this node upstream to produce an informative error message. The pattern
|
||||
// followed is: {{node <node_name>}}
|
||||
string FormatNodeForError(const Node& node);
|
||||
string FormatNodeDefForError(const NodeDef& node_def);
|
||||
string FormatNodeDefForError(
|
||||
StringPiece node_name, bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
|
||||
|
||||
// Merges the original node names from the debug information of 'from' to the
|
||||
// debug information of 'to'.
|
||||
|
@ -1132,28 +1132,31 @@ namespace {
|
||||
static const StringPiece kKernelAttr("_kernel");
|
||||
|
||||
// TODO(irving): Replace with const Node& version below.
|
||||
Status FindKernelRegistration(const DeviceType& device_type,
|
||||
const NodeDef& node_def,
|
||||
const KernelRegistration** reg,
|
||||
bool* was_attr_mismatch) {
|
||||
Status FindKernelRegistration(
|
||||
const DeviceType& device_type, StringPiece node_name,
|
||||
bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
|
||||
StringPiece node_op, AttrSlice node_attrs, const KernelRegistration** reg,
|
||||
bool* was_attr_mismatch) {
|
||||
*reg = nullptr;
|
||||
*was_attr_mismatch = false;
|
||||
// Label defaults to empty if not found in NodeDef.
|
||||
const string& label = GetNodeAttrString(node_def, kKernelAttr);
|
||||
const string& label = GetNodeAttrString(node_attrs, kKernelAttr);
|
||||
|
||||
const string key = Key(node_def.op(), device_type, label);
|
||||
const string key = Key(node_op, device_type, label);
|
||||
auto regs = GlobalKernelRegistryTyped()->equal_range(key);
|
||||
for (auto iter = regs.first; iter != regs.second; ++iter) {
|
||||
// If there is a kernel registered for the op and device_type,
|
||||
// check that the attrs match.
|
||||
bool match;
|
||||
TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
|
||||
TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_attrs, &match));
|
||||
if (match) {
|
||||
if (*reg != nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Multiple OpKernel registrations match NodeDef '",
|
||||
FormatNodeDefForError(node_def), "': '",
|
||||
ProtoShortDebugString((*reg)->def), "' and '",
|
||||
FormatNodeDefForError(node_name, has_experimental_debug_info,
|
||||
experimental_debug_info),
|
||||
"': '", ProtoShortDebugString((*reg)->def), "' and '",
|
||||
ProtoShortDebugString(iter->second.def), "'");
|
||||
}
|
||||
*reg = &iter->second;
|
||||
@ -1164,6 +1167,16 @@ Status FindKernelRegistration(const DeviceType& device_type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FindKernelRegistration(const DeviceType& device_type,
|
||||
const NodeDef& node_def,
|
||||
const KernelRegistration** reg,
|
||||
bool* was_attr_mismatch) {
|
||||
return FindKernelRegistration(
|
||||
device_type, node_def.name(), node_def.has_experimental_debug_info(),
|
||||
node_def.experimental_debug_info(), node_def.op(),
|
||||
AttrSlice(&node_def.attr()), reg, was_attr_mismatch);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool KernelDefAvailable(const DeviceType& device_type,
|
||||
@ -1176,24 +1189,31 @@ bool KernelDefAvailable(const DeviceType& device_type,
|
||||
}
|
||||
|
||||
// TODO(irving): Change const NodeDef& to const Node&
|
||||
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
||||
const KernelDef** def, string* kernel_class_name) {
|
||||
Status FindKernelDef(
|
||||
const DeviceType& device_type, StringPiece node_name,
|
||||
bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
|
||||
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
|
||||
const KernelDef** def, string* kernel_class_name) {
|
||||
const KernelRegistration* reg = nullptr;
|
||||
bool was_attr_mismatch;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindKernelRegistration(device_type, node_def, ®, &was_attr_mismatch));
|
||||
TF_RETURN_IF_ERROR(FindKernelRegistration(
|
||||
device_type, node_name, has_experimental_debug_info,
|
||||
experimental_debug_info, node_op, node_attrs, ®, &was_attr_mismatch));
|
||||
if (reg == nullptr) {
|
||||
Status s = errors::NotFound(
|
||||
"No registered '", node_def.op(), "' OpKernel for ",
|
||||
"No registered '", node_op, "' OpKernel for ",
|
||||
DeviceTypeString(device_type), " devices compatible with node ",
|
||||
FormatNodeDefForError(node_def));
|
||||
FormatNodeDefForError(node_name, has_experimental_debug_info,
|
||||
experimental_debug_info));
|
||||
if (was_attr_mismatch) {
|
||||
errors::AppendToMessage(
|
||||
&s, " (OpKernel was found, but attributes didn't match) ",
|
||||
"Requested Attributes: ", SummarizeAttrs(node_def));
|
||||
"Requested Attributes: ",
|
||||
SummarizeAttrsHelper(node_attrs, node_device));
|
||||
}
|
||||
errors::AppendToMessage(
|
||||
&s, ". Registered:", KernelsRegisteredForOp(node_def.op()));
|
||||
errors::AppendToMessage(&s,
|
||||
". Registered:", KernelsRegisteredForOp(node_op));
|
||||
return s;
|
||||
}
|
||||
if (def != nullptr) *def = ®->def;
|
||||
@ -1201,6 +1221,14 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
||||
const KernelDef** def, string* kernel_class_name) {
|
||||
return FindKernelDef(
|
||||
device_type, node_def.name(), node_def.has_experimental_debug_info(),
|
||||
node_def.experimental_debug_info(), node_def.op(), node_def.device(),
|
||||
AttrSlice(&node_def.attr()), def, kernel_class_name);
|
||||
}
|
||||
|
||||
Status SupportedDeviceTypesForNode(
|
||||
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
|
||||
PrioritizedDeviceTypeVector* prioritized_device_types) {
|
||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/control_flow.h"
|
||||
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
@ -1436,6 +1437,17 @@ class Name : public KernelDefBuilder {
|
||||
// Checks whether a given kernel is registered on device_type.
|
||||
bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
|
||||
|
||||
// If node of node_name, experimental_debug_info, node_op, node_device and
|
||||
// node_attrs has a corresponding kernel registered on device_type, returns OK
|
||||
// and fill in the kernel def and kernel_class_name. <def> and
|
||||
// <kernel_class_name> may be null.
|
||||
Status FindKernelDef(
|
||||
const DeviceType& device_type, StringPiece node_name,
|
||||
bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
|
||||
StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
|
||||
const KernelDef** def, string* kernel_class_name);
|
||||
|
||||
// If node_def has a corresponding kernel registered on device_type,
|
||||
// returns OK and fill in the kernel def and kernel_class_name. <def> and
|
||||
// <kernel_class_name> may be null.
|
||||
|
@ -315,9 +315,15 @@ Status Node::input_tensor(int idx, OutputTensor* t) const {
|
||||
// NodeDebugInfo
|
||||
|
||||
NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
|
||||
NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) : name(ndef.name()) {
|
||||
if (ndef.has_experimental_debug_info()) {
|
||||
const auto& names = ndef.experimental_debug_info().original_node_names();
|
||||
NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
|
||||
: NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
|
||||
ndef.experimental_debug_info()) {}
|
||||
NodeDebugInfo::NodeDebugInfo(
|
||||
StringPiece node_name, bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
|
||||
: name(node_name) {
|
||||
if (has_experimental_debug_info) {
|
||||
const auto& names = experimental_debug_info.original_node_names();
|
||||
original_node_names.assign(names.begin(), names.end());
|
||||
}
|
||||
}
|
||||
|
@ -40,7 +40,9 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/edgeset.h"
|
||||
@ -309,6 +311,8 @@ struct NodeDebugInfo {
|
||||
|
||||
NodeDebugInfo(const Node& n);
|
||||
NodeDebugInfo(const NodeDef& ndef);
|
||||
NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
|
||||
};
|
||||
|
||||
// Represents an input of a node, i.e., the `index`-th input to `node`.
|
||||
|
@ -493,13 +493,26 @@ Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IsKernelRegisteredForNode(const NodeDef& node) {
|
||||
Status IsKernelRegisteredForNode(
|
||||
absl::string_view node_name, bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
|
||||
absl::string_view node_op, absl::string_view node_device,
|
||||
AttrSlice node_attrs) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
|
||||
if (!DeviceNameUtils::ParseFullName(node_device, &parsed_name)) {
|
||||
return errors::InvalidArgument("Could not parse device name: ",
|
||||
node.device());
|
||||
node_device);
|
||||
}
|
||||
return FindKernelDef(DeviceType(parsed_name.type), node, nullptr, nullptr);
|
||||
return FindKernelDef(DeviceType(parsed_name.type), node_name,
|
||||
has_experimental_debug_info, experimental_debug_info,
|
||||
node_op, node_device, node_attrs, nullptr, nullptr);
|
||||
}
|
||||
|
||||
Status IsKernelRegisteredForNode(const NodeDef& node) {
|
||||
return IsKernelRegisteredForNode(node.name(),
|
||||
node.has_experimental_debug_info(),
|
||||
node.experimental_debug_info(), node.op(),
|
||||
node.device(), AttrSlice(&node.attr()));
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -298,6 +298,11 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
|
||||
|
||||
// Returns Status::OK() if a kernel is registered for node.op() on the device
|
||||
// type corresponding to node.device().
|
||||
Status IsKernelRegisteredForNode(
|
||||
absl::string_view node_name, bool has_experimental_debug_info,
|
||||
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
|
||||
absl::string_view node_op, absl::string_view node_device,
|
||||
AttrSlice node_attrs);
|
||||
Status IsKernelRegisteredForNode(const NodeDef& node);
|
||||
|
||||
Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
|
||||
|
Loading…
Reference in New Issue
Block a user