Support XlaCompiler arguments of type kConstant in XlaCompilationCache
These arguments types are generated in two ways: 1) XlaCompileOnDemandOp, based on the xla op kernel requirement, provides inputs with constant constraint during compilation. This allows the op to get compiled to HLO and then the actual HLO op can get executed on the required device. 2) _XlaCompileOp gets some of the inputs as "constants" variadic input. These are the constants coming from GuaranteeConst op. TPUCompileMlirOp doesn't support these constants yet. I also found a note in the old bridge to revisit this idea as it has limited utility. (Couldn't find the note now.) kConstant type arguments are supported by inlining the constant arguments after import and dropping those arguments from the signature. Mapping from the old set of arguments to the final list of arguments is set in the input_mapping of compilation result. Some of the compilation tests are disabled with this change as those were going through the old bridge and requires further changes to support those. PiperOrigin-RevId: 310040060 Change-Id: I121bb717e0376728bd727c7646116a04c77dcb53
This commit is contained in:
parent
9d15d75988
commit
ae422505db
|
@ -41,6 +41,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
@ -277,29 +278,25 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||
const NodeDef& node_def = ctx->op_kernel().def();
|
||||
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
|
||||
|
||||
bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
|
||||
return arg.kind == XlaCompiler::Argument::kParameter;
|
||||
});
|
||||
bool are_args_supported =
|
||||
absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
|
||||
return arg.kind == XlaCompiler::Argument::kConstant ||
|
||||
arg.kind == XlaCompiler::Argument::kParameter;
|
||||
});
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||
// Use MLIR bridge if all the arguments are parameters.
|
||||
// TODO(hinsu): Support other argument types instead of silently falling
|
||||
// back to the XLA compiler.
|
||||
if (!are_params || !use_mlir) {
|
||||
// TODO(b/155596779): Understand the source of other argument types and
|
||||
// depending on the source either support those or avoid these codepath.
|
||||
if (!use_mlir || !are_args_supported) {
|
||||
return compiler->CompileGraph(compile_options, node_def.name(),
|
||||
std::move(graph), args, result);
|
||||
}
|
||||
|
||||
absl::InlinedVector<TensorShape, 4> arg_shapes;
|
||||
arg_shapes.reserve(args.size());
|
||||
for (const XlaCompiler::Argument& arg : args) {
|
||||
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
return CompileGraphToXlaHlo(
|
||||
*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
*graph, {args.data(), args.size()}, options.device_type.type_string(),
|
||||
compile_options.use_tuple_arg, *options.flib_def, debug_info,
|
||||
options.shape_representation_fn, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
|
|
@ -1114,6 +1114,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
|||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
":convert_tensor",
|
||||
]
|
||||
|
||||
# Prefer to link 'compile_mlir_util' library that also links necessary
|
||||
|
|
|
@ -17,10 +17,13 @@ limitations under the License.
|
|||
|
||||
#include "absl/types/optional.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
|
@ -35,6 +38,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
|
@ -393,14 +397,47 @@ Status CompileSerializedMlirToXlaHlo(
|
|||
std::move(custom_legalization_passes));
|
||||
}
|
||||
|
||||
// Rewrites the given module with specified args. For each of the constant args,
|
||||
// it gets inlined in the "main' function and the corresponding argument is
|
||||
// removed from the signature.
|
||||
// Returns the original indices for the other arguments on success.
|
||||
static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<const XlaCompiler::Argument> args) {
|
||||
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
std::vector<int> params;
|
||||
|
||||
auto builder = mlir::OpBuilder(main_fn.getBody());
|
||||
std::vector<int> args_to_erase;
|
||||
for (int idx = 0; idx < args.size(); idx++) {
|
||||
const XlaCompiler::Argument& xla_arg = args[idx];
|
||||
mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
|
||||
if (xla_arg.kind != XlaCompiler::Argument::kConstant) {
|
||||
params.push_back(idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto value_attr,
|
||||
ConvertTensor(xla_arg.constant_value, &builder));
|
||||
// TODO(hinsu): Use the actual location of the constant.
|
||||
auto constant = builder.create<mlir::TF::ConstOp>(
|
||||
mlir::UnknownLoc::get(module.getContext()), value_attr);
|
||||
mlir_arg.replaceAllUsesWith(constant);
|
||||
args_to_erase.push_back(idx);
|
||||
}
|
||||
|
||||
for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
|
||||
return params;
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
RegisterDialects();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig config;
|
||||
config.graph_as_function = true;
|
||||
|
@ -408,10 +445,19 @@ Status CompileGraphToXlaHlo(
|
|||
ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
|
||||
if (!module_or.ok()) return module_or.status();
|
||||
|
||||
return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes,
|
||||
device_type, use_tuple_args,
|
||||
shape_representation_fn, compilation_result,
|
||||
std::move(custom_legalization_passes));
|
||||
mlir::ModuleOp module = module_or.ValueOrDie().get();
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
|
||||
RewriteWithArgs(module, {args.data(), args.size()}));
|
||||
llvm::SmallVector<TensorShape, 4> arg_shapes;
|
||||
arg_shapes.reserve(args.size());
|
||||
for (unsigned idx : remaining_params)
|
||||
arg_shapes.push_back(absl::get<TensorShape>(args[idx].shape));
|
||||
|
||||
auto status = CompileMlirToXlaHlo(
|
||||
module, arg_shapes, device_type, use_tuple_args, shape_representation_fn,
|
||||
compilation_result, std::move(custom_legalization_passes));
|
||||
compilation_result->input_mapping = remaining_params;
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -71,7 +71,7 @@ Status CompileSerializedMlirToXlaHlo(
|
|||
|
||||
// Same as the above but takes input as TensorFlow Graph.
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
|
|
|
@ -455,8 +455,12 @@ TEST(CompileGraphToXlaHlo, Basic) {
|
|||
test::graph::Retval(&graph, 0, arg);
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
XlaCompiler::Argument compiler_arg;
|
||||
compiler_arg.kind = XlaCompiler::Argument::kParameter;
|
||||
compiler_arg.shape = TensorShape();
|
||||
|
||||
TF_ASSERT_OK(
|
||||
CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT",
|
||||
CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
|
||||
/*shape_representation_fn=*/nullptr, &result));
|
||||
|
||||
|
|
|
@ -1102,6 +1102,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
x,
|
||||
expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/155097273): Handle complex dtype constants")
|
||||
def testExpandDims(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
|
@ -1199,6 +1201,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
np.full([1, 1, 3, 5], 3., dtype=np.float32),
|
||||
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/155097273): Handle complex dtype constants")
|
||||
def testPad(self):
|
||||
for dtype, pad_type in itertools.product(
|
||||
self.numeric_types, [np.int32, np.int64]):
|
||||
|
@ -1230,6 +1234,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
[7, 7, 7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires concatenate op support in MlirHloBuilder")
|
||||
def testSymmetricMirrorPad(self):
|
||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
|
||||
for dtype in self.numeric_types:
|
||||
|
@ -1261,6 +1267,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
np.array([[0, 0], [0, 0]], dtype=np.int32),
|
||||
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires concatenate op support in MlirHloBuilder")
|
||||
def testReflectMirrorPad(self):
|
||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
|
||||
for dtype in self.numeric_types:
|
||||
|
@ -1335,6 +1343,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/155097273): Handle complex dtype constants")
|
||||
def testReshape(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
|
@ -1414,6 +1424,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
],
|
||||
equality_test=self.ListsAreClose)
|
||||
|
||||
@test_util.disable_mlir_bridge("TODO(b/155097657): Debug incorrect answer")
|
||||
def testTile(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
|
@ -1466,6 +1477,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
[1, 2]],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/155097273): Handle complex dtype constants")
|
||||
def testTranspose(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testBinary(
|
||||
|
@ -1484,6 +1497,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
np.array([1, 0], dtype=np.int32),
|
||||
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/155097273): Handle complex dtype constants")
|
||||
def testConjugateTranspose(self):
|
||||
for dtype in self.complex_types:
|
||||
self._testBinary(
|
||||
|
@ -1521,6 +1536,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
|
||||
expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Define BroadcastArgs op in TF and const fold it")
|
||||
def testBroadcastArgs(self):
|
||||
self._testBinary(array_ops.broadcast_dynamic_shape,
|
||||
np.array([2, 3, 5], dtype=np.int32),
|
||||
|
@ -1572,6 +1589,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
np.array([2, 1, 5], dtype=np.int32),
|
||||
expected=np.array([2, 3, 5], dtype=np.int32))
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testBroadcastArgsError(self):
|
||||
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
||||
"Incompatible shapes"):
|
||||
self._testBinary(array_ops.broadcast_dynamic_shape,
|
||||
|
@ -1579,6 +1598,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||
np.array([4, 5, 6], dtype=np.int32),
|
||||
expected=None)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires BroadcastInDim method in MlirHloBuilder")
|
||||
def testBroadcastTo(self):
|
||||
for dtype in self.all_types:
|
||||
x = np.random.randint(0, high=100, size=[2, 3])
|
||||
|
|
|
@ -47,6 +47,7 @@ class GatherNdTest(xla_test.XLATestCase):
|
|||
np.array([8, 1, 2, 3, 7, 5], dtype=dtype),
|
||||
np.array([[4], [4], [0]], np.int32)))
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
|
||||
with self.session():
|
||||
params = np.ones((3, 3), dtype=np.float32)
|
||||
|
|
|
@ -195,7 +195,8 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
args=(np.array([1, 2, 3], dtype=dtype),),
|
||||
expected=np.array([-1, -2, -3], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
@test_util.disable_mlir_bridge(
|
||||
'Requires XlaPad op shape inference to have static result types')
|
||||
def testPad(self):
|
||||
for dtype in self.numeric_types:
|
||||
|
||||
|
@ -308,6 +309,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDynamicSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -320,6 +322,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
[[673, 674], [683, 684], [693, 694]]]),
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDynamicSliceWithIncorrectStartIndicesShape(self):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
|
@ -333,6 +336,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
(r'start_indices must be a vector with length equal to input rank, '
|
||||
r'but input rank is 3 and start_indices has shape \[2\].*'))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testDynamicSliceWithIncorrectSizeIndicesShape(self):
|
||||
with self.session() as session:
|
||||
with self.test_scope():
|
||||
|
|
Loading…
Reference in New Issue