This commit is contained in:
Mihai Maruseac 2021-01-21 11:13:01 -08:00
parent c48f3fa872
commit 099e67544d
14 changed files with 155 additions and 98 deletions

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tfl-prepare-tf=tfl-allow-bf16-type-legalization=true %s | FileCheck %s
// RUN: tf-opt -tfl-prepare-tf=tfl-allow-bf16-and-f16-type-legalization=true %s | FileCheck %s
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
@ -23,4 +23,11 @@ func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x
// CHECK: "tfl.depthwise_conv_2d"
}
// CHECK-LABEL: conv_2d_f16
func @conv_2d_f16(%arg0 : tensor<256x32x32x3xf16>, %arg1 : tensor<3x3x3x16xf16>) -> tensor<256x8x7x16xf16> {
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf16>, tensor<3x3x3x16xf16>) -> tensor<256x8x7x16xf16>
return %0 : tensor<256x8x7x16xf16>
// CHECK: "tfl.conv_2d"
}
}

View File

@ -189,7 +189,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// the TFLite dialect.
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
pass_config.unfold_batch_matmul,
/*allow_bf16_type_legalization=*/!pass_config.runtime_verification));
/*allow_bf16_and_f16_type_legalization=*/!pass_config
.runtime_verification));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
if (pass_config.shape_inference) {
// Add a shape inference pass to optimize away the unnecessary casts.

View File

@ -41,7 +41,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
bool unfold_batch_matmul, bool allow_bf16_type_legalization);
bool unfold_batch_matmul, bool allow_bf16_and_f16_type_legalization);
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.

View File

@ -84,9 +84,10 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
PrepareTFPass() = default;
PrepareTFPass(const PrepareTFPass &) {}
explicit PrepareTFPass(bool unfold_batch_matmul,
bool allow_bf16_type_legalization) {
bool allow_bf16_and_f16_type_legalization) {
unfold_batch_matmul_ = unfold_batch_matmul;
allow_bf16_type_legalization_ = allow_bf16_type_legalization;
allow_bf16_and_f16_type_legalization_ =
allow_bf16_and_f16_type_legalization;
}
void runOnFunction() override;
@ -101,8 +102,8 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
llvm::cl::init(true)};
Option<bool> allow_bf16_type_legalization_{
*this, "tfl-allow-bf16-type-legalization",
Option<bool> allow_bf16_and_f16_type_legalization_{
*this, "tfl-allow-bf16-and-f16-type-legalization",
llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
};
@ -291,10 +292,12 @@ struct ConvertTFConvOpMatchState {
template <typename ConcreteType, typename TFConvOpType>
class ConvertTFConvOp : public RewritePattern {
public:
ConvertTFConvOp(MLIRContext *context, bool allow_bf16_type_legalization)
ConvertTFConvOp(MLIRContext *context,
bool allow_bf16_and_f16_type_legalization)
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
intAttrOne(Builder(context).getI32IntegerAttr(1)),
allow_bf16_type_legalization_(allow_bf16_type_legalization) {}
allow_bf16_and_f16_type_legalization_(
allow_bf16_and_f16_type_legalization) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
@ -311,8 +314,8 @@ class ConvertTFConvOp : public RewritePattern {
TFConvOpType tf_op = cast<TFConvOpType>(op);
if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
!(allow_bf16_type_legalization_ &&
TFTypeIsBFloat16Tensor(tf_op.input())))
!(allow_bf16_and_f16_type_legalization_ &&
TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
return failure();
if (!TFDataFormatIsNHWC(op)) return failure();
@ -374,7 +377,7 @@ class ConvertTFConvOp : public RewritePattern {
const IntegerAttr intAttrOne;
private:
bool allow_bf16_type_legalization_;
bool allow_bf16_and_f16_type_legalization_;
};
class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
@ -1342,7 +1345,7 @@ void PrepareTFPass::runOnFunction() {
phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
ctx, allow_bf16_type_legalization_);
ctx, allow_bf16_and_f16_type_legalization_);
applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
}

View File

@ -64,6 +64,18 @@ inline bool TFTypeIsBFloat16Tensor(Value value) {
return tensorType.getElementType().isBF16();
}
// Returns true iff the given value is a f16 tensor.
inline bool TFTypeIsHalfTensor(Value value) {
auto tensorType = value.getType().dyn_cast<TensorType>();
if (!tensorType) return false;
return tensorType.getElementType().isF16();
}
// Returns true iff the given value is a f16 or bf16 tensor.
inline bool TFTypeIsBFloat16OrHalfTensor(Value value) {
return TFTypeIsBFloat16Tensor(value) || TFTypeIsHalfTensor(value);
}
// Returns true iff the given TensorFlow op has a `padding` attribute whose
// value is "SAME" or "VALID", and writes the attribute to `padding`.
inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {

View File

@ -15321,6 +15321,40 @@ On GPU, if an out of bound index is found, the index is ignored.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TensorScatterMaxOp : TF_Op<"TensorScatterMax", [NoSideEffect]> {
let summary = "";
let arguments = (ins
TF_Tensor:$tensor,
TF_I32OrI64Tensor:$indices,
TF_Tensor:$updates
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TensorScatterMinOp : TF_Op<"TensorScatterMin", [NoSideEffect]> {
let summary = "";
let arguments = (ins
TF_Tensor:$tensor,
TF_I32OrI64Tensor:$indices,
TF_Tensor:$updates
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TensorScatterSubOp : TF_Op<"TensorScatterSub", [NoSideEffect]> {
let summary = [{
Subtracts sparse `updates` from an existing tensor according to `indices`.

View File

@ -1,21 +1,14 @@
"""Build rules for Tensorflow/XLA testing."""
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_is_configured")
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
"tf_exec_properties",
)
load("//tensorflow:tensorflow.bzl", "py_test")
def all_backends():
b = ["cpu"] + plugins.keys()
if cuda_is_configured() or rocm_is_configured():
return b + ["gpu"]
else:
return b
all_backends = ["cpu", "gpu"] + plugins.keys()
def tf_xla_py_test(
name,
@ -32,7 +25,7 @@ def tf_xla_py_test(
"""Generates py_test targets, one per XLA backend.
This rule generates py_test() targets named name_backend, for each backend
in all_backends(). The rule also generates a test suite with named `name` that
in all_backends. The rule also generates a test suite with named `name` that
tests all backends for the test.
For example, the following rule generates test cases foo_test_cpu,
@ -62,7 +55,7 @@ def tf_xla_py_test(
**kwargs: keyword arguments passed onto the generated py_test() rules.
"""
if enabled_backends == None:
enabled_backends = all_backends()
enabled_backends = all_backends
if disabled_backends == None:
disabled_backends = []
if type(disabled_backends) != "list":
@ -140,6 +133,6 @@ def tf_xla_py_test(
def generate_backend_suites(backends = []):
"""Generates per-backend test_suites that run all tests for a backend."""
if not backends:
backends = all_backends()
backends = all_backends
for backend in backends:
native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend])

View File

@ -1989,19 +1989,21 @@ xla_test(
name = "collective_ops_test",
srcs = ["collective_ops_test.cc"],
args = ["--xla_force_host_platform_device_count=4"],
backends = [
"gpu",
"cpu",
],
tags = [
backend_tags = {
# This test is tagged "manual" because it requires multiple GPUs, and
# Forge only supports single-GPU tests. Guitar skips "manual" tests
# unless they're also tagged "guitar".
"guitar",
"manual",
"multi_gpu",
"no_oss",
"notap",
"gpu": [
"guitar",
"manual",
"multi_gpu",
"no_oss",
"notap",
],
},
backends = [
"gpu",
"cpu",
],
deps = [
":client_library_test_base",

View File

@ -1,33 +1,18 @@
"""Build rules for XLA testing."""
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_is_configured")
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
load(
"//tensorflow/stream_executor:build_defs.bzl",
"if_gpu_is_configured",
)
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
"tf_gpu_tests_tags",
)
all_backends = ["cpu", "gpu"] + plugins.keys()
def filter_backends(backends):
"""Removes "gpu" from a backend list if CUDA or ROCm is not enabled.
This allows us to simply hardcode lists including "gpu" here and in the
BUILD file, without causing failures when CUDA or ROCm isn't enabled.'
Args:
backends: A list of backends to filter.
Returns:
The filtered list of backends.
"""
if cuda_is_configured() or rocm_is_configured():
return backends
else:
return [backend for backend in backends if backend != "gpu"]
def xla_test(
name,
srcs,
@ -132,7 +117,7 @@ def xla_test(
deps = deps,
)
for backend in filter_backends(backends):
for backend in backends:
test_name = "%s_%s" % (name, backend)
this_backend_tags = ["xla_%s" % backend]
this_backend_copts = []
@ -142,9 +127,9 @@ def xla_test(
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
elif backend == "gpu":
backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
this_backend_tags += tf_cuda_tests_tags()
backend_deps = if_gpu_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"])
backend_deps += if_gpu_is_configured(["//tensorflow/compiler/xla/tests:test_macros_gpu"])
this_backend_tags += tf_gpu_tests_tags()
elif backend in plugins:
backend_deps = []
backend_deps += plugins[backend]["deps"]
@ -219,7 +204,7 @@ def xla_test_library(
if not backends:
backends = all_backends
for backend in filter_backends(backends):
for backend in backends:
this_backend_copts = []
if backend in ["cpu", "gpu"]:
backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
@ -242,7 +227,7 @@ def xla_test_library(
def generate_backend_suites(backends = []):
if not backends:
backends = all_backends
for backend in filter_backends(backends):
for backend in backends:
native.test_suite(
name = "%s_tests" % backend,
tags = ["xla_%s" % backend, "-broken", "manual"],
@ -251,7 +236,7 @@ def generate_backend_suites(backends = []):
def generate_backend_test_macros(backends = []):
if not backends:
backends = all_backends
for backend in filter_backends(backends):
for backend in backends:
manifest = ""
if backend in plugins:
manifest = plugins[backend]["disabled_manifest"]

View File

@ -765,8 +765,8 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) {
result);
}
}
// TODO(b/178047150): Fails on GPU with wrong answers.
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllGather_Dim1)) {
XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
@ -789,7 +789,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllGather_Dim1)) {
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
LiteralTestUtil::ExpectR1Equal<uint32>({10, 11, 12, 13, 15, 16, 17, 18},
result);
}
}

View File

@ -1,5 +1,5 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_is_configured")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
def stream_executor_friends():
return ["//tensorflow/..."]
@ -18,9 +18,7 @@ def tf_additional_cudnn_plugin_deps():
# Returns whether any GPU backend is configuered.
def if_gpu_is_configured(x):
if cuda_is_configured() or rocm_is_configured():
return x
return []
return if_cuda_is_configured(x) + if_rocm_is_configured(x)
def if_cuda_or_rocm(x):
return if_gpu_is_configured(x)

View File

@ -158,12 +158,20 @@ cc_library(
deps = ["//tensorflow/stream_executor:platform"],
)
cc_library(
name = "rocblas_if_static",
deps = if_static([
"@local_config_rocm//rocm:rocblas",
]),
)
cc_library(
name = "rocblas_plugin",
srcs = if_rocm_is_configured(["rocm_blas.cc"]),
hdrs = if_rocm_is_configured(["rocm_blas.h"]),
visibility = ["//visibility:public"],
deps = if_rocm_is_configured([
":rocblas_if_static",
":rocm_gpu_executor",
":rocm_platform_id",
"//third_party/eigen3",
@ -184,18 +192,24 @@ cc_library(
"//tensorflow/stream_executor/platform:dso_loader",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:rocm_headers",
] + if_static([
"@local_config_rocm//rocm:rocblas",
])),
]),
alwayslink = True,
)
cc_library(
name = "rocfft_if_static",
deps = if_static([
"@local_config_rocm//rocm:rocfft",
]),
)
cc_library(
name = "rocfft_plugin",
srcs = if_rocm_is_configured(["rocm_fft.cc"]),
hdrs = if_rocm_is_configured(["rocm_fft.h"]),
visibility = ["//visibility:public"],
deps = if_rocm_is_configured([
":rocfft_if_static",
":rocm_platform_id",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:fft",
@ -210,12 +224,17 @@ cc_library(
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
"@local_config_rocm//rocm:rocm_headers",
] + if_static([
"@local_config_rocm//rocm:rocfft",
])),
]),
alwayslink = True,
)
cc_library(
name = "miopen_if_static",
deps = if_static([
"@local_config_rocm//rocm:miopen",
]),
)
cc_library(
name = "miopen_plugin",
srcs = if_rocm_is_configured(["rocm_dnn.cc"]),
@ -227,6 +246,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = if_rocm_is_configured([
":miopen_if_static",
":rocm_diagnostics",
":rocm_driver",
":rocm_gpu_executor",
@ -248,17 +268,23 @@ cc_library(
"//tensorflow/stream_executor/platform:dso_loader",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:rocm_headers",
] + if_static([
"@local_config_rocm//rocm:miopen",
])),
]),
alwayslink = True,
)
cc_library(
name = "hiprand_if_static",
deps = if_static([
"@local_config_rocm//rocm:hiprand",
]),
)
cc_library(
name = "rocrand_plugin",
srcs = if_rocm_is_configured(["rocm_rng.cc"]),
hdrs = if_rocm_is_configured([]),
deps = if_rocm_is_configured([
":hiprand_if_static",
":rocm_gpu_executor",
":rocm_platform_id",
"@local_config_rocm//rocm:rocm_headers",
@ -273,26 +299,30 @@ cc_library(
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
] + if_static([
"@local_config_rocm//rocm:hiprand",
])),
]),
alwayslink = True,
)
cc_library(
name = "hipsparse_if_static",
deps = if_static([
"@local_config_rocm//rocm:hipsparse",
]),
)
cc_library(
name = "hipsparse_wrapper",
srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
deps = if_rocm_is_configured([
":hipsparse_if_static",
":rocm_gpu_executor",
":rocm_platform_id",
"@local_config_rocm//rocm:rocm_headers",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
] + if_static([
"@local_config_rocm//rocm:hiprand",
])),
]),
alwayslink = True,
)

View File

@ -50,10 +50,6 @@ def cuda_default_copts():
["-O3"]
)
def cuda_is_configured():
"""Returns true if CUDA was enabled during the configure process."""
return %{cuda_is_configured}
def cuda_gpu_architectures():
"""Returns a list of supported GPU architectures."""
return %{cuda_gpu_architectures}
@ -64,7 +60,7 @@ def if_cuda_is_configured(x):
Unlike if_cuda(), this does not require that we are building with
--config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
"""
if cuda_is_configured():
if %{cuda_is_configured}:
return select({"//conditions:default": x})
return select({"//conditions:default": []})

View File

@ -30,10 +30,6 @@ def rocm_copts(opts = []):
]),
}) + if_rocm_is_configured(opts)
def rocm_is_configured():
"""Returns true if ROCm was enabled during the configure process."""
return %{rocm_is_configured}
def rocm_gpu_architectures():
"""Returns a list of supported GPU architectures."""
return %{rocm_gpu_architectures}
@ -44,9 +40,9 @@ def if_rocm_is_configured(x):
Unlike if_rocm(), this does not require that we are building with
--config=rocm. Used to allow non-ROCm code to depend on ROCm libraries.
"""
if rocm_is_configured():
return x
return []
if %{rocm_is_configured}:
return select({"//conditions:default": x})
return select({"//conditions:default": []})
def rocm_library(copts = [], **kwargs):
"""Wrapper over cc_library which adds default ROCm options."""