Strip default attributes before sending a function or op to eager service

We don't need to add default attributes back because:
 - function instantiation already adds default attributes
 - conversion from eager AttrBuilder to NodeDef adds default attributes

PiperOrigin-RevId: 270928211
This commit is contained in:
Igor Ganichev 2019-09-24 10:02:55 -07:00 committed by TensorFlower Gardener
parent 1c751f7355
commit 80c49c0d9f
8 changed files with 162 additions and 4 deletions

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -172,6 +173,37 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
}
}
namespace {
bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
const AttrValue& attr_value) {
// TODO(iga): It might make sense to augment OpRegistrationData with a
// {attr_name -> default_attr_value} FlatMap to avoid the loop here.
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
if (attr_def.name() == attr_name && attr_def.has_default_value() &&
AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
return true;
}
}
return false;
}
} // namespace
void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
const OpDef* op_def = nullptr;
Status s = OpDefForOp(op_name().c_str(), &op_def);
for (auto& entry : encoded_attrs_) {
attr_tmp_.ParseFromString(entry.second);
// Insert the attr-value pair if we did not find the OpDef or if the value
// is different from default.
if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
}
}
}
void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
const AttrValue& value) {
encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());

View File

@ -132,6 +132,13 @@ class AttrBuilder {
// well as any default attr-value pairs from the associated op_def, if there
// is one.
void FillAttrValueMap(AttrValueMap* m) const;
// Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except
// when the value matches the default for this attr.
// More precisely, if the global op registry contains an OpDef for this op
// and if an attribute value is the same as the default (according to the
// OpDef), this attr-value pair is not added to `m`.
void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const;
const NodeDef& BuildNodeDef();
private:

View File

@ -83,5 +83,38 @@ TEST(AttrTypeMap, CacheKey) {
ASSERT_FALSE(cache_key == a.CacheKey("cpu:0"));
}
string ToString(const AttrValueMap& m) {
std::vector<string> strs;
for (const auto& e : m) {
strs.push_back(absl::StrCat(e.first, " -> ", e.second.DebugString()));
}
return absl::StrJoin(strs, "\n");
}
TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_MatMul) {
AttrBuilder a("MatMul");
a.Set("transpose_a", true);
a.Set("transpose_b", false);
AttrValueMap m;
a.FillAttrValueMapWithoutDefaults(&m);
// Only non-default value must end up in the map
ASSERT_EQ(1, m.size()) << ToString(m);
ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m);
}
TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_UnknownOp) {
AttrBuilder a("SomeUnknownOp");
a.Set("transpose_a", true);
a.Set("transpose_b", false);
AttrValueMap m;
a.FillAttrValueMapWithoutDefaults(&m);
// Only non-default value must end up in the map
ASSERT_EQ(2, m.size()) << ToString(m);
ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m);
ASSERT_EQ(false, m["transpose_b"].b()) << ToString(m);
}
} // namespace
} // namespace tensorflow

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -343,9 +344,7 @@ Status EagerContext::FindDeviceByName(const string& name,
return Status::OK();
}
void EagerContext::ClearRunMetadata() {
run_metadata_.Clear();
}
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
void EagerContext::StartStep() {
mutex_lock ml(metadata_mu_);
@ -386,6 +385,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
eager::RegisterFunctionRequest request;
request.set_context_id(GetContextId());
*request.mutable_function_def() = fdef;
StripDefaultAttributes(*OpRegistry::Global(),
request.mutable_function_def()->mutable_node_def());
std::vector<eager::RegisterFunctionResponse> responses(
remote_contexts_.size());
std::vector<Status> statuses(remote_contexts_.size());

View File

@ -658,7 +658,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
remote_op->set_id(ctx->RemoteMgr()->NextOpId());
remote_op->set_name(op->Name());
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
remote_op->set_device(op->Device()->name());
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -163,6 +164,42 @@ Status RemoveNewDefaultAttrsFromGraphDef(
return Status::OK();
}
void StripDefaultAttributes(const OpRegistryInterface& op_registry,
protobuf::RepeatedPtrField<NodeDef>* nodes) {
for (int i = 0; i < nodes->size(); ++i) {
NodeDef* node = nodes->Mutable(i);
const OpDef* op_def;
const OpRegistrationData* op_reg_data = nullptr;
Status s = op_registry.LookUp(node->op(), &op_reg_data);
if (!s.ok()) {
VLOG(1) << "Ignoring encountered unknown operation "
<< SummarizeNodeDef(*node)
<< " when stripping default attributes. It is likely a function, "
"in which case ignoring it is fine";
continue;
}
op_def = &op_reg_data->op_def;
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
if (attr_def.has_default_value()) {
AttrValueMap* attrs = node->mutable_attr();
const string& name = attr_def.name();
auto iter = attrs->find(name);
if (iter != attrs->end()) {
const AttrValue& default_value = attr_def.default_value();
// The "Fast*" version can return false negatives for very large
// AttrValues containing Tensors. There should never be an attribute
// whose default value is a tensor larger than 32MB.
if (FastAreAttrValuesEqual(iter->second, default_value)) {
attrs->erase(name);
}
}
}
}
}
}
void OpsUsedByGraph(const GraphDef& graph_def,
std::set<string>* ops_used_in_graph) {
// Map function names to definitions.

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
#include <set>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
@ -24,6 +25,7 @@ namespace tensorflow {
// Forward declare proto so that it's symbols can be removed from .so exports
class GraphDef;
class NodeDef;
// Produce a human-readable version of a GraphDef that is more concise
// than a text-format proto.
@ -97,6 +99,18 @@ Status RemoveNewDefaultAttrsFromGraphDef(
const OpRegistryInterface& producer_op_registry,
std::set<std::pair<string, string>>* op_attr_removed);
// Goes over the `nodes` and removes attributes that are set to their
// default values according to op_registry.
// If some node's definition is not found in the `op_registry`, this node is
// simply skipped. In most cases, these nodes would be function calls.
// If a stricter behavior is desired, one can add FunctionLibraryDefinition
// argument to check for functions and their attributes.
// This is obvious from signature, but as a warning, if `nodes` contain
// nodes calling functions, e.g. PartitionCallOp or FunctionalIf, this
// function does not "recurse" into them.
void StripDefaultAttributes(const OpRegistryInterface& op_registry,
protobuf::RepeatedPtrField<NodeDef>* nodes);
// Two functions that collect the ops used by a graph.
//
// This returns the ops used as a set of strings.

View File

@ -235,6 +235,40 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
EXPECT_EQ(expected_removed, op_attr_removed);
}
TEST(StripDefaultAttributesTest, DefaultStripped) {
OpList op_list;
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
op_list.add_op()));
OpListOpRegistry registry(&op_list);
GraphDef graph_def;
// This adds the default attribute
TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", &registry)
.Finalize(graph_def.add_node()));
ASSERT_EQ(1, graph_def.node(0).attr_size());
ASSERT_EQ(12, graph_def.node(0).attr().at("a").i());
StripDefaultAttributes(registry, graph_def.mutable_node());
ASSERT_EQ(1, graph_def.node_size());
ASSERT_EQ(0, graph_def.node(0).attr_size());
}
TEST(StripDefaultAttributesTest, NonDefaultNotStripped) {
OpList op_list;
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
op_list.add_op()));
OpListOpRegistry registry(&op_list);
GraphDef graph_def;
TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", &registry)
.Attr("a", 9)
.Finalize(graph_def.add_node()));
GraphDef expected = graph_def;
StripDefaultAttributes(registry, graph_def.mutable_node());
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
TEST(StrippedOpListForGraphTest, FlatTest) {
// Make four ops
OpList op_list;