[TF2XLA] [NFC] Make a function only used in a single library fileprivate
PiperOrigin-RevId: 332588570 Change-Id: Id65323ba198b1caed6b73b43fe047fe2ad659a6d
This commit is contained in:
		
							parent
							
								
									42afb8a459
								
							
						
					
					
						commit
						877527646a
					
				@ -1708,40 +1708,6 @@ std::atomic<int64>* GetPointerToFuel(int64 initial_value) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
}  // anonymous namespace
 | 
					}  // anonymous namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
 | 
					 | 
				
			||||||
                  RecursiveCompilabilityChecker::UncompilableNodesMap*
 | 
					 | 
				
			||||||
                      uncompilable_node_info) {
 | 
					 | 
				
			||||||
  Device* device = flr->device();
 | 
					 | 
				
			||||||
  const XlaOpRegistry::DeviceRegistration* registration;
 | 
					 | 
				
			||||||
  CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
 | 
					 | 
				
			||||||
                                            ®istration));
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // We can always *compile* resource operations, stateful RNGs and dummy ops,
 | 
					 | 
				
			||||||
  // even if we are sometimes unable to auto-cluster them.
 | 
					 | 
				
			||||||
  RecursiveCompilabilityChecker::OperationFilter op_filter;
 | 
					 | 
				
			||||||
  op_filter.allow_resource_ops_in_called_functions = true;
 | 
					 | 
				
			||||||
  op_filter.allow_stack_ops = true;
 | 
					 | 
				
			||||||
  op_filter.allow_tensor_array_ops = true;
 | 
					 | 
				
			||||||
  op_filter.allow_stateful_rng_ops = true;
 | 
					 | 
				
			||||||
  op_filter.allow_control_trigger = true;
 | 
					 | 
				
			||||||
  op_filter.allow_eliding_assert_and_checknumerics_ops = true;
 | 
					 | 
				
			||||||
  op_filter.allow_ops_producing_or_consuming_variant = true;
 | 
					 | 
				
			||||||
  op_filter.allow_slow_ops = true;
 | 
					 | 
				
			||||||
  op_filter.allow_inaccurate_ops = true;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  RecursiveCompilabilityChecker checker{
 | 
					 | 
				
			||||||
      op_filter, DeviceType{registration->compilation_device_name}};
 | 
					 | 
				
			||||||
  if (!uncompilable_node_info) {
 | 
					 | 
				
			||||||
    // We do not need uncompilable node info. Just return the result.
 | 
					 | 
				
			||||||
    return checker.IsCompilableCall(ndef, flr);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
 | 
					 | 
				
			||||||
      checker.FindUncompilableNodes(ndef, flr);
 | 
					 | 
				
			||||||
  uncompilable_node_info->swap(uncompilable_node_result);
 | 
					 | 
				
			||||||
  return uncompilable_node_info->empty();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Status MarkForCompilationPass::Run(
 | 
					Status MarkForCompilationPass::Run(
 | 
				
			||||||
    const GraphOptimizationPassOptions& options) {
 | 
					    const GraphOptimizationPassOptions& options) {
 | 
				
			||||||
  MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
 | 
					  MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
 | 
				
			||||||
 | 
				
			|||||||
@ -50,14 +50,6 @@ class MarkForCompilationPass : public GraphOptimizationPass {
 | 
				
			|||||||
  friend class MarkForCompilationPassTestHelper;
 | 
					  friend class MarkForCompilationPassTestHelper;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Returns true iff 'ndef' is a call to a function that is compilable.  A
 | 
					 | 
				
			||||||
// function is compilable iff every operator in the function body is
 | 
					 | 
				
			||||||
// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
 | 
					 | 
				
			||||||
// null, we will populate 'uncompilable_node_info' with uncompilable node info.
 | 
					 | 
				
			||||||
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
 | 
					 | 
				
			||||||
                  RecursiveCompilabilityChecker::UncompilableNodesMap*
 | 
					 | 
				
			||||||
                      uncompilable_node_info = nullptr);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
 | 
					absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace testing {
 | 
					namespace testing {
 | 
				
			||||||
 | 
				
			|||||||
@ -21,7 +21,6 @@ limitations under the License.
 | 
				
			|||||||
#include "tensorflow/compiler/jit/defs.h"
 | 
					#include "tensorflow/compiler/jit/defs.h"
 | 
				
			||||||
#include "tensorflow/compiler/jit/flags.h"
 | 
					#include "tensorflow/compiler/jit/flags.h"
 | 
				
			||||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
 | 
					#include "tensorflow/compiler/jit/kernels/xla_ops.h"
 | 
				
			||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
 | 
					 | 
				
			||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
 | 
					#include "tensorflow/compiler/tf2xla/const_analysis.h"
 | 
				
			||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 | 
					#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 | 
				
			||||||
#include "tensorflow/core/common_runtime/function.h"
 | 
					#include "tensorflow/core/common_runtime/function.h"
 | 
				
			||||||
@ -32,6 +31,44 @@ limitations under the License.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
namespace tensorflow {
 | 
					namespace tensorflow {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Returns true iff 'ndef' is a call to a function that is compilable.  A
 | 
				
			||||||
 | 
					// function is compilable iff every operator in the function body is
 | 
				
			||||||
 | 
					// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not
 | 
				
			||||||
 | 
					// null, we will populate 'uncompilable_node_info' with uncompilable node info.
 | 
				
			||||||
 | 
					static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
 | 
				
			||||||
 | 
					                         RecursiveCompilabilityChecker::UncompilableNodesMap*
 | 
				
			||||||
 | 
					                             uncompilable_node_info) {
 | 
				
			||||||
 | 
					  Device* device = flr->device();
 | 
				
			||||||
 | 
					  const XlaOpRegistry::DeviceRegistration* registration;
 | 
				
			||||||
 | 
					  CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
 | 
				
			||||||
 | 
					                                            ®istration));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // We can always *compile* resource operations, stateful RNGs and dummy ops,
 | 
				
			||||||
 | 
					  // even if we are sometimes unable to auto-cluster them.
 | 
				
			||||||
 | 
					  RecursiveCompilabilityChecker::OperationFilter op_filter;
 | 
				
			||||||
 | 
					  op_filter.allow_resource_ops_in_called_functions = true;
 | 
				
			||||||
 | 
					  op_filter.allow_stack_ops = true;
 | 
				
			||||||
 | 
					  op_filter.allow_tensor_array_ops = true;
 | 
				
			||||||
 | 
					  op_filter.allow_stateful_rng_ops = true;
 | 
				
			||||||
 | 
					  op_filter.allow_control_trigger = true;
 | 
				
			||||||
 | 
					  op_filter.allow_eliding_assert_and_checknumerics_ops = true;
 | 
				
			||||||
 | 
					  op_filter.allow_ops_producing_or_consuming_variant = true;
 | 
				
			||||||
 | 
					  op_filter.allow_slow_ops = true;
 | 
				
			||||||
 | 
					  op_filter.allow_inaccurate_ops = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  RecursiveCompilabilityChecker checker{
 | 
				
			||||||
 | 
					      op_filter, DeviceType{registration->compilation_device_name}};
 | 
				
			||||||
 | 
					  if (!uncompilable_node_info) {
 | 
				
			||||||
 | 
					    // We do not need uncompilable node info. Just return the result.
 | 
				
			||||||
 | 
					    return checker.IsCompilableCall(ndef, flr);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result =
 | 
				
			||||||
 | 
					      checker.FindUncompilableNodes(ndef, flr);
 | 
				
			||||||
 | 
					  uncompilable_node_info->swap(uncompilable_node_result);
 | 
				
			||||||
 | 
					  return uncompilable_node_info->empty();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool XlaKernelCreator::CanCreateKernel(
 | 
					bool XlaKernelCreator::CanCreateKernel(
 | 
				
			||||||
    const FunctionLibraryRuntime& flr,
 | 
					    const FunctionLibraryRuntime& flr,
 | 
				
			||||||
    const std::shared_ptr<const NodeProperties>& props) const {
 | 
					    const std::shared_ptr<const NodeProperties>& props) const {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user