Merge branch 'master' into master

This commit is contained in:
Mihai Maruseac 2020-06-19 17:23:54 +00:00 committed by GitHub
commit fccd345ab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
251 changed files with 6055 additions and 2116 deletions

View File

@ -24,9 +24,21 @@ cc_library(
"//tensorflow:windows": get_win_copts(), "//tensorflow:windows": get_win_copts(),
}), }),
deps = [ deps = [
":gcs_helper",
"//tensorflow/c:env",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
cc_library(
name = "gcs_helper",
srcs = ["gcs_helper.cc"],
hdrs = ["gcs_helper.h"],
linkstatic = 1,
deps = [
"//tensorflow/c:env",
],
)

View File

@ -15,9 +15,13 @@ limitations under the License.
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <fstream>
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h" #include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments. // Implementation of a filesystem for GCS environments.
@ -86,6 +90,20 @@ namespace tf_random_access_file {
// SECTION 2. Implementation for `TF_WritableFile` // SECTION 2. Implementation for `TF_WritableFile`
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_writable_file { namespace tf_writable_file {
typedef struct GCSFile {
const char* bucket;
const char* object;
gcs::Client* gcs_client; // not owned
TempFile outfile;
bool sync_need;
} GCSFile;
static void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
plugin_memory_free(const_cast<char*>(gcs_file->object));
delete gcs_file;
}
// TODO(vnvo2409): Implement later // TODO(vnvo2409): Implement later
@ -119,6 +137,20 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) {
// TODO(vnvo2409): Implement later // TODO(vnvo2409): Implement later
static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
TempFile outfile(TF_GetTempFileName(""), std::ios::binary | std::ios::out);
file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client, std::move(outfile), true});
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_gcs_filesystem } // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -126,9 +158,14 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
TF_SetFilesystemVersionMetadata(ops); TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri); ops->scheme = strdup(uri);
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>( ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init; ops->filesystem_ops->init = tf_gcs_filesystem::Init;
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
} }
void TF_InitPlugin(TF_FilesystemPluginInfo* info) { void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include <stdio.h>
#include <fstream>
#include <string>
#include <utility>
TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode)
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
TempFile::TempFile(TempFile&& rhs)
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
TempFile::~TempFile() {
std::fstream::close();
std::remove(name_.c_str());
}
const std::string TempFile::getName() const { return name_; }

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#include <fstream>
#include <string>
class TempFile : public std::fstream {
public:
// We should specify openmode each time we call TempFile.
TempFile(const char* temp_file_name, std::ios::openmode mode);
TempFile(TempFile&& rhs);
~TempFile() override;
const std::string getName() const;
private:
const std::string name_;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_

View File

@ -314,7 +314,6 @@ tf_cc_test(
cc_library( cc_library(
name = "tensorflow_lite_legalize_tf", name = "tensorflow_lite_legalize_tf",
srcs = [ srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc", "transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc", "transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc", "transforms/generated_lower_static_tensor_list.inc",

View File

@ -953,14 +953,14 @@ in the batch dimensions and broadcasting.
}]; }];
let arguments = (ins let arguments = (ins
TFL_TensorOf<[F32]>:$x, TFL_TensorOf<[F32, QI8]>:$x,
TFL_TensorOf<[F32]>:$y, TFL_TensorOf<[F32, QI8]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x, DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y DefaultValuedAttr<BoolAttr, "false">:$adj_y
); );
let results = (outs let results = (outs
TFL_TensorOf<[F32]>:$output TFL_TensorOf<[F32, QI8]>:$output
); );
let hasOptions = 1; let hasOptions = 1;

View File

@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
standard_pipeline_options.enable_inliner = false; standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters; standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options); mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass()); pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
if (pass_config.shape_inference) { if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());

View File

@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime constraints. // Verifies runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass(); std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TFL } // namespace TFL
} // namespace mlir } // namespace mlir

View File

@ -475,6 +475,7 @@ cc_library(
"transforms/cluster_outlining.cc", "transforms/cluster_outlining.cc",
"transforms/collection_ops_util.cc", "transforms/collection_ops_util.cc",
"transforms/decompose_resource_ops_pass.cc", "transforms/decompose_resource_ops_pass.cc",
"transforms/device_index_selector.cc",
"transforms/einsum.cc", "transforms/einsum.cc",
"transforms/executor_island_coarsening.cc", "transforms/executor_island_coarsening.cc",
"transforms/executor_tpuv1_inline_tpu_island.cc", "transforms/executor_tpuv1_inline_tpu_island.cc",

View File

@ -164,6 +164,81 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
let hasFolder = 1; let hasFolder = 1;
} }
def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> {
let summary = "Adjust the contrast of one or more images.";
let description = [{
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
interpreted as `[height, width, channels]`. The other dimensions only
represent a collection of images, such as `[batch, height, width, channels].`
Contrast is adjusted independently for each channel of each image.
For each channel, the Op first computes the mean of the image pixels in the
channel and then adjusts each component of each pixel to
`(x - mean) * contrast_factor + mean`.
}];
let arguments = (ins
TensorOf<[F16, F32]>:$images,
F32Tensor:$contrast_factor
);
let results = (outs
TensorOf<[F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AdjustHueOp : TF_Op<"AdjustHue", [NoSideEffect]> {
let summary = "Adjust the hue of one or more images.";
let description = [{
`images` is a tensor of at least 3 dimensions. The last dimension is
interpreted as channels, and must be three.
The input image is considered in the RGB colorspace. Conceptually, the RGB
colors are first mapped into HSV. A delta is then applied all the hue values,
and then remapped back to RGB colorspace.
}];
let arguments = (ins
TensorOf<[F16, F32]>:$images,
F32Tensor:$delta
);
let results = (outs
TensorOf<[F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AdjustSaturationOp : TF_Op<"AdjustSaturation", [NoSideEffect]> {
let summary = "Adjust the saturation of one or more images.";
let description = [{
`images` is a tensor of at least 3 dimensions. The last dimension is
interpreted as channels, and must be three.
The input image is considered in the RGB colorspace. Conceptually, the RGB
colors are first mapped into HSV. A scale is then applied all the saturation
values, and then remapped back to RGB colorspace.
}];
let arguments = (ins
TensorOf<[F16, F32]>:$images,
F32Tensor:$scale
);
let results = (outs
TensorOf<[F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AllOp : TF_Op<"All", [NoSideEffect]> { def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
let summary = [{ let summary = [{
Computes the "logical and" of elements across dimensions of a tensor. Computes the "logical and" of elements across dimensions of a tensor.
@ -3866,6 +3941,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_HSVToRGBOp : TF_Op<"HSVToRGB", [NoSideEffect]> {
let summary = "Convert one or more images from HSV to RGB.";
let description = [{
Outputs a tensor of the same shape as the `images` tensor, containing the RGB
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
See `rgb_to_hsv` for a description of the HSV encoding.
}];
let arguments = (ins
TF_FpTensor:$images
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_HashTableV2Op : TF_Op<"HashTableV2", []> { def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
let summary = "Creates a non-initialized hash table."; let summary = "Creates a non-initialized hash table.";
@ -5962,11 +6059,11 @@ I.e., \\(y = -x\\).
}]; }];
let arguments = (ins let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
); );
let results = (outs let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y
); );
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -6733,6 +6830,41 @@ the dimension is padded with zeros.
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>; TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
} }
def TF_RGBToHSVOp : TF_Op<"RGBToHSV", [NoSideEffect]> {
let summary = "Converts one or more images from RGB to HSV.";
let description = [{
Outputs a tensor of the same shape as the `images` tensor, containing the HSV
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
Usage Example:
>>> blue_image = tf.stack([
... tf.zeros([5,5]),
... tf.zeros([5,5]),
... tf.ones([5,5])],
... axis=-1)
>>> blue_hsv_image = tf.image.rgb_to_hsv(blue_image)
>>> blue_hsv_image[0,0].numpy()
array([0.6666667, 1. , 1. ], dtype=float32)
}];
let arguments = (ins
TF_FpTensor:$images
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>, def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
WithBroadcastableBinOpBuilder { WithBroadcastableBinOpBuilder {
let summary = [{ let summary = [{
@ -7230,6 +7362,27 @@ Input images can be of different types but output images are always float.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_ResizeBilinearGradOp : TF_Op<"ResizeBilinearGrad", [NoSideEffect]> {
let summary = "Computes the gradient of bilinear interpolation.";
let description = [{
}];
let arguments = (ins
F32Tensor:$grads,
TF_FpTensor:$original_image,
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> { def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> {
let summary = [{ let summary = [{
Resize `images` to `size` using nearest neighbor interpolation. Resize `images` to `size` using nearest neighbor interpolation.

View File

@ -21,11 +21,11 @@ limitations under the License.
#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir { namespace mlir {
namespace TFL { namespace TF {
namespace { namespace {
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the // Folds the DeviceIndex op to a constant value. The DeviceIndex return the
@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() {
// Convert all the DeviceIndex ops to constant values. // Convert all the DeviceIndex ops to constant values.
func.getBody().walk([](TF::DeviceIndexOp op) { func.getBody().walk([](TF::DeviceIndexOp op) {
// This just selects the default in all cases where DeviceIndex feeds into // This just selects the default in all cases where DeviceIndex feeds into
// tf.Case. This could be enhanced based on explicit TFLite specification or // tf.Case. This could be enhanced to have some sort of policy in the
// TAC in future. // future.
OpBuilder b(op); OpBuilder b(op);
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32)); RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
int index = op.device_names().size(); int index = op.device_names().size();
@ -79,7 +79,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
} }
static PassRegistration<DeviceIndexSelector> pass( static PassRegistration<DeviceIndexSelector> pass(
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant"); "tf-device-index-selector", "Fold tf.DeviceIndex to constant");
} // namespace TFL } // namespace TF
} // namespace mlir } // namespace mlir

View File

@ -52,11 +52,6 @@ struct FusedKernelMatcherPass
void runOnFunction() override; void runOnFunction() override;
}; };
// Returns an op's name with the dialect prefix stripped off.
StringRef GetOpNameWithoutDialect(Operation *op) {
return op->getName().getStringRef().split(".").second;
}
bool IsActivationFunction(Operation *op) { bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op); return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
} }
@ -128,8 +123,8 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
} }
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()}; SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{ SmallVector<Attribute, 2> fused_ops{StringAttr::get(
StringAttr::get(GetOpNameWithoutDialect(bias_add), context)}; bias_add.getOperation()->getName().stripDialect(), context)};
// BiasAdd may or may not feed into an activation function. // BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add); auto activation = GetActivation(bias_add);
@ -143,7 +138,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
if (fuse_activation) { if (fuse_activation) {
locations.push_back(activation->getLoc()); locations.push_back(activation->getLoc());
fused_ops.push_back( fused_ops.push_back(
StringAttr::get(GetOpNameWithoutDialect(activation), context)); StringAttr::get(activation->getName().stripDialect(), context));
result_type = activation->getResultTypes().front(); result_type = activation->getResultTypes().front();
} else { } else {
result_type = bias_add.getResult().getType(); result_type = bias_add.getResult().getType();

View File

@ -147,6 +147,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
// generally used beyond exporting to runtimes that supports these ops. In the // generally used beyond exporting to runtimes that supports these ops. In the
// future these fusions may be codegen'd automatically. // future these fusions may be codegen'd automatically.
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass(); std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TF } // namespace TF
namespace tf_executor { namespace tf_executor {

View File

@ -106,53 +106,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
return GetI64ElementsAttr(slice_limits, builder); return GetI64ElementsAttr(slice_limits, builder);
} }
// Returns the padding value of the given position. If padding_attr is a
// nullptr, returns 0.
static int64_t GetPaddingValue(DenseIntElementsAttr padding_attr,
ArrayRef<uint64_t> index) {
if (!padding_attr) return 0;
return padding_attr.getValue<int64_t>(index);
}
static bool IsOnlyPaddingSpatialDims(Value lhs,
ConvDimensionNumbers dimension_numbers,
DenseIntElementsAttr edge_padding_low,
DenseIntElementsAttr edge_padding_high) {
const int64_t batch_dim = dimension_numbers.input_batch_dimension().getInt();
const int64_t feature_dim =
dimension_numbers.input_feature_dimension().getInt();
if (edge_padding_low.getValue<int64_t>(batch_dim) ||
edge_padding_high.getValue<int64_t>(batch_dim))
return false;
if (edge_padding_low.getValue<int64_t>(feature_dim) ||
edge_padding_high.getValue<int64_t>(feature_dim))
return false;
return true;
}
DenseIntElementsAttr BuildConvPaddingAttrs(
DenseIntElementsAttr edge_padding_low,
DenseIntElementsAttr edge_padding_high, DenseIntElementsAttr padding_attr,
ConvDimensionNumbers dimension_numbers, Builder* builder) {
SmallVector<int64_t, 4> padding_low, padding_high;
for (const auto& dim : dimension_numbers.input_spatial_dimensions()) {
unsigned i = dim.getZExtValue();
padding_low.push_back(edge_padding_low.getValue<int64_t>(i));
padding_high.push_back(edge_padding_high.getValue<int64_t>(i));
}
int rank = padding_low.size();
SmallVector<int64_t, 8> padding;
for (unsigned i = 0, e = rank; i < e; ++i) {
padding.push_back(GetPaddingValue(padding_attr, {i, 0}) + padding_low[i]);
padding.push_back(GetPaddingValue(padding_attr, {i, 1}) + padding_high[i]);
}
// padding_attr.getType() doesn't work because it is an optional attribute,
// which can be a nullptr.
auto type = RankedTensorType::get({rank, 2}, builder->getIntegerType(64));
return DenseIntElementsAttr::get(type, padding);
}
#include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc" #include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc"
} // namespace } // namespace
@ -2153,14 +2106,5 @@ LogicalResult deriveShapeFromFirstOperand(
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// ConvOp
//===----------------------------------------------------------------------===//
void ConvOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<FoldPadIntoConv>(context);
}
} // namespace xla_hlo } // namespace xla_hlo
} // namespace mlir } // namespace mlir

View File

@ -929,8 +929,6 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
); );
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
let hasCanonicalizer = 1;
} }
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp { def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {

View File

@ -132,6 +132,22 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
return MakeXlaOp(op); return MakeXlaOp(op);
} }
StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(false),
builder_.getStringAttr(opaque));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal( StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands, const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation, const XlaComputation& computation,

View File

@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder {
FftType fft_type, FftType fft_type,
absl::Span<const int64> fft_length) override; absl::Span<const int64> fft_length) override;
StatusOr<XlaOp> CustomCallInternal(const string& call_target_name,
absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>>
operand_shapes_with_layout) override;
StatusOr<XlaOp> ReduceInternal( StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands, const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation, const XlaComputation& computation,

View File

@ -415,71 +415,6 @@ func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
return %0 : tensor<1x4xf32> return %0 : tensor<1x4xf32>
} }
// CHECK-LABEL: func @fold_pad_into_conv_f32
func @fold_pad_into_conv_f32(%arg0 : tensor<1x32x32x3xf32>,
%arg1 : tensor<7x7x3x64xf32>)
-> tensor<1x16x16x64xf32> {
// CHECK-NOT: xla_hlo.pad
// CHECK: xla_hlo.convolution
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.pad"(%arg0, %0) {
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
interior_padding = dense<0> : tensor<4xi64>
} : (tensor<1x32x32x3xf32>, tensor<f32>) -> tensor<1x38x38x3xf32>
%2 = "xla_hlo.convolution"(%1, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
window_strides = dense<2> : tensor<2xi64>
} : (tensor<1x38x38x3xf32>, tensor<7x7x3x64xf32>) -> tensor<1x16x16x64xf32>
return %2 : tensor<1x16x16x64xf32>
}
// CHECK-LABEL: func @fold_pad_into_conv_i32
func @fold_pad_into_conv_i32(%arg0 : tensor<1x32x32x3xi32>,
%arg1 : tensor<7x7x3x64xi32>)
-> tensor<1x16x16x64xi32> {
// CHECK-NOT: xla_hlo.pad
// CHECK: xla_hlo.convolution
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "xla_hlo.pad"(%arg0, %0) {
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
interior_padding = dense<0> : tensor<4xi64>
} : (tensor<1x32x32x3xi32>, tensor<i32>) -> tensor<1x38x38x3xi32>
%2 = "xla_hlo.convolution"(%1, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
window_strides = dense<2> : tensor<2xi64>
} : (tensor<1x38x38x3xi32>, tensor<7x7x3x64xi32>) -> tensor<1x16x16x64xi32>
return %2 : tensor<1x16x16x64xi32>
}
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic // CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
// CHECK: xla_hlo.reshape // CHECK: xla_hlo.reshape

View File

@ -35,7 +35,7 @@ func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor
// CHECK-LABEL: unranked_operand // CHECK-LABEL: unranked_operand
func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: tf.Abs // CHECK: tf.Abs
// expected-remark@+1 {{lowering requires static shaped operands}} // expected-remark@+1 {{lowering requires static shaped tensor operands}}
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
@ -44,12 +44,20 @@ func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: dynamic_operand // CHECK-LABEL: dynamic_operand
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: tf.Abs // CHECK: tf.Abs
// expected-remark@+1 {{lowering requires static shaped operands}} // expected-remark@+1 {{lowering requires static shaped tensor operands}}
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
// CHECK-LABEL: tuple_type
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
// Verifies that the pass can handle operands of non-tensor type like tuple
// from non TensorFlow ops.
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: unsupported_dtype // CHECK-LABEL: unsupported_dtype
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> { func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
// CHECK: tf.AddN // CHECK: tf.AddN

View File

@ -28,54 +28,3 @@ def UnaryEinsumToEinsum : Pat<
(HLO_UnaryEinsumOp $operand, $equation), (HLO_UnaryEinsumOp $operand, $equation),
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
$operand, (UnaryToBinaryEinsumEq $equation))>; $operand, (UnaryToBinaryEinsumEq $equation))>;
//===----------------------------------------------------------------------===//
// Conv op patterns.
//===----------------------------------------------------------------------===//
def IsZero : Attr<CPred<
"($_self.isa<DenseFPElementsAttr>() &&"
"$_self.cast<DenseFPElementsAttr>().isSplat() &&"
"$_self.cast<DenseFPElementsAttr>().getSplatValue<FloatAttr>()"
".getValue().isZero()) ||"
"($_self.isa<DenseIntElementsAttr>() &&"
"$_self.cast<DenseIntElementsAttr>().isSplat() &&"
"$_self.cast<DenseIntElementsAttr>().getSplatValue<IntegerAttr>()"
".getInt() == 0)">>;
def IsOnlyPaddingSpatialDims
: Constraint<CPred<"IsOnlyPaddingSpatialDims($0, $1, $2, $3)">>;
def BuildConvPaddingAttrs : NativeCodeCall<
"BuildConvPaddingAttrs($0, $1, $2, $3, &$_builder)">;
def FoldPadIntoConv : Pat<
(HLO_ConvOp
(HLO_PadOp $lhs,
(HLO_ConstOp IsZero:$padding_value),
$edge_padding_low,
$edge_padding_high,
IsZero:$interior_padding),
$rhs,
$window_strides,
$padding,
$lhs_dilation,
$rhs_dilation,
$dimension_numbers,
$feature_group_count,
$batch_group_count,
$precision_config),
(HLO_ConvOp
$lhs,
$rhs,
$window_strides,
(BuildConvPaddingAttrs $edge_padding_low, $edge_padding_high, $padding,
$dimension_numbers),
$lhs_dilation,
$rhs_dilation,
$dimension_numbers,
$feature_group_count,
$batch_group_count,
$precision_config),
[(IsOnlyPaddingSpatialDims $lhs, $dimension_numbers, $edge_padding_low,
$edge_padding_high)]>;

View File

@ -389,10 +389,13 @@ struct HloLegalizeToLhlo
target.addLegalOp<ModuleTerminatorOp>(); target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>(); target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>(); target.addIllegalDialect<xla_hlo::XlaHloDialect>();
BufferAssignmentTypeConverter converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs(); auto inputs = op.getType().getInputs();
return std::all_of(inputs.begin(), inputs.end(), return llvm::all_of(inputs,
[](Type input) { return input.isa<MemRefType>(); }); [](Type input) { return input.isa<MemRefType>(); }) &&
converter.isLegal(&op.getBody());
}); });
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) { target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
return std::all_of(returnOp.operand_type_begin(), return std::all_of(returnOp.operand_type_begin(),
@ -401,8 +404,7 @@ struct HloLegalizeToLhlo
}); });
auto module = getOperation(); auto module = getOperation();
BufferAssignmentTypeConverter converter; module.walk([&](FuncOp func) -> WalkResult {
module.walk([&](FuncOp func) {
BufferAssignmentPlacer bufferAssignment(func); BufferAssignmentPlacer bufferAssignment(func);
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
@ -418,8 +420,7 @@ struct HloLegalizeToLhlo
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns); &converter, &patterns);
} }
return WalkResult( return applyPartialConversion(func, target, patterns);
applyPartialConversion(func, target, patterns, &converter));
}); });
} }
@ -463,6 +464,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<xla_hlo::RealOp>, HloToLhloOpConverter<xla_hlo::RealOp>,
HloToLhloOpConverter<xla_hlo::RemOp>, HloToLhloOpConverter<xla_hlo::RemOp>,
HloToLhloOpConverter<xla_hlo::RsqrtOp>, HloToLhloOpConverter<xla_hlo::RsqrtOp>,
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
HloToLhloOpConverter<xla_hlo::SelectOp>, HloToLhloOpConverter<xla_hlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp>, HloToLhloOpConverter<xla_hlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SqrtOp>, HloToLhloOpConverter<xla_hlo::SqrtOp>,

View File

@ -5238,8 +5238,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>(); target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
DenseSet<Operation *> nonlegalized_ops; DenseSet<Operation *> nonlegalized_ops;
LogicalResult result = applyPartialConversion( LogicalResult result =
op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops); applyPartialConversion(op, target, patterns, &nonlegalized_ops);
// In order to enforce that the conversion result is fully converted, // In order to enforce that the conversion result is fully converted,
// fail if there are any nonlegalized ops in the set. // fail if there are any nonlegalized ops in the set.
if (failed(result) || !nonlegalized_ops.empty()) { if (failed(result) || !nonlegalized_ops.empty()) {

View File

@ -88,6 +88,9 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::AddNOp>(), TypeID::get<TF::AddNOp>(),
TypeID::get<TF::AddV2Op>(), TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(), TypeID::get<TF::AngleOp>(),
TypeID::get<TF::AdjustContrastv2Op>(),
TypeID::get<TF::AdjustHueOp>(),
TypeID::get<TF::AdjustSaturationOp>(),
TypeID::get<TF::ApproximateEqualOp>(), TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(), TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(), TypeID::get<TF::ArgMinOp>(),
@ -127,6 +130,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::GatherNdOp>(), TypeID::get<TF::GatherNdOp>(),
TypeID::get<TF::GreaterEqualOp>(), TypeID::get<TF::GreaterEqualOp>(),
TypeID::get<TF::GreaterOp>(), TypeID::get<TF::GreaterOp>(),
TypeID::get<TF::HSVToRGBOp>(),
TypeID::get<TF::IFFT2DOp>(), TypeID::get<TF::IFFT2DOp>(),
TypeID::get<TF::IFFT3DOp>(), TypeID::get<TF::IFFT3DOp>(),
TypeID::get<TF::IFFTOp>(), TypeID::get<TF::IFFTOp>(),
@ -157,10 +161,14 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::PowOp>(), TypeID::get<TF::PowOp>(),
TypeID::get<TF::RFFT2DOp>(), TypeID::get<TF::RFFT2DOp>(),
TypeID::get<TF::RFFT3DOp>(), TypeID::get<TF::RFFT3DOp>(),
TypeID::get<TF::RGBToHSVOp>(),
TypeID::get<TF::RealDivOp>(), TypeID::get<TF::RealDivOp>(),
TypeID::get<TF::ReciprocalOp>(), TypeID::get<TF::ReciprocalOp>(),
TypeID::get<TF::ReciprocalGradOp>(), TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(), TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::ResizeBilinearOp>(),
TypeID::get<TF::ResizeBilinearGradOp>(),
TypeID::get<TF::ResizeNearestNeighborOp>(),
TypeID::get<TF::ReverseSequenceOp>(), TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(), TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(), TypeID::get<TF::RintOp>(),
@ -337,9 +345,9 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
// Only static shaped operands are supported in XLA builders for now. // Only static shaped operands are supported in XLA builders for now.
for (Type ty : op->getOperandTypes()) { for (Type ty : op->getOperandTypes()) {
auto ranked_ty = ty.cast<ShapedType>(); auto ranked_ty = ty.dyn_cast<ShapedType>();
if (!ranked_ty.hasStaticShape()) { if (!ranked_ty || !ranked_ty.hasStaticShape()) {
op->emitRemark() << "lowering requires static shaped operands"; op->emitRemark() << "lowering requires static shaped tensor operands";
return success(); return success();
} }
} }

View File

@ -177,7 +177,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
target.addIllegalOp<ReduceOp>(); target.addIllegalOp<ReduceOp>();
auto func = getFunction(); auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext()); patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure(); signalPassFailure();
} }
} }

View File

@ -43,7 +43,7 @@ class TestLhloToLLVMPass
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<XlaLhloDialect>(); target.addIllegalDialect<XlaLhloDialect>();
if (failed(applyFullConversion(m, target, patterns, &converter))) { if (failed(applyFullConversion(m, target, patterns))) {
signalPassFailure(); signalPassFailure();
} }
} }

View File

@ -711,7 +711,7 @@ struct LhloLegalizeToParallelLoops
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp, target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>(); xla_lhlo::SelectAndScatterOp>();
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure(); signalPassFailure();
} }
} }

View File

@ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);
MAP_HLO_TO_LHLO(RemOp); MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(RsqrtOp);
MAP_HLO_TO_LHLO(SelectOp); MAP_HLO_TO_LHLO(SelectOp);

View File

@ -867,7 +867,7 @@ struct LhloLegalizeToLinalg
auto func = getFunction(); auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure(); signalPassFailure();
} }
} }
@ -882,7 +882,7 @@ struct HloLegalizeToLinalg
auto func = getFunction(); auto func = getFunction();
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure(); signalPassFailure();
} }
} }

View File

@ -1,6 +1,6 @@
// Test DeviceIndex selector. // Test DeviceIndex selector.
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s // RUN: tf-opt --tf-device-index-selector %s | FileCheck %s
// CHECK-LABEL: func @select // CHECK-LABEL: func @select
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) { func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {

View File

@ -770,6 +770,7 @@ tf_xla_py_test(
size = "small", size = "small",
timeout = "long", timeout = "long",
srcs = ["image_ops_test.py"], srcs = ["image_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
shard_count = 10, shard_count = 10,
tags = [ tags = [

View File

@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall(
const Shape& shape, const string& opaque, const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) { absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) { if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument( return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' " "Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.", "are reserved for internal use.",
call_target_name); call_target_name);
} }
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (operand_shapes_with_layout.has_value()) { if (operand_shapes_with_layout.has_value()) {
if (!LayoutUtil::HasLayout(shape)) { if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument( return InvalidArgument(
@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall(
"with constrained layout; given %d shapes, expected %d", "with constrained layout; given %d shapes, expected %d",
operand_shapes_with_layout->size(), operands.size()); operand_shapes_with_layout->size(), operands.size());
} }
instr.set_constrain_layout(true);
int64 operand_num = 0; int64 operand_num = 0;
for (const Shape& operand_shape : *operand_shapes_with_layout) { for (const Shape& operand_shape : *operand_shapes_with_layout) {
if (!LayoutUtil::HasLayout(operand_shape)) { if (!LayoutUtil::HasLayout(operand_shape)) {
@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall(
"constrained layout.", "constrained layout.",
operand_num); operand_num);
} }
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
++operand_num; ++operand_num;
} }
} }
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout);
}); });
} }
StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (operand_shapes_with_layout.has_value()) {
instr.set_constrain_layout(true);
for (const Shape& operand_shape : *operand_shapes_with_layout) {
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
}
}
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}
XlaOp XlaBuilder::CustomCall( XlaOp XlaBuilder::CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands, const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape, const string& opaque, const XlaComputation& computation, const Shape& shape, const string& opaque,

View File

@ -527,6 +527,14 @@ class XlaBuilder {
const Shape& shape_with_layout, const string& opaque, const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout); absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
// method above calls this method after error handling.
virtual StatusOr<XlaOp> CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
XlaOp CustomCall( XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands, const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape_with_layout, const XlaComputation& computation, const Shape& shape_with_layout,

View File

@ -141,7 +141,9 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core:allocator", "//tensorflow/core:allocator",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/lib:traceme_encode",
"//tensorflow/stream_executor:event", "//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:stream", "//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/host:host_platform_id", "//tensorflow/stream_executor/host:host_platform_id",

View File

@ -127,11 +127,15 @@ std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
} else { } else {
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top()); std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
usage_stream_pool_.pop(); usage_stream_pool_.pop();
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
return stream; return stream;
} }
} }
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) { void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
usage_stream_pool_.push(std::move(stream)); usage_stream_pool_.push(std::move(stream));
} }

View File

@ -98,7 +98,9 @@ limitations under the License.
#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
#include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/event.h" #include "tensorflow/stream_executor/event.h"
@ -749,16 +751,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
// memory that has already been allocated, and a possible Event // memory that has already been allocated, and a possible Event
// allocation. // allocation.
se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer = device_buffer->AsShapedBuffer( ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform()); compact_shape, on_device_shape, client->client()->platform());
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, buffer)); h2d_stream, literal, buffer));
std::shared_ptr<BufferSequencingEvent> event = std::shared_ptr<BufferSequencingEvent> event =
device_buffer->definition_events()[0]; device_buffer->definition_events()[0];
TF_CHECK_OK(AddDestinationBufferSynchronization( TF_CHECK_OK(AddDestinationBufferSynchronization(
local_device, std::move(device_buffer), event, local_device, std::move(device_buffer), event, h2d_stream));
local_device->host_to_device_stream()));
// This can sometimes catch the case where the literal memory has been
// freed before the H2D transfer was issued.
h2d_stream->RefreshStatus()
.IgnoreError(); // Can return error::Unimplemented
QCHECK(h2d_stream->ok());
}; };
client->h2d_transfer_pool()->Schedule(transfer_h2d); client->h2d_transfer_pool()->Schedule(transfer_h2d);
return py_buffer; return py_buffer;
@ -1069,13 +1077,17 @@ Status PjRtBuffer::CopyToHostAsync() {
return Status::OK(); return Status::OK();
} }
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral() { StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral"); tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
TF_RETURN_IF_ERROR(CopyToHostAsync()); TF_RETURN_IF_ERROR(CopyToHostAsync());
std::shared_ptr<HostValue> host_value; std::shared_ptr<HostValue> host_value;
{ {
absl::MutexLock lock(&mu_); absl::MutexLock lock(&mu_);
host_value = host_value_; host_value = host_value_;
if (discard_cached_copy) {
host_value_ = nullptr;
}
} }
if (host_value == nullptr) { if (host_value == nullptr) {
return InvalidArgument("ToLiteral called on invalid buffer"); return InvalidArgument("ToLiteral called on invalid buffer");
@ -1429,10 +1441,9 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
int executable_idx, const RunId& run_id, const ExecuteOptions& options, int executable_idx, const RunId& run_id, const ExecuteOptions& options,
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const { Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const {
int device_ordinal = device->local_device_state()->device_ordinal(); int device_ordinal = device->local_device_state()->device_ordinal();
tensorflow::profiler::TraceMe traceme([&] { tensorflow::profiler::TraceMeConsumer activity(
return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(), "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
"#"); run_id.ToInt());
});
VLOG(3) << "Replica " << replica << ", partition " << partition VLOG(3) << "Replica " << replica << ", partition " << partition
<< " mapped to device ordinal for execution: " << device_ordinal; << " mapped to device ordinal for execution: " << device_ordinal;
@ -1721,10 +1732,9 @@ PjRtExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles, absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const { const ExecuteOptions& options) const {
RunId run_id; RunId run_id;
tensorflow::profiler::TraceMe traceme([&] { tensorflow::profiler::TraceMeProducer activity(
return absl::StrCat( "LocalExecutable::ExecuteOnLocalDevices",
"LocalExecutable::ExecuteOnLocalDevices#run_id=", run_id.ToInt(), "#"); tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
});
const int num_local_devices = local_devices_.size(); const int num_local_devices = local_devices_.size();

View File

@ -478,8 +478,12 @@ class PjRtBuffer {
// Returns the buffer's value as an XLA Literal. If the value has previously // Returns the buffer's value as an XLA Literal. If the value has previously
// been prefetched to the host, then returns the prefetched version, otherwise // been prefetched to the host, then returns the prefetched version, otherwise
// copies the buffer to the host. Blocks until the value is ready. // copies the buffer to the host. Blocks until the value is ready. If
StatusOr<std::shared_ptr<Literal>> ToLiteral(); // `discard_cached_copy` is true then buffer will no longer keep hold of a
// cached copy of the literal (i.e. The reference to the host value will be
// removed.)
StatusOr<std::shared_ptr<Literal>> ToLiteral(
bool discard_cached_copy = false);
// Initiates a copy of the buffer to the host. Does not block waiting for // Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to // the transfer to complete. The value can be retrieved by a later call to

View File

@ -106,7 +106,6 @@ class BranchVisitor {
boundaries_.emplace_back(operand, i, inst); boundaries_.emplace_back(operand, i, inst);
continue; continue;
} }
worklist_.push_back(operand); worklist_.push_back(operand);
visited_.insert(operand); visited_.insert(operand);
} }
@ -197,6 +196,7 @@ bool WorthHoisting(HloInstruction* instruction) {
case HloOpcode::kMultiply: case HloOpcode::kMultiply:
case HloOpcode::kDivide: case HloOpcode::kDivide:
case HloOpcode::kTuple: case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kGetTupleElement: case HloOpcode::kGetTupleElement:
return true; return true;
default: default:
@ -206,10 +206,11 @@ bool WorthHoisting(HloInstruction* instruction) {
// Compare if the instructions to be visited at each branches are identical. // Compare if the instructions to be visited at each branches are identical.
bool InstructionWithinBranchIdentical( bool InstructionWithinBranchIdentical(
const std::vector<HloInstruction*>& instructions, bool is_layout_senstive) { const std::vector<HloInstruction*>& instructions,
bool is_layout_sensitive) {
// Identical includes the shape of each operands are equal. // Identical includes the shape of each operands are equal.
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) { auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
bool eq_operands = is_layout_senstive bool eq_operands = is_layout_sensitive
? ShapeUtil::Equal(a->shape(), b->shape()) ? ShapeUtil::Equal(a->shape(), b->shape())
: ShapeUtil::Compatible(a->shape(), b->shape()); : ShapeUtil::Compatible(a->shape(), b->shape());
return eq_operands; return eq_operands;
@ -233,7 +234,7 @@ bool InstructionWithinBranchIdentical(
auto old_channel_id = instruction->channel_id(); auto old_channel_id = instruction->channel_id();
instruction->set_channel_id(instructions[0]->channel_id()); instruction->set_channel_id(instructions[0]->channel_id());
bool eq_instructions = instructions[0]->Identical( bool eq_instructions = instructions[0]->Identical(
*instruction, eq_operand, eq_computations, is_layout_senstive); *instruction, eq_operand, eq_computations, is_layout_sensitive);
instruction->set_channel_id(old_channel_id); instruction->set_channel_id(old_channel_id);
return eq_instructions; return eq_instructions;
}); });
@ -243,7 +244,7 @@ bool InstructionWithinBranchIdentical(
[&](HloInstruction* instruction) { [&](HloInstruction* instruction) {
return instructions[0]->Identical( return instructions[0]->Identical(
*instruction, eq_operand, eq_computations, *instruction, eq_operand, eq_computations,
is_layout_senstive); is_layout_sensitive);
}); });
} }
@ -354,12 +355,228 @@ Status RemoveInstructionFromComputation(
return Status::OK(); return Status::OK();
} }
// Identify converts to be hoisted/rematerialized out of the branch
// computations.
absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
int branch_count,
HloInstruction* conditional,
bool is_layout_sensitive) {
absl::flat_hash_set<int64> kspecial_convert;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
continue;
}
bool replica = true;
HloInstruction* kspecial_convert_candidate =
old_root->mutable_operand(operand_num);
// Check whether an identical candidate appears in other branches
for (int others = 1; others < branch_count; ++others) {
HloInstruction* others_root =
conditional->branch_computation(others)->root_instruction();
bool eq_shape =
is_layout_sensitive
? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape())
: ShapeUtil::Compatible(
others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape());
if ((others_root->operand(operand_num)->opcode() ==
HloOpcode::kConvert) &&
eq_shape) {
// Nothing to be done.
} else {
replica = false;
break;
}
}
if (replica) {
kspecial_convert.insert(operand_num);
}
}
return kspecial_convert;
}
// Restructuring the conditional instruction as follows:
// i.e., %result = conditional() becomes
// x = conditional()
// y.{0..n} = gte(x, {0..n})
// z = tuple(y.0, y.1, ...y.n)
// Doing so ensures that we can accommodate the possible shape-change of the
// conditional when the instructions are hoisted.
Status RestructureConditionalInstruction(HloComputation* computation,
HloInstruction* conditional) {
HloInstruction* old_root = computation->root_instruction();
std::vector<HloInstruction*> new_operands;
int cur_index = 0;
for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
++cur_index) {
new_operands.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
conditional, cur_index)));
}
HloInstruction* new_tuple =
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
if (old_root == conditional) {
computation->set_root_instruction(new_tuple);
} else {
std::vector<HloInstruction*> new_tuple_users;
for (auto conditional_user : conditional->users()) {
auto is_new_gte = absl::c_find_if(
new_operands,
[&](HloInstruction* instr) { return instr == conditional_user; });
if (is_new_gte == new_operands.end()) {
new_tuple_users.push_back(conditional_user);
}
}
for (auto new_tuple_user : new_tuple_users) {
TF_RETURN_IF_ERROR(
conditional->ReplaceUseWith(new_tuple_user, new_tuple));
}
}
VLOG(2) << "computation after root restructure:\n" << computation->ToString();
return Status::OK();
}
StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
bool is_layout_sensitive) {
int branch_count = conditional->branch_count();
if (branch_count <= 0) {
return false;
}
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
if (old_root->opcode() != HloOpcode::kTuple) {
return false;
} else {
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
// Identify the gte using `index'.
auto find_gte = [](const HloInstruction* conditional_result,
int64 index) -> HloInstruction* {
for (HloInstruction* instr : conditional_result->users()) {
if (instr->opcode() != HloOpcode::kGetTupleElement) {
return nullptr;
}
if (instr->tuple_index() == index) {
return instr;
}
}
return nullptr;
};
// Captures tuple indices refering to converts to be rematerialized/hoisted.
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
old_root, branch_count, conditional, is_layout_sensitive);
// Exit if we cannot find any converts to be hoisted.
if (kspecial_convert.empty()) {
return false;
}
TF_RETURN_IF_ERROR(
RestructureConditionalInstruction(conditional->parent(), conditional));
for (int branch = 0; branch < branch_count; branch++) {
old_root = conditional->branch_computation(branch)->root_instruction();
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
std::vector<HloInstruction*> new_operands(old_root->operand_count());
std::unordered_set<HloInstruction*> to_hoist_set;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
operand_num;
}
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
HloInstruction* hoist = old_root->mutable_operand(operand_num);
if (!kspecial_convert.contains(operand_num)) {
new_operands[operand_num] = old_root->mutable_operand(operand_num);
continue;
}
to_hoist_set.insert(hoist);
int64 new_tuple_count = old_root->operand_count();
// Replace the hoisted instr in the tuple with the operand/operands.
// We will replace at least one of the operands of the hoist at the
// tuple place; the rest will be added at the end.
bool inplace = true;
CHECK(!hoist->operands().empty());
for (HloInstruction* prod : hoist->operands()) {
if (inplace) {
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
new_operands[map_inst_to_tuple_index[hoist]] = prod;
inplace = false;
} else {
map_inst_to_tuple_index[prod] = new_tuple_count++;
new_operands.push_back(prod);
}
}
}
// Create the new root instruction.
HloComputation* cur_branch = conditional->branch_computation(branch);
HloInstruction* new_branch_root =
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
// The shape can vary since the operands to convert are now
// being returned through the branches' root.
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
// Only one of the branches needs to change the conditional->parent().
if (branch != 0) {
continue;
}
HloComputation* conditional_parent = conditional->parent();
HloInstruction* newconditional =
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
cur_branch->root_instruction()->shape(),
conditional->mutable_operand(0),
absl::MakeSpan(conditional->branch_computations()),
absl::MakeSpan(conditional->operands()).subspan(1)));
// Ensure that all the users of conditional refer to the new one.
TF_RETURN_IF_ERROR(
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
conditional = newconditional;
// Add the hoisted instructions in the parent.
for (HloInstruction* hoist : to_hoist_set) {
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
int64 hoist_index = map_inst_to_tuple_index[hoist];
// Find out the gte that captured the hoisted instr result.
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
CHECK(gte_hoist != nullptr);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* op : hoist->operands()) {
HloInstruction* gte = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(
op->shape(), conditional, map_inst_to_tuple_index[op]));
new_operands.push_back(gte);
}
HloInstruction* hoisted = conditional_parent->AddInstruction(
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
}
// No need to explicitly delete a hoisted instruction since if its dead
// then the subsequent DCE will remove it.
}
}
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
return true;
}
// Hoist identical ops out of the conditional. The definition of identical // Hoist identical ops out of the conditional. The definition of identical
// are the shape of the operands are identical and their properties are // are the shape of the operands are identical and their properties are
// identical. Will start from the root instruction of each branch and get // identical. Will start from the root instruction of each branch and get
// the identical ops to hoist. // the identical ops to hoist.
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional, StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
bool is_layout_sensitive) { bool is_layout_sensitive) {
VLOG(1) << " visiting conditional:" << conditional->ToString();
int branch_count = conditional->branch_count(); int branch_count = conditional->branch_count();
if (branch_count <= 0) { if (branch_count <= 0) {
return false; return false;
@ -399,7 +616,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
} }
} }
if (visitors[0].HoistInstructionSize() <= 1) { if (visitors[0].HoistInstructionSize() < 1) {
return false; return false;
} }
@ -442,7 +659,6 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(), RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
conditional->branch_computation(i))); conditional->branch_computation(i)));
} }
return true; return true;
} }
@ -451,9 +667,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) { StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
bool changed = false; bool changed = false;
// Gather all the conditional ops in our module. We do this ahead of time so if (pursue_full_conditional_code_motion_) {
// we don't have to worry about mutating the lists of computations or
// instructions as we iterate.
std::vector<HloInstruction*> conditional_ops; std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) { for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) { for (auto* instr : comp->MakeInstructionPostOrder()) {
@ -464,13 +678,44 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
} }
for (HloInstruction* conditional_op : conditional_ops) { for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements( TF_ASSIGN_OR_RETURN(
conditional_op, is_layout_sensitive_)); bool result,
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
changed |= result; changed |= result;
} }
if (changed) { if (changed) {
HloPassPipeline subpipeline("after_conditional_code_motion"); HloPassPipeline subpipeline("after_conditional_code_motion");
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
changed |= cleanup_changed;
}
}
// handling convert rematerialization/hoisting
{
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
}
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(
bool convert_result,
ConvertSpecialMove(conditional_op, is_layout_sensitive_));
changed |= convert_result;
}
}
if (changed) {
HloPassPipeline subpipeline(
"after_conditional_code_motion_after_convert_hoisting");
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>(); subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>(); subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));

View File

@ -23,7 +23,11 @@ limitations under the License.
namespace xla { namespace xla {
// HLO pass that moves identical ops out of conditional. // ConditionalCodeMotion specializes in hoisting/rematerializing
// unconditional converts in the default mode.
// When pursue_full_conditional_code_motion_ is set to true, the
// full HLO pass moves identical ops out of a conditional in addition to moving
// converts.
// - The definition of identical are the shape of the operands are identical // - The definition of identical are the shape of the operands are identical
// and their properties are identical. // and their properties are identical.
// - Currently, only some types of instructions is supported. // - Currently, only some types of instructions is supported.
@ -35,13 +39,18 @@ class ConditionalCodeMotion : public HloModulePass {
public: public:
// If is_layout_sensitive is true, then the hoist process preserves layout // If is_layout_sensitive is true, then the hoist process preserves layout
// during identical comparison. Otherwise, layout is ignored. // during identical comparison. Otherwise, layout is ignored.
explicit ConditionalCodeMotion(bool is_layout_sensitive = true) explicit ConditionalCodeMotion(
: is_layout_sensitive_(is_layout_sensitive) {} bool is_layout_sensitive = true,
bool pursue_full_conditional_code_motion = false)
: is_layout_sensitive_(is_layout_sensitive),
pursue_full_conditional_code_motion_(
pursue_full_conditional_code_motion) {}
absl::string_view name() const override { return "conditional-code-motion"; } absl::string_view name() const override { return "conditional-code-motion"; }
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
private: private:
const bool is_layout_sensitive_; const bool is_layout_sensitive_;
const bool pursue_full_conditional_code_motion_;
}; };
} // namespace xla } // namespace xla

View File

@ -38,7 +38,86 @@ namespace {
using ConditionalCodeMotionTest = HloTestBase; using ConditionalCodeMotionTest = HloTestBase;
namespace op = xla::testing::opcode_matchers; namespace op = xla::testing::opcode_matchers;
TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) { TEST_F(ConditionalCodeMotionTest, MoveSubsetTupleOut) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.2894, %reshape.8493)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%add = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.3604, %add)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
get-first-index.2 = f32[2,512,364]{2,1,0} get-tuple-element(conditional), index=1
ROOT result = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(get-first-index, get-first-index.2)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
ROOT conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
absl::string_view hlo_string = absl::string_view hlo_string =
R"( R"(
HloModule RemoveDotOpOut HloModule RemoveDotOpOut
@ -65,12 +144,16 @@ ENTRY main {
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index) add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
} }
)"; )";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass; ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
} }
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) { TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
@ -123,7 +206,7 @@ ENTRY main {
} }
)"; )";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass; ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional = const HloInstruction* conditional =
@ -181,7 +264,7 @@ ENTRY main {
} }
)"; )";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass; ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional = const HloInstruction* conditional =
FindInstruction(module.get(), "conditional"); FindInstruction(module.get(), "conditional");
@ -245,7 +328,7 @@ ENTRY main {
} }
)"; )";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass; ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional = const HloInstruction* conditional =
@ -317,7 +400,7 @@ ENTRY main {
)"; )";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass; ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
} }
@ -390,7 +473,7 @@ ENTRY main {
} }
)"; )";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass; ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional = const HloInstruction* conditional =
FindInstruction(module.get(), "conditional"); FindInstruction(module.get(), "conditional");

View File

@ -226,6 +226,11 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
dims_to_keep.push_back(dim); dims_to_keep.push_back(dim);
} }
} }
// We support fast codegen for three cases:
// 1) Row reduction: (K, R)
// 2) Column reduction: (K, R, K)
// 3) "Batched" row reduction: (R, K, R)
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
dims_to_keep) && dims_to_keep) &&
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),

View File

@ -77,8 +77,6 @@ class KernelThunk : public Thunk {
// Will be set by IrEmitterUnnested. // Will be set by IrEmitterUnnested.
LaunchDimensions launch_dimensions_; LaunchDimensions launch_dimensions_;
// Describes how to load this kernel. ExecuteOnStream reuses this loader
// specification for all executions.
mutable tensorflow::mutex mutex_; mutable tensorflow::mutex mutex_;
// Loaded kernels for each `StreamExecutor`. Requires pointer stability of // Loaded kernels for each `StreamExecutor`. Requires pointer stability of

View File

@ -3908,6 +3908,10 @@ const string& HloInstruction::outfeed_config() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_config(); return Cast<HloOutfeedInstruction>(this)->outfeed_config();
} }
void HloInstruction::set_outfeed_config(const string& config) {
return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
}
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
return Cast<HloCollectiveInstruction>(this)->replica_groups(); return Cast<HloCollectiveInstruction>(this)->replica_groups();
} }

View File

@ -1755,6 +1755,9 @@ class HloInstruction {
// Returns the config for the Outfeed instruction. // Returns the config for the Outfeed instruction.
const string& outfeed_config() const; const string& outfeed_config() const;
// Delegates to HloOutfeedInstruction::set_outfeed_config.
void set_outfeed_config(const string& config);
// Returns the shape for the Outfeed instruction. // Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const; const Shape& outfeed_shape() const;

View File

@ -1141,6 +1141,7 @@ class HloOutfeedInstruction : public HloInstruction {
const Shape& outfeed_shape() const { return outfeed_shape_; } const Shape& outfeed_shape() const { return outfeed_shape_; }
// Returns the config for the Outfeed instruction. // Returns the config for the Outfeed instruction.
const string& outfeed_config() const { return outfeed_config_; } const string& outfeed_config() const { return outfeed_config_; }
void set_outfeed_config(const string& config) { outfeed_config_ = config; }
// Returns a serialized representation of this instruction. // Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override; HloInstructionProto ToProto() const override;

View File

@ -29,36 +29,78 @@ StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
// Propagate the operand subshapes. // Propagate the operand subshapes.
for (int operand_idx = 0; operand_idx < instruction->operand_count(); for (int operand_idx = 0; operand_idx < instruction->operand_count();
++operand_idx) { ++operand_idx) {
modified |= for (const ShapeUtil::IndexedShape& indexed_shape :
PropagateSubshapes(instruction->operand(operand_idx)->shape(), ShapeUtil::GetLeafShapes(
instruction->fused_parameter(operand_idx)); instruction->operand(operand_idx)->shape())) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
modified |= Propagate(indexed_shape.index,
instruction->fused_parameter(operand_idx),
memory_space);
}
} }
// Propagate output subshapes. // Propagate output subshapes.
modified |= PropagateSubshapes(instruction->shape(), for (const ShapeUtil::IndexedShape& indexed_shape :
instruction->fused_expression_root()); ShapeUtil::GetLeafShapes(instruction->shape())) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
modified |=
Propagate(indexed_shape.index,
instruction->fused_expression_root(), memory_space);
}
} }
} }
} }
return modified; return modified;
} }
bool MemorySpacePropagation::PropagateSubshapes( bool MemorySpacePropagation::Propagate(ShapeIndexView index,
const Shape& caller_shape, const HloInstruction* callee_instruction) const { const HloInstruction* callee_instruction,
int64 memory_space) const {
bool modified = false; bool modified = false;
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(caller_shape)) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
const HloValue& value = dataflow_analysis_->GetUniqueValueAt( const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
callee_instruction, indexed_shape.index); callee_instruction, index.ToShapeIndex());
for (const HloPosition& position : value.positions()) { for (const HloPosition& position : value.positions()) {
Shape* shape = ShapeUtil::GetMutableSubshape( HloInstruction* instruction = position.instruction;
position.instruction->mutable_shape(), position.index); Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(),
if (shape->layout().memory_space() != memory_space) { position.index);
if (shape->layout().memory_space() == memory_space) {
continue;
}
shape->mutable_layout()->set_memory_space(memory_space); shape->mutable_layout()->set_memory_space(memory_space);
modified = true; modified = true;
// For fusion outputs, propagate the memory space to the fusion root.
if (instruction->opcode() == HloOpcode::kFusion) {
Propagate(position.index, instruction->fused_expression_root(),
memory_space);
} }
const HloInstruction* parent_fusion =
instruction->parent()->FusionInstruction();
// For nested fusion roots, pop one level up and propagate the memory space
// to the output of the calling fusion instruction.
if (instruction == instruction->parent()->root_instruction() &&
parent_fusion->parent()->IsFusionComputation()) {
Propagate(position.index, parent_fusion, memory_space);
}
// For nested fusion parameters, pop one level up and propagate the memory
// space to the operand of the calling fusion instruction.
if (instruction->opcode() == HloOpcode::kParameter &&
parent_fusion->parent()->IsFusionComputation()) {
const HloInstruction* fusion_operand =
parent_fusion->operand(instruction->parameter_number());
Propagate(position.index, fusion_operand, memory_space);
}
}
for (const HloUse& use : value.uses()) {
// For fusion uses, propagate the memory space to the fusion parameter.
if (use.instruction->opcode() == HloOpcode::kFusion) {
modified |= Propagate(
use.operand_index,
use.instruction->fused_parameter(use.operand_number), memory_space);
} }
} }
return modified; return modified;

View File

@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
private: private:
// Given the caller shape (operand or output) and its corresponding // Given the shape index (operand or output) and its corresponding instruction
// insturction in the fused computation (parameter or root), propagates the // in the fused computation (parameter or root), propagates the memory space
// memory space to all the subshapes in the callee side. Returns true if the // in the callee side. Returns true if the module is modified.
// module is modified. bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction,
bool PropagateSubshapes(const Shape& caller_shape, int64 memory_space) const;
const HloInstruction* callee_instruction) const;
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_; std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
}; };

View File

@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) {
EXPECT_EQ(module->Hash(), ref->Hash()); EXPECT_EQ(module->Hash(), ref->Hash());
} }
TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
// Tests propagating the memory space to nested fusions on the input side.
absl::string_view hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[3,2]{0,1:T(128)} parameter(0)
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[3,2]{0,1:T(128)} parameter(0)
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
}
ENTRY %entry {
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
}
ENTRY %entry {
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
// Tests propagating the memory space to nested fusions on the output side.
absl::string_view hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[6]{0:T(128)} parameter(0)
ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -552,7 +552,7 @@ class LowerToNVVMPass
// TODO(csigg): Remove once we support replacing non-root ops. // TODO(csigg): Remove once we support replacing non-root ops.
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp, target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
::mlir::gpu::YieldOp>(); ::mlir::gpu::YieldOp>();
if (failed(mlir::applyFullConversion(m, target, patterns, &converter))) { if (failed(mlir::applyFullConversion(m, target, patterns))) {
signalPassFailure(); signalPassFailure();
} }
} }

View File

@ -52,16 +52,26 @@ cc_library(
name = "test_macros_header", name = "test_macros_header",
testonly = True, testonly = True,
hdrs = ["test_macros.h"], hdrs = ["test_macros.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
],
) )
# Generate a test_macros_${BACKEND} library per backend with the proper copts. # Generate a test_macros_${BACKEND} library per backend with the proper copts.
generate_backend_test_macros() generate_backend_test_macros()
cc_library(
name = "manifest_checking_test",
testonly = True,
srcs = ["manifest_checking_test.cc"],
hdrs = ["manifest_checking_test.h"],
deps = [
":test_macros_header",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "test_utils", name = "test_utils",
srcs = ["test_utils.cc"], srcs = ["test_utils.cc"],
@ -136,6 +146,7 @@ cc_library(
hdrs = ["hlo_test_base.h"], hdrs = ["hlo_test_base.h"],
deps = [ deps = [
":literal_test_util", ":literal_test_util",
":manifest_checking_test",
":test_utils", ":test_utils",
":verified_hlo_module", ":verified_hlo_module",
"//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:debug_options_flags",
@ -193,6 +204,7 @@ cc_library(
srcs = ["client_library_test_base.cc"], srcs = ["client_library_test_base.cc"],
hdrs = ["client_library_test_base.h"], hdrs = ["client_library_test_base.h"],
deps = [ deps = [
":manifest_checking_test",
"//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:array4d",
@ -273,6 +285,7 @@ cc_library(
hdrs = ["local_client_test_base.h"], hdrs = ["local_client_test_base.h"],
deps = [ deps = [
":client_library_test_base", ":client_library_test_base",
":manifest_checking_test",
":verified_hlo_module", ":verified_hlo_module",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",

View File

@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []):
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
], ],
deps = [ deps = [
"@com_google_absl//absl/container:flat_hash_map", "//tensorflow/core/platform:logging",
"@com_google_absl//absl/strings",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
], ],
) )

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/bitmap.h"
@ -62,7 +63,7 @@ std::vector<TestCase> ExpandUseBfloat16(
} }
// A client library test establishes an in-process XLA client connection. // A client library test establishes an in-process XLA client connection.
class ClientLibraryTestBase : public ::testing::Test { class ClientLibraryTestBase : public ManifestCheckingTest {
protected: protected:
explicit ClientLibraryTestBase(se::Platform* platform = nullptr); explicit ClientLibraryTestBase(se::Platform* platform = nullptr);

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -67,7 +68,7 @@ namespace xla {
// ) // )
// //
// For a more detailed example, see "../tests/sample_text_test.cc". // For a more detailed example, see "../tests/sample_text_test.cc".
class HloTestBase : public ::testing::Test { class HloTestBase : public ManifestCheckingTest {
public: public:
// Creates a new HLO module for a test. The module created will have // Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug // TestName() for its name; it will also automatically populate its debug

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
}; };
// A base class for tests which exercise the LocalClient interface. // A base class for tests which exercise the LocalClient interface.
class LocalClientTestBase : public ::testing::Test { class LocalClientTestBase : public ManifestCheckingTest {
protected: protected:
struct EigenThreadPoolWrapper; struct EigenThreadPoolWrapper;
explicit LocalClientTestBase(se::Platform* platform = nullptr); explicit LocalClientTestBase(se::Platform* platform = nullptr);

View File

@ -0,0 +1,129 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include <fstream>
#include <iterator>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
// disabled - a sequence of regexps.
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
ManifestT ReadManifest() {
ManifestT manifest;
absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath);
if (path.empty()) {
return manifest;
}
// Note: parens are required to disambiguate vs function decl.
std::ifstream file_stream((std::string(path)));
std::string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
for (std::string& line : lines) {
auto comment = line.find("//");
if (comment != std::string::npos) {
line = line.substr(0, comment);
}
if (line.empty()) {
continue;
}
absl::StripTrailingAsciiWhitespace(&line);
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (size_t i = 1; i < pieces.size(); ++i) {
platforms.push_back(pieces[i]);
}
}
return manifest;
}
} // namespace
void ManifestCheckingTest::SetUp() {
const testing::TestInfo* test_info =
testing::UnitTest::GetInstance()->current_test_info();
absl::string_view test_case_name = test_info->test_suite_name();
absl::string_view test_name = test_info->name();
VLOG(1) << "test_case_name: " << test_case_name;
VLOG(1) << "test_name: " << test_name;
// Remove the type suffix from the test case name.
if (const char* type_param = test_info->type_param()) {
VLOG(1) << "type_param: " << type_param;
size_t last_slash = test_case_name.rfind('/');
test_case_name = test_case_name.substr(0, last_slash);
VLOG(1) << "test_case_name: " << test_case_name;
}
// Remove the test instantiation name if it is present.
auto first_slash = test_case_name.find('/');
if (first_slash != test_case_name.npos) {
test_case_name.remove_prefix(first_slash + 1);
VLOG(1) << "test_case_name: " << test_case_name;
}
ManifestT manifest = ReadManifest();
// If the test name ends with a slash followed by one or more characters,
// strip that off.
auto last_slash = test_name.rfind('/');
if (last_slash != test_name.npos) {
test_name = test_name.substr(0, last_slash);
VLOG(1) << "test_name: " << test_name;
}
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
return;
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<std::string>& disabled_platforms = it->second;
auto platform_string = kTestPlatform;
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
GTEST_SKIP();
return;
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.
}
} // namespace xla

View File

@ -0,0 +1,35 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
#include "tensorflow/core/platform/test.h"
namespace xla {
// This class allows us to intercept the test name and use an arbitrary
// heuristic to decide whether the test case should be disabled. We
// determine whether the test case should be disabled by resolving the (test
// case name, test name) in a manifest file.
class ManifestCheckingTest : public ::testing::Test {
protected:
// This method runs before each test runs.
void SetUp() override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_

View File

@ -15,93 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_macros.h"
#include <fstream>
#include <streambuf>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla { namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is static bool InitModule() {
// disabled - a sequence of regexps. kDisabledManifestPath = XLA_DISABLED_MANIFEST;
using ManifestT = absl::flat_hash_map<string, std::vector<string>>; VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath;
kTestPlatform = XLA_PLATFORM;
ManifestT ReadManifest() { VLOG(1) << "kTestPlatform: " << kTestPlatform;
ManifestT manifest; return false;
string path = XLA_DISABLED_MANIFEST;
if (path.empty()) {
return manifest;
}
std::ifstream file_stream(path);
// Note: parens are required to disambiguate vs function decl.
string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
std::vector<string> lines = absl::StrSplit(contents, '\n');
for (string& line : lines) {
auto comment = line.find("//");
if (comment != string::npos) {
line = line.substr(0, comment);
}
if (line.empty()) {
continue;
}
absl::StripTrailingAsciiWhitespace(&line);
std::vector<string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (int64 i = 1; i < pieces.size(); ++i) {
platforms.push_back(pieces[i]);
}
}
return manifest;
} }
} // namespace static bool module_initialized = InitModule();
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
absl::string_view test_name) {
ManifestT manifest = ReadManifest();
// If the test name ends with a slash followed by one or more digits, strip
// that off; this is just a shard number, and matching on this would be
// unstable even if someone wanted to do it.
static LazyRE2 shard_num_pattern = {R"(/\d+$)"};
absl::string_view suffix;
if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) {
test_name.remove_suffix(suffix.size());
}
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
return std::string(test_name);
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<string>& disabled_platforms = it->second;
string platform_string = XLA_PLATFORM;
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
return absl::StrCat("DISABLED_", test_name);
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.
return std::string(test_name);
}
} // namespace xla } // namespace xla

View File

@ -28,12 +28,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/test.h"
#define DISABLED_ON_CPU(X) X #define DISABLED_ON_CPU(X) X
#define DISABLED_ON_GPU(X) X #define DISABLED_ON_GPU(X) X
#define DISABLED_ON_GPU_ROCM(X) X #define DISABLED_ON_GPU_ROCM(X) X
@ -79,117 +73,15 @@ limitations under the License.
namespace xla { namespace xla {
// Reads a disabled manifest file to resolve whether test cases should be inline const char *kDisabledManifestPath = nullptr;
// disabled on a particular platform. For a test that should be disabled, inline const char *kTestPlatform = nullptr;
// returns DISABLED_ prepended to its name; otherwise returns the test name
// unmodified.
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
absl::string_view test_name);
} // namespace xla } // namespace xla
// This is the internal "gtest" class instantiation -- it is identical to the #define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
// GTEST_TEST_ macro, except that we intercept the test name for potential
// modification by PrependDisabledIfIndicated. That file can use an arbitrary
// heuristic to decide whether the test case should be disabled, and we
// determine whether the test case should be disabled by resolving the (test
// case name, test name) in a manifest file.
#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
: public parent_class { \
public: \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
\
private: \
virtual void TestBody(); \
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)); \
}; \
\
::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)::test_info_ = \
::testing::RegisterTest( \
#test_case_name, \
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
.c_str(), \
nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \
return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \
}); \
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
// This is identical to the TEST_F macro from "gtest", but it potentially #define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
// disables the test based on an external manifest file, DISABLED_MANIFEST.
//
// Per usual, you can see what tests are available via --gunit_list_tests and
// choose to run tests that have been disabled via the manifest via
// --gunit_also_run_disabled_tests.
#define XLA_TEST_F(test_fixture, test_name) \
XLA_GTEST_TEST_(test_fixture, test_name, test_fixture)
// Likewise, this is identical to the TEST_P macro from "gtest", but #define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName)
// potentially disables the test based on the DISABLED_MANIFEST file.
//
// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
// be properly expanded before the stringification occurs.
#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
: public test_case_name { \
public: \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
virtual void TestBody(); \
\
private: \
static int AddToRegistry() { \
::testing::UnitTest::GetInstance() \
->parameterized_test_registry() \
.GetTestCasePatternHolder<test_case_name>( \
#test_case_name, \
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestPattern( \
#test_case_name, \
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
.c_str(), \
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
test_case_name, test_name)>()); \
return 0; \
} \
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)); \
}; \
int GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)::gtest_registering_dummy_ = \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
#define XLA_TEST_P(test_case_name, test_name) \
XLA_TEST_P_IMPL_(test_case_name, test_name)
// This is identical to the TEST_F macro from "gtest", but it potentially
// disables the test based on an external manifest file, DISABLED_MANIFEST.
#define XLA_TYPED_TEST(CaseName, TestName) \
template <typename gtest_TypeParam_> \
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
: public CaseName<gtest_TypeParam_> { \
private: \
typedef CaseName<gtest_TypeParam_> TestFixture; \
typedef gtest_TypeParam_ TypeParam; \
virtual void TestBody(); \
}; \
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
::testing::internal::TypeParameterizedTest< \
CaseName, \
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
TestName)>, \
GTEST_TYPE_PARAMS_(CaseName)>:: \
Register( \
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
#CaseName, \
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
0); \
template <typename gtest_TypeParam_> \
void GTEST_TEST_CLASS_NAME_(CaseName, \
TestName)<gtest_TypeParam_>::TestBody()
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_

View File

@ -719,6 +719,7 @@ tf_cuda_library(
visibility = [ visibility = [
"//tensorflow/core:__pkg__", "//tensorflow/core:__pkg__",
"//tensorflow/core/util:__pkg__", "//tensorflow/core/util:__pkg__",
"//tensorflow/security/fuzzing:__subpackages__",
], ],
deps = [ deps = [
":allocation_description_proto_cc", ":allocation_description_proto_cc",

View File

@ -153,16 +153,9 @@ limitations under the License.
#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines #endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
// Defines for sets of types. // Defines for sets of types.
// TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES.
//
// The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and
// thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in
// TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the
// TF binary size and performance.
#define TF_CALL_INTEGRAL_TYPES(m) \ #define TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \
TF_CALL_uint8(m) TF_CALL_int8(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_FLOAT_TYPES(m) \ #define TF_CALL_FLOAT_TYPES(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
@ -176,8 +169,8 @@ limitations under the License.
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ #define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \ TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
TF_CALL_int8(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) #define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)

View File

@ -238,11 +238,6 @@ int DataTypeSize(DataType dt) {
TF_CALL_qint16(CASE); TF_CALL_qint16(CASE);
TF_CALL_quint16(CASE); TF_CALL_quint16(CASE);
// uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
// don't want to define kernels for them at this stage to avoid binary
// bloat.
TF_CALL_uint32(CASE);
TF_CALL_uint64(CASE);
default: default:
return 0; return 0;
} }

View File

@ -837,6 +837,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle", "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson", "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
"RandomPoissonV2", "RandomPoissonV2",
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE, // ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
// but it can't generate any observable side-effect. // but it can't generate any observable side-effect.
@ -850,12 +851,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
// the same device_ordinal on the same host. // the same device_ordinal on the same host.
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch", "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
"EnqueueTPUEmbeddingSparseTensorBatch", "EnqueueTPUEmbeddingSparseTensorBatch",
"EnqueueTPUEmbeddingRaggedTensorBatch", "EnqueueTPUEmbeddingRaggedTensorBatch"});
// SaveV2 and RestoreV2 should be allowed to operate in parallel on
// multiple hosts.
"SaveV2", "RestoreV2"});
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
return exemption->contains(op); return exemption->contains(op);
} }

View File

@ -4168,6 +4168,25 @@ tf_kernel_library(
]), ]),
) )
tf_cuda_cc_test(
name = "mlir_generated_op_gpu_tanh_test",
size = "small",
srcs = if_mlir_generated_gpu_kernels_enabled(["mlir_generated_op_gpu_tanh_test.cc"]),
tags = tf_cuda_tests_tags() + ["no_rocm"],
deps = [
":cwise_op",
":ops_testutil",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/common_runtime:device",
"//tensorflow/core/common_runtime:device_factory",
],
)
tf_kernel_library( tf_kernel_library(
name = "nextafter_op", name = "nextafter_op",
prefix = "nextafter_op", prefix = "nextafter_op",
@ -4900,7 +4919,9 @@ tf_kernel_library(
"topk_op_gpu_double.cu.cc", "topk_op_gpu_double.cu.cc",
"topk_op_gpu_float.cu.cc", "topk_op_gpu_float.cu.cc",
"topk_op_gpu_half.cu.cc", "topk_op_gpu_half.cu.cc",
"topk_op_gpu_uint64.cu.cc",
"topk_op_gpu_int64.cu.cc", "topk_op_gpu_int64.cu.cc",
"topk_op_gpu_uint32.cu.cc",
"topk_op_gpu_int32.cu.cc", "topk_op_gpu_int32.cu.cc",
"topk_op_gpu_int16.cu.cc", "topk_op_gpu_int16.cu.cc",
"topk_op_gpu_uint16.cu.cc", "topk_op_gpu_uint16.cu.cc",
@ -6802,7 +6823,8 @@ filegroup(
"cwise_op_minimum.cc", "cwise_op_minimum.cc",
"cwise_op_mul_1.cc", "cwise_op_mul_1.cc",
"cwise_op_mul_2.cc", "cwise_op_mul_2.cc",
"cwise_op_neg.cc", "cwise_op_neg_1.cc",
"cwise_op_neg_2.cc",
"cwise_op_pow.cc", "cwise_op_pow.cc",
"cwise_op_real.cc", "cwise_op_real.cc",
"cwise_op_reciprocal.cc", "cwise_op_reciprocal.cc",
@ -8780,7 +8802,8 @@ exports_files([
"cwise_op_mod.cc", "cwise_op_mod.cc",
"cwise_op_mul_1.cc", "cwise_op_mul_1.cc",
"cwise_op_mul_2.cc", "cwise_op_mul_2.cc",
"cwise_op_neg.cc", "cwise_op_neg_1.cc",
"cwise_op_neg_2.cc",
"cwise_op_not_equal_to_1.cc", "cwise_op_not_equal_to_1.cc",
"cwise_op_not_equal_to_2.cc", "cwise_op_not_equal_to_2.cc",
"cwise_op_round.cc", "cwise_op_round.cc",

View File

@ -116,8 +116,6 @@ REGISTER(qint8)
REGISTER(quint16) REGISTER(quint16)
REGISTER(qint16) REGISTER(qint16)
REGISTER(qint32) REGISTER(qint32)
REGISTER(uint32)
REGISTER(uint64)
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \ #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
!defined(__ANDROID_TYPES_FULL__) !defined(__ANDROID_TYPES_FULL__)

View File

@ -208,8 +208,6 @@ REGISTER_CONCAT(qint8);
REGISTER_CONCAT(quint16); REGISTER_CONCAT(quint16);
REGISTER_CONCAT(qint16); REGISTER_CONCAT(qint16);
REGISTER_CONCAT(qint32); REGISTER_CONCAT(qint32);
REGISTER_CONCAT(uint32);
REGISTER_CONCAT(uint64);
#undef REGISTER_CONCAT #undef REGISTER_CONCAT

View File

@ -211,7 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
// the conversion from uint8 to quint8. // the conversion from uint8 to quint8.
REGISTER_KERNEL(CPU, quint8); REGISTER_KERNEL(CPU, quint8);
REGISTER_KERNEL(CPU, quint16); REGISTER_KERNEL(CPU, quint16);
REGISTER_KERNEL(CPU, uint32);
#undef REGISTER_CPU_KERNEL #undef REGISTER_CPU_KERNEL
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL

View File

@ -101,27 +101,21 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH); TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
REGISTER_CPU_SWITCH(uint64);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH); TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_SWITCH(uint64);
TF_CALL_variant(REGISTER_GPU_SWITCH); TF_CALL_variant(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
TF_CALL_bool(REGISTER_GPU_SWITCH);
TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
#undef REGISTER_CPU_SWITCH #undef REGISTER_CPU_SWITCH
#undef REGISTER_CPU_REF_SWITCH #undef REGISTER_CPU_REF_SWITCH
#undef REGISTER_GPU_SWITCH #undef REGISTER_GPU_SWITCH
#undef REGISTER_GPU_REF_SWITCH #undef REGISTER_GPU_REF_SWITCH
// Special GPU kernels for int32, string & resource handles. Requiring all // Special GPU kernels for int32 and string.
// inputs and outputs to be in host memory. // TODO(b/25387198): Also enable int32 in device memory. This kernel
// TODO(b/25387198): Also enable int32 in device memory. // registration requires all int32 inputs and outputs to be in host memory.
#define REGISTER_GPU_HOST_KERNEL(type) \ #define REGISTER_GPU_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \ REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_GPU) \ .Device(DEVICE_GPU) \
@ -151,6 +145,8 @@ TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_REF_KERNEL(int32); REGISTER_GPU_HOST_REF_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(bool);
REGISTER_GPU_HOST_REF_KERNEL(bool);
REGISTER_GPU_HOST_KERNEL(tstring); REGISTER_GPU_HOST_KERNEL(tstring);
REGISTER_GPU_HOST_REF_KERNEL(tstring); REGISTER_GPU_HOST_REF_KERNEL(tstring);
REGISTER_GPU_HOST_KERNEL(ResourceHandle); REGISTER_GPU_HOST_KERNEL(ResourceHandle);
@ -311,7 +307,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
REGISTER_GPU_KERNEL(bool); REGISTER_GPU_KERNEL(bool);
REGISTER_GPU_REF_KERNEL(bool); REGISTER_GPU_REF_KERNEL(bool);
REGISTER_GPU_KERNEL(uint64);
TF_CALL_variant(REGISTER_GPU_KERNEL); TF_CALL_variant(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL #undef REGISTER_GPU_KERNEL

View File

@ -19,8 +19,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace functor { namespace functor {
DEFINE_UNARY7(neg, Eigen::half, float, double, int32, int64, complex64, DEFINE_UNARY4(neg, int8, int16, int32, int64);
complex128); DEFINE_UNARY6(neg, Eigen::half, float, double, bfloat16, complex64, complex128);
} // namespace functor } // namespace functor
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,8 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h" #include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow { namespace tensorflow {
REGISTER8(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32, REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
complex64, int64, complex128, bfloat16);
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64); REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
@ -30,8 +29,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg")
#endif // TENSORFLOW_USE_SYCL #endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64, REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
complex64, complex128);
// A special GPU kernel for int32. // A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel // TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER6(UnaryOp, CPU, "Neg", functor::neg, Eigen::half, float, double,
bfloat16, complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
bfloat16, complex64, complex128);
#endif
} // namespace tensorflow

View File

@ -220,8 +220,6 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
break; break;
TF_CALL_NUMBER_TYPES(CASE); TF_CALL_NUMBER_TYPES(CASE);
TF_CALL_tstring(CASE); TF_CALL_tstring(CASE);
TF_CALL_uint32(CASE);
TF_CALL_uint64(CASE);
// TODO(feihugis): figure out how to support variant tensors. // TODO(feihugis): figure out how to support variant tensors.
#undef CASE #undef CASE
default: default:

View File

@ -98,7 +98,6 @@ typedef Eigen::SyclDevice SYCLDevice;
TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_ALL_TYPES(REGISTER_KERNELS);
// uint32 not included in ALL_TYPES // uint32 not included in ALL_TYPES
TF_CALL_uint32(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
// quint16 not included in QUANTIZIED_TYPES // quint16 not included in QUANTIZIED_TYPES
TF_CALL_quint16(REGISTER_KERNELS); TF_CALL_quint16(REGISTER_KERNELS);

View File

@ -164,8 +164,6 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared {
DynamicPartitionOp<T>) DynamicPartitionOp<T>)
TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION); TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
// For partitioning fingerprints.
TF_CALL_uint64(REGISTER_DYNAMIC_PARTITION);
#undef REGISTER_DYNAMIC_PARTITION #undef REGISTER_DYNAMIC_PARTITION
} // namespace tensorflow } // namespace tensorflow

View File

@ -45,6 +45,8 @@ DEFINE_SETZERO_CPU(Eigen::half);
DEFINE_SETZERO_CPU(bfloat16); DEFINE_SETZERO_CPU(bfloat16);
DEFINE_SETZERO_CPU(float); DEFINE_SETZERO_CPU(float);
DEFINE_SETZERO_CPU(double); DEFINE_SETZERO_CPU(double);
DEFINE_SETZERO_CPU(uint32);
DEFINE_SETZERO_CPU(uint64);
DEFINE_SETZERO_CPU(uint8); DEFINE_SETZERO_CPU(uint8);
DEFINE_SETZERO_CPU(int8); DEFINE_SETZERO_CPU(int8);
DEFINE_SETZERO_CPU(uint16); DEFINE_SETZERO_CPU(uint16);
@ -96,6 +98,8 @@ DEFINE_SETONE_CPU(Eigen::half);
DEFINE_SETONE_CPU(bfloat16); DEFINE_SETONE_CPU(bfloat16);
DEFINE_SETONE_CPU(float); DEFINE_SETONE_CPU(float);
DEFINE_SETONE_CPU(double); DEFINE_SETONE_CPU(double);
DEFINE_SETONE_CPU(uint32);
DEFINE_SETONE_CPU(uint64);
DEFINE_SETONE_CPU(uint8); DEFINE_SETONE_CPU(uint8);
DEFINE_SETONE_CPU(int8); DEFINE_SETONE_CPU(int8);
DEFINE_SETONE_CPU(uint16); DEFINE_SETONE_CPU(uint16);
@ -137,7 +141,6 @@ struct FillFunctor<Eigen::ThreadPoolDevice, T> {
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU); TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
DEFINE_FILL_CPU(quint8); DEFINE_FILL_CPU(quint8);
DEFINE_FILL_CPU(quint16); DEFINE_FILL_CPU(quint16);
DEFINE_FILL_CPU(uint32);
#undef DEFINE_FILL_CPU #undef DEFINE_FILL_CPU
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL

View File

@ -211,8 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
TF_CALL_quint16(REGISTER_GATHER_CPU); TF_CALL_quint16(REGISTER_GATHER_CPU);
TF_CALL_qint16(REGISTER_GATHER_CPU); TF_CALL_qint16(REGISTER_GATHER_CPU);
TF_CALL_uint32(REGISTER_GATHER_CPU);
TF_CALL_uint64(REGISTER_GATHER_CPU);
#undef REGISTER_GATHER_CPU #undef REGISTER_GATHER_CPU

View File

@ -122,7 +122,6 @@ REGISTER_SYCL_HOST_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(Variant); REGISTER_GPU_KERNEL(Variant);
TF_CALL_uint32(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL #undef REGISTER_GPU_KERNEL

View File

@ -178,6 +178,9 @@ class MklAddNOp : public OpKernel {
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format); dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
} }
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine));
// Create memory descriptor for MKL-DNN. // Create memory descriptor for MKL-DNN.
// If all input in Tensorflow format, create block memory descriptor, // If all input in Tensorflow format, create block memory descriptor,
// else convert TF format to MKL memory descriptor // else convert TF format to MKL memory descriptor
@ -215,6 +218,7 @@ class MklAddNOp : public OpKernel {
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine)); srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
#endif #endif
src.SetUsrMem(md, &src_tensor); src.SetUsrMem(md, &src_tensor);
src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
inputs.push_back(src.GetOpMem()); inputs.push_back(src.GetOpMem());
} }
@ -240,11 +244,10 @@ class MklAddNOp : public OpKernel {
} }
AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape, AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
output_mkl_shape); output_mkl_shape);
dst.SetUsrMemDataHandle(dst_tensor); dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
// Create Sum op, and submit net for execution. // Create Sum op, and submit net for execution.
std::vector<primitive> net; std::vector<primitive> net;
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
mkldnn::sum sum_op(sum_pd); mkldnn::sum sum_op(sum_pd);
std::unordered_map<int, memory> net_args = { std::unordered_map<int, memory> net_args = {

View File

@ -281,11 +281,19 @@ class MklConcatFwdPrimitive : public MklPrimitive {
std::shared_ptr<stream> fwd_stream) { std::shared_ptr<stream> fwd_stream) {
DCHECK_EQ(in_data.size(), context_.data_mem.size()); DCHECK_EQ(in_data.size(), context_.data_mem.size());
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.data_mem_shdptr[i]->set_data_handle(
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
}
context_.dst_mem->set_data_handle(
static_cast<void*>(dst_data.get_data_handle()), *fwd_stream);
#else
context_.data_mem_shdptr[i]->set_data_handle( context_.data_mem_shdptr[i]->set_data_handle(
static_cast<void*>(in_data[i].get_data_handle())); static_cast<void*>(in_data[i].get_data_handle()));
} }
context_.dst_mem->set_data_handle( context_.dst_mem->set_data_handle(
static_cast<void*>(dst_data.get_data_handle())); static_cast<void*>(dst_data.get_data_handle()));
#endif // ENABLE_MKLDNN_THREADPOOL
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
context_.data_mem[i] = *context_.data_mem_shdptr[i]; context_.data_mem[i] = *context_.data_mem_shdptr[i];
@ -788,11 +796,13 @@ class MklConcatOp : public OpKernel {
dnn_shape_dst); dnn_shape_dst);
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL"; DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
if (dnn_shape_dst.IsMklTensor()) if (dnn_shape_dst.IsMklTensor())
dst_md = dnn_shape_dst.GetMklLayout(); dst_md = dnn_shape_dst.GetMklLayout();
dst.SetUsrMem(dst_md, dst_tensor); dst.SetUsrMem(dst_md, dst_tensor);
std::shared_ptr<stream> fwd_cpu_stream; dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
auto concat_op = concat(concat_pd); auto concat_op = concat(concat_pd);
std::unordered_map<int, memory> net_args = { std::unordered_map<int, memory> net_args = {
@ -830,9 +840,10 @@ class MklConcatOp : public OpKernel {
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout() dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
: dst_md; : dst_md;
dst.SetUsrMem(dst_md, dst_tensor);
std::shared_ptr<stream> fwd_cpu_stream; std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine())); fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
dst.SetUsrMem(dst_md, dst_tensor);
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
// Execute concat // Execute concat
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims, concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
fwd_cpu_stream); fwd_cpu_stream);

View File

@ -75,6 +75,9 @@ class MklDequantizeOp : public OpKernel {
MklDnnData<T> src(&cpu_engine); MklDnnData<T> src(&cpu_engine);
MklDnnData<float> dst(&cpu_engine); MklDnnData<float> dst(&cpu_engine);
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine));
// If input is in MKL layout, then simply grab input layout; otherwise, // If input is in MKL layout, then simply grab input layout; otherwise,
// construct input TF layout. For TF layout, although input shape // construct input TF layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
@ -85,6 +88,7 @@ class MklDequantizeOp : public OpKernel {
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc); : memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
src.SetUsrMem(src_md, &src_tensor); src.SetUsrMem(src_md, &src_tensor);
src.SetUsrMemDataHandle(&src_tensor, reorder_stream);
Tensor* output_tensor = nullptr; Tensor* output_tensor = nullptr;
MklDnnShape output_mkl_shape; MklDnnShape output_mkl_shape;
@ -129,6 +133,7 @@ class MklDequantizeOp : public OpKernel {
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape, AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
output_mkl_shape); output_mkl_shape);
dst.SetUsrMem(dst_md, output_tensor); dst.SetUsrMem(dst_md, output_tensor);
dst.SetUsrMemDataHandle(output_tensor, reorder_stream);
// The quantization logic here for mode SCALED is similar to the logic // The quantization logic here for mode SCALED is similar to the logic
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3. // in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
@ -155,8 +160,6 @@ class MklDequantizeOp : public OpKernel {
// Also it does not define round_nearest (enum). // Also it does not define round_nearest (enum).
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
#endif // !ENABLE_MKLDNN_V1 #endif // !ENABLE_MKLDNN_V1
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine));
std::vector<primitive> net; std::vector<primitive> net;
// Create reorder primitive and then execute. // Create reorder primitive and then execute.

View File

@ -137,6 +137,7 @@ class MklLRNOp : public OpKernel {
// that input is in NHWC layout with Channel being the last dimension. // that input is in NHWC layout with Channel being the last dimension.
src_dnn_data.SetUsrMem(src_md, &src_tensor); src_dnn_data.SetUsrMem(src_md, &src_tensor);
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc); src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
src_dnn_data.SetUsrMemDataHandle(&src_tensor, fwd_stream_);
// dst_dnn_data has the same shape as input. // dst_dnn_data has the same shape as input.
dst_dnn_data.SetUsrMem(src_md); dst_dnn_data.SetUsrMem(src_md);
@ -157,7 +158,7 @@ class MklLRNOp : public OpKernel {
&output_tensor); &output_tensor);
OP_REQUIRES_OK(context, context->status()); OP_REQUIRES_OK(context, context->status());
DCHECK(output_tensor != nullptr); DCHECK(output_tensor != nullptr);
dst_dnn_data.SetUsrMemDataHandle(output_tensor); dst_dnn_data.SetUsrMemDataHandle(output_tensor, fwd_stream_);
// Handle workspace required for MKL-DNN. // Handle workspace required for MKL-DNN.
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data); AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
@ -393,6 +394,7 @@ class MklLRNGradOp : public OpKernel {
orig_input_dnn_shape.GetSizesAsMklDnnDims(); orig_input_dnn_shape.GetSizesAsMklDnnDims();
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor); orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc); orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
orig_input_dnn_data.SetUsrMemDataHandle(&orig_input_tensor, bwd_stream_);
// output_dnn_data has the same shape as original input // output_dnn_data has the same shape as original input
output_dnn_data.SetUsrMem(orig_input_md); output_dnn_data.SetUsrMem(orig_input_md);
@ -421,7 +423,7 @@ class MklLRNGradOp : public OpKernel {
orig_input_format, &output_tensor); orig_input_format, &output_tensor);
OP_REQUIRES_OK(context, context->status()); OP_REQUIRES_OK(context, context->status());
DCHECK(output_tensor != nullptr); DCHECK(output_tensor != nullptr);
output_dnn_data.SetUsrMemDataHandle(output_tensor); output_dnn_data.SetUsrMemDataHandle(output_tensor, bwd_stream_);
// Create LRN primitive and add it to the net // Create LRN primitive and add it to the net
// At this point, workspace is enabled, so we don't need // At this point, workspace is enabled, so we don't need

View File

@ -137,6 +137,7 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
memory::dims out_strides = memory::dims out_strides =
ReorderStrides(CalculateTFStrides(out_dims), perm); ReorderStrides(CalculateTFStrides(out_dims), perm);
std::shared_ptr<stream> transpose_stream;
in.SetUsrMem(in_dims, in_strides, &in_tensor); in.SetUsrMem(in_dims, in_strides, &in_tensor);
// Output dimensions are same as input dimensions. We adjust the layout // Output dimensions are same as input dimensions. We adjust the layout
// using strides. // using strides.
@ -144,16 +145,16 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
std::vector<primitive> net; std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
std::shared_ptr<stream> transpose_stream;
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()); auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
transpose_stream.reset(CreateStream(context, prim->GetEngine())); transpose_stream.reset(CreateStream(context, prim->GetEngine()));
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
out.SetUsrMemDataHandle(out_tensor, transpose_stream);
net.push_back(*(prim->GetPrimitive())); net.push_back(*(prim->GetPrimitive()));
std::vector<MemoryArgsMap> net_args; std::vector<MemoryArgsMap> net_args;
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()}, net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
{MKLDNN_ARG_TO, *out.GetUsrMem()}}); {MKLDNN_ARG_TO, *out.GetUsrMem()}});
execute_primitives(net, transpose_stream, net_args); execute_primitives(net, transpose_stream, net_args);
#else #else
std::shared_ptr<stream> transpose_stream;
transpose_stream.reset(new CPU_STREAM(cpu_engine)); transpose_stream.reset(new CPU_STREAM(cpu_engine));
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem())); net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
transpose_stream->submit(net).wait(); transpose_stream->submit(net).wait();

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class MlirGeneratedOpGpuTanhTest : public OpsTestBase {
protected:
void SetUp() override {
std::unique_ptr<tensorflow::Device> device_gpu(
tensorflow::DeviceFactory::NewDevice("GPU", {},
"/job:a/replica:0/task:0"));
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
}
template <typename T, typename RT = T>
void RunTanhOp(std::initializer_list<T> input) {
TensorShape shape({2, 7});
TF_ASSERT_OK(NodeDefBuilder("tanh_op", "Tanh")
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Attr("T", DataTypeToEnum<T>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<T>(shape, input);
TF_ASSERT_OK(RunOpKernel());
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
std::vector<T> expected;
expected.reserve(input.size());
for (const T& inp : input) {
expected.push_back(static_cast<T>(std::tanh(static_cast<RT>(inp))));
}
test::FillValues<T>(&expected_tensor, expected);
test::ExpectClose(expected_tensor, *GetOutput(0));
}
};
TEST_F(MlirGeneratedOpGpuTanhTest, TanhFloat) {
RunTanhOp<float>({-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f,
0.5f, 0.7f, 0.9f, 9.0f, 18.0f});
}
TEST_F(MlirGeneratedOpGpuTanhTest, TanhDouble) {
RunTanhOp<double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
0.7, 0.9, 9.0, 18.0});
}
TEST_F(MlirGeneratedOpGpuTanhTest, TanhHalf) {
RunTanhOp<Eigen::half, float>(
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)});
}
} // namespace
} // end namespace tensorflow

View File

@ -296,8 +296,6 @@ TF_CALL_tstring(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_quint16(REGISTER_CPU_KERNEL); TF_CALL_quint16(REGISTER_CPU_KERNEL);
TF_CALL_qint16(REGISTER_CPU_KERNEL); TF_CALL_qint16(REGISTER_CPU_KERNEL);
TF_CALL_uint32(REGISTER_CPU_KERNEL);
TF_CALL_uint64(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL #undef REGISTER_CPU_KERNEL
#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE #undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE

View File

@ -308,8 +308,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
TF_CALL_quint16(REGISTER_KERNELS); TF_CALL_quint16(REGISTER_KERNELS);
TF_CALL_qint16(REGISTER_KERNELS); TF_CALL_qint16(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS);
TF_CALL_uint64(REGISTER_KERNELS);
#undef REGISTER_KERNELS #undef REGISTER_KERNELS
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
} // namespace tensorflow } // namespace tensorflow

View File

@ -561,8 +561,6 @@ TF_CALL_string(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL); TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_quint16(REGISTER_CPU_KERNEL); TF_CALL_quint16(REGISTER_CPU_KERNEL);
TF_CALL_qint16(REGISTER_CPU_KERNEL); TF_CALL_qint16(REGISTER_CPU_KERNEL);
TF_CALL_uint32(REGISTER_CPU_KERNEL);
TF_CALL_uint64(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL #undef REGISTER_CPU_KERNEL

View File

@ -213,8 +213,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
TF_CALL_quint16(REGISTER_KERNELS); TF_CALL_quint16(REGISTER_KERNELS);
TF_CALL_qint16(REGISTER_KERNELS); TF_CALL_qint16(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS);
TF_CALL_uint64(REGISTER_KERNELS);
#undef REGISTER_KERNELS #undef REGISTER_KERNELS
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
} // namespace tensorflow } // namespace tensorflow

View File

@ -35,6 +35,7 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
static constexpr int VectorSizeElements = 8;
namespace functor { namespace functor {
// This kernel computes ReluGrad by processing one half2, two fp16, at a time. // This kernel computes ReluGrad by processing one half2, two fp16, at a time.
@ -93,6 +94,66 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
} }
} }
__global__ void ReluGradHalfKernelVector(
const Eigen::half* __restrict__ gradient,
const Eigen::half* __restrict__ feature, Eigen::half* __restrict__ backprop,
int32 count) {
int32 half8_count = count / VectorSizeElements;
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < half8_count) {
// Cast to xx_h8 for vector load and store.
float4 gradient_h8 = reinterpret_cast<const float4*>(gradient)[index];
float4 feature_h8 = reinterpret_cast<const float4*>(feature)[index];
float4* p_backprop_h8 = reinterpret_cast<float4*>(backprop) + index;
half2* gradient_h2 = reinterpret_cast<half2*>(&gradient_h8);
half2* feature_h2 = reinterpret_cast<half2*>(&feature_h8);
float4 backprop_h8;
half2* p_backprop_h2 = reinterpret_cast<half2*>(&backprop_h8);
// Fast path, when half2 primitives are available.
#if __CUDA_ARCH__ >= 530
const half2 kZeroH2 = __float2half2_rn(0.f);
#endif
for (int i = 0; i < VectorSizeElements / 2; i++) {
#if __CUDA_ARCH__ >= 530
// mask = (feature > 0)
half2 mask_h2 = __hgt2(feature_h2[i], kZeroH2);
// backprop = mask * gradient
half2 backprop_h2 = __hmul2(mask_h2, gradient_h2[i]);
#else
// Fall back: convert half2 to float2 for processing.
float2 feature_f2 = __half22float2(feature_h2[i]);
float2 gradient_f2 = __half22float2(gradient_h2[i]);
float2 backprop_f2 =
make_float2((feature_f2.x > 0.0f) ? float(gradient_f2.x) : 0.0f,
(feature_f2.y > 0.0f) ? float(gradient_f2.y) : 0.0f);
// Convert back to half2.
half2 backprop_h2 = __float22half2_rn(backprop_f2);
#endif
p_backprop_h2[i] = backprop_h2;
}
// Write back the result.
*p_backprop_h8 = backprop_h8;
}
int remaining_count = (count % VectorSizeElements);
if (index < remaining_count) {
// Use first threads to process the remaining elements.
Eigen::half grad_h = gradient[half8_count * VectorSizeElements + index];
Eigen::half feature_h = feature[half8_count * VectorSizeElements + index];
float grad_f = static_cast<float>(grad_h);
float feature_f = static_cast<float>(feature_h);
float backprop_f = (feature_f > 0) ? grad_f : 0;
Eigen::half backprop_h(backprop_f);
backprop[half8_count * VectorSizeElements + index] = backprop_h;
}
}
template <typename Device> template <typename Device>
struct ReluGrad<Device, Eigen::half> { struct ReluGrad<Device, Eigen::half> {
// Computes ReluGrad backprop. // Computes ReluGrad backprop.
@ -108,16 +169,29 @@ struct ReluGrad<Device, Eigen::half> {
// NOTE: When the activation is exactly zero, we do not propagate the // NOTE: When the activation is exactly zero, we do not propagate the
// associated gradient value. This allows the output of the Relu to be used, // associated gradient value. This allows the output of the Relu to be used,
// as well as its input. // as well as its input.
auto gradient_ptr = reinterpret_cast<uintptr_t>(gradient.data());
auto feature_ptr = reinterpret_cast<uintptr_t>(feature.data());
auto backprop_ptr = reinterpret_cast<uintptr_t>(backprop.data());
bool aligned = gradient_ptr % 16 == 0 && feature_ptr % 16 == 0 &&
backprop_ptr % 16 == 0;
int32 count = gradient.size(); int32 count = gradient.size();
if (count == 0) return;
int32 half2_count = Eigen::divup(count, 2);
constexpr int32 kThreadInBlock = 512; constexpr int32 kThreadInBlock = 512;
if (count == 0) return;
if (aligned) {
int32 half8_count = Eigen::divup(count, VectorSizeElements);
int32 kBlock = Eigen::divup(half8_count, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernelVector, kBlock, kThreadInBlock, 0, d.stream(),
gradient.data(), feature.data(), backprop.data(), count));
} else {
int32 half2_count = Eigen::divup(count, 2);
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize( GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock); half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel( TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0, ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
d.stream(), gradient.data(), feature.data(), backprop.data(), count)); d.stream(), gradient.data(), feature.data(), backprop.data(), count));
} }
}
}; };
__global__ void Relu_int8x4_kernel(int vect_count, __global__ void Relu_int8x4_kernel(int vect_count,

View File

@ -512,7 +512,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS);
#undef REGISTER_KERNELS #undef REGISTER_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -43,7 +43,6 @@ void Split<Eigen::ThreadPoolDevice, T, NDims>::operator()(
TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS) TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
DEFINE_CPU_KERNELS(quint8) DEFINE_CPU_KERNELS(quint8)
DEFINE_CPU_KERNELS(uint64)
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
template <typename T, int NDims> template <typename T, int NDims>

View File

@ -404,7 +404,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
TF_CALL_ALL_TYPES(REGISTER_SPLIT); TF_CALL_ALL_TYPES(REGISTER_SPLIT);
REGISTER_SPLIT(quint8); REGISTER_SPLIT(quint8);
REGISTER_SPLIT(uint64);
#undef REGISTER_SPLIT #undef REGISTER_SPLIT

View File

@ -440,8 +440,6 @@ class StridedSliceAssignOp : public OpKernel {
StridedSliceAssignOp<CPUDevice, type, true>) StridedSliceAssignOp<CPUDevice, type, true>)
TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE); TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
TF_CALL_uint32(REGISTER_STRIDED_SLICE);
TF_CALL_uint64(REGISTER_STRIDED_SLICE);
#undef REGISTER_STRIDED_SLICE #undef REGISTER_STRIDED_SLICE

View File

@ -287,8 +287,6 @@ TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
TF_CALL_uint32(DECLARE_FOR_N_CPU);
TF_CALL_uint64(DECLARE_FOR_N_CPU);
#ifdef TENSORFLOW_USE_SYCL #ifdef TENSORFLOW_USE_SYCL
#define PREVENT_FOR_N_SYCL(T) \ #define PREVENT_FOR_N_SYCL(T) \

View File

@ -52,7 +52,8 @@ class SummaryScalarOp : public OpKernel {
Summary s; Summary s;
for (int i = 0; i < Ttags.size(); i++) { for (int i = 0; i < Ttags.size(); i++) {
Summary::Value* v = s.add_value(); Summary::Value* v = s.add_value();
v->set_tag(string(Ttags(i))); // NOLINT const tstring& Ttags_i = Ttags(i);
v->set_tag(Ttags_i.data(), Ttags_i.size());
v->set_simple_value(float(Tvalues(i))); v->set_simple_value(float(Tvalues(i)));
} }
@ -102,7 +103,8 @@ class SummaryHistoOp : public OpKernel {
Summary s; Summary s;
Summary::Value* v = s.add_value(); Summary::Value* v = s.add_value();
v->set_tag(string(tags.scalar<tstring>()())); // NOLINT const tstring& tags0 = tags.scalar<tstring>()();
v->set_tag(tags0.data(), tags0.size());
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
Tensor* summary_tensor = nullptr; Tensor* summary_tensor = nullptr;

View File

@ -258,7 +258,6 @@ namespace functor {
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
TF_CALL_uint32(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC #undef DECLARE_GPU_SPEC
@ -276,7 +275,6 @@ TF_CALL_uint32(DECLARE_GPU_SPEC);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS)
#undef REGISTER_KERNELS #undef REGISTER_KERNELS
#endif // end GOOGLE_CUDA #endif // end GOOGLE_CUDA

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
#include "tensorflow/core/kernels/topk_op_gpu.h"
namespace tensorflow {
using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, uint32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
#include "tensorflow/core/kernels/topk_op_gpu.h"
namespace tensorflow {
using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, uint64>;
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -252,7 +252,8 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS); TF_CALL_uint32(REGISTER_GPU_KERNELS);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS #undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -78,3 +78,32 @@ op {
} }
} }
} }
op {
name: "Acos"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

View File

@ -78,3 +78,32 @@ op {
} }
} }
} }
op {
name: "Asin"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

View File

@ -78,3 +78,32 @@ op {
} }
} }
} }
op {
name: "Atan"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

View File

@ -168,3 +168,32 @@ op {
} }
} }
} }
op {
name: "Inv"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

Some files were not shown because too many files have changed in this diff Show More