[TF/XLA] Ignore _noinline inside force-compiled clusters
The code surrounding the handling of _noinline functions is very rarely hit, and as a result is not well tested. For now, the better approach is to follow a more well-lit codepath and try to minimize the use of _noinline functions. As a starting point, inline blocks even with _noinline inside force-compiled blocks. PiperOrigin-RevId: 313280139 Change-Id: I9f2d9b95d4bfe15eb2acea2a3d101b82355c14d5
This commit is contained in:
parent
1de7105aeb
commit
53037dcd66
@ -350,7 +350,6 @@ cc_library(
|
||||
":sharding_util",
|
||||
":side_effect_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||
@ -572,10 +571,6 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
|
||||
CopyGraph(*fbody->graph, graph.get());
|
||||
|
||||
bool is_inside_mustcompile;
|
||||
TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
|
||||
&is_inside_mustcompile);
|
||||
|
||||
// Performs a first function inlining pass before shape inference, since
|
||||
// otherwise shape inference can't see inside functions and a comprehensive
|
||||
// shape_map, including function ops, is needed to constant-propagate Shape
|
||||
@ -627,8 +622,6 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
graph_optimizer_options.inline_multi_device_functions = true;
|
||||
graph_optimizer_options.inline_impl_selection_group_functions = true;
|
||||
graph_optimizer_options.inline_with_single_device_body_placer = true;
|
||||
graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
|
||||
|
||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||
|
||||
|
@ -42,7 +42,7 @@ void GraphOptimizer::Optimize(
|
||||
const NodePredicate& cse_consider_fn, const NodePredicate& cf_consider_fn,
|
||||
bool inline_multi_device_functions,
|
||||
bool inline_impl_selection_group_functions,
|
||||
bool inline_with_single_device_body_placer, bool ignore_noinline) {
|
||||
bool inline_with_single_device_body_placer) {
|
||||
Graph* g = graph->get();
|
||||
DumpGraph("Initial", g);
|
||||
|
||||
@ -116,11 +116,6 @@ void GraphOptimizer::Optimize(
|
||||
.inline_impl_selection_group_functions = true;
|
||||
}
|
||||
|
||||
if (ignore_noinline) {
|
||||
expand_inline_opts.multi_device_options.ignore_noinline = true;
|
||||
expand_inline_opts.native_options.ignore_noinline = true;
|
||||
}
|
||||
|
||||
bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts);
|
||||
if (was_mutated) {
|
||||
DumpGraph("ExpandInlineFunctions", g);
|
||||
@ -143,11 +138,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
|
||||
const Device* device,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
const Options& options) {
|
||||
Optimize(
|
||||
runtime, env, device, graph, options.shape_map, options.cse_consider_fn,
|
||||
options.cf_consider_fn, options.inline_multi_device_functions,
|
||||
options.inline_impl_selection_group_functions,
|
||||
options.inline_with_single_device_body_placer, options.ignore_noinline);
|
||||
Optimize(runtime, env, device, graph, options.shape_map,
|
||||
options.cse_consider_fn, options.cf_consider_fn,
|
||||
options.inline_multi_device_functions,
|
||||
options.inline_impl_selection_group_functions,
|
||||
options.inline_with_single_device_body_placer);
|
||||
}
|
||||
|
||||
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
|
||||
|
@ -58,9 +58,6 @@ class GraphOptimizer {
|
||||
// If true all functions will be inlined with a single device function
|
||||
// body placer strategy.
|
||||
bool inline_with_single_device_body_placer = false;
|
||||
|
||||
// If true, the _noinline attribute on functions and callers is ignored.
|
||||
bool ignore_noinline = false;
|
||||
};
|
||||
|
||||
explicit GraphOptimizer(const OptimizerOptions& opts);
|
||||
@ -84,8 +81,7 @@ class GraphOptimizer {
|
||||
const NodePredicate& cf_consider_fn = nullptr,
|
||||
bool inline_multi_device_functions = false,
|
||||
bool inline_impl_selection_group_functions = false,
|
||||
bool inline_with_single_device_body_placer = false,
|
||||
bool ignore_noinline = false);
|
||||
bool inline_with_single_device_body_placer = false);
|
||||
|
||||
const OptimizerOptions& options() { return opts_; }
|
||||
|
||||
|
@ -355,27 +355,6 @@ class DefFunctionTest(test.TestCase):
|
||||
self.assertAllClose([5.0, 5.0, 5.0], g())
|
||||
self.assertAllClose(compiled_g(), g())
|
||||
|
||||
def testTensorListConcatGradNestedCompile(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(x):
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, size=2, element_shape=[3])
|
||||
ta = ta.write(0, 2 * x)
|
||||
ta = ta.write(1, 3 * x)
|
||||
return ta.concat()
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def g():
|
||||
x = constant_op.constant([3.14, 2.68, 7.69])
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = f(x)
|
||||
out = tape.gradient(y, x)
|
||||
return out
|
||||
|
||||
self.assertAllClose([5.0, 5.0, 5.0], g())
|
||||
|
||||
def testCumsum(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
|
Loading…
Reference in New Issue
Block a user