From ae422505dbc32cd185983e311be05771c27c4b67 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 5 May 2020 16:07:18 -0700 Subject: [PATCH] Support XlaCompiler arguments of type kConstant in XlaCompilationCache These arguments types are generated in two ways: 1) XlaCompileOnDemandOp, based on the xla op kernel requirement, provides inputs with constant constraint during compilation. This allows the op to get compiled to HLO and then the actual HLO op can get executed on the required device. 2) _XlaCompileOp gets some of the inputs as "constants" variadic input. These are the constants coming from GuaranteeConst op. TPUCompileMlirOp doesn't support these constants yet. I also found a note in the old bridge to revisit this idea as it has limited utility. (Couldn't find the note now.) kConstant type arguments are supported by inlining the constant arguments after import and dropping those arguments from the signature. Mapping from the old set of arguments to the final list of arguments is set in the input_mapping of compilation result. Some of the compilation tests are disabled with this change as those were going through the old bridge and requires further changes to support those. PiperOrigin-RevId: 310040060 Change-Id: I121bb717e0376728bd727c7646116a04c77dcb53 --- .../compiler/jit/xla_compilation_cache.cc | 27 ++++----- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../tensorflow/utils/compile_mlir_util.cc | 56 +++++++++++++++++-- .../mlir/tensorflow/utils/compile_mlir_util.h | 2 +- .../utils/compile_mlir_util_test.cc | 6 +- tensorflow/compiler/tests/binary_ops_test.py | 21 +++++++ .../compiler/tests/gather_nd_op_test.py | 1 + tensorflow/compiler/tests/xla_ops_test.py | 6 +- 8 files changed, 97 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index c90e8dead76..62b0c0ab4cf 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" @@ -277,29 +278,25 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kParameter; - }); + bool are_args_supported = + absl::c_all_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kConstant || + arg.kind == XlaCompiler::Argument::kParameter; + }); const ConfigProto* config = ctx->function_library()->config_proto(); bool use_mlir = config && config->experimental().enable_mlir_bridge(); - // Use MLIR bridge if all the arguments are parameters. - // TODO(hinsu): Support other argument types instead of silently falling - // back to the XLA compiler. - if (!are_params || !use_mlir) { + // TODO(b/155596779): Understand the source of other argument types and + // depending on the source either support those or avoid these codepath. + if (!use_mlir || !are_args_supported) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } - absl::InlinedVector arg_shapes; - arg_shapes.reserve(args.size()); - for (const XlaCompiler::Argument& arg : args) { - arg_shapes.push_back(absl::get(arg.shape)); - } GraphDebugInfo debug_info; return CompileGraphToXlaHlo( - *graph, {arg_shapes.data(), arg_shapes.size()}, - options.device_type.type_string(), compile_options.use_tuple_arg, - *options.flib_def, debug_info, options.shape_representation_fn, result); + *graph, {args.data(), args.size()}, options.device_type.type_string(), + compile_options.use_tuple_arg, *options.flib_def, debug_info, + options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 3cc4272561b..f75e45c8b1f 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1114,6 +1114,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/stream_executor/lib", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + ":convert_tensor", ] # Prefer to link 'compile_mlir_util' library that also links necessary diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 784921393c7..2374687c920 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -17,10 +17,13 @@ limitations under the License. #include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -35,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -393,14 +397,47 @@ Status CompileSerializedMlirToXlaHlo( std::move(custom_legalization_passes)); } +// Rewrites the given module with specified args. For each of the constant args, +// it gets inlined in the "main' function and the corresponding argument is +// removed from the signature. +// Returns the original indices for the other arguments on success. +static StatusOr> RewriteWithArgs( + mlir::ModuleOp module, llvm::ArrayRef args) { + mlir::FuncOp main_fn = module.lookupSymbol("main"); + std::vector params; + + auto builder = mlir::OpBuilder(main_fn.getBody()); + std::vector args_to_erase; + for (int idx = 0; idx < args.size(); idx++) { + const XlaCompiler::Argument& xla_arg = args[idx]; + mlir::BlockArgument mlir_arg = main_fn.getArgument(idx); + if (xla_arg.kind != XlaCompiler::Argument::kConstant) { + params.push_back(idx); + continue; + } + + TF_ASSIGN_OR_RETURN(auto value_attr, + ConvertTensor(xla_arg.constant_value, &builder)); + // TODO(hinsu): Use the actual location of the constant. + auto constant = builder.create( + mlir::UnknownLoc::get(module.getContext()), value_attr); + mlir_arg.replaceAllUsesWith(constant); + args_to_erase.push_back(idx); + } + + for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx); + return params; +} + Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result, std::vector> custom_legalization_passes) { RegisterDialects(); + mlir::MLIRContext context; GraphImportConfig config; config.graph_as_function = true; @@ -408,10 +445,19 @@ Status CompileGraphToXlaHlo( ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); if (!module_or.ok()) return module_or.status(); - return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, - device_type, use_tuple_args, - shape_representation_fn, compilation_result, - std::move(custom_legalization_passes)); + mlir::ModuleOp module = module_or.ValueOrDie().get(); + TF_ASSIGN_OR_RETURN(std::vector remaining_params, + RewriteWithArgs(module, {args.data(), args.size()})); + llvm::SmallVector arg_shapes; + arg_shapes.reserve(args.size()); + for (unsigned idx : remaining_params) + arg_shapes.push_back(absl::get(args[idx].shape)); + + auto status = CompileMlirToXlaHlo( + module, arg_shapes, device_type, use_tuple_args, shape_representation_fn, + compilation_result, std::move(custom_legalization_passes)); + compilation_result->input_mapping = remaining_params; + return status; } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 0218efb83c6..24b60dcb346 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -71,7 +71,7 @@ Status CompileSerializedMlirToXlaHlo( // Same as the above but takes input as TensorFlow Graph. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef arg_shapes, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 118af434629..91640aff437 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -455,8 +455,12 @@ TEST(CompileGraphToXlaHlo, Basic) { test::graph::Retval(&graph, 0, arg); XlaCompiler::CompilationResult result; + XlaCompiler::Argument compiler_arg; + compiler_arg.kind = XlaCompiler::Argument::kParameter; + compiler_arg.shape = TensorShape(); + TF_ASSERT_OK( - CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT", + CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT", /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index d9721a3c8ac..db0f9e2fda8 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1102,6 +1102,8 @@ class BinaryOpsTest(xla_test.XLATestCase): x, expected=np.matmul(x, x.transpose([0, 1, 3, 2]))) + @test_util.disable_mlir_bridge( + "TODO(b/155097273): Handle complex dtype constants") def testExpandDims(self): for dtype in self.numeric_types: self._testBinary( @@ -1199,6 +1201,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.full([1, 1, 3, 5], 3., dtype=np.float32), expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32)) + @test_util.disable_mlir_bridge( + "TODO(b/155097273): Handle complex dtype constants") def testPad(self): for dtype, pad_type in itertools.product( self.numeric_types, [np.int32, np.int64]): @@ -1230,6 +1234,8 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Requires concatenate op support in MlirHloBuilder") def testSymmetricMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") for dtype in self.numeric_types: @@ -1261,6 +1267,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[0, 0], [0, 0]], dtype=np.int32), expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Requires concatenate op support in MlirHloBuilder") def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: @@ -1335,6 +1343,8 @@ class BinaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/155097273): Handle complex dtype constants") def testReshape(self): for dtype in self.numeric_types: self._testBinary( @@ -1414,6 +1424,7 @@ class BinaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + @test_util.disable_mlir_bridge("TODO(b/155097657): Debug incorrect answer") def testTile(self): for dtype in self.numeric_types: self._testBinary( @@ -1466,6 +1477,8 @@ class BinaryOpsTest(xla_test.XLATestCase): [1, 2]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/155097273): Handle complex dtype constants") def testTranspose(self): for dtype in self.numeric_types: self._testBinary( @@ -1484,6 +1497,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1, 3], [2, 4]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/155097273): Handle complex dtype constants") def testConjugateTranspose(self): for dtype in self.complex_types: self._testBinary( @@ -1521,6 +1536,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype), expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Define BroadcastArgs op in TF and const fold it") def testBroadcastArgs(self): self._testBinary(array_ops.broadcast_dynamic_shape, np.array([2, 3, 5], dtype=np.int32), @@ -1572,6 +1589,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([2, 1, 5], dtype=np.int32), expected=np.array([2, 3, 5], dtype=np.int32)) + @test_util.disable_mlir_bridge("Error handling") + def testBroadcastArgsError(self): with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, "Incompatible shapes"): self._testBinary(array_ops.broadcast_dynamic_shape, @@ -1579,6 +1598,8 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) + @test_util.disable_mlir_bridge( + "Requires BroadcastInDim method in MlirHloBuilder") def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 70377af6bdc..bfd79d816f5 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -47,6 +47,7 @@ class GatherNdTest(xla_test.XLATestCase): np.array([8, 1, 2, 3, 7, 5], dtype=dtype), np.array([[4], [4], [0]], np.int32))) + @test_util.disable_mlir_bridge("Error handling") def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): with self.session(): params = np.ones((3, 3), dtype=np.float32) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index df388c655d0..cae10ad51aa 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -195,7 +195,8 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([1, 2, 3], dtype=dtype),), expected=np.array([-1, -2, -3], dtype=dtype)) - @test_util.disable_mlir_bridge('Not supported yet') + @test_util.disable_mlir_bridge( + 'Requires XlaPad op shape inference to have static result types') def testPad(self): for dtype in self.numeric_types: @@ -308,6 +309,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + @test_util.disable_mlir_bridge('Not supported yet') def testDynamicSlice(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -320,6 +322,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): [[673, 674], [683, 684], [693, 694]]]), dtype=dtype)) + @test_util.disable_mlir_bridge('Not supported yet') def testDynamicSliceWithIncorrectStartIndicesShape(self): with self.session() as session: with self.test_scope(): @@ -333,6 +336,7 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): (r'start_indices must be a vector with length equal to input rank, ' r'but input rank is 3 and start_indices has shape \[2\].*')) + @test_util.disable_mlir_bridge('Not supported yet') def testDynamicSliceWithIncorrectSizeIndicesShape(self): with self.session() as session: with self.test_scope():