Update CompileGraphToXlaHlo to use llvm::ArrayRef<XlaArgument> instead of llvm::ArrayRef<const XlaArgument> (NFC).
This will allow for std::vector<XlaArgument> and llvm::SmallVector arg parameters in CompileGraphToXlaHlo to be used under different builds. PiperOrigin-RevId: 329757301 Change-Id: I1025f3106af21b2672e2157c3f5b80af07ef0d0f
This commit is contained in:
parent
bf94fa24d2
commit
0651d1ac60
@ -332,6 +332,7 @@ cc_library(
|
|||||||
":flags",
|
":flags",
|
||||||
":xla_activity_listener",
|
":xla_activity_listener",
|
||||||
":xla_activity_proto_cc",
|
":xla_activity_proto_cc",
|
||||||
|
"//tensorflow/compiler/mlir:array_container_utils",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
|
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
@ -293,9 +294,9 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
return CompileGraphToXlaHlo(
|
return CompileGraphToXlaHlo(
|
||||||
*graph, {args.data(), args.size()}, options.device_type.type_string(),
|
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||||
compile_options.use_tuple_arg, *options.flib_def, debug_info,
|
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||||
options.shape_representation_fn, result);
|
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||||
};
|
};
|
||||||
return CompileImpl(options, name, args, compile_op,
|
return CompileImpl(options, name, args, compile_op,
|
||||||
/*compile_threshold=*/absl::nullopt,
|
/*compile_threshold=*/absl::nullopt,
|
||||||
|
@ -434,7 +434,7 @@ Status CompileSerializedMlirToXlaHlo(
|
|||||||
// removed from the signature. For resource args, their subtypes are populated.
|
// removed from the signature. For resource args, their subtypes are populated.
|
||||||
// Returns the original indices for the other arguments on success.
|
// Returns the original indices for the other arguments on success.
|
||||||
static StatusOr<std::vector<int>> RewriteWithArgs(
|
static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||||
mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
|
mlir::ModuleOp module, llvm::ArrayRef<XlaArgument> args) {
|
||||||
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||||
std::vector<int> params;
|
std::vector<int> params;
|
||||||
|
|
||||||
@ -495,7 +495,7 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
@ -75,7 +75,7 @@ Status CompileSerializedMlirToXlaHlo(
|
|||||||
// Same as the above but takes input as TensorFlow Graph.
|
// Same as the above but takes input as TensorFlow Graph.
|
||||||
// TODO(lyandy): Allow populating of targets/control outputs.
|
// TODO(lyandy): Allow populating of targets/control outputs.
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
@ -352,6 +352,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/jit:common",
|
"//tensorflow/compiler/jit:common",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:shape_inference",
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
|
"//tensorflow/compiler/mlir:array_container_utils",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
"//tensorflow/compiler/xla:protobuf_util",
|
"//tensorflow/compiler/xla:protobuf_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
|
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||||
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
#include "tensorflow/compiler/tf2xla/rearrange_function_argument.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
@ -733,7 +734,7 @@ Status XlaCompiler::CompileFunction(
|
|||||||
VLOG(1) << "Using MLIR bridge";
|
VLOG(1) << "Using MLIR bridge";
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
||||||
std::move(*graph), {args.data(), args.size()},
|
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||||
options_.device_type.type_string(), options.use_tuple_arg,
|
options_.device_type.type_string(), options.use_tuple_arg,
|
||||||
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
||||||
result));
|
result));
|
||||||
|
Loading…
Reference in New Issue
Block a user