Restrict propagation of internal attributes in functionalization

PiperOrigin-RevId: 319288137
Change-Id: Ide5d7f62674f5c8ebce23ca654d5833aa172fe2f
This commit is contained in:
Michael Gester 2020-07-01 13:38:54 -07:00 committed by TensorFlower Gardener
parent 28766652e6
commit 9cd5f20326
5 changed files with 22 additions and 12 deletions

View File

@ -655,6 +655,7 @@ cc_library(
"functionalize_cond.h",
],
deps = [
":frontend_attributes_util",
":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
@ -811,12 +812,14 @@ Status Conditional::BuildIfNode(Graph* graph,
<< PartialTensorShapeUtils::PartialShapeListString(output_shapes);
builder.Attr("Tcond", DT_BOOL);
// Add all underscore attributes, these need to be propagated.
for (const auto& attr : predicate_.node->def().attr()) {
const string& name(attr.first);
const AttrValue& value(attr.second);
if (absl::StartsWith(name, "_")) {
builder.Attr(name, value);
// Add some internal attributes which need to be propagated.
// TODO(b/160275126): attributes shouldn't be hard-coded here
for (const char* attr_name :
{kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
kTpuReplicateAttrName}) {
string attr_val;
if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) {
builder.Attr(attr_name, attr_val);
}
}
builder.Device(predicate_.node->assigned_device_name());

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -435,12 +436,14 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
// Add all underscore attributes, these need to be propagated.
for (const auto& attr : frame->loop_cond->def().attr()) {
const string& name(attr.first);
const AttrValue& value(attr.second);
if (absl::StartsWith(name, "_")) {
builder.Attr(name, value);
// Add some internal attributes which need to be propagated.
// TODO(b/160275126): attributes shouldn't be hard-coded here
for (const char* attr_name :
{kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
kTpuReplicateAttrName}) {
string attr_val;
if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
builder.Attr(attr_name, attr_val);
}
}
std::vector<NodeDefBuilder::NodeOut> inputs;

View File

@ -302,6 +302,7 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
} // namespace
const char kTpuReplicateAttrName[] = "_tpu_replicate";
const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
Status ValidateConfig(const tf2xla::Config& config) {

View File

@ -212,6 +212,8 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g,
Status RewriteTensorListWithConstElement(Graph* g,
FunctionLibraryDefinition* fld);
extern const char kTpuReplicateAttrName[];
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_