diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc index 55eecbcbce2..be0f4a009da 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -172,6 +173,37 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const { } } +namespace { + +bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name, + const AttrValue& attr_value) { + // TODO(iga): It might make sense to augment OpRegistrationData with a + // {attr_name -> default_attr_value} FlatMap to avoid the loop here. + for (const OpDef::AttrDef& attr_def : op_def->attr()) { + if (attr_def.name() == attr_name && attr_def.has_default_value() && + AreAttrValuesEqual(attr_def.default_value(), attr_value)) { + return true; + } + } + return false; +} + +} // namespace + +void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const { + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name().c_str(), &op_def); + + for (auto& entry : encoded_attrs_) { + attr_tmp_.ParseFromString(entry.second); + // Insert the attr-value pair if we did not find the OpDef or if the value + // is different from default. + if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) { + m->insert(AttrValueMap::value_type(entry.first, attr_tmp_)); + } + } +} + void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name, const AttrValue& value) { encoded_attrs_.emplace(string(attr_name), value.SerializeAsString()); diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h index f66ab0a8277..aaf9950faae 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.h +++ b/tensorflow/core/common_runtime/eager/attr_builder.h @@ -132,6 +132,13 @@ class AttrBuilder { // well as any default attr-value pairs from the associated op_def, if there // is one. void FillAttrValueMap(AttrValueMap* m) const; + + // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except + // when the value matches the default for this attr. + // More precisely, if the global op registry contains an OpDef for this op + // and if an attribute value is the same as the default (according to the + // OpDef), this attr-value pair is not added to `m`. + void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const; const NodeDef& BuildNodeDef(); private: diff --git a/tensorflow/core/common_runtime/eager/attr_builder_test.cc b/tensorflow/core/common_runtime/eager/attr_builder_test.cc index ebe0779cd74..d3d39193056 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder_test.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder_test.cc @@ -83,5 +83,38 @@ TEST(AttrTypeMap, CacheKey) { ASSERT_FALSE(cache_key == a.CacheKey("cpu:0")); } +string ToString(const AttrValueMap& m) { + std::vector strs; + for (const auto& e : m) { + strs.push_back(absl::StrCat(e.first, " -> ", e.second.DebugString())); + } + return absl::StrJoin(strs, "\n"); +} + +TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_MatMul) { + AttrBuilder a("MatMul"); + a.Set("transpose_a", true); + a.Set("transpose_b", false); + + AttrValueMap m; + a.FillAttrValueMapWithoutDefaults(&m); + // Only non-default value must end up in the map + ASSERT_EQ(1, m.size()) << ToString(m); + ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m); +} + +TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_UnknownOp) { + AttrBuilder a("SomeUnknownOp"); + a.Set("transpose_a", true); + a.Set("transpose_b", false); + + AttrValueMap m; + a.FillAttrValueMapWithoutDefaults(&m); + // Only non-default value must end up in the map + ASSERT_EQ(2, m.size()) << ToString(m); + ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m); + ASSERT_EQ(false, m["transpose_b"].b()) << ToString(m); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 3bb1954bfb4..9bd0960e398 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/device_name_utils.h" @@ -343,9 +344,7 @@ Status EagerContext::FindDeviceByName(const string& name, return Status::OK(); } -void EagerContext::ClearRunMetadata() { - run_metadata_.Clear(); -} +void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); } void EagerContext::StartStep() { mutex_lock ml(metadata_mu_); @@ -386,6 +385,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { eager::RegisterFunctionRequest request; request.set_context_id(GetContextId()); *request.mutable_function_def() = fdef; + StripDefaultAttributes(*OpRegistry::Global(), + request.mutable_function_def()->mutable_node_def()); std::vector responses( remote_contexts_.size()); std::vector statuses(remote_contexts_.size()); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 7c160c5e4f9..217b1c69b8b 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -658,7 +658,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) { remote_op->set_id(ctx->RemoteMgr()->NextOpId()); remote_op->set_name(op->Name()); - op->Attrs().FillAttrValueMap(remote_op->mutable_attrs()); + op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs()); remote_op->set_device(op->Device()->name()); } diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index e86a88c661b..978ed814e8e 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -163,6 +164,42 @@ Status RemoveNewDefaultAttrsFromGraphDef( return Status::OK(); } +void StripDefaultAttributes(const OpRegistryInterface& op_registry, + protobuf::RepeatedPtrField* nodes) { + for (int i = 0; i < nodes->size(); ++i) { + NodeDef* node = nodes->Mutable(i); + + const OpDef* op_def; + const OpRegistrationData* op_reg_data = nullptr; + Status s = op_registry.LookUp(node->op(), &op_reg_data); + if (!s.ok()) { + VLOG(1) << "Ignoring encountered unknown operation " + << SummarizeNodeDef(*node) + << " when stripping default attributes. It is likely a function, " + "in which case ignoring it is fine"; + continue; + } + op_def = &op_reg_data->op_def; + + for (const OpDef::AttrDef& attr_def : op_def->attr()) { + if (attr_def.has_default_value()) { + AttrValueMap* attrs = node->mutable_attr(); + const string& name = attr_def.name(); + auto iter = attrs->find(name); + if (iter != attrs->end()) { + const AttrValue& default_value = attr_def.default_value(); + // The "Fast*" version can return false negatives for very large + // AttrValues containing Tensors. There should never be an attribute + // whose default value is a tensor larger than 32MB. + if (FastAreAttrValuesEqual(iter->second, default_value)) { + attrs->erase(name); + } + } + } + } + } +} + void OpsUsedByGraph(const GraphDef& graph_def, std::set* ops_used_in_graph) { // Map function names to definitions. diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h index 2f8d5e8f511..9ebe610bdca 100644 --- a/tensorflow/core/framework/graph_def_util.h +++ b/tensorflow/core/framework/graph_def_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_ #include + #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/core/status.h" @@ -24,6 +25,7 @@ namespace tensorflow { // Forward declare proto so that it's symbols can be removed from .so exports class GraphDef; +class NodeDef; // Produce a human-readable version of a GraphDef that is more concise // than a text-format proto. @@ -97,6 +99,18 @@ Status RemoveNewDefaultAttrsFromGraphDef( const OpRegistryInterface& producer_op_registry, std::set>* op_attr_removed); +// Goes over the `nodes` and removes attributes that are set to their +// default values according to op_registry. +// If some node's definition is not found in the `op_registry`, this node is +// simply skipped. In most cases, these nodes would be function calls. +// If a stricter behavior is desired, one can add FunctionLibraryDefinition +// argument to check for functions and their attributes. +// This is obvious from signature, but as a warning, if `nodes` contain +// nodes calling functions, e.g. PartitionCallOp or FunctionalIf, this +// function does not "recurse" into them. +void StripDefaultAttributes(const OpRegistryInterface& op_registry, + protobuf::RepeatedPtrField* nodes); + // Two functions that collect the ops used by a graph. // // This returns the ops used as a set of strings. diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index 08cb4c28b20..14089aca51e 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -235,6 +235,40 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) { EXPECT_EQ(expected_removed, op_attr_removed); } +TEST(StripDefaultAttributesTest, DefaultStripped) { + OpList op_list; + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"), + op_list.add_op())); + OpListOpRegistry registry(&op_list); + + GraphDef graph_def; + // This adds the default attribute + TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry) + .Finalize(graph_def.add_node())); + ASSERT_EQ(1, graph_def.node(0).attr_size()); + ASSERT_EQ(12, graph_def.node(0).attr().at("a").i()); + + StripDefaultAttributes(registry, graph_def.mutable_node()); + ASSERT_EQ(1, graph_def.node_size()); + ASSERT_EQ(0, graph_def.node(0).attr_size()); +} + +TEST(StripDefaultAttributesTest, NonDefaultNotStripped) { + OpList op_list; + TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"), + op_list.add_op())); + OpListOpRegistry registry(&op_list); + + GraphDef graph_def; + TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry) + .Attr("a", 9) + .Finalize(graph_def.add_node())); + + GraphDef expected = graph_def; + StripDefaultAttributes(registry, graph_def.mutable_node()); + TF_EXPECT_GRAPH_EQ(expected, graph_def); +} + TEST(StrippedOpListForGraphTest, FlatTest) { // Make four ops OpList op_list;