Update CompileGraphToXlaHlo to use llvm::ArrayRef<XlaArgument> instead of llvm::ArrayRef<const XlaArgument> (NFC).
This will allow for std::vector<XlaArgument> and llvm::SmallVector arg parameters in CompileGraphToXlaHlo to be used under different builds. PiperOrigin-RevId: 329757301 Change-Id: I1025f3106af21b2672e2157c3f5b80af07ef0d0f
This commit is contained in:
parent
bf94fa24d2
commit
0651d1ac60
tensorflow/compiler
jit
mlir/tensorflow/utils
tf2xla
@ -332,6 +332,7 @@ cc_library(
|
||||
":flags",
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/mlir:array_container_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
@ -293,9 +294,9 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
return CompileGraphToXlaHlo(
|
||||
*graph, {args.data(), args.size()}, options.device_type.type_string(),
|
||||
compile_options.use_tuple_arg, *options.flib_def, debug_info,
|
||||
options.shape_representation_fn, result);
|
||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
@ -434,7 +434,7 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
// removed from the signature. For resource args, their subtypes are populated.
|
||||
// Returns the original indices for the other arguments on success.
|
||||
static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
|
||||
mlir::ModuleOp module, llvm::ArrayRef<XlaArgument> args) {
|
||||
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
std::vector<int> params;
|
||||
|
||||
@ -495,7 +495,7 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
||||
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
|
@ -75,7 +75,7 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
// Same as the above but takes input as TensorFlow Graph.
|
||||
// TODO(lyandy): Allow populating of targets/control outputs.
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
||||
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
|
@ -352,6 +352,7 @@ cc_library(
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
"//tensorflow/compiler/mlir:array_container_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
@ -733,7 +734,7 @@ Status XlaCompiler::CompileFunction(
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
||||
std::move(*graph), {args.data(), args.size()},
|
||||
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||
options_.device_type.type_string(), options.use_tuple_arg,
|
||||
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
||||
result));
|
||||
|
Loading…
Reference in New Issue
Block a user