diff --git a/tensorflow/core/ops/tpu_functional_ops.cc b/tensorflow/core/ops/tpu_functional_ops.cc index aa81e8b24b5..789b4398689 100644 --- a/tensorflow/core/ops/tpu_functional_ops.cc +++ b/tensorflow/core/ops/tpu_functional_ops.cc @@ -26,6 +26,7 @@ REGISTER_OP("TPUPartitionedCall") .Attr("Tin: list(type) >= 0") .Attr("Tout: list(type) >= 0") .Attr("f: func") + .Attr("autotuner_thresh: int = 0") .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 41b260ffb21..bce52c64434 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -560,6 +560,13 @@ message ConfigProto { // If this option is set to true when a session is created, the // `RunOptions.output_partition_graphs` options must not be set. bool disable_output_partition_graphs = 14; + + // Minimum number of batches run through the XLA graph before XLA fusion + // autotuner is enabled. Default value of zero disables the autotuner. + // + // The XLA fusion autotuner can improve performance by executing a heuristic + // search on the compiler parameters. + int64 xla_fusion_autotuner_thresh = 15; }; Experimental experimental = 16; diff --git a/tensorflow/core/protobuf/tpu/compile_metadata.proto b/tensorflow/core/protobuf/tpu/compile_metadata.proto index 2a8607d3821..47304cb2039 100644 --- a/tensorflow/core/protobuf/tpu/compile_metadata.proto +++ b/tensorflow/core/protobuf/tpu/compile_metadata.proto @@ -86,4 +86,10 @@ message TPUCompileMetadataProto { // The location of step markers that XLA compile will instrument. xla.DebugOptions.StepMarkerLocation step_marker_location = 12; + + // Minimum number of batches run through the XLA graph before XLA fusion + // autotuner is enabled. Default value of zero disables the autotuner. + // The XLA fusion autotuner can improve performance by executing a heuristic + // search on the compiler parameters. + int64 xla_fusion_autotuner_thresh = 13; } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt index b34809b568a..cde90e76f5d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -81,6 +81,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "xla_fusion_autotuner_thresh" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt index db4ba6a54d4..2802a584421 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -204,6 +204,12 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "xla_fusion_autotuner_thresh" + number: 15 + label: LABEL_OPTIONAL + type: TYPE_INT64 + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index f7f3565e180..192bf689d42 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4282,7 +4282,7 @@ tf_module { } member_method { name: "TPUPartitionedCall" - argspec: "args=[\'args\', \'device_ordinal\', \'Tout\', \'f\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'args\', \'device_ordinal\', \'Tout\', \'f\', \'autotuner_thresh\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " } member_method { name: "TPUReplicateMetadata" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index f7f3565e180..192bf689d42 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4282,7 +4282,7 @@ tf_module { } member_method { name: "TPUPartitionedCall" - argspec: "args=[\'args\', \'device_ordinal\', \'Tout\', \'f\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'args\', \'device_ordinal\', \'Tout\', \'f\', \'autotuner_thresh\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " } member_method { name: "TPUReplicateMetadata"