Restrict propagation of internal attributes in functionalization

PiperOrigin-RevId: 319288137
Change-Id: Ide5d7f62674f5c8ebce23ca654d5833aa172fe2f
This commit is contained in:
Michael Gester 2020-07-01 13:38:54 -07:00 committed by TensorFlower Gardener
parent 28766652e6
commit 9cd5f20326
5 changed files with 22 additions and 12 deletions

View File

@ -655,6 +655,7 @@ cc_library(
"functionalize_cond.h", "functionalize_cond.h",
], ],
deps = [ deps = [
":frontend_attributes_util",
":functionalize_control_flow_util", ":functionalize_control_flow_util",
":tf2xla_util", ":tf2xla_util",
"//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/jit:union_find",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/common_runtime/shape_refiner.h"
@ -811,12 +812,14 @@ Status Conditional::BuildIfNode(Graph* graph,
<< PartialTensorShapeUtils::PartialShapeListString(output_shapes); << PartialTensorShapeUtils::PartialShapeListString(output_shapes);
builder.Attr("Tcond", DT_BOOL); builder.Attr("Tcond", DT_BOOL);
// Add all underscore attributes, these need to be propagated. // Add some internal attributes which need to be propagated.
for (const auto& attr : predicate_.node->def().attr()) { // TODO(b/160275126): attributes shouldn't be hard-coded here
const string& name(attr.first); for (const char* attr_name :
const AttrValue& value(attr.second); {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
if (absl::StartsWith(name, "_")) { kTpuReplicateAttrName}) {
builder.Attr(name, value); string attr_val;
if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) {
builder.Attr(attr_name, attr_val);
} }
} }
builder.Device(predicate_.node->assigned_device_name()); builder.Device(predicate_.node->assigned_device_name());

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -435,12 +436,14 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
builder.Attr("T", arg_types); builder.Attr("T", arg_types);
builder.Attr("cond", cond_name); builder.Attr("cond", cond_name);
builder.Attr("body", body_name); builder.Attr("body", body_name);
// Add all underscore attributes, these need to be propagated. // Add some internal attributes which need to be propagated.
for (const auto& attr : frame->loop_cond->def().attr()) { // TODO(b/160275126): attributes shouldn't be hard-coded here
const string& name(attr.first); for (const char* attr_name :
const AttrValue& value(attr.second); {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
if (absl::StartsWith(name, "_")) { kTpuReplicateAttrName}) {
builder.Attr(name, value); string attr_val;
if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
builder.Attr(attr_name, attr_val);
} }
} }
std::vector<NodeDefBuilder::NodeOut> inputs; std::vector<NodeDefBuilder::NodeOut> inputs;

View File

@ -302,6 +302,7 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
} // namespace } // namespace
const char kTpuReplicateAttrName[] = "_tpu_replicate";
const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
Status ValidateConfig(const tf2xla::Config& config) { Status ValidateConfig(const tf2xla::Config& config) {

View File

@ -212,6 +212,8 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g,
Status RewriteTensorListWithConstElement(Graph* g, Status RewriteTensorListWithConstElement(Graph* g,
FunctionLibraryDefinition* fld); FunctionLibraryDefinition* fld);
extern const char kTpuReplicateAttrName[];
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_