Restrict propagation of internal attributes in functionalization
PiperOrigin-RevId: 319288137 Change-Id: Ide5d7f62674f5c8ebce23ca654d5833aa172fe2f
This commit is contained in:
parent
28766652e6
commit
9cd5f20326
tensorflow/compiler/tf2xla
@ -655,6 +655,7 @@ cc_library(
|
||||
"functionalize_cond.h",
|
||||
],
|
||||
deps = [
|
||||
":frontend_attributes_util",
|
||||
":functionalize_control_flow_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:union_find",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
@ -811,12 +812,14 @@ Status Conditional::BuildIfNode(Graph* graph,
|
||||
<< PartialTensorShapeUtils::PartialShapeListString(output_shapes);
|
||||
|
||||
builder.Attr("Tcond", DT_BOOL);
|
||||
// Add all underscore attributes, these need to be propagated.
|
||||
for (const auto& attr : predicate_.node->def().attr()) {
|
||||
const string& name(attr.first);
|
||||
const AttrValue& value(attr.second);
|
||||
if (absl::StartsWith(name, "_")) {
|
||||
builder.Attr(name, value);
|
||||
// Add some internal attributes which need to be propagated.
|
||||
// TODO(b/160275126): attributes shouldn't be hard-coded here
|
||||
for (const char* attr_name :
|
||||
{kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
|
||||
kTpuReplicateAttrName}) {
|
||||
string attr_val;
|
||||
if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) {
|
||||
builder.Attr(attr_name, attr_val);
|
||||
}
|
||||
}
|
||||
builder.Device(predicate_.node->assigned_device_name());
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -435,12 +436,14 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
|
||||
builder.Attr("T", arg_types);
|
||||
builder.Attr("cond", cond_name);
|
||||
builder.Attr("body", body_name);
|
||||
// Add all underscore attributes, these need to be propagated.
|
||||
for (const auto& attr : frame->loop_cond->def().attr()) {
|
||||
const string& name(attr.first);
|
||||
const AttrValue& value(attr.second);
|
||||
if (absl::StartsWith(name, "_")) {
|
||||
builder.Attr(name, value);
|
||||
// Add some internal attributes which need to be propagated.
|
||||
// TODO(b/160275126): attributes shouldn't be hard-coded here
|
||||
for (const char* attr_name :
|
||||
{kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
|
||||
kTpuReplicateAttrName}) {
|
||||
string attr_val;
|
||||
if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
|
||||
builder.Attr(attr_name, attr_val);
|
||||
}
|
||||
}
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs;
|
||||
|
@ -302,6 +302,7 @@ Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
|
||||
|
||||
} // namespace
|
||||
|
||||
const char kTpuReplicateAttrName[] = "_tpu_replicate";
|
||||
const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
|
||||
|
||||
Status ValidateConfig(const tf2xla::Config& config) {
|
||||
|
@ -212,6 +212,8 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
||||
Status RewriteTensorListWithConstElement(Graph* g,
|
||||
FunctionLibraryDefinition* fld);
|
||||
|
||||
extern const char kTpuReplicateAttrName[];
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
|
||||
|
Loading…
Reference in New Issue
Block a user