Strip default attributes before sending a function or op to eager service
We don't need to add default attributes back because: - function instantiation already adds default attributes - conversion from eager AttrBuilder to NodeDef adds default attributes PiperOrigin-RevId: 270928211
This commit is contained in:
parent
1c751f7355
commit
80c49c0d9f
tensorflow/core
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value_util.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
@ -172,6 +173,37 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
|
||||||
|
const AttrValue& attr_value) {
|
||||||
|
// TODO(iga): It might make sense to augment OpRegistrationData with a
|
||||||
|
// {attr_name -> default_attr_value} FlatMap to avoid the loop here.
|
||||||
|
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
|
||||||
|
if (attr_def.name() == attr_name && attr_def.has_default_value() &&
|
||||||
|
AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
|
||||||
|
const OpDef* op_def = nullptr;
|
||||||
|
Status s = OpDefForOp(op_name().c_str(), &op_def);
|
||||||
|
|
||||||
|
for (auto& entry : encoded_attrs_) {
|
||||||
|
attr_tmp_.ParseFromString(entry.second);
|
||||||
|
// Insert the attr-value pair if we did not find the OpDef or if the value
|
||||||
|
// is different from default.
|
||||||
|
if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
|
||||||
|
m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
|
void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
|
||||||
const AttrValue& value) {
|
const AttrValue& value) {
|
||||||
encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
|
encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
|
||||||
|
@ -132,6 +132,13 @@ class AttrBuilder {
|
|||||||
// well as any default attr-value pairs from the associated op_def, if there
|
// well as any default attr-value pairs from the associated op_def, if there
|
||||||
// is one.
|
// is one.
|
||||||
void FillAttrValueMap(AttrValueMap* m) const;
|
void FillAttrValueMap(AttrValueMap* m) const;
|
||||||
|
|
||||||
|
// Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except
|
||||||
|
// when the value matches the default for this attr.
|
||||||
|
// More precisely, if the global op registry contains an OpDef for this op
|
||||||
|
// and if an attribute value is the same as the default (according to the
|
||||||
|
// OpDef), this attr-value pair is not added to `m`.
|
||||||
|
void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const;
|
||||||
const NodeDef& BuildNodeDef();
|
const NodeDef& BuildNodeDef();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -83,5 +83,38 @@ TEST(AttrTypeMap, CacheKey) {
|
|||||||
ASSERT_FALSE(cache_key == a.CacheKey("cpu:0"));
|
ASSERT_FALSE(cache_key == a.CacheKey("cpu:0"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string ToString(const AttrValueMap& m) {
|
||||||
|
std::vector<string> strs;
|
||||||
|
for (const auto& e : m) {
|
||||||
|
strs.push_back(absl::StrCat(e.first, " -> ", e.second.DebugString()));
|
||||||
|
}
|
||||||
|
return absl::StrJoin(strs, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_MatMul) {
|
||||||
|
AttrBuilder a("MatMul");
|
||||||
|
a.Set("transpose_a", true);
|
||||||
|
a.Set("transpose_b", false);
|
||||||
|
|
||||||
|
AttrValueMap m;
|
||||||
|
a.FillAttrValueMapWithoutDefaults(&m);
|
||||||
|
// Only non-default value must end up in the map
|
||||||
|
ASSERT_EQ(1, m.size()) << ToString(m);
|
||||||
|
ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(AttrBuilder, FillAttrValueMapWithoutDefaults_UnknownOp) {
|
||||||
|
AttrBuilder a("SomeUnknownOp");
|
||||||
|
a.Set("transpose_a", true);
|
||||||
|
a.Set("transpose_b", false);
|
||||||
|
|
||||||
|
AttrValueMap m;
|
||||||
|
a.FillAttrValueMapWithoutDefaults(&m);
|
||||||
|
// Only non-default value must end up in the map
|
||||||
|
ASSERT_EQ(2, m.size()) << ToString(m);
|
||||||
|
ASSERT_EQ(true, m["transpose_a"].b()) << ToString(m);
|
||||||
|
ASSERT_EQ(false, m["transpose_b"].b()) << ToString(m);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
|
#include "tensorflow/core/framework/graph_def_util.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
@ -343,9 +344,7 @@ Status EagerContext::FindDeviceByName(const string& name,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void EagerContext::ClearRunMetadata() {
|
void EagerContext::ClearRunMetadata() { run_metadata_.Clear(); }
|
||||||
run_metadata_.Clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void EagerContext::StartStep() {
|
void EagerContext::StartStep() {
|
||||||
mutex_lock ml(metadata_mu_);
|
mutex_lock ml(metadata_mu_);
|
||||||
@ -386,6 +385,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
|
|||||||
eager::RegisterFunctionRequest request;
|
eager::RegisterFunctionRequest request;
|
||||||
request.set_context_id(GetContextId());
|
request.set_context_id(GetContextId());
|
||||||
*request.mutable_function_def() = fdef;
|
*request.mutable_function_def() = fdef;
|
||||||
|
StripDefaultAttributes(*OpRegistry::Global(),
|
||||||
|
request.mutable_function_def()->mutable_node_def());
|
||||||
std::vector<eager::RegisterFunctionResponse> responses(
|
std::vector<eager::RegisterFunctionResponse> responses(
|
||||||
remote_contexts_.size());
|
remote_contexts_.size());
|
||||||
std::vector<Status> statuses(remote_contexts_.size());
|
std::vector<Status> statuses(remote_contexts_.size());
|
||||||
|
@ -658,7 +658,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
|
|||||||
remote_op->set_id(ctx->RemoteMgr()->NextOpId());
|
remote_op->set_id(ctx->RemoteMgr()->NextOpId());
|
||||||
remote_op->set_name(op->Name());
|
remote_op->set_name(op->Name());
|
||||||
|
|
||||||
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
|
op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs());
|
||||||
remote_op->set_device(op->Device()->name());
|
remote_op->set_device(op->Device()->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
@ -163,6 +164,42 @@ Status RemoveNewDefaultAttrsFromGraphDef(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void StripDefaultAttributes(const OpRegistryInterface& op_registry,
|
||||||
|
protobuf::RepeatedPtrField<NodeDef>* nodes) {
|
||||||
|
for (int i = 0; i < nodes->size(); ++i) {
|
||||||
|
NodeDef* node = nodes->Mutable(i);
|
||||||
|
|
||||||
|
const OpDef* op_def;
|
||||||
|
const OpRegistrationData* op_reg_data = nullptr;
|
||||||
|
Status s = op_registry.LookUp(node->op(), &op_reg_data);
|
||||||
|
if (!s.ok()) {
|
||||||
|
VLOG(1) << "Ignoring encountered unknown operation "
|
||||||
|
<< SummarizeNodeDef(*node)
|
||||||
|
<< " when stripping default attributes. It is likely a function, "
|
||||||
|
"in which case ignoring it is fine";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
op_def = &op_reg_data->op_def;
|
||||||
|
|
||||||
|
for (const OpDef::AttrDef& attr_def : op_def->attr()) {
|
||||||
|
if (attr_def.has_default_value()) {
|
||||||
|
AttrValueMap* attrs = node->mutable_attr();
|
||||||
|
const string& name = attr_def.name();
|
||||||
|
auto iter = attrs->find(name);
|
||||||
|
if (iter != attrs->end()) {
|
||||||
|
const AttrValue& default_value = attr_def.default_value();
|
||||||
|
// The "Fast*" version can return false negatives for very large
|
||||||
|
// AttrValues containing Tensors. There should never be an attribute
|
||||||
|
// whose default value is a tensor larger than 32MB.
|
||||||
|
if (FastAreAttrValuesEqual(iter->second, default_value)) {
|
||||||
|
attrs->erase(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void OpsUsedByGraph(const GraphDef& graph_def,
|
void OpsUsedByGraph(const GraphDef& graph_def,
|
||||||
std::set<string>* ops_used_in_graph) {
|
std::set<string>* ops_used_in_graph) {
|
||||||
// Map function names to definitions.
|
// Map function names to definitions.
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
@ -24,6 +25,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
// Forward declare proto so that it's symbols can be removed from .so exports
|
// Forward declare proto so that it's symbols can be removed from .so exports
|
||||||
class GraphDef;
|
class GraphDef;
|
||||||
|
class NodeDef;
|
||||||
|
|
||||||
// Produce a human-readable version of a GraphDef that is more concise
|
// Produce a human-readable version of a GraphDef that is more concise
|
||||||
// than a text-format proto.
|
// than a text-format proto.
|
||||||
@ -97,6 +99,18 @@ Status RemoveNewDefaultAttrsFromGraphDef(
|
|||||||
const OpRegistryInterface& producer_op_registry,
|
const OpRegistryInterface& producer_op_registry,
|
||||||
std::set<std::pair<string, string>>* op_attr_removed);
|
std::set<std::pair<string, string>>* op_attr_removed);
|
||||||
|
|
||||||
|
// Goes over the `nodes` and removes attributes that are set to their
|
||||||
|
// default values according to op_registry.
|
||||||
|
// If some node's definition is not found in the `op_registry`, this node is
|
||||||
|
// simply skipped. In most cases, these nodes would be function calls.
|
||||||
|
// If a stricter behavior is desired, one can add FunctionLibraryDefinition
|
||||||
|
// argument to check for functions and their attributes.
|
||||||
|
// This is obvious from signature, but as a warning, if `nodes` contain
|
||||||
|
// nodes calling functions, e.g. PartitionCallOp or FunctionalIf, this
|
||||||
|
// function does not "recurse" into them.
|
||||||
|
void StripDefaultAttributes(const OpRegistryInterface& op_registry,
|
||||||
|
protobuf::RepeatedPtrField<NodeDef>* nodes);
|
||||||
|
|
||||||
// Two functions that collect the ops used by a graph.
|
// Two functions that collect the ops used by a graph.
|
||||||
//
|
//
|
||||||
// This returns the ops used as a set of strings.
|
// This returns the ops used as a set of strings.
|
||||||
|
@ -235,6 +235,40 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
|
|||||||
EXPECT_EQ(expected_removed, op_attr_removed);
|
EXPECT_EQ(expected_removed, op_attr_removed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StripDefaultAttributesTest, DefaultStripped) {
|
||||||
|
OpList op_list;
|
||||||
|
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
|
||||||
|
op_list.add_op()));
|
||||||
|
OpListOpRegistry registry(&op_list);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
// This adds the default attribute
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry)
|
||||||
|
.Finalize(graph_def.add_node()));
|
||||||
|
ASSERT_EQ(1, graph_def.node(0).attr_size());
|
||||||
|
ASSERT_EQ(12, graph_def.node(0).attr().at("a").i());
|
||||||
|
|
||||||
|
StripDefaultAttributes(registry, graph_def.mutable_node());
|
||||||
|
ASSERT_EQ(1, graph_def.node_size());
|
||||||
|
ASSERT_EQ(0, graph_def.node(0).attr_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StripDefaultAttributesTest, NonDefaultNotStripped) {
|
||||||
|
OpList op_list;
|
||||||
|
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("OpName1").Attr("a: int = 12"),
|
||||||
|
op_list.add_op()));
|
||||||
|
OpListOpRegistry registry(&op_list);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("op1", "OpName1", ®istry)
|
||||||
|
.Attr("a", 9)
|
||||||
|
.Finalize(graph_def.add_node()));
|
||||||
|
|
||||||
|
GraphDef expected = graph_def;
|
||||||
|
StripDefaultAttributes(registry, graph_def.mutable_node());
|
||||||
|
TF_EXPECT_GRAPH_EQ(expected, graph_def);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StrippedOpListForGraphTest, FlatTest) {
|
TEST(StrippedOpListForGraphTest, FlatTest) {
|
||||||
// Make four ops
|
// Make four ops
|
||||||
OpList op_list;
|
OpList op_list;
|
||||||
|
Loading…
Reference in New Issue
Block a user