From 79b8c700d416c29ea886f6f27cb78f27db810e1a Mon Sep 17 00:00:00 2001 From: Smit Hinsu <hinsu@google.com> Date: Mon, 23 Mar 2020 14:47:18 -0700 Subject: [PATCH] Roll forward change to run use MLIR based TensorFlow compiler in XLA on demand compiler This splits compile_mlir_util lib into two parts. One with TF dialect passes that includes TF constant folding hook and other without it. Constant folding hook depends on the TF eager so splitting the library into two parts is required to avoid the circular dependency. PiperOrigin-RevId: 302522554 Change-Id: I4f8f0a8e745a9becff3845cc59950f181e6f415a --- tensorflow/compiler/jit/BUILD | 1 + .../compiler/jit/xla_compilation_cache.cc | 29 ++++++- .../compiler/jit/xla_compilation_cache.h | 4 +- tensorflow/compiler/mlir/tensorflow/BUILD | 76 ++++++++++++------ .../tensorflow/utils/compile_mlir_util.cc | 59 +++++++++++--- .../mlir/tensorflow/utils/compile_mlir_util.h | 10 +++ .../utils/compile_mlir_util_test.cc | 39 +++++++++ tensorflow/compiler/tests/BUILD | 20 +++++ .../compiler/tests/unary_mlir_ops_test.py | 80 +++++++++++++++++++ tensorflow/tools/lib_package/BUILD | 2 + 10 files changed, 278 insertions(+), 42 deletions(-) create mode 100644 tensorflow/compiler/tests/unary_mlir_ops_test.py diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f71331af0df..f44a0253464 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -338,6 +338,7 @@ cc_library( deps = [ ":xla_activity_listener", ":xla_activity_proto_cc", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 5540fee7276..5081df28a08 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" @@ -273,8 +276,30 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - return compiler->CompileGraph(compile_options, node_def.name(), - std::move(graph), args, result); + + bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kParameter; + }); + const ConfigProto* config = ctx->function_library()->config_proto(); + bool use_mlir = config && config->experimental().enable_mlir_bridge(); + // Use MLIR bridge if all the arguments are parameters. + // TODO(hinsu): Support other argument types instead of silently falling + // back to the XLA compiler. + if (!are_params || !use_mlir) { + return compiler->CompileGraph(compile_options, node_def.name(), + std::move(graph), args, result); + } + + absl::InlinedVector<TensorShape, 4> arg_shapes; + arg_shapes.reserve(args.size()); + for (const XlaCompiler::Argument& arg : args) { + arg_shapes.push_back(absl::get<TensorShape>(arg.shape)); + } + GraphDebugInfo debug_info; + return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, + compile_options.use_tuple_arg, + *options.flib_def, debug_info, + options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 83a0bda97d5..cd58cf31988 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -78,7 +78,9 @@ class XlaCompilationCache : public ResourceBase { xla::LocalExecutable** out_executable); // As above, but calls XlaCompiler::CompileSingleOp instead of - // XlaCompiler::CompileFunction. + // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto + // in OpKernelContext, then uses MLIR bridge for compilation instead of + // XlaCompiler, if possible. Status CompileSingleOp( const XlaCompiler::Options& options, absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 8ac33c906bb..7b088cad715 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1052,36 +1052,58 @@ gentbl( ], ) +COMPILE_MLIR_UTIL_DEPS = [ + ":bridge_logger", + ":convert_graphdef", + ":convert_type", + ":dump_mlir_util", + ":error_util", + ":mlir_roundtrip_flags", + ":tensorflow", + ":tensorflow_dialect_registration", + ":tensorflow_passes", + ":translate_utils", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", + "//tensorflow/compiler/mlir/xla:type_to_shape", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:logging", + "//tensorflow/stream_executor/lib", +] + +# Prefer to link 'compile_mlir_util' library that also links necessary +# TensorFlow passes to the pipeline. This library without tf passes is useful +# if the constant folding is not required on the TensorFlow dialect. For +# example, this is used in XLA ondemand compilation which compiles a single op +# at a time and doesn't require constant folding. Doing so helps avoid a +# circular dependency between c_api and tf passes. +# TODO(hinsu): Split out the constant folding hook and only exclude that in +# this target. cc_library( - name = "compile_mlir_util", + name = "compile_mlir_util_no_tf_dialect_passes", srcs = ["utils/compile_mlir_util.cc"], hdrs = ["utils/compile_mlir_util.h"], - deps = [ - ":bridge_logger", - ":convert_type", - ":dump_mlir_util", - ":error_util", - ":tensorflow", - ":tensorflow_dialect_registration", - ":tensorflow_passes", + deps = COMPILE_MLIR_UTIL_DEPS, +) + +cc_library( + name = "compile_mlir_util", + hdrs = ["utils/compile_mlir_util.h"], + deps = COMPILE_MLIR_UTIL_DEPS + [ + "compile_mlir_util_no_tf_dialect_passes", ":tf_dialect_passes", - ":translate_utils", - "//tensorflow/compiler/mlir/xla:hlo", - "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", - "//tensorflow/compiler/mlir/xla:type_to_shape", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:framework", - "//tensorflow/core/platform:logging", - "//tensorflow/stream_executor/lib", - "@llvm-project//llvm:support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", ], ) @@ -1096,8 +1118,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/stream_executor/lib", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 5394dbfb21a..3fd711b9ef8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -276,19 +278,11 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, return Status::OK(); } -Status CompileSerializedMlirToXlaHlo( - llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, +static Status CompileMlirToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { - RegisterDialects(); - mlir::MLIRContext mlir_context; - mlir::OwningModuleRef mlir_module; - - TF_RETURN_IF_ERROR( - ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - auto module_op = mlir_module.get(); - if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -309,9 +303,14 @@ Status CompileSerializedMlirToXlaHlo( GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); auto shape_representation_fn_no_fast_memory = - [shape_representation_fn](const TensorShape& shape, DataType dtype) { - return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); - }; + [shape_representation_fn](const TensorShape& shape, + DataType dtype) -> StatusOr<xla::Shape> { + if (shape_representation_fn) + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; // Compute all input shapes. TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, @@ -333,4 +332,38 @@ Status CompileSerializedMlirToXlaHlo( return Status::OK(); } +Status CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, + bool use_tuple_args, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + RegisterDialects(); + mlir::MLIRContext mlir_context; + mlir::OwningModuleRef mlir_module; + + TF_RETURN_IF_ERROR( + ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); + return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, + shape_representation_fn, compilation_result); +} + +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result) { + RegisterDialects(); + mlir::MLIRContext context; + GraphImportConfig config; + config.graph_as_function = true; + auto module_or = + ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); + if (!module_or.ok()) return module_or.status(); + + return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, + use_tuple_args, shape_representation_fn, + compilation_result); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 19423adfe17..0dd4b8c5efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { @@ -53,6 +54,15 @@ Status CompileSerializedMlirToXlaHlo( bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); + +// Same as the above but takes input as TensorFlow Graph. +Status CompileGraphToXlaHlo( + const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::CompilationResult* compilation_result); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 7db3d34a4ad..f65fcc1016d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -20,6 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -285,5 +288,41 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { status_or_hlo_module.ValueOrDie()->ToString()); } +// Verify that conversion from Graph to MLIR and empty shape representation +// function is successful. +TEST(CompileGraphToXlaHlo, Basic) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + Graph graph(OpRegistry::Global()); + + Tensor dummy_tensor(DT_FLOAT, TensorShape({1})); + test::FillValues<float>(&dummy_tensor, {-1.0}); + + Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); + test::graph::Retval(&graph, 0, arg); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(CompileGraphToXlaHlo( + graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, + GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); + + const xla::HloModuleConfig module_config( + result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + + string expected_hlo_module_string = R"(HloModule main.3 + +ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { + %Arg_0.1 = f32[] parameter(0) + ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1) +} + +)"; + + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 77cd3dc074c..d586b8178c5 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1354,6 +1354,26 @@ tf_xla_py_test( ], ) +# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it. +tf_xla_py_test( + name = "unary_mlir_ops_test", + size = "medium", + srcs = ["unary_mlir_ops_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fused_batchnorm_test", size = "medium", diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py new file mode 100644 index 00000000000..2b3dec3d5a7 --- /dev/null +++ b/tensorflow/compiler/tests/unary_mlir_ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class UnaryOpsTest(xla_test.XLATestCase): + """Test cases for unary operators.""" + + def __init__(self, method_name='runTest'): + super(UnaryOpsTest, self).__init__(method_name) + context.context().enable_mlir_bridge = True + + def _assertOpOutputMatchesExpected(self, + op, + inp, + expected, + equality_test=None, + rtol=1e-3, + atol=1e-5): + """Verifies that 'op' produces 'expected' when fed input 'inp' . + + Args: + op: operator to test + inp: numpy input array to use as input to 'op'. + expected: numpy array representing the expected output of 'op'. + equality_test: either None, or a function that tests two numpy arrays for + equality. If None, self.assertAllClose is used. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + with self.session() as session: + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name='a') + output = op(pinp) + result = session.run(output, {pinp: inp}) + if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + else: + equality_test(result, expected, rtol=rtol, atol=atol) + + def testNumericOps(self): + # TODO(hinsu): Enable complex types after fixing the failure in export to + # HLOModule. + for dtype in self.numeric_types - {np.int8, np.uint8} - self.complex_types: + self._assertOpOutputMatchesExpected( + math_ops.abs, + np.array([[2, -1]], dtype=dtype), + expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 30ab95e370d..89ec2a0c7c3 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -154,6 +154,7 @@ genrule( "@icu//:icu4c/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE", @@ -234,6 +235,7 @@ genrule( "@icu//:icu4j/main/shared/licenses/LICENSE", "@libjpeg_turbo//:LICENSE.md", "@llvm-project//llvm:LICENSE.TXT", + "@llvm-project//mlir:LICENSE.TXT", "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@local_config_tensorrt//:LICENSE",