[TF/XLA] Ignore _noinline inside force-compiled clusters
The code surrounding the handling of _noinline functions is very rarely hit, and as a result is not well tested. For now, the better approach is to follow a more well-lit codepath and try to minimize the use of _noinline functions. As a starting point, inline blocks even with _noinline inside force-compiled blocks. PiperOrigin-RevId: 313280139 Change-Id: I9f2d9b95d4bfe15eb2acea2a3d101b82355c14d5
This commit is contained in:
parent
1de7105aeb
commit
53037dcd66
@ -350,7 +350,6 @@ cc_library(
|
|||||||
":sharding_util",
|
":sharding_util",
|
||||||
":side_effect_util",
|
":side_effect_util",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
"//tensorflow/compiler/jit:common",
|
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:shape_inference",
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||||
@ -572,10 +571,6 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
|||||||
std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
|
std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
|
||||||
CopyGraph(*fbody->graph, graph.get());
|
CopyGraph(*fbody->graph, graph.get());
|
||||||
|
|
||||||
bool is_inside_mustcompile;
|
|
||||||
TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
|
|
||||||
&is_inside_mustcompile);
|
|
||||||
|
|
||||||
// Performs a first function inlining pass before shape inference, since
|
// Performs a first function inlining pass before shape inference, since
|
||||||
// otherwise shape inference can't see inside functions and a comprehensive
|
// otherwise shape inference can't see inside functions and a comprehensive
|
||||||
// shape_map, including function ops, is needed to constant-propagate Shape
|
// shape_map, including function ops, is needed to constant-propagate Shape
|
||||||
@ -627,8 +622,6 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
|||||||
graph_optimizer_options.inline_multi_device_functions = true;
|
graph_optimizer_options.inline_multi_device_functions = true;
|
||||||
graph_optimizer_options.inline_impl_selection_group_functions = true;
|
graph_optimizer_options.inline_impl_selection_group_functions = true;
|
||||||
graph_optimizer_options.inline_with_single_device_body_placer = true;
|
graph_optimizer_options.inline_with_single_device_body_placer = true;
|
||||||
graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
|
|
||||||
|
|
||||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ void GraphOptimizer::Optimize(
|
|||||||
const NodePredicate& cse_consider_fn, const NodePredicate& cf_consider_fn,
|
const NodePredicate& cse_consider_fn, const NodePredicate& cf_consider_fn,
|
||||||
bool inline_multi_device_functions,
|
bool inline_multi_device_functions,
|
||||||
bool inline_impl_selection_group_functions,
|
bool inline_impl_selection_group_functions,
|
||||||
bool inline_with_single_device_body_placer, bool ignore_noinline) {
|
bool inline_with_single_device_body_placer) {
|
||||||
Graph* g = graph->get();
|
Graph* g = graph->get();
|
||||||
DumpGraph("Initial", g);
|
DumpGraph("Initial", g);
|
||||||
|
|
||||||
@ -116,11 +116,6 @@ void GraphOptimizer::Optimize(
|
|||||||
.inline_impl_selection_group_functions = true;
|
.inline_impl_selection_group_functions = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ignore_noinline) {
|
|
||||||
expand_inline_opts.multi_device_options.ignore_noinline = true;
|
|
||||||
expand_inline_opts.native_options.ignore_noinline = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts);
|
bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts);
|
||||||
if (was_mutated) {
|
if (was_mutated) {
|
||||||
DumpGraph("ExpandInlineFunctions", g);
|
DumpGraph("ExpandInlineFunctions", g);
|
||||||
@ -143,11 +138,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
|||||||
const Device* device,
|
const Device* device,
|
||||||
std::unique_ptr<Graph>* graph,
|
std::unique_ptr<Graph>* graph,
|
||||||
const Options& options) {
|
const Options& options) {
|
||||||
Optimize(
|
Optimize(runtime, env, device, graph, options.shape_map,
|
||||||
runtime, env, device, graph, options.shape_map, options.cse_consider_fn,
|
options.cse_consider_fn, options.cf_consider_fn,
|
||||||
options.cf_consider_fn, options.inline_multi_device_functions,
|
options.inline_multi_device_functions,
|
||||||
options.inline_impl_selection_group_functions,
|
options.inline_impl_selection_group_functions,
|
||||||
options.inline_with_single_device_body_placer, options.ignore_noinline);
|
options.inline_with_single_device_body_placer);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
|
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
|
||||||
|
@ -58,9 +58,6 @@ class GraphOptimizer {
|
|||||||
// If true all functions will be inlined with a single device function
|
// If true all functions will be inlined with a single device function
|
||||||
// body placer strategy.
|
// body placer strategy.
|
||||||
bool inline_with_single_device_body_placer = false;
|
bool inline_with_single_device_body_placer = false;
|
||||||
|
|
||||||
// If true, the _noinline attribute on functions and callers is ignored.
|
|
||||||
bool ignore_noinline = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit GraphOptimizer(const OptimizerOptions& opts);
|
explicit GraphOptimizer(const OptimizerOptions& opts);
|
||||||
@ -84,8 +81,7 @@ class GraphOptimizer {
|
|||||||
const NodePredicate& cf_consider_fn = nullptr,
|
const NodePredicate& cf_consider_fn = nullptr,
|
||||||
bool inline_multi_device_functions = false,
|
bool inline_multi_device_functions = false,
|
||||||
bool inline_impl_selection_group_functions = false,
|
bool inline_impl_selection_group_functions = false,
|
||||||
bool inline_with_single_device_body_placer = false,
|
bool inline_with_single_device_body_placer = false);
|
||||||
bool ignore_noinline = false);
|
|
||||||
|
|
||||||
const OptimizerOptions& options() { return opts_; }
|
const OptimizerOptions& options() { return opts_; }
|
||||||
|
|
||||||
|
@ -355,27 +355,6 @@ class DefFunctionTest(test.TestCase):
|
|||||||
self.assertAllClose([5.0, 5.0, 5.0], g())
|
self.assertAllClose([5.0, 5.0, 5.0], g())
|
||||||
self.assertAllClose(compiled_g(), g())
|
self.assertAllClose(compiled_g(), g())
|
||||||
|
|
||||||
def testTensorListConcatGradNestedCompile(self):
|
|
||||||
|
|
||||||
@def_function.function(experimental_compile=True)
|
|
||||||
def f(x):
|
|
||||||
ta = tensor_array_ops.TensorArray(
|
|
||||||
dtype=dtypes.float32, size=2, element_shape=[3])
|
|
||||||
ta = ta.write(0, 2 * x)
|
|
||||||
ta = ta.write(1, 3 * x)
|
|
||||||
return ta.concat()
|
|
||||||
|
|
||||||
@def_function.function(experimental_compile=True)
|
|
||||||
def g():
|
|
||||||
x = constant_op.constant([3.14, 2.68, 7.69])
|
|
||||||
with backprop.GradientTape() as tape:
|
|
||||||
tape.watch(x)
|
|
||||||
y = f(x)
|
|
||||||
out = tape.gradient(y, x)
|
|
||||||
return out
|
|
||||||
|
|
||||||
self.assertAllClose([5.0, 5.0, 5.0], g())
|
|
||||||
|
|
||||||
def testCumsum(self):
|
def testCumsum(self):
|
||||||
|
|
||||||
@def_function.function(experimental_compile=True)
|
@def_function.function(experimental_compile=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user