diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 53f9b70c876..35dee1561f7 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -161,6 +161,9 @@ void AllocateAndParseFlags() { Flag("tf_xla_always_defer_compilation", &ops_flags->tf_xla_always_defer_compilation, ""), + Flag("tf_xla_noresolve_compile_time_constants", + &ops_flags->tf_xla_noresolve_compile_time_constants, + "Do not perform constant folding in XlaCompiler::CompileGraph"), Flag("tf_introduce_floating_point_jitter_to_tensors", setter_for_jitter_tensor_names, "", diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 9307874133c..baed7ad778d 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -91,6 +91,14 @@ struct XlaOpsCommonFlags { // If true, _XlaCompile always refuses to compile the cluster, which means the // XLA clusters always run in the TF executor. Defaults to false. bool tf_xla_always_defer_compilation; + + // If true, sets compile_options.resolve_compile_time_constants to false, + // which stops the bridge from using the HloEvaluator for constant resolution + // in XlaCompiler::CompileGraph. + // + // For some models, constant folding during compile graph experiences a + // non-linear blow up, which overshadows both compilation and execution. + bool tf_xla_noresolve_compile_time_constants; }; // Flags for the build_xla_ops pass. diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 0e8bce34fe3..edb19bc4750 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -326,11 +326,8 @@ static Status CompileToLocalExecutable( } XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; - // If we resolve constants we never emit them on the device, meaning that if - // they are needed by a following computation the host has to transfer - // them. Not resolving constants is expected to be faster than resolving - // constants. - compile_options.resolve_compile_time_constants = true; + compile_options.resolve_compile_time_constants = + !GetXlaOpsCommonFlags().tf_xla_noresolve_compile_time_constants; // Optimization: where possible, have the computation return a naked array // rather than a one-element tuple. compile_options.always_return_tuple = false;