diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5992b45e209..748dce4fab3 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -332,6 +332,7 @@ cc_library( ":flags", ":xla_activity_listener", ":xla_activity_proto_cc", + "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index b5bb2fab0ed..fb184d62e27 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -293,9 +294,9 @@ Status XlaCompilationCache::CompileSingleOp( GraphDebugInfo debug_info; return CompileGraphToXlaHlo( - *graph, {args.data(), args.size()}, options.device_type.type_string(), - compile_options.use_tuple_arg, *options.flib_def, debug_info, - options.shape_representation_fn, result); + *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), + options.device_type.type_string(), compile_options.use_tuple_arg, + *options.flib_def, debug_info, options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 0dbda2e4f9c..6ddd0856112 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -434,7 +434,7 @@ Status CompileSerializedMlirToXlaHlo( // removed from the signature. For resource args, their subtypes are populated. // Returns the original indices for the other arguments on success. static StatusOr<std::vector<int>> RewriteWithArgs( - mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) { + mlir::ModuleOp module, llvm::ArrayRef<XlaArgument> args) { mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main"); std::vector<int> params; @@ -495,7 +495,7 @@ static StatusOr<std::vector<int>> RewriteWithArgs( } Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef<const XlaArgument> args, + const Graph& graph, llvm::ArrayRef<XlaArgument> args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 5c64a65ecbd..cba646d40a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -75,7 +75,7 @@ Status CompileSerializedMlirToXlaHlo( // Same as the above but takes input as TensorFlow Graph. // TODO(lyandy): Allow populating of targets/control outputs. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef<const XlaArgument> args, + const Graph& graph, llvm::ArrayRef<XlaArgument> args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e9bcbcc6d83..50ebc035404 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -352,6 +352,7 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", + "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 4d8f6f96811..b22dc05eaa1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -733,7 +734,7 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( - std::move(*graph), {args.data(), args.size()}, + std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args), options_.device_type.type_string(), options.use_tuple_arg, *options_.flib_def, debug_info, options_.shape_representation_fn, result));