[TF/XLA] Rollback of rollback of 313256383, with a UB fix.

PiperOrigin-RevId: 313319715
Change-Id: I4b73f95a228b3e6e4fed524492c9389a19629f02
This commit is contained in:
George Karpenkov 2020-05-26 20:42:09 -07:00 committed by TensorFlower Gardener
parent fa0a9c876a
commit 0dda89c61e
5 changed files with 45 additions and 7 deletions

View File

@ -350,6 +350,7 @@ cc_library(
":sharding_util",
":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/jit:xla_cluster_util",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/shape_inference.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
@ -571,6 +572,10 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
CopyGraph(*fbody->graph, graph.get());
bool is_inside_mustcompile = false;
TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
&is_inside_mustcompile);
// Performs a first function inlining pass before shape inference, since
// otherwise shape inference can't see inside functions and a comprehensive
// shape_map, including function ops, is needed to constant-propagate Shape
@ -622,6 +627,8 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
graph_optimizer_options.inline_multi_device_functions = true;
graph_optimizer_options.inline_impl_selection_group_functions = true;
graph_optimizer_options.inline_with_single_device_body_placer = true;
graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
/*device=*/nullptr, &graph, graph_optimizer_options);

View File

@ -42,7 +42,7 @@ void GraphOptimizer::Optimize(
const NodePredicate& cse_consider_fn, const NodePredicate& cf_consider_fn,
bool inline_multi_device_functions,
bool inline_impl_selection_group_functions,
bool inline_with_single_device_body_placer) {
bool inline_with_single_device_body_placer, bool ignore_noinline) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@ -116,6 +116,11 @@ void GraphOptimizer::Optimize(
.inline_impl_selection_group_functions = true;
}
if (ignore_noinline) {
expand_inline_opts.multi_device_options.ignore_noinline = true;
expand_inline_opts.native_options.ignore_noinline = true;
}
bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts);
if (was_mutated) {
DumpGraph("ExpandInlineFunctions", g);
@ -138,11 +143,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
const Device* device,
std::unique_ptr<Graph>* graph,
const Options& options) {
Optimize(runtime, env, device, graph, options.shape_map,
options.cse_consider_fn, options.cf_consider_fn,
options.inline_multi_device_functions,
options.inline_impl_selection_group_functions,
options.inline_with_single_device_body_placer);
Optimize(
runtime, env, device, graph, options.shape_map, options.cse_consider_fn,
options.cf_consider_fn, options.inline_multi_device_functions,
options.inline_impl_selection_group_functions,
options.inline_with_single_device_body_placer, options.ignore_noinline);
}
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,

View File

@ -58,6 +58,9 @@ class GraphOptimizer {
// If true all functions will be inlined with a single device function
// body placer strategy.
bool inline_with_single_device_body_placer = false;
// If true, the _noinline attribute on functions and callers is ignored.
bool ignore_noinline = false;
};
explicit GraphOptimizer(const OptimizerOptions& opts);
@ -81,7 +84,8 @@ class GraphOptimizer {
const NodePredicate& cf_consider_fn = nullptr,
bool inline_multi_device_functions = false,
bool inline_impl_selection_group_functions = false,
bool inline_with_single_device_body_placer = false);
bool inline_with_single_device_body_placer = false,
bool ignore_noinline = false);
const OptimizerOptions& options() { return opts_; }

View File

@ -355,6 +355,27 @@ class DefFunctionTest(test.TestCase):
self.assertAllClose([5.0, 5.0, 5.0], g())
self.assertAllClose(compiled_g(), g())
def testTensorListConcatGradNestedCompile(self):
@def_function.function(experimental_compile=True)
def f(x):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=2, element_shape=[3])
ta = ta.write(0, 2 * x)
ta = ta.write(1, 3 * x)
return ta.concat()
@def_function.function(experimental_compile=True)
def g():
x = constant_op.constant([3.14, 2.68, 7.69])
with backprop.GradientTape() as tape:
tape.watch(x)
y = f(x)
out = tape.gradient(y, x)
return out
self.assertAllClose([5.0, 5.0, 5.0], g())
def testCumsum(self):
@def_function.function(experimental_compile=True)