sync
This commit is contained in:
parent
c48f3fa872
commit
099e67544d
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -tfl-prepare-tf=tfl-allow-bf16-type-legalization=true %s | FileCheck %s
|
||||
// RUN: tf-opt -tfl-prepare-tf=tfl-allow-bf16-and-f16-type-legalization=true %s | FileCheck %s
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
|
||||
@ -23,4 +23,11 @@ func @depthwise_conv_2d_bf16(%arg0 : tensor<256x32x32x3xbf16>, %arg1 : tensor<3x
|
||||
// CHECK: "tfl.depthwise_conv_2d"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: conv_2d_f16
|
||||
func @conv_2d_f16(%arg0 : tensor<256x32x32x3xf16>, %arg1 : tensor<3x3x3x16xf16>) -> tensor<256x8x7x16xf16> {
|
||||
%0 = "tf.Conv2D"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf16>, tensor<3x3x3x16xf16>) -> tensor<256x8x7x16xf16>
|
||||
return %0 : tensor<256x8x7x16xf16>
|
||||
// CHECK: "tfl.conv_2d"
|
||||
}
|
||||
|
||||
}
|
@ -189,7 +189,8 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// the TFLite dialect.
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreatePrepareTFPass(
|
||||
pass_config.unfold_batch_matmul,
|
||||
/*allow_bf16_type_legalization=*/!pass_config.runtime_verification));
|
||||
/*allow_bf16_and_f16_type_legalization=*/!pass_config
|
||||
.runtime_verification));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
if (pass_config.shape_inference) {
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
|
@ -41,7 +41,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateOptimizePass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareTF pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreatePrepareTFPass(
|
||||
bool unfold_batch_matmul, bool allow_bf16_type_legalization);
|
||||
bool unfold_batch_matmul, bool allow_bf16_and_f16_type_legalization);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
|
@ -84,9 +84,10 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
PrepareTFPass() = default;
|
||||
PrepareTFPass(const PrepareTFPass &) {}
|
||||
explicit PrepareTFPass(bool unfold_batch_matmul,
|
||||
bool allow_bf16_type_legalization) {
|
||||
bool allow_bf16_and_f16_type_legalization) {
|
||||
unfold_batch_matmul_ = unfold_batch_matmul;
|
||||
allow_bf16_type_legalization_ = allow_bf16_type_legalization;
|
||||
allow_bf16_and_f16_type_legalization_ =
|
||||
allow_bf16_and_f16_type_legalization;
|
||||
}
|
||||
void runOnFunction() override;
|
||||
|
||||
@ -101,8 +102,8 @@ class PrepareTFPass : public PassWrapper<PrepareTFPass, FunctionPass> {
|
||||
llvm::cl::desc("Unfold BatchMatMul into individual MatMul ops."),
|
||||
llvm::cl::init(true)};
|
||||
|
||||
Option<bool> allow_bf16_type_legalization_{
|
||||
*this, "tfl-allow-bf16-type-legalization",
|
||||
Option<bool> allow_bf16_and_f16_type_legalization_{
|
||||
*this, "tfl-allow-bf16-and-f16-type-legalization",
|
||||
llvm::cl::desc("Allow bf16 type legalization."), llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
@ -291,10 +292,12 @@ struct ConvertTFConvOpMatchState {
|
||||
template <typename ConcreteType, typename TFConvOpType>
|
||||
class ConvertTFConvOp : public RewritePattern {
|
||||
public:
|
||||
ConvertTFConvOp(MLIRContext *context, bool allow_bf16_type_legalization)
|
||||
ConvertTFConvOp(MLIRContext *context,
|
||||
bool allow_bf16_and_f16_type_legalization)
|
||||
: RewritePattern(TFConvOpType::getOperationName(), 1, context),
|
||||
intAttrOne(Builder(context).getI32IntegerAttr(1)),
|
||||
allow_bf16_type_legalization_(allow_bf16_type_legalization) {}
|
||||
allow_bf16_and_f16_type_legalization_(
|
||||
allow_bf16_and_f16_type_legalization) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@ -311,8 +314,8 @@ class ConvertTFConvOp : public RewritePattern {
|
||||
TFConvOpType tf_op = cast<TFConvOpType>(op);
|
||||
|
||||
if (!TFTypeIsFloat32Tensor(tf_op.input()) &&
|
||||
!(allow_bf16_type_legalization_ &&
|
||||
TFTypeIsBFloat16Tensor(tf_op.input())))
|
||||
!(allow_bf16_and_f16_type_legalization_ &&
|
||||
TFTypeIsBFloat16OrHalfTensor(tf_op.input())))
|
||||
return failure();
|
||||
|
||||
if (!TFDataFormatIsNHWC(op)) return failure();
|
||||
@ -374,7 +377,7 @@ class ConvertTFConvOp : public RewritePattern {
|
||||
const IntegerAttr intAttrOne;
|
||||
|
||||
private:
|
||||
bool allow_bf16_type_legalization_;
|
||||
bool allow_bf16_and_f16_type_legalization_;
|
||||
};
|
||||
|
||||
class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
||||
@ -1342,7 +1345,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
phase_2_patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo,
|
||||
ConvertTFStridedSlice, ConvertRfftToRfft2d>(ctx);
|
||||
phase_2_patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative>(
|
||||
ctx, allow_bf16_type_legalization_);
|
||||
ctx, allow_bf16_and_f16_type_legalization_);
|
||||
|
||||
applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
|
||||
}
|
||||
|
@ -64,6 +64,18 @@ inline bool TFTypeIsBFloat16Tensor(Value value) {
|
||||
return tensorType.getElementType().isBF16();
|
||||
}
|
||||
|
||||
// Returns true iff the given value is a f16 tensor.
|
||||
inline bool TFTypeIsHalfTensor(Value value) {
|
||||
auto tensorType = value.getType().dyn_cast<TensorType>();
|
||||
if (!tensorType) return false;
|
||||
return tensorType.getElementType().isF16();
|
||||
}
|
||||
|
||||
// Returns true iff the given value is a f16 or bf16 tensor.
|
||||
inline bool TFTypeIsBFloat16OrHalfTensor(Value value) {
|
||||
return TFTypeIsBFloat16Tensor(value) || TFTypeIsHalfTensor(value);
|
||||
}
|
||||
|
||||
// Returns true iff the given TensorFlow op has a `padding` attribute whose
|
||||
// value is "SAME" or "VALID", and writes the attribute to `padding`.
|
||||
inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {
|
||||
|
@ -15321,6 +15321,40 @@ On GPU, if an out of bound index is found, the index is ignored.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorScatterMaxOp : TF_Op<"TensorScatterMax", [NoSideEffect]> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$tensor,
|
||||
TF_I32OrI64Tensor:$indices,
|
||||
TF_Tensor:$updates
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorScatterMinOp : TF_Op<"TensorScatterMin", [NoSideEffect]> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$tensor,
|
||||
TF_I32OrI64Tensor:$indices,
|
||||
TF_Tensor:$updates
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorScatterSubOp : TF_Op<"TensorScatterSub", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Subtracts sparse `updates` from an existing tensor according to `indices`.
|
||||
|
@ -1,21 +1,14 @@
|
||||
"""Build rules for Tensorflow/XLA testing."""
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_is_configured")
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
"tf_exec_properties",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
def all_backends():
|
||||
b = ["cpu"] + plugins.keys()
|
||||
if cuda_is_configured() or rocm_is_configured():
|
||||
return b + ["gpu"]
|
||||
else:
|
||||
return b
|
||||
all_backends = ["cpu", "gpu"] + plugins.keys()
|
||||
|
||||
def tf_xla_py_test(
|
||||
name,
|
||||
@ -32,7 +25,7 @@ def tf_xla_py_test(
|
||||
"""Generates py_test targets, one per XLA backend.
|
||||
|
||||
This rule generates py_test() targets named name_backend, for each backend
|
||||
in all_backends(). The rule also generates a test suite with named `name` that
|
||||
in all_backends. The rule also generates a test suite with named `name` that
|
||||
tests all backends for the test.
|
||||
|
||||
For example, the following rule generates test cases foo_test_cpu,
|
||||
@ -62,7 +55,7 @@ def tf_xla_py_test(
|
||||
**kwargs: keyword arguments passed onto the generated py_test() rules.
|
||||
"""
|
||||
if enabled_backends == None:
|
||||
enabled_backends = all_backends()
|
||||
enabled_backends = all_backends
|
||||
if disabled_backends == None:
|
||||
disabled_backends = []
|
||||
if type(disabled_backends) != "list":
|
||||
@ -140,6 +133,6 @@ def tf_xla_py_test(
|
||||
def generate_backend_suites(backends = []):
|
||||
"""Generates per-backend test_suites that run all tests for a backend."""
|
||||
if not backends:
|
||||
backends = all_backends()
|
||||
backends = all_backends
|
||||
for backend in backends:
|
||||
native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend])
|
||||
|
@ -1989,19 +1989,21 @@ xla_test(
|
||||
name = "collective_ops_test",
|
||||
srcs = ["collective_ops_test.cc"],
|
||||
args = ["--xla_force_host_platform_device_count=4"],
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
tags = [
|
||||
backend_tags = {
|
||||
# This test is tagged "manual" because it requires multiple GPUs, and
|
||||
# Forge only supports single-GPU tests. Guitar skips "manual" tests
|
||||
# unless they're also tagged "guitar".
|
||||
"guitar",
|
||||
"manual",
|
||||
"multi_gpu",
|
||||
"no_oss",
|
||||
"notap",
|
||||
"gpu": [
|
||||
"guitar",
|
||||
"manual",
|
||||
"multi_gpu",
|
||||
"no_oss",
|
||||
"notap",
|
||||
],
|
||||
},
|
||||
backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
|
@ -1,33 +1,18 @@
|
||||
"""Build rules for XLA testing."""
|
||||
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_is_configured")
|
||||
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
|
||||
load(
|
||||
"//tensorflow/stream_executor:build_defs.bzl",
|
||||
"if_gpu_is_configured",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
"tf_gpu_tests_tags",
|
||||
)
|
||||
|
||||
all_backends = ["cpu", "gpu"] + plugins.keys()
|
||||
|
||||
def filter_backends(backends):
|
||||
"""Removes "gpu" from a backend list if CUDA or ROCm is not enabled.
|
||||
|
||||
This allows us to simply hardcode lists including "gpu" here and in the
|
||||
BUILD file, without causing failures when CUDA or ROCm isn't enabled.'
|
||||
|
||||
Args:
|
||||
backends: A list of backends to filter.
|
||||
|
||||
Returns:
|
||||
The filtered list of backends.
|
||||
"""
|
||||
if cuda_is_configured() or rocm_is_configured():
|
||||
return backends
|
||||
else:
|
||||
return [backend for backend in backends if backend != "gpu"]
|
||||
|
||||
def xla_test(
|
||||
name,
|
||||
srcs,
|
||||
@ -132,7 +117,7 @@ def xla_test(
|
||||
deps = deps,
|
||||
)
|
||||
|
||||
for backend in filter_backends(backends):
|
||||
for backend in backends:
|
||||
test_name = "%s_%s" % (name, backend)
|
||||
this_backend_tags = ["xla_%s" % backend]
|
||||
this_backend_copts = []
|
||||
@ -142,9 +127,9 @@ def xla_test(
|
||||
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
|
||||
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
|
||||
elif backend == "gpu":
|
||||
backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
|
||||
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
|
||||
this_backend_tags += tf_cuda_tests_tags()
|
||||
backend_deps = if_gpu_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"])
|
||||
backend_deps += if_gpu_is_configured(["//tensorflow/compiler/xla/tests:test_macros_gpu"])
|
||||
this_backend_tags += tf_gpu_tests_tags()
|
||||
elif backend in plugins:
|
||||
backend_deps = []
|
||||
backend_deps += plugins[backend]["deps"]
|
||||
@ -219,7 +204,7 @@ def xla_test_library(
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
|
||||
for backend in filter_backends(backends):
|
||||
for backend in backends:
|
||||
this_backend_copts = []
|
||||
if backend in ["cpu", "gpu"]:
|
||||
backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
|
||||
@ -242,7 +227,7 @@ def xla_test_library(
|
||||
def generate_backend_suites(backends = []):
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
for backend in filter_backends(backends):
|
||||
for backend in backends:
|
||||
native.test_suite(
|
||||
name = "%s_tests" % backend,
|
||||
tags = ["xla_%s" % backend, "-broken", "manual"],
|
||||
@ -251,7 +236,7 @@ def generate_backend_suites(backends = []):
|
||||
def generate_backend_test_macros(backends = []):
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
for backend in filter_backends(backends):
|
||||
for backend in backends:
|
||||
manifest = ""
|
||||
if backend in plugins:
|
||||
manifest = plugins[backend]["disabled_manifest"]
|
||||
|
@ -765,8 +765,8 @@ XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) {
|
||||
result);
|
||||
}
|
||||
}
|
||||
// TODO(b/178047150): Fails on GPU with wrong answers.
|
||||
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllGather_Dim1)) {
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
ENTRY test_computation {
|
||||
@ -789,7 +789,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllGather_Dim1)) {
|
||||
/*use_threads=*/true, /*run_hlo_passes=*/true));
|
||||
ASSERT_EQ(results.size(), kNumReplicas);
|
||||
for (const Literal& result : results) {
|
||||
LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
|
||||
LiteralTestUtil::ExpectR1Equal<uint32>({10, 11, 12, 13, 15, 16, 17, 18},
|
||||
result);
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "rocm_is_configured")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
|
||||
|
||||
def stream_executor_friends():
|
||||
return ["//tensorflow/..."]
|
||||
@ -18,9 +18,7 @@ def tf_additional_cudnn_plugin_deps():
|
||||
|
||||
# Returns whether any GPU backend is configuered.
|
||||
def if_gpu_is_configured(x):
|
||||
if cuda_is_configured() or rocm_is_configured():
|
||||
return x
|
||||
return []
|
||||
return if_cuda_is_configured(x) + if_rocm_is_configured(x)
|
||||
|
||||
def if_cuda_or_rocm(x):
|
||||
return if_gpu_is_configured(x)
|
||||
|
@ -158,12 +158,20 @@ cc_library(
|
||||
deps = ["//tensorflow/stream_executor:platform"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rocblas_if_static",
|
||||
deps = if_static([
|
||||
"@local_config_rocm//rocm:rocblas",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rocblas_plugin",
|
||||
srcs = if_rocm_is_configured(["rocm_blas.cc"]),
|
||||
hdrs = if_rocm_is_configured(["rocm_blas.h"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_rocm_is_configured([
|
||||
":rocblas_if_static",
|
||||
":rocm_gpu_executor",
|
||||
":rocm_platform_id",
|
||||
"//third_party/eigen3",
|
||||
@ -184,18 +192,24 @@ cc_library(
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:rocblas",
|
||||
])),
|
||||
]),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rocfft_if_static",
|
||||
deps = if_static([
|
||||
"@local_config_rocm//rocm:rocfft",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rocfft_plugin",
|
||||
srcs = if_rocm_is_configured(["rocm_fft.cc"]),
|
||||
hdrs = if_rocm_is_configured(["rocm_fft.h"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_rocm_is_configured([
|
||||
":rocfft_if_static",
|
||||
":rocm_platform_id",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:fft",
|
||||
@ -210,12 +224,17 @@ cc_library(
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:rocfft",
|
||||
])),
|
||||
]),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "miopen_if_static",
|
||||
deps = if_static([
|
||||
"@local_config_rocm//rocm:miopen",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "miopen_plugin",
|
||||
srcs = if_rocm_is_configured(["rocm_dnn.cc"]),
|
||||
@ -227,6 +246,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_rocm_is_configured([
|
||||
":miopen_if_static",
|
||||
":rocm_diagnostics",
|
||||
":rocm_driver",
|
||||
":rocm_gpu_executor",
|
||||
@ -248,17 +268,23 @@ cc_library(
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:miopen",
|
||||
])),
|
||||
]),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hiprand_if_static",
|
||||
deps = if_static([
|
||||
"@local_config_rocm//rocm:hiprand",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rocrand_plugin",
|
||||
srcs = if_rocm_is_configured(["rocm_rng.cc"]),
|
||||
hdrs = if_rocm_is_configured([]),
|
||||
deps = if_rocm_is_configured([
|
||||
":hiprand_if_static",
|
||||
":rocm_gpu_executor",
|
||||
":rocm_platform_id",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
@ -273,26 +299,30 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:hiprand",
|
||||
])),
|
||||
]),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipsparse_if_static",
|
||||
deps = if_static([
|
||||
"@local_config_rocm//rocm:hipsparse",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hipsparse_wrapper",
|
||||
srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
|
||||
hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
|
||||
deps = if_rocm_is_configured([
|
||||
":hipsparse_if_static",
|
||||
":rocm_gpu_executor",
|
||||
":rocm_platform_id",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:hiprand",
|
||||
])),
|
||||
]),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
|
6
tf/third_party/gpus/cuda/build_defs.bzl.tpl
vendored
6
tf/third_party/gpus/cuda/build_defs.bzl.tpl
vendored
@ -50,10 +50,6 @@ def cuda_default_copts():
|
||||
["-O3"]
|
||||
)
|
||||
|
||||
def cuda_is_configured():
|
||||
"""Returns true if CUDA was enabled during the configure process."""
|
||||
return %{cuda_is_configured}
|
||||
|
||||
def cuda_gpu_architectures():
|
||||
"""Returns a list of supported GPU architectures."""
|
||||
return %{cuda_gpu_architectures}
|
||||
@ -64,7 +60,7 @@ def if_cuda_is_configured(x):
|
||||
Unlike if_cuda(), this does not require that we are building with
|
||||
--config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
|
||||
"""
|
||||
if cuda_is_configured():
|
||||
if %{cuda_is_configured}:
|
||||
return select({"//conditions:default": x})
|
||||
return select({"//conditions:default": []})
|
||||
|
||||
|
10
tf/third_party/gpus/rocm/build_defs.bzl.tpl
vendored
10
tf/third_party/gpus/rocm/build_defs.bzl.tpl
vendored
@ -30,10 +30,6 @@ def rocm_copts(opts = []):
|
||||
]),
|
||||
}) + if_rocm_is_configured(opts)
|
||||
|
||||
def rocm_is_configured():
|
||||
"""Returns true if ROCm was enabled during the configure process."""
|
||||
return %{rocm_is_configured}
|
||||
|
||||
def rocm_gpu_architectures():
|
||||
"""Returns a list of supported GPU architectures."""
|
||||
return %{rocm_gpu_architectures}
|
||||
@ -44,9 +40,9 @@ def if_rocm_is_configured(x):
|
||||
Unlike if_rocm(), this does not require that we are building with
|
||||
--config=rocm. Used to allow non-ROCm code to depend on ROCm libraries.
|
||||
"""
|
||||
if rocm_is_configured():
|
||||
return x
|
||||
return []
|
||||
if %{rocm_is_configured}:
|
||||
return select({"//conditions:default": x})
|
||||
return select({"//conditions:default": []})
|
||||
|
||||
def rocm_library(copts = [], **kwargs):
|
||||
"""Wrapper over cc_library which adds default ROCm options."""
|
||||
|
Loading…
Reference in New Issue
Block a user