diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f71331af0df..f44a0253464 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -338,6 +338,7 @@ cc_library( deps = [ ":xla_activity_listener", ":xla_activity_proto_cc", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 5540fee7276..5081df28a08 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" @@ -273,8 +276,30 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - return compiler->CompileGraph(compile_options, node_def.name(), - std::move(graph), args, result); + + bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kParameter; + }); + const ConfigProto* config = ctx->function_library()->config_proto(); + bool use_mlir = config && config->experimental().enable_mlir_bridge(); + // Use MLIR bridge if all the arguments are parameters. + // TODO(hinsu): Support other argument types instead of silently falling + // back to the XLA compiler. + if (!are_params || !use_mlir) { + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); + } + + absl::InlinedVector<TensorShape, 4> arg_shapes; + arg_shapes.reserve(args.size()); + for (const XlaCompiler::Argument& arg : args) { + arg_shapes.push_back(absl::get<TensorShape>(arg.shape)); + } + GraphDebugInfo debug_info; + return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, + compile_options.use_tuple_arg, + *options.flib_def, debug_info, + options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 83a0bda97d5..cd58cf31988 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -78,7 +78,9 @@ class XlaCompilationCache : public ResourceBase { xla::LocalExecutable** out_executable); // As above, but calls XlaCompiler::CompileSingleOp instead of - // XlaCompiler::CompileFunction. + // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto + // in OpKernelContext, then uses MLIR bridge for compilation instead of + // XlaCompiler, if possible. Status CompileSingleOp( const XlaCompiler::Options& options, absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 8ac33c906bb..7b088cad715 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1052,36 +1052,58 @@ gentbl( ], ) +COMPILE_MLIR_UTIL_DEPS = [ + ":bridge_logger", + ":convert_graphdef", + ":convert_type", + ":dump_mlir_util", + ":error_util", + ":mlir_roundtrip_flags", + ":tensorflow", + ":tensorflow_dialect_registration", + ":tensorflow_passes", + ":translate_utils", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", + "//tensorflow/compiler/mlir/xla:type_to_shape", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:logging", + "//tensorflow/stream_executor/lib", +] + +# Prefer to link 'compile_mlir_util' library that also links necessary +# TensorFlow passes to the pipeline. This library without tf passes is useful +# if the constant folding is not required on the TensorFlow dialect. For +# example, this is used in XLA ondemand compilation which compiles a single op +# at a time and doesn't require constant folding. Doing so helps avoid a +# circular dependency between c_api and tf passes. +# TODO(hinsu): Split out the constant folding hook and only exclude that in +# this target. cc_library( - name = "compile_mlir_util", + name = "compile_mlir_util_no_tf_dialect_passes", srcs = ["utils/compile_mlir_util.cc"], hdrs = ["utils/compile_mlir_util.h"], - deps = [ - ":bridge_logger", - ":convert_type", - ":dump_mlir_util", - ":error_util", - ":tensorflow", - ":tensorflow_dialect_registration", - ":tensorflow_passes", + deps = COMPILE_MLIR_UTIL_DEPS, +) + +cc_library( + name = "compile_mlir_util", + hdrs = ["utils/compile_mlir_util.h"], + deps = COMPILE_MLIR_UTIL_DEPS + [ + "compile_mlir_util_no_tf_dialect_passes", ":tf_dialect_passes", - ":translate_utils", - "//tensorflow/compiler/mlir/xla:hlo", - "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", - "//tensorflow/compiler/mlir/xla:type_to_shape", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:framework", - "//tensorflow/core/platform:logging", - "//tensorflow/stream_executor/lib", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", ], ) @@ -1096,8 +1118,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/stream_executor/lib", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 5394dbfb21a..3fd711b9ef8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -276,19 +278,11 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, return Status::OK(); } -Status CompileSerializedMlirToXlaHlo( - llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, +static Status CompileMlirToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { - RegisterDialects(); - mlir::MLIRContext mlir_context; - mlir::OwningModuleRef mlir_module; - - TF_RETURN_IF_ERROR( - ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - auto module_op = mlir_module.get(); - if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -309,9 +303,14 @@ Status CompileSerializedMlirToXlaHlo( GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); auto shape_representation_fn_no_fast_memory = - [shape_representation_fn](const TensorShape& shape, DataType dtype) { - return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); - }; + [shape_representation_fn](const TensorShape& shape, + DataType dtype) -> StatusOr<xla::Shape> { + if (shape_representation_fn) + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; // Compute all input shapes. TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, @@ -333,4 +332,38 @@ Status CompileSerializedMlirToXlaHlo( return Status::OK(); } +Status CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, + bool use_tuple_args, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + RegisterDialects(); + mlir::MLIRContext mlir_context; + mlir::OwningModuleRef mlir_module; + + TF_RETURN_IF_ERROR( + ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, + shape_representation_fn, compilation_result); +} + +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + RegisterDialects(); + mlir::MLIRContext context; + GraphImportConfig config; + config.graph_as_function = true; + auto module_or = + ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); + if (!module_or.ok()) return module_or.status(); + + return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, + use_tuple_args, shape_representation_fn, + compilation_result); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 19423adfe17..0dd4b8c5efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -53,6 +54,15 @@ Status CompileSerializedMlirToXlaHlo( bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); + +// Same as the above but takes input as TensorFlow Graph. +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 7db3d34a4ad..f65fcc1016d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -20,6 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -285,5 +288,41 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { status_or_hlo_module.ValueOrDie()->ToString()); } +// Verify that conversion from Graph to MLIR and empty shape representation +// function is successful. +TEST(CompileGraphToXlaHlo, Basic) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + Graph graph(OpRegistry::Global()); + + Tensor dummy_tensor(DT_FLOAT, TensorShape({1})); + test::FillValues<float>(&dummy_tensor, {-1.0}); + + Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); + test::graph::Retval(&graph, 0, arg); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(CompileGraphToXlaHlo( + graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, + GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); + + const xla::HloModuleConfig module_config( + result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + + string expected_hlo_module_string = R"(HloModule main.3 + +ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { + %Arg_0.1 = f32[] parameter(0) + ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1) +} + +)"; + + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 77cd3dc074c..d586b8178c5 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1354,6 +1354,26 @@ tf_xla_py_test( ], ) +# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it. +tf_xla_py_test( + name = "unary_mlir_ops_test", + size = "medium", + srcs = ["unary_mlir_ops_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fused_batchnorm_test", size = "medium", diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py new file mode 100644 index 00000000000..2b3dec3d5a7 --- /dev/null +++ b/tensorflow/compiler/tests/unary_mlir_ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class UnaryOpsTest(xla_test.XLATestCase): + """Test cases for unary operators.""" + + def __init__(self, method_name='runTest'): + super(UnaryOpsTest, self).__init__(method_name) + context.context().enable_mlir_bridge = True + + def _assertOpOutputMatchesExpected(self, + op, + inp, + expected, + equality_test=None, + rtol=1e-3, + atol=1e-5): + """Verifies that 'op' produces 'expected' when fed input 'inp' . + + Args: + op: operator to test + inp: numpy input array to use as input to 'op'. + expected: numpy array representing the expected output of 'op'. + equality_test: either None, or a function that tests two numpy arrays for + equality. If None, self.assertAllClose is used. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + with self.session() as session: + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name='a') + output = op(pinp) + result = session.run(output, {pinp: inp}) + if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + else: + equality_test(result, expected, rtol=rtol, atol=atol) + + def testNumericOps(self): + # TODO(hinsu): Enable complex types after fixing the failure in export to + # HLOModule. + for dtype in self.numeric_types - {np.int8, np.uint8} - self.complex_types: + self._assertOpOutputMatchesExpected( + math_ops.abs, + np.array([[2, -1]], dtype=dtype), + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 30ab95e370d..89ec2a0c7c3 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -154,6 +154,7 @@ genrule( "@icu//:icu4c/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", @@ -234,6 +235,7 @@ genrule( "@icu//:icu4j/main/shared/licenses/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE",