Roll forward change to run use MLIR based TensorFlow compiler in XLA on demand compiler

This splits compile_mlir_util lib into two parts. One with TF dialect passes that includes TF constant folding hook and other without it. Constant folding hook depends on the TF eager so splitting the library into two parts is required to avoid the circular dependency.

PiperOrigin-RevId: 302522554
Change-Id: I4f8f0a8e745a9becff3845cc59950f181e6f415a
This commit is contained in:
Smit Hinsu 2020-03-23 14:47:18 -07:00 committed by TensorFlower Gardener
parent aba2ca4603
commit 79b8c700d4
10 changed files with 278 additions and 42 deletions

View File

@ -338,6 +338,7 @@ cc_library(
deps = [
":xla_activity_listener",
":xla_activity_proto_cc",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
@ -40,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
@ -273,8 +276,30 @@ Status XlaCompilationCache::CompileSingleOp(
const NodeDef& node_def = ctx->op_kernel().def();
TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
bool are_params = absl::c_all_of(args, [](const XlaCompiler::Argument arg) {
return arg.kind == XlaCompiler::Argument::kParameter;
});
const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge();
// Use MLIR bridge if all the arguments are parameters.
// TODO(hinsu): Support other argument types instead of silently falling
// back to the XLA compiler.
if (!are_params || !use_mlir) {
return compiler->CompileGraph(compile_options, node_def.name(),
std::move(graph), args, result);
}
absl::InlinedVector<TensorShape, 4> arg_shapes;
arg_shapes.reserve(args.size());
for (const XlaCompiler::Argument& arg : args) {
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
}
GraphDebugInfo debug_info;
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
compile_options.use_tuple_arg,
*options.flib_def, debug_info,
options.shape_representation_fn, result);
};
return CompileImpl(options, name, args, compile_op,
/*compile_threshold=*/absl::nullopt,

View File

@ -78,7 +78,9 @@ class XlaCompilationCache : public ResourceBase {
xla::LocalExecutable** out_executable);
// As above, but calls XlaCompiler::CompileSingleOp instead of
// XlaCompiler::CompileFunction.
// XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto
// in OpKernelContext, then uses MLIR bridge for compilation instead of
// XlaCompiler, if possible.
Status CompileSingleOp(
const XlaCompiler::Options& options,
absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,

View File

@ -1052,36 +1052,58 @@ gentbl(
],
)
COMPILE_MLIR_UTIL_DEPS = [
":bridge_logger",
":convert_graphdef",
":convert_type",
":dump_mlir_util",
":error_util",
":mlir_roundtrip_flags",
":tensorflow",
":tensorflow_dialect_registration",
":tensorflow_passes",
":translate_utils",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging",
"//tensorflow/stream_executor/lib",
]
# Prefer to link 'compile_mlir_util' library that also links necessary
# TensorFlow passes to the pipeline. This library without tf passes is useful
# if the constant folding is not required on the TensorFlow dialect. For
# example, this is used in XLA ondemand compilation which compiles a single op
# at a time and doesn't require constant folding. Doing so helps avoid a
# circular dependency between c_api and tf passes.
# TODO(hinsu): Split out the constant folding hook and only exclude that in
# this target.
cc_library(
name = "compile_mlir_util",
name = "compile_mlir_util_no_tf_dialect_passes",
srcs = ["utils/compile_mlir_util.cc"],
hdrs = ["utils/compile_mlir_util.h"],
deps = [
":bridge_logger",
":convert_type",
":dump_mlir_util",
":error_util",
":tensorflow",
":tensorflow_dialect_registration",
":tensorflow_passes",
deps = COMPILE_MLIR_UTIL_DEPS,
)
cc_library(
name = "compile_mlir_util",
hdrs = ["utils/compile_mlir_util.h"],
deps = COMPILE_MLIR_UTIL_DEPS + [
"compile_mlir_util_no_tf_dialect_passes",
":tf_dialect_passes",
":translate_utils",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
"//tensorflow/core/platform:logging",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
)
@ -1096,8 +1118,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/stream_executor/lib",
],
)

View File

@ -31,6 +31,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
@ -276,19 +278,11 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
return Status::OK();
}
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
static Status CompileMlirToXlaHlo(
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result) {
RegisterDialects();
mlir::MLIRContext mlir_context;
mlir::OwningModuleRef mlir_module;
TF_RETURN_IF_ERROR(
ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module));
auto module_op = mlir_module.get();
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
@ -309,9 +303,14 @@ Status CompileSerializedMlirToXlaHlo(
GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
auto shape_representation_fn_no_fast_memory =
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
};
[shape_representation_fn](const TensorShape& shape,
DataType dtype) -> StatusOr<xla::Shape> {
if (shape_representation_fn)
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
};
// Compute all input shapes.
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
@ -333,4 +332,38 @@ Status CompileSerializedMlirToXlaHlo(
return Status::OK();
}
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result) {
RegisterDialects();
mlir::MLIRContext mlir_context;
mlir::OwningModuleRef mlir_module;
TF_RETURN_IF_ERROR(
ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module));
return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args,
shape_representation_fn, compilation_result);
}
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
const GraphDebugInfo& debug_info,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result) {
RegisterDialects();
mlir::MLIRContext context;
GraphImportConfig config;
config.graph_as_function = true;
auto module_or =
ConvertGraphToMlir(graph, debug_info, flib_def, config, &context);
if (!module_or.ok()) return module_or.status();
return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes,
use_tuple_args, shape_representation_fn,
compilation_result);
}
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -53,6 +54,15 @@ Status CompileSerializedMlirToXlaHlo(
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result);
// Same as the above but takes input as TensorFlow Graph.
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
const GraphDebugInfo& debug_info,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_

View File

@ -20,6 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@ -285,5 +288,41 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
status_or_hlo_module.ValueOrDie()->ToString());
}
// Verify that conversion from Graph to MLIR and empty shape representation
// function is successful.
TEST(CompileGraphToXlaHlo, Basic) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
Graph graph(OpRegistry::Global());
Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
test::FillValues<float>(&dummy_tensor, {-1.0});
Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT);
test::graph::Retval(&graph, 0, arg);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(CompileGraphToXlaHlo(
graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def,
GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result));
const xla::HloModuleConfig module_config(
result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
result.computation->proto(), module_config);
ASSERT_TRUE(status_or_hlo_module.ok());
string expected_hlo_module_string = R"(HloModule main.3
ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) {
%Arg_0.1 = f32[] parameter(0)
ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
} // namespace
} // namespace tensorflow

View File

@ -1354,6 +1354,26 @@ tf_xla_py_test(
],
)
# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it.
tf_xla_py_test(
name = "unary_mlir_ops_test",
size = "medium",
srcs = ["unary_mlir_ops_test.py"],
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "fused_batchnorm_test",
size = "medium",

View File

@ -0,0 +1,80 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for XLA JIT compiler."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
class UnaryOpsTest(xla_test.XLATestCase):
"""Test cases for unary operators."""
def __init__(self, method_name='runTest'):
super(UnaryOpsTest, self).__init__(method_name)
context.context().enable_mlir_bridge = True
def _assertOpOutputMatchesExpected(self,
op,
inp,
expected,
equality_test=None,
rtol=1e-3,
atol=1e-5):
"""Verifies that 'op' produces 'expected' when fed input 'inp' .
Args:
op: operator to test
inp: numpy input array to use as input to 'op'.
expected: numpy array representing the expected output of 'op'.
equality_test: either None, or a function that tests two numpy arrays for
equality. If None, self.assertAllClose is used.
rtol: relative tolerance for equality test.
atol: absolute tolerance for equality test.
"""
with self.session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
dtypes.as_dtype(inp.dtype), inp.shape, name='a')
output = op(pinp)
result = session.run(output, {pinp: inp})
if equality_test is None:
self.assertEqual(output.dtype, expected.dtype)
self.assertAllCloseAccordingToType(
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
else:
equality_test(result, expected, rtol=rtol, atol=atol)
def testNumericOps(self):
# TODO(hinsu): Enable complex types after fixing the failure in export to
# HLOModule.
for dtype in self.numeric_types - {np.int8, np.uint8} - self.complex_types:
self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[2, -1]], dtype=dtype),
expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype))
if __name__ == '__main__':
googletest.main()

View File

@ -154,6 +154,7 @@ genrule(
"@icu//:icu4c/LICENSE",
"@libjpeg_turbo//:LICENSE.md",
"@llvm-project//llvm:LICENSE.TXT",
"@llvm-project//mlir:LICENSE.TXT",
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
"@local_config_tensorrt//:LICENSE",
@ -234,6 +235,7 @@ genrule(
"@icu//:icu4j/main/shared/licenses/LICENSE",
"@libjpeg_turbo//:LICENSE.md",
"@llvm-project//llvm:LICENSE.TXT",
"@llvm-project//mlir:LICENSE.TXT",
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
"@local_config_tensorrt//:LICENSE",