Strip default attributes before sending a function or op to eager service
We don't need to add default attributes back because: - function instantiation already adds default attributes - conversion from eager AttrBuilder to NodeDef adds default attributes PiperOrigin-RevId: 270928211
This commit is contained in:
parent
1c751f7355
commit
80c49c0d9f
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
@ -172,6 +173,37 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
|
||||
const AttrValue& attr_value) {
|
||||
// TODO(iga): It might make sense to augment OpRegistrationData with a
|
||||
// {attr_name -> default_attr_value} FlatMap to avoid the loop here.
|
||||
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
|
||||
if (attr_def.name() == attr_name && attr_def.has_default_value() &&
|
||||
AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
|
||||
const OpDef* op_def = nullptr;
|
||||
Status s = OpDefForOp(op_name().c_str(), &op_def);
|
||||
|
||||
for (auto& entry : encoded_attrs_) {
|
||||
attr_tmp_.ParseFromString(entry.second);
|
||||
// Insert the attr-value pair if we did not find the OpDef or if the value
|
||||
// is different from default.
|
||||
if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
|
||||
m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
|
||||
const AttrValue& value) {
|
||||
encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
|
||||
|
@ -132,6 +132,13 @@ class AttrBuilder {
|
||||
// well as any default attr-value pairs from the associated op_def, if there
|
||||
// is one.
|
||||
void FillAttrValueMap(AttrValueMap* m) const;
|
||||
|
||||
// Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except
|
||||
// when the value matches the default for this attr.
|
||||
// More precisely, if the global op registry contains an OpDef for this op
|
||||
// and if an attribute value is the same as the default (according to the
|
||||
// OpDef), this attr-value pair is not added to `m`.
|
||||
void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const;
|
||||
const NodeDef& BuildNodeDef();
|
||||
|
||||
private:
|
||||
|
@ -83,5 +83,38 @@ TEST(AttrTypeMap, CacheKey) {
|
||||
ASSERT_FALSE(cache_key == a.CacheKey("cpu:0"));
|
||||
}
|
||||
|
||||
string ToString(const AttrValueMap& m) {
|
||||
std::vector<string> strs;
|
||||
for (const auto& e : m) {
|
||||
strs.push_back(absl::StrCat(e.first, " -> ", e.second.DebugString()));
|
||||
}
|
||||
return absl::StrJoin(strs, "\n");
|
||||
}
|
||||
|
||||
TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_MatMul) {
|
||||
AttrBuilder a("MatMul");
|
||||
a.Set("transpose_a", true);
|
||||
a.Set("transpose_b", false);
|
||||
|
||||
AttrValueMap m;
|
||||
a.FillAttrValueMapWithoutDefaults(&m);
|
||||
// Only non-default value must end up in the map
|
||||
ASSERT_EQ(1, m.size()) << ToString(m);
|
||||
ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m);
|
||||
}
|
||||
|
||||
TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_UnknownOp) {
|
||||
AttrBuilder a("SomeUnknownOp");
|
||||
a.Set("transpose_a", true);
|
||||
a.Set("transpose_b", false);
|
||||
|
||||
AttrValueMap m;
|
||||
a.FillAttrValueMapWithoutDefaults(&m);
|
||||
// Only non-default value must end up in the map
|
||||
ASSERT_EQ(2, m.size()) << ToString(m);
|
||||
ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m);
|
||||
ASSERT_EQ(false, m["transpose_b"].b()) << ToString(m);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -343,9 +344,7 @@ Status EagerContext::FindDeviceByName(const string& name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EagerContext::ClearRunMetadata() {
|
||||
run_metadata_.Clear();
|
||||
}
|
||||
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
||||
|
||||
void EagerContext::StartStep() {
|
||||
mutex_lock ml(metadata_mu_);
|
||||
@ -386,6 +385,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
|
||||
eager::RegisterFunctionRequest request;
|
||||
request.set_context_id(GetContextId());
|
||||
*request.mutable_function_def() = fdef;
|
||||
StripDefaultAttributes(*OpRegistry::Global(),
|
||||
request.mutable_function_def()->mutable_node_def());
|
||||
std::vector<eager::RegisterFunctionResponse> responses(
|
||||
remote_contexts_.size());
|
||||
std::vector<Status> statuses(remote_contexts_.size());
|
||||
|
@ -658,7 +658,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
|
||||
remote_op->set_id(ctx->RemoteMgr()->NextOpId());
|
||||
remote_op->set_name(op->Name());
|
||||
|
||||
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
|
||||
op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
|
||||
remote_op->set_device(op->Device()->name());
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -163,6 +164,42 @@ Status RemoveNewDefaultAttrsFromGraphDef(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void StripDefaultAttributes(const OpRegistryInterface& op_registry,
|
||||
protobuf::RepeatedPtrField<NodeDef>* nodes) {
|
||||
for (int i = 0; i < nodes->size(); ++i) {
|
||||
NodeDef* node = nodes->Mutable(i);
|
||||
|
||||
const OpDef* op_def;
|
||||
const OpRegistrationData* op_reg_data = nullptr;
|
||||
Status s = op_registry.LookUp(node->op(), &op_reg_data);
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Ignoring encountered unknown operation "
|
||||
<< SummarizeNodeDef(*node)
|
||||
<< " when stripping default attributes. It is likely a function, "
|
||||
"in which case ignoring it is fine";
|
||||
continue;
|
||||
}
|
||||
op_def = &op_reg_data->op_def;
|
||||
|
||||
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
|
||||
if (attr_def.has_default_value()) {
|
||||
AttrValueMap* attrs = node->mutable_attr();
|
||||
const string& name = attr_def.name();
|
||||
auto iter = attrs->find(name);
|
||||
if (iter != attrs->end()) {
|
||||
const AttrValue& default_value = attr_def.default_value();
|
||||
// The "Fast*" version can return false negatives for very large
|
||||
// AttrValues containing Tensors. There should never be an attribute
|
||||
// whose default value is a tensor larger than 32MB.
|
||||
if (FastAreAttrValuesEqual(iter->second, default_value)) {
|
||||
attrs->erase(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OpsUsedByGraph(const GraphDef& graph_def,
|
||||
std::set<string>* ops_used_in_graph) {
|
||||
// Map function names to definitions.
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
@ -24,6 +25,7 @@ namespace tensorflow {
|
||||
|
||||
// Forward declare proto so that it's symbols can be removed from .so exports
|
||||
class GraphDef;
|
||||
class NodeDef;
|
||||
|
||||
// Produce a human-readable version of a GraphDef that is more concise
|
||||
// than a text-format proto.
|
||||
@ -97,6 +99,18 @@ Status RemoveNewDefaultAttrsFromGraphDef(
|
||||
const OpRegistryInterface& producer_op_registry,
|
||||
std::set<std::pair<string, string>>* op_attr_removed);
|
||||
|
||||
// Goes over the `nodes` and removes attributes that are set to their
|
||||
// default values according to op_registry.
|
||||
// If some node's definition is not found in the `op_registry`, this node is
|
||||
// simply skipped. In most cases, these nodes would be function calls.
|
||||
// If a stricter behavior is desired, one can add FunctionLibraryDefinition
|
||||
// argument to check for functions and their attributes.
|
||||
// This is obvious from signature, but as a warning, if `nodes` contain
|
||||
// nodes calling functions, e.g. PartitionCallOp or FunctionalIf, this
|
||||
// function does not "recurse" into them.
|
||||
void StripDefaultAttributes(const OpRegistryInterface& op_registry,
|
||||
protobuf::RepeatedPtrField<NodeDef>* nodes);
|
||||
|
||||
// Two functions that collect the ops used by a graph.
|
||||
//
|
||||
// This returns the ops used as a set of strings.
|
||||
|
@ -235,6 +235,40 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
|
||||
EXPECT_EQ(expected_removed, op_attr_removed);
|
||||
}
|
||||
|
||||
TEST(StripDefaultAttributesTest, DefaultStripped) {
|
||||
OpList op_list;
|
||||
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
|
||||
op_list.add_op()));
|
||||
OpListOpRegistry registry(&op_list);
|
||||
|
||||
GraphDef graph_def;
|
||||
// This adds the default attribute
|
||||
TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry)
|
||||
.Finalize(graph_def.add_node()));
|
||||
ASSERT_EQ(1, graph_def.node(0).attr_size());
|
||||
ASSERT_EQ(12, graph_def.node(0).attr().at("a").i());
|
||||
|
||||
StripDefaultAttributes(registry, graph_def.mutable_node());
|
||||
ASSERT_EQ(1, graph_def.node_size());
|
||||
ASSERT_EQ(0, graph_def.node(0).attr_size());
|
||||
}
|
||||
|
||||
TEST(StripDefaultAttributesTest, NonDefaultNotStripped) {
|
||||
OpList op_list;
|
||||
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
|
||||
op_list.add_op()));
|
||||
OpListOpRegistry registry(&op_list);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry)
|
||||
.Attr("a", 9)
|
||||
.Finalize(graph_def.add_node()));
|
||||
|
||||
GraphDef expected = graph_def;
|
||||
StripDefaultAttributes(registry, graph_def.mutable_node());
|
||||
TF_EXPECT_GRAPH_EQ(expected, graph_def);
|
||||
}
|
||||
|
||||
TEST(StrippedOpListForGraphTest, FlatTest) {
|
||||
// Make four ops
|
||||
OpList op_list;
|
||||
|
Loading…
Reference in New Issue
Block a user