Update CompileGraphToXlaHlo to use llvm::ArrayRef<XlaArgument> instead of llvm::ArrayRef<const XlaArgument> (NFC).

This will allow for std::vector<XlaArgument> and llvm::SmallVector arg parameters in CompileGraphToXlaHlo to be used under different builds.

PiperOrigin-RevId: 329757301
Change-Id: I1025f3106af21b2672e2157c3f5b80af07ef0d0f
This commit is contained in:
Andy Ly 2020-09-02 11:51:35 -07:00 committed by TensorFlower Gardener
parent bf94fa24d2
commit 0651d1ac60
6 changed files with 11 additions and 7 deletions

View File

@ -332,6 +332,7 @@ cc_library(
":flags",
":xla_activity_listener",
":xla_activity_proto_cc",
"//tensorflow/compiler/mlir:array_container_utils",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -293,9 +294,9 @@ Status XlaCompilationCache::CompileSingleOp(
GraphDebugInfo debug_info;
return CompileGraphToXlaHlo(
*graph, {args.data(), args.size()}, options.device_type.type_string(),
compile_options.use_tuple_arg, *options.flib_def, debug_info,
options.shape_representation_fn, result);
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -434,7 +434,7 @@ Status CompileSerializedMlirToXlaHlo(
// removed from the signature. For resource args, their subtypes are populated.
// Returns the original indices for the other arguments on success.
static StatusOr<std::vector<int>> RewriteWithArgs(
mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
mlir::ModuleOp module, llvm::ArrayRef<XlaArgument> args) {
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
std::vector<int> params;
@ -495,7 +495,7 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
}
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args,
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,

View File

@ -75,7 +75,7 @@ Status CompileSerializedMlirToXlaHlo(
// Same as the above but takes input as TensorFlow Graph.
// TODO(lyandy): Allow populating of targets/control outputs.
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args,
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,

View File

@ -352,6 +352,7 @@ cc_library(
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/mlir:array_container_utils",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/shape_inference.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@ -733,7 +734,7 @@ Status XlaCompiler::CompileFunction(
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
std::move(*graph), {args.data(), args.size()},
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
options_.device_type.type_string(), options.use_tuple_arg,
*options_.flib_def, debug_info, options_.shape_representation_fn,
result));