Merge branch 'master' into master

This commit is contained in:
Mihai Maruseac 2020-06-19 17:23:54 +00:00 committed by GitHub
commit fccd345ab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
251 changed files with 6055 additions and 2116 deletions

View File

@ -24,9 +24,21 @@ cc_library(
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":gcs_helper",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gcs_helper",
srcs = ["gcs_helper.cc"],
hdrs = ["gcs_helper.h"],
linkstatic = 1,
deps = [
"//tensorflow/c:env",
],
)

View File

@ -15,9 +15,13 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
#include <fstream>
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
@ -86,6 +90,20 @@ namespace tf_random_access_file {
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct GCSFile {
const char* bucket;
const char* object;
gcs::Client* gcs_client; // not owned
TempFile outfile;
bool sync_need;
} GCSFile;
static void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
plugin_memory_free(const_cast<char*>(gcs_file->object));
delete gcs_file;
}
// TODO(vnvo2409): Implement later
@ -119,6 +137,20 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) {
// TODO(vnvo2409): Implement later
static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
TempFile outfile(TF_GetTempFileName(""), std::ios::binary | std::ios::out);
file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client, std::move(outfile), true});
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -126,9 +158,14 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include <stdio.h>
#include <fstream>
#include <string>
#include <utility>
TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode)
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
TempFile::TempFile(TempFile&& rhs)
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
TempFile::~TempFile() {
std::fstream::close();
std::remove(name_.c_str());
}
const std::string TempFile::getName() const { return name_; }

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#include <fstream>
#include <string>
class TempFile : public std::fstream {
public:
// We should specify openmode each time we call TempFile.
TempFile(const char* temp_file_name, std::ios::openmode mode);
TempFile(TempFile&& rhs);
~TempFile() override;
const std::string getName() const;
private:
const std::string name_;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_

View File

@ -314,7 +314,6 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",

View File

@ -953,14 +953,14 @@ in the batch dimensions and broadcasting.
}];
let arguments = (ins
TFL_TensorOf<[F32]>:$x,
TFL_TensorOf<[F32]>:$y,
TFL_TensorOf<[F32, QI8]>:$x,
TFL_TensorOf<[F32, QI8]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y
);
let results = (outs
TFL_TensorOf<[F32]>:$output
TFL_TensorOf<[F32, QI8]>:$output
);
let hasOptions = 1;

View File

@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());

View File

@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TFL
} // namespace mlir

View File

@ -475,6 +475,7 @@ cc_library(
"transforms/cluster_outlining.cc",
"transforms/collection_ops_util.cc",
"transforms/decompose_resource_ops_pass.cc",
"transforms/device_index_selector.cc",
"transforms/einsum.cc",
"transforms/executor_island_coarsening.cc",
"transforms/executor_tpuv1_inline_tpu_island.cc",

View File

@ -164,6 +164,81 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
let hasFolder = 1;
}
def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> {
let summary = "Adjust the contrast of one or more images.";
let description = [{
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
interpreted as `[height, width, channels]`. The other dimensions only
represent a collection of images, such as `[batch, height, width, channels].`
Contrast is adjusted independently for each channel of each image.
For each channel, the Op first computes the mean of the image pixels in the
channel and then adjusts each component of each pixel to
`(x - mean) * contrast_factor + mean`.
}];
let arguments = (ins
TensorOf<[F16, F32]>:$images,
F32Tensor:$contrast_factor
);
let results = (outs
TensorOf<[F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AdjustHueOp : TF_Op<"AdjustHue", [NoSideEffect]> {
let summary = "Adjust the hue of one or more images.";
let description = [{
`images` is a tensor of at least 3 dimensions. The last dimension is
interpreted as channels, and must be three.
The input image is considered in the RGB colorspace. Conceptually, the RGB
colors are first mapped into HSV. A delta is then applied all the hue values,
and then remapped back to RGB colorspace.
}];
let arguments = (ins
TensorOf<[F16, F32]>:$images,
F32Tensor:$delta
);
let results = (outs
TensorOf<[F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AdjustSaturationOp : TF_Op<"AdjustSaturation", [NoSideEffect]> {
let summary = "Adjust the saturation of one or more images.";
let description = [{
`images` is a tensor of at least 3 dimensions. The last dimension is
interpreted as channels, and must be three.
The input image is considered in the RGB colorspace. Conceptually, the RGB
colors are first mapped into HSV. A scale is then applied all the saturation
values, and then remapped back to RGB colorspace.
}];
let arguments = (ins
TensorOf<[F16, F32]>:$images,
F32Tensor:$scale
);
let results = (outs
TensorOf<[F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
let summary = [{
Computes the "logical and" of elements across dimensions of a tensor.
@ -3866,6 +3941,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_HSVToRGBOp : TF_Op<"HSVToRGB", [NoSideEffect]> {
let summary = "Convert one or more images from HSV to RGB.";
let description = [{
Outputs a tensor of the same shape as the `images` tensor, containing the RGB
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
See `rgb_to_hsv` for a description of the HSV encoding.
}];
let arguments = (ins
TF_FpTensor:$images
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
let summary = "Creates a non-initialized hash table.";
@ -5962,11 +6059,11 @@ I.e., \\(y = -x\\).
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -6733,6 +6830,41 @@ the dimension is padded with zeros.
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
}
def TF_RGBToHSVOp : TF_Op<"RGBToHSV", [NoSideEffect]> {
let summary = "Converts one or more images from RGB to HSV.";
let description = [{
Outputs a tensor of the same shape as the `images` tensor, containing the HSV
value of the pixels. The output is only well defined if the value in `images`
are in `[0,1]`.
`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
Usage Example:
>>> blue_image = tf.stack([
... tf.zeros([5,5]),
... tf.zeros([5,5]),
... tf.ones([5,5])],
... axis=-1)
>>> blue_hsv_image = tf.image.rgb_to_hsv(blue_image)
>>> blue_hsv_image[0,0].numpy()
array([0.6666667, 1. , 1. ], dtype=float32)
}];
let arguments = (ins
TF_FpTensor:$images
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
WithBroadcastableBinOpBuilder {
let summary = [{
@ -7230,6 +7362,27 @@ Input images can be of different types but output images are always float.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ResizeBilinearGradOp : TF_Op<"ResizeBilinearGrad", [NoSideEffect]> {
let summary = "Computes the gradient of bilinear interpolation.";
let description = [{
}];
let arguments = (ins
F32Tensor:$grads,
TF_FpTensor:$original_image,
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> {
let summary = [{
Resize `images` to `size` using nearest neighbor interpolation.

View File

@ -21,11 +21,11 @@ limitations under the License.
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TFL {
namespace TF {
namespace {
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() {
// Convert all the DeviceIndex ops to constant values.
func.getBody().walk([](TF::DeviceIndexOp op) {
// This just selects the default in all cases where DeviceIndex feeds into
// tf.Case. This could be enhanced based on explicit TFLite specification or
// TAC in future.
// tf.Case. This could be enhanced to have some sort of policy in the
// future.
OpBuilder b(op);
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
int index = op.device_names().size();
@ -79,7 +79,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
}
static PassRegistration<DeviceIndexSelector> pass(
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant");
"tf-device-index-selector", "Fold tf.DeviceIndex to constant");
} // namespace TFL
} // namespace TF
} // namespace mlir

View File

@ -52,11 +52,6 @@ struct FusedKernelMatcherPass
void runOnFunction() override;
};
// Returns an op's name with the dialect prefix stripped off.
StringRef GetOpNameWithoutDialect(Operation *op) {
return op->getName().getStringRef().split(".").second;
}
bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
}
@ -128,8 +123,8 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
}
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{
StringAttr::get(GetOpNameWithoutDialect(bias_add), context)};
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
bias_add.getOperation()->getName().stripDialect(), context)};
// BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add);
@ -143,7 +138,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
if (fuse_activation) {
locations.push_back(activation->getLoc());
fused_ops.push_back(
StringAttr::get(GetOpNameWithoutDialect(activation), context));
StringAttr::get(activation->getName().stripDialect(), context));
result_type = activation->getResultTypes().front();
} else {
result_type = bias_add.getResult().getType();

View File

@ -147,6 +147,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
// generally used beyond exporting to runtimes that supports these ops. In the
// future these fusions may be codegen'd automatically.
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TF
namespace tf_executor {

View File

@ -106,53 +106,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
return GetI64ElementsAttr(slice_limits, builder);
}
// Returns the padding value of the given position. If padding_attr is a
// nullptr, returns 0.
static int64_t GetPaddingValue(DenseIntElementsAttr padding_attr,
ArrayRef<uint64_t> index) {
if (!padding_attr) return 0;
return padding_attr.getValue<int64_t>(index);
}
static bool IsOnlyPaddingSpatialDims(Value lhs,
ConvDimensionNumbers dimension_numbers,
DenseIntElementsAttr edge_padding_low,
DenseIntElementsAttr edge_padding_high) {
const int64_t batch_dim = dimension_numbers.input_batch_dimension().getInt();
const int64_t feature_dim =
dimension_numbers.input_feature_dimension().getInt();
if (edge_padding_low.getValue<int64_t>(batch_dim) ||
edge_padding_high.getValue<int64_t>(batch_dim))
return false;
if (edge_padding_low.getValue<int64_t>(feature_dim) ||
edge_padding_high.getValue<int64_t>(feature_dim))
return false;
return true;
}
DenseIntElementsAttr BuildConvPaddingAttrs(
DenseIntElementsAttr edge_padding_low,
DenseIntElementsAttr edge_padding_high, DenseIntElementsAttr padding_attr,
ConvDimensionNumbers dimension_numbers, Builder* builder) {
SmallVector<int64_t, 4> padding_low, padding_high;
for (const auto& dim : dimension_numbers.input_spatial_dimensions()) {
unsigned i = dim.getZExtValue();
padding_low.push_back(edge_padding_low.getValue<int64_t>(i));
padding_high.push_back(edge_padding_high.getValue<int64_t>(i));
}
int rank = padding_low.size();
SmallVector<int64_t, 8> padding;
for (unsigned i = 0, e = rank; i < e; ++i) {
padding.push_back(GetPaddingValue(padding_attr, {i, 0}) + padding_low[i]);
padding.push_back(GetPaddingValue(padding_attr, {i, 1}) + padding_high[i]);
}
// padding_attr.getType() doesn't work because it is an optional attribute,
// which can be a nullptr.
auto type = RankedTensorType::get({rank, 2}, builder->getIntegerType(64));
return DenseIntElementsAttr::get(type, padding);
}
#include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc"
} // namespace
@ -2153,14 +2106,5 @@ LogicalResult deriveShapeFromFirstOperand(
return success();
}
//===----------------------------------------------------------------------===//
// ConvOp
//===----------------------------------------------------------------------===//
void ConvOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<FoldPadIntoConv>(context);
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -929,8 +929,6 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
);
let results = (outs HLO_Tensor);
let hasCanonicalizer = 1;
}
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {

View File

@ -132,6 +132,22 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(false),
builder_.getStringAttr(opaque));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,

View File

@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder {
FftType fft_type,
absl::Span<const int64> fft_length) override;
StatusOr<XlaOp> CustomCallInternal(const string& call_target_name,
absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>>
operand_shapes_with_layout) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,

View File

@ -415,71 +415,6 @@ func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: func @fold_pad_into_conv_f32
func @fold_pad_into_conv_f32(%arg0 : tensor<1x32x32x3xf32>,
%arg1 : tensor<7x7x3x64xf32>)
-> tensor<1x16x16x64xf32> {
// CHECK-NOT: xla_hlo.pad
// CHECK: xla_hlo.convolution
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.pad"(%arg0, %0) {
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
interior_padding = dense<0> : tensor<4xi64>
} : (tensor<1x32x32x3xf32>, tensor<f32>) -> tensor<1x38x38x3xf32>
%2 = "xla_hlo.convolution"(%1, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
window_strides = dense<2> : tensor<2xi64>
} : (tensor<1x38x38x3xf32>, tensor<7x7x3x64xf32>) -> tensor<1x16x16x64xf32>
return %2 : tensor<1x16x16x64xf32>
}
// CHECK-LABEL: func @fold_pad_into_conv_i32
func @fold_pad_into_conv_i32(%arg0 : tensor<1x32x32x3xi32>,
%arg1 : tensor<7x7x3x64xi32>)
-> tensor<1x16x16x64xi32> {
// CHECK-NOT: xla_hlo.pad
// CHECK: xla_hlo.convolution
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "xla_hlo.pad"(%arg0, %0) {
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
interior_padding = dense<0> : tensor<4xi64>
} : (tensor<1x32x32x3xi32>, tensor<i32>) -> tensor<1x38x38x3xi32>
%2 = "xla_hlo.convolution"(%1, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
window_strides = dense<2> : tensor<2xi64>
} : (tensor<1x38x38x3xi32>, tensor<7x7x3x64xi32>) -> tensor<1x16x16x64xi32>
return %2 : tensor<1x16x16x64xi32>
}
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
// CHECK: xla_hlo.reshape

View File

@ -35,7 +35,7 @@ func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor
// CHECK-LABEL: unranked_operand
func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: tf.Abs
// expected-remark@+1 {{lowering requires static shaped operands}}
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
@ -44,12 +44,20 @@ func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: dynamic_operand
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: tf.Abs
// expected-remark@+1 {{lowering requires static shaped operands}}
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: tuple_type
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
// Verifies that the pass can handle operands of non-tensor type like tuple
// from non TensorFlow ops.
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: unsupported_dtype
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
// CHECK: tf.AddN

View File

@ -28,54 +28,3 @@ def UnaryEinsumToEinsum : Pat<
(HLO_UnaryEinsumOp $operand, $equation),
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
$operand, (UnaryToBinaryEinsumEq $equation))>;
//===----------------------------------------------------------------------===//
// Conv op patterns.
//===----------------------------------------------------------------------===//
def IsZero : Attr<CPred<
"($_self.isa<DenseFPElementsAttr>() &&"
"$_self.cast<DenseFPElementsAttr>().isSplat() &&"
"$_self.cast<DenseFPElementsAttr>().getSplatValue<FloatAttr>()"
".getValue().isZero()) ||"
"($_self.isa<DenseIntElementsAttr>() &&"
"$_self.cast<DenseIntElementsAttr>().isSplat() &&"
"$_self.cast<DenseIntElementsAttr>().getSplatValue<IntegerAttr>()"
".getInt() == 0)">>;
def IsOnlyPaddingSpatialDims
: Constraint<CPred<"IsOnlyPaddingSpatialDims($0, $1, $2, $3)">>;
def BuildConvPaddingAttrs : NativeCodeCall<
"BuildConvPaddingAttrs($0, $1, $2, $3, &$_builder)">;
def FoldPadIntoConv : Pat<
(HLO_ConvOp
(HLO_PadOp $lhs,
(HLO_ConstOp IsZero:$padding_value),
$edge_padding_low,
$edge_padding_high,
IsZero:$interior_padding),
$rhs,
$window_strides,
$padding,
$lhs_dilation,
$rhs_dilation,
$dimension_numbers,
$feature_group_count,
$batch_group_count,
$precision_config),
(HLO_ConvOp
$lhs,
$rhs,
$window_strides,
(BuildConvPaddingAttrs $edge_padding_low, $edge_padding_high, $padding,
$dimension_numbers),
$lhs_dilation,
$rhs_dilation,
$dimension_numbers,
$feature_group_count,
$batch_group_count,
$precision_config),
[(IsOnlyPaddingSpatialDims $lhs, $dimension_numbers, $edge_padding_low,
$edge_padding_high)]>;

View File

@ -389,10 +389,13 @@ struct HloLegalizeToLhlo
target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
BufferAssignmentTypeConverter converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs();
return std::all_of(inputs.begin(), inputs.end(),
[](Type input) { return input.isa<MemRefType>(); });
return llvm::all_of(inputs,
[](Type input) { return input.isa<MemRefType>(); }) &&
converter.isLegal(&op.getBody());
});
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
return std::all_of(returnOp.operand_type_begin(),
@ -401,8 +404,7 @@ struct HloLegalizeToLhlo
});
auto module = getOperation();
BufferAssignmentTypeConverter converter;
module.walk([&](FuncOp func) {
module.walk([&](FuncOp func) -> WalkResult {
BufferAssignmentPlacer bufferAssignment(func);
OwningRewritePatternList patterns;
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
@ -418,8 +420,7 @@ struct HloLegalizeToLhlo
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns);
}
return WalkResult(
applyPartialConversion(func, target, patterns, &converter));
return applyPartialConversion(func, target, patterns);
});
}
@ -463,6 +464,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<xla_hlo::RealOp>,
HloToLhloOpConverter<xla_hlo::RemOp>,
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
HloToLhloOpConverter<xla_hlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SqrtOp>,

View File

@ -5238,8 +5238,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
DenseSet<Operation *> nonlegalized_ops;
LogicalResult result = applyPartialConversion(
op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops);
LogicalResult result =
applyPartialConversion(op, target, patterns, &nonlegalized_ops);
// In order to enforce that the conversion result is fully converted,
// fail if there are any nonlegalized ops in the set.
if (failed(result) || !nonlegalized_ops.empty()) {

View File

@ -88,6 +88,9 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::AddNOp>(),
TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(),
TypeID::get<TF::AdjustContrastv2Op>(),
TypeID::get<TF::AdjustHueOp>(),
TypeID::get<TF::AdjustSaturationOp>(),
TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(),
@ -127,6 +130,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::GatherNdOp>(),
TypeID::get<TF::GreaterEqualOp>(),
TypeID::get<TF::GreaterOp>(),
TypeID::get<TF::HSVToRGBOp>(),
TypeID::get<TF::IFFT2DOp>(),
TypeID::get<TF::IFFT3DOp>(),
TypeID::get<TF::IFFTOp>(),
@ -157,10 +161,14 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::PowOp>(),
TypeID::get<TF::RFFT2DOp>(),
TypeID::get<TF::RFFT3DOp>(),
TypeID::get<TF::RGBToHSVOp>(),
TypeID::get<TF::RealDivOp>(),
TypeID::get<TF::ReciprocalOp>(),
TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::ResizeBilinearOp>(),
TypeID::get<TF::ResizeBilinearGradOp>(),
TypeID::get<TF::ResizeNearestNeighborOp>(),
TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(),
@ -337,9 +345,9 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
// Only static shaped operands are supported in XLA builders for now.
for (Type ty : op->getOperandTypes()) {
auto ranked_ty = ty.cast<ShapedType>();
if (!ranked_ty.hasStaticShape()) {
op->emitRemark() << "lowering requires static shaped operands";
auto ranked_ty = ty.dyn_cast<ShapedType>();
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
op->emitRemark() << "lowering requires static shaped tensor operands";
return success();
}
}

View File

@ -177,7 +177,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
target.addIllegalOp<ReduceOp>();
auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
}
}

View File

@ -43,7 +43,7 @@ class TestLhloToLLVMPass
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<XlaLhloDialect>();
if (failed(applyFullConversion(m, target, patterns, &converter))) {
if (failed(applyFullConversion(m, target, patterns))) {
signalPassFailure();
}
}

View File

@ -711,7 +711,7 @@ struct LhloLegalizeToParallelLoops
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>();
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
}
}

View File

@ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);
MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(RsqrtOp);
MAP_HLO_TO_LHLO(SelectOp);

View File

@ -867,7 +867,7 @@ struct LhloLegalizeToLinalg
auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
}
}
@ -882,7 +882,7 @@ struct HloLegalizeToLinalg
auto func = getFunction();
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
}
}

View File

@ -1,6 +1,6 @@
// Test DeviceIndex selector.
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s
// RUN: tf-opt --tf-device-index-selector %s | FileCheck %s
// CHECK-LABEL: func @select
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {

View File

@ -770,6 +770,7 @@ tf_xla_py_test(
size = "small",
timeout = "long",
srcs = ["image_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 10,
tags = [

View File

@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall(
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.",
call_target_name);
}
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (operand_shapes_with_layout.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument(
@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall(
"with constrained layout; given %d shapes, expected %d",
operand_shapes_with_layout->size(), operands.size());
}
instr.set_constrain_layout(true);
int64 operand_num = 0;
for (const Shape& operand_shape : *operand_shapes_with_layout) {
if (!LayoutUtil::HasLayout(operand_shape)) {
@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall(
"constrained layout.",
operand_num);
}
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
++operand_num;
}
}
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout);
});
}
StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (operand_shapes_with_layout.has_value()) {
instr.set_constrain_layout(true);
for (const Shape& operand_shape : *operand_shapes_with_layout) {
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
}
}
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}
XlaOp XlaBuilder::CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape, const string& opaque,

View File

@ -527,6 +527,14 @@ class XlaBuilder {
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
// method above calls this method after error handling.
virtual StatusOr<XlaOp> CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape_with_layout,

View File

@ -141,7 +141,9 @@ cc_library(
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
"//tensorflow/core:allocator",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/profiler/lib:traceme_encode",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/host:host_platform_id",

View File

@ -127,11 +127,15 @@ std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
} else {
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
usage_stream_pool_.pop();
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
return stream;
}
}
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
QCHECK(stream->ok());
absl::MutexLock lock(&mu_);
usage_stream_pool_.push(std::move(stream));
}

View File

@ -98,7 +98,9 @@ limitations under the License.
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/profiler/lib/traceme_encode.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/event.h"
@ -749,16 +751,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
// memory that has already been allocated, and a possible Event
// allocation.
se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform());
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, buffer));
h2d_stream, literal, buffer));
std::shared_ptr<BufferSequencingEvent> event =
device_buffer->definition_events()[0];
TF_CHECK_OK(AddDestinationBufferSynchronization(
local_device, std::move(device_buffer), event,
local_device->host_to_device_stream()));
local_device, std::move(device_buffer), event, h2d_stream));
// This can sometimes catch the case where the literal memory has been
// freed before the H2D transfer was issued.
h2d_stream->RefreshStatus()
.IgnoreError(); // Can return error::Unimplemented
QCHECK(h2d_stream->ok());
};
client->h2d_transfer_pool()->Schedule(transfer_h2d);
return py_buffer;
@ -1069,13 +1077,17 @@ Status PjRtBuffer::CopyToHostAsync() {
return Status::OK();
}
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral() {
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
TF_RETURN_IF_ERROR(CopyToHostAsync());
std::shared_ptr<HostValue> host_value;
{
absl::MutexLock lock(&mu_);
host_value = host_value_;
if (discard_cached_copy) {
host_value_ = nullptr;
}
}
if (host_value == nullptr) {
return InvalidArgument("ToLiteral called on invalid buffer");
@ -1429,10 +1441,9 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const {
int device_ordinal = device->local_device_state()->device_ordinal();
tensorflow::profiler::TraceMe traceme([&] {
return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(),
"#");
});
tensorflow::profiler::TraceMeConsumer activity(
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
run_id.ToInt());
VLOG(3) << "Replica " << replica << ", partition " << partition
<< " mapped to device ordinal for execution: " << device_ordinal;
@ -1721,10 +1732,9 @@ PjRtExecutable::ExecuteOnLocalDevices(
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
const ExecuteOptions& options) const {
RunId run_id;
tensorflow::profiler::TraceMe traceme([&] {
return absl::StrCat(
"LocalExecutable::ExecuteOnLocalDevices#run_id=", run_id.ToInt(), "#");
});
tensorflow::profiler::TraceMeProducer activity(
"LocalExecutable::ExecuteOnLocalDevices",
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
const int num_local_devices = local_devices_.size();

View File

@ -478,8 +478,12 @@ class PjRtBuffer {
// Returns the buffer's value as an XLA Literal. If the value has previously
// been prefetched to the host, then returns the prefetched version, otherwise
// copies the buffer to the host. Blocks until the value is ready.
StatusOr<std::shared_ptr<Literal>> ToLiteral();
// copies the buffer to the host. Blocks until the value is ready. If
// `discard_cached_copy` is true then buffer will no longer keep hold of a
// cached copy of the literal (i.e. The reference to the host value will be
// removed.)
StatusOr<std::shared_ptr<Literal>> ToLiteral(
bool discard_cached_copy = false);
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to

View File

@ -106,7 +106,6 @@ class BranchVisitor {
boundaries_.emplace_back(operand, i, inst);
continue;
}
worklist_.push_back(operand);
visited_.insert(operand);
}
@ -197,6 +196,7 @@ bool WorthHoisting(HloInstruction* instruction) {
case HloOpcode::kMultiply:
case HloOpcode::kDivide:
case HloOpcode::kTuple:
case HloOpcode::kSqrt:
case HloOpcode::kGetTupleElement:
return true;
default:
@ -206,10 +206,11 @@ bool WorthHoisting(HloInstruction* instruction) {
// Compare if the instructions to be visited at each branches are identical.
bool InstructionWithinBranchIdentical(
const std::vector<HloInstruction*>& instructions, bool is_layout_senstive) {
const std::vector<HloInstruction*>& instructions,
bool is_layout_sensitive) {
// Identical includes the shape of each operands are equal.
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
bool eq_operands = is_layout_senstive
bool eq_operands = is_layout_sensitive
? ShapeUtil::Equal(a->shape(), b->shape())
: ShapeUtil::Compatible(a->shape(), b->shape());
return eq_operands;
@ -233,7 +234,7 @@ bool InstructionWithinBranchIdentical(
auto old_channel_id = instruction->channel_id();
instruction->set_channel_id(instructions[0]->channel_id());
bool eq_instructions = instructions[0]->Identical(
*instruction, eq_operand, eq_computations, is_layout_senstive);
*instruction, eq_operand, eq_computations, is_layout_sensitive);
instruction->set_channel_id(old_channel_id);
return eq_instructions;
});
@ -243,7 +244,7 @@ bool InstructionWithinBranchIdentical(
[&](HloInstruction* instruction) {
return instructions[0]->Identical(
*instruction, eq_operand, eq_computations,
is_layout_senstive);
is_layout_sensitive);
});
}
@ -354,12 +355,228 @@ Status RemoveInstructionFromComputation(
return Status::OK();
}
// Identify converts to be hoisted/rematerialized out of the branch
// computations.
absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
int branch_count,
HloInstruction* conditional,
bool is_layout_sensitive) {
absl::flat_hash_set<int64> kspecial_convert;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
continue;
}
bool replica = true;
HloInstruction* kspecial_convert_candidate =
old_root->mutable_operand(operand_num);
// Check whether an identical candidate appears in other branches
for (int others = 1; others < branch_count; ++others) {
HloInstruction* others_root =
conditional->branch_computation(others)->root_instruction();
bool eq_shape =
is_layout_sensitive
? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape())
: ShapeUtil::Compatible(
others_root->operand(operand_num)->shape(),
kspecial_convert_candidate->shape());
if ((others_root->operand(operand_num)->opcode() ==
HloOpcode::kConvert) &&
eq_shape) {
// Nothing to be done.
} else {
replica = false;
break;
}
}
if (replica) {
kspecial_convert.insert(operand_num);
}
}
return kspecial_convert;
}
// Restructuring the conditional instruction as follows:
// i.e., %result = conditional() becomes
// x = conditional()
// y.{0..n} = gte(x, {0..n})
// z = tuple(y.0, y.1, ...y.n)
// Doing so ensures that we can accommodate the possible shape-change of the
// conditional when the instructions are hoisted.
Status RestructureConditionalInstruction(HloComputation* computation,
HloInstruction* conditional) {
HloInstruction* old_root = computation->root_instruction();
std::vector<HloInstruction*> new_operands;
int cur_index = 0;
for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
++cur_index) {
new_operands.push_back(
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
conditional, cur_index)));
}
HloInstruction* new_tuple =
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
if (old_root == conditional) {
computation->set_root_instruction(new_tuple);
} else {
std::vector<HloInstruction*> new_tuple_users;
for (auto conditional_user : conditional->users()) {
auto is_new_gte = absl::c_find_if(
new_operands,
[&](HloInstruction* instr) { return instr == conditional_user; });
if (is_new_gte == new_operands.end()) {
new_tuple_users.push_back(conditional_user);
}
}
for (auto new_tuple_user : new_tuple_users) {
TF_RETURN_IF_ERROR(
conditional->ReplaceUseWith(new_tuple_user, new_tuple));
}
}
VLOG(2) << "computation after root restructure:\n" << computation->ToString();
return Status::OK();
}
StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
bool is_layout_sensitive) {
int branch_count = conditional->branch_count();
if (branch_count <= 0) {
return false;
}
HloInstruction* old_root =
conditional->branch_computation(0)->root_instruction();
if (old_root->opcode() != HloOpcode::kTuple) {
return false;
} else {
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
// Identify the gte using `index'.
auto find_gte = [](const HloInstruction* conditional_result,
int64 index) -> HloInstruction* {
for (HloInstruction* instr : conditional_result->users()) {
if (instr->opcode() != HloOpcode::kGetTupleElement) {
return nullptr;
}
if (instr->tuple_index() == index) {
return instr;
}
}
return nullptr;
};
// Captures tuple indices refering to converts to be rematerialized/hoisted.
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
old_root, branch_count, conditional, is_layout_sensitive);
// Exit if we cannot find any converts to be hoisted.
if (kspecial_convert.empty()) {
return false;
}
TF_RETURN_IF_ERROR(
RestructureConditionalInstruction(conditional->parent(), conditional));
for (int branch = 0; branch < branch_count; branch++) {
old_root = conditional->branch_computation(branch)->root_instruction();
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
std::vector<HloInstruction*> new_operands(old_root->operand_count());
std::unordered_set<HloInstruction*> to_hoist_set;
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
operand_num;
}
for (int64 operand_num = 0; operand_num < old_root->operand_count();
++operand_num) {
HloInstruction* hoist = old_root->mutable_operand(operand_num);
if (!kspecial_convert.contains(operand_num)) {
new_operands[operand_num] = old_root->mutable_operand(operand_num);
continue;
}
to_hoist_set.insert(hoist);
int64 new_tuple_count = old_root->operand_count();
// Replace the hoisted instr in the tuple with the operand/operands.
// We will replace at least one of the operands of the hoist at the
// tuple place; the rest will be added at the end.
bool inplace = true;
CHECK(!hoist->operands().empty());
for (HloInstruction* prod : hoist->operands()) {
if (inplace) {
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
new_operands[map_inst_to_tuple_index[hoist]] = prod;
inplace = false;
} else {
map_inst_to_tuple_index[prod] = new_tuple_count++;
new_operands.push_back(prod);
}
}
}
// Create the new root instruction.
HloComputation* cur_branch = conditional->branch_computation(branch);
HloInstruction* new_branch_root =
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
// The shape can vary since the operands to convert are now
// being returned through the branches' root.
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
// Only one of the branches needs to change the conditional->parent().
if (branch != 0) {
continue;
}
HloComputation* conditional_parent = conditional->parent();
HloInstruction* newconditional =
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
cur_branch->root_instruction()->shape(),
conditional->mutable_operand(0),
absl::MakeSpan(conditional->branch_computations()),
absl::MakeSpan(conditional->operands()).subspan(1)));
// Ensure that all the users of conditional refer to the new one.
TF_RETURN_IF_ERROR(
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
conditional = newconditional;
// Add the hoisted instructions in the parent.
for (HloInstruction* hoist : to_hoist_set) {
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
int64 hoist_index = map_inst_to_tuple_index[hoist];
// Find out the gte that captured the hoisted instr result.
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
CHECK(gte_hoist != nullptr);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* op : hoist->operands()) {
HloInstruction* gte = conditional_parent->AddInstruction(
HloInstruction::CreateGetTupleElement(
op->shape(), conditional, map_inst_to_tuple_index[op]));
new_operands.push_back(gte);
}
HloInstruction* hoisted = conditional_parent->AddInstruction(
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
}
// No need to explicitly delete a hoisted instruction since if its dead
// then the subsequent DCE will remove it.
}
}
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
return true;
}
// Hoist identical ops out of the conditional. The definition of identical
// are the shape of the operands are identical and their properties are
// identical. Will start from the root instruction of each branch and get
// the identical ops to hoist.
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
bool is_layout_sensitive) {
VLOG(1) << " visiting conditional:" << conditional->ToString();
int branch_count = conditional->branch_count();
if (branch_count <= 0) {
return false;
@ -399,7 +616,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
}
}
if (visitors[0].HoistInstructionSize() <= 1) {
if (visitors[0].HoistInstructionSize() < 1) {
return false;
}
@ -442,7 +659,6 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
conditional->branch_computation(i)));
}
return true;
}
@ -451,26 +667,55 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
bool changed = false;
// Gather all the conditional ops in our module. We do this ahead of time so
// we don't have to worry about mutating the lists of computations or
// instructions as we iterate.
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
if (pursue_full_conditional_code_motion_) {
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
}
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(
bool result,
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
changed |= result;
}
if (changed) {
HloPassPipeline subpipeline("after_conditional_code_motion");
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
changed |= cleanup_changed;
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements(
conditional_op, is_layout_sensitive_));
changed |= result;
// handling convert rematerialization/hoisting
{
std::vector<HloInstruction*> conditional_ops;
for (auto* comp : module->MakeComputationPostOrder()) {
for (auto* instr : comp->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kConditional) {
conditional_ops.push_back(instr);
}
}
}
for (HloInstruction* conditional_op : conditional_ops) {
TF_ASSIGN_OR_RETURN(
bool convert_result,
ConvertSpecialMove(conditional_op, is_layout_sensitive_));
changed |= convert_result;
}
}
if (changed) {
HloPassPipeline subpipeline("after_conditional_code_motion");
HloPassPipeline subpipeline(
"after_conditional_code_motion_after_convert_hoisting");
subpipeline.AddPass<HloDCE>();
subpipeline.AddPass<TupleSimplifier>();
subpipeline.AddPass<HloDCE>();
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));

View File

@ -23,7 +23,11 @@ limitations under the License.
namespace xla {
// HLO pass that moves identical ops out of conditional.
// ConditionalCodeMotion specializes in hoisting/rematerializing
// unconditional converts in the default mode.
// When pursue_full_conditional_code_motion_ is set to true, the
// full HLO pass moves identical ops out of a conditional in addition to moving
// converts.
// - The definition of identical are the shape of the operands are identical
// and their properties are identical.
// - Currently, only some types of instructions is supported.
@ -35,13 +39,18 @@ class ConditionalCodeMotion : public HloModulePass {
public:
// If is_layout_sensitive is true, then the hoist process preserves layout
// during identical comparison. Otherwise, layout is ignored.
explicit ConditionalCodeMotion(bool is_layout_sensitive = true)
: is_layout_sensitive_(is_layout_sensitive) {}
explicit ConditionalCodeMotion(
bool is_layout_sensitive = true,
bool pursue_full_conditional_code_motion = false)
: is_layout_sensitive_(is_layout_sensitive),
pursue_full_conditional_code_motion_(
pursue_full_conditional_code_motion) {}
absl::string_view name() const override { return "conditional-code-motion"; }
StatusOr<bool> Run(HloModule* module) override;
private:
const bool is_layout_sensitive_;
const bool pursue_full_conditional_code_motion_;
};
} // namespace xla

View File

@ -38,7 +38,86 @@ namespace {
using ConditionalCodeMotionTest = HloTestBase;
namespace op = xla::testing::opcode_matchers;
TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) {
TEST_F(ConditionalCodeMotionTest, MoveSubsetTupleOut) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.2894, %reshape.8493)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%add = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.3604, %add)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
get-first-index.2 = f32[2,512,364]{2,1,0} get-tuple-element(conditional), index=1
ROOT result = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(get-first-index, get-first-index.2)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
on_true {
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
}
on_false {
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
ROOT conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
}
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
absl::string_view hlo_string =
R"(
HloModule RemoveDotOpOut
@ -65,12 +144,16 @@ ENTRY main {
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index)
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
}
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
@ -123,7 +206,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
@ -181,7 +264,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");
@ -245,7 +328,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
@ -317,7 +400,7 @@ ENTRY main {
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
}
@ -390,7 +473,7 @@ ENTRY main {
}
)";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
ConditionalCodeMotion pass;
ConditionalCodeMotion pass(true, true);
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
const HloInstruction* conditional =
FindInstruction(module.get(), "conditional");

View File

@ -226,6 +226,11 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
dims_to_keep.push_back(dim);
}
}
// We support fast codegen for three cases:
// 1) Row reduction: (K, R)
// 2) Column reduction: (K, R, K)
// 3) "Batched" row reduction: (R, K, R)
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
dims_to_keep) &&
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),

View File

@ -77,8 +77,6 @@ class KernelThunk : public Thunk {
// Will be set by IrEmitterUnnested.
LaunchDimensions launch_dimensions_;
// Describes how to load this kernel. ExecuteOnStream reuses this loader
// specification for all executions.
mutable tensorflow::mutex mutex_;
// Loaded kernels for each `StreamExecutor`. Requires pointer stability of

View File

@ -3908,6 +3908,10 @@ const string& HloInstruction::outfeed_config() const {
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
}
void HloInstruction::set_outfeed_config(const string& config) {
return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
}
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
return Cast<HloCollectiveInstruction>(this)->replica_groups();
}

View File

@ -1755,6 +1755,9 @@ class HloInstruction {
// Returns the config for the Outfeed instruction.
const string& outfeed_config() const;
// Delegates to HloOutfeedInstruction::set_outfeed_config.
void set_outfeed_config(const string& config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const;

View File

@ -1141,6 +1141,7 @@ class HloOutfeedInstruction : public HloInstruction {
const Shape& outfeed_shape() const { return outfeed_shape_; }
// Returns the config for the Outfeed instruction.
const string& outfeed_config() const { return outfeed_config_; }
void set_outfeed_config(const string& config) { outfeed_config_ = config; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;

View File

@ -29,36 +29,78 @@ StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
// Propagate the operand subshapes.
for (int operand_idx = 0; operand_idx < instruction->operand_count();
++operand_idx) {
modified |=
PropagateSubshapes(instruction->operand(operand_idx)->shape(),
instruction->fused_parameter(operand_idx));
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(
instruction->operand(operand_idx)->shape())) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
modified |= Propagate(indexed_shape.index,
instruction->fused_parameter(operand_idx),
memory_space);
}
}
// Propagate output subshapes.
modified |= PropagateSubshapes(instruction->shape(),
instruction->fused_expression_root());
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(instruction->shape())) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
modified |=
Propagate(indexed_shape.index,
instruction->fused_expression_root(), memory_space);
}
}
}
}
return modified;
}
bool MemorySpacePropagation::PropagateSubshapes(
const Shape& caller_shape, const HloInstruction* callee_instruction) const {
bool MemorySpacePropagation::Propagate(ShapeIndexView index,
const HloInstruction* callee_instruction,
int64 memory_space) const {
bool modified = false;
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(caller_shape)) {
int64 memory_space = indexed_shape.shape.layout().memory_space();
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
callee_instruction, indexed_shape.index);
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
callee_instruction, index.ToShapeIndex());
for (const HloPosition& position : value.positions()) {
Shape* shape = ShapeUtil::GetMutableSubshape(
position.instruction->mutable_shape(), position.index);
if (shape->layout().memory_space() != memory_space) {
shape->mutable_layout()->set_memory_space(memory_space);
modified = true;
}
for (const HloPosition& position : value.positions()) {
HloInstruction* instruction = position.instruction;
Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(),
position.index);
if (shape->layout().memory_space() == memory_space) {
continue;
}
shape->mutable_layout()->set_memory_space(memory_space);
modified = true;
// For fusion outputs, propagate the memory space to the fusion root.
if (instruction->opcode() == HloOpcode::kFusion) {
Propagate(position.index, instruction->fused_expression_root(),
memory_space);
}
const HloInstruction* parent_fusion =
instruction->parent()->FusionInstruction();
// For nested fusion roots, pop one level up and propagate the memory space
// to the output of the calling fusion instruction.
if (instruction == instruction->parent()->root_instruction() &&
parent_fusion->parent()->IsFusionComputation()) {
Propagate(position.index, parent_fusion, memory_space);
}
// For nested fusion parameters, pop one level up and propagate the memory
// space to the operand of the calling fusion instruction.
if (instruction->opcode() == HloOpcode::kParameter &&
parent_fusion->parent()->IsFusionComputation()) {
const HloInstruction* fusion_operand =
parent_fusion->operand(instruction->parameter_number());
Propagate(position.index, fusion_operand, memory_space);
}
}
for (const HloUse& use : value.uses()) {
// For fusion uses, propagate the memory space to the fusion parameter.
if (use.instruction->opcode() == HloOpcode::kFusion) {
modified |= Propagate(
use.operand_index,
use.instruction->fused_parameter(use.operand_number), memory_space);
}
}
return modified;

View File

@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override;
private:
// Given the caller shape (operand or output) and its corresponding
// insturction in the fused computation (parameter or root), propagates the
// memory space to all the subshapes in the callee side. Returns true if the
// module is modified.
bool PropagateSubshapes(const Shape& caller_shape,
const HloInstruction* callee_instruction) const;
// Given the shape index (operand or output) and its corresponding instruction
// in the fused computation (parameter or root), propagates the memory space
// in the callee side. Returns true if the module is modified.
bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction,
int64 memory_space) const;
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
};

View File

@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) {
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
// Tests propagating the memory space to nested fusions on the input side.
absl::string_view hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[3,2]{0,1:T(128)} parameter(0)
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[3,2]{0,1:T(128)} parameter(0)
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
}
ENTRY %entry {
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
}
ENTRY %entry {
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
// Tests propagating the memory space to nested fusions on the output side.
absl::string_view hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[6]{0:T(128)} parameter(0)
ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)} parameter(0)
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
}
)";
absl::string_view expected_hlo_string = R"(
HloModule NestedFusion
%bitcast_fusion {
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
}
%fused_computation {
%param_1.3 = s32[1]{0:T(128)} parameter(1)
%constant.2 = s32[]{:T(128)} constant(-2147483648)
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
}
ENTRY %entry {
%param0 = s32[6]{0:T(128)} parameter(0)
%param1 = s32[1]{0:T(128)} parameter(1)
%param2 = s32[5]{0:T(128)} parameter(2)
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
%arg1 = s32[1]{0:T(128)} copy(%param1)
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnUnverifiedModule(hlo_string));
MemorySpacePropagation memory_space_propagation;
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
TF_EXPECT_OK(Verify(module.get()));
TF_ASSERT_OK_AND_ASSIGN(auto ref,
ParseAndReturnVerifiedModule(expected_hlo_string));
EXPECT_EQ(module->Hash(), ref->Hash());
}
} // namespace
} // namespace xla

View File

@ -552,7 +552,7 @@ class LowerToNVVMPass
// TODO(csigg): Remove once we support replacing non-root ops.
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
::mlir::gpu::YieldOp>();
if (failed(mlir::applyFullConversion(m, target, patterns, &converter))) {
if (failed(mlir::applyFullConversion(m, target, patterns))) {
signalPassFailure();
}
}

View File

@ -52,16 +52,26 @@ cc_library(
name = "test_macros_header",
testonly = True,
hdrs = ["test_macros.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:test",
"@com_google_absl//absl/strings",
],
)
# Generate a test_macros_${BACKEND} library per backend with the proper copts.
generate_backend_test_macros()
cc_library(
name = "manifest_checking_test",
testonly = True,
srcs = ["manifest_checking_test.cc"],
hdrs = ["manifest_checking_test.h"],
deps = [
":test_macros_header",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "test_utils",
srcs = ["test_utils.cc"],
@ -136,6 +146,7 @@ cc_library(
hdrs = ["hlo_test_base.h"],
deps = [
":literal_test_util",
":manifest_checking_test",
":test_utils",
":verified_hlo_module",
"//tensorflow/compiler/xla:debug_options_flags",
@ -193,6 +204,7 @@ cc_library(
srcs = ["client_library_test_base.cc"],
hdrs = ["client_library_test_base.h"],
deps = [
":manifest_checking_test",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
@ -273,6 +285,7 @@ cc_library(
hdrs = ["local_client_test_base.h"],
deps = [
":client_library_test_base",
":manifest_checking_test",
":verified_hlo_module",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",

View File

@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []):
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
],
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
],
)

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
@ -62,7 +63,7 @@ std::vector<TestCase> ExpandUseBfloat16(
}
// A client library test establishes an in-process XLA client connection.
class ClientLibraryTestBase : public ::testing::Test {
class ClientLibraryTestBase : public ManifestCheckingTest {
protected:
explicit ClientLibraryTestBase(se::Platform* platform = nullptr);

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -67,7 +68,7 @@ namespace xla {
// )
//
// For a more detailed example, see "../tests/sample_text_test.cc".
class HloTestBase : public ::testing::Test {
class HloTestBase : public ManifestCheckingTest {
public:
// Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h"
@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
};
// A base class for tests which exercise the LocalClient interface.
class LocalClientTestBase : public ::testing::Test {
class LocalClientTestBase : public ManifestCheckingTest {
protected:
struct EigenThreadPoolWrapper;
explicit LocalClientTestBase(se::Platform* platform = nullptr);

View File

@ -0,0 +1,129 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
#include <fstream>
#include <iterator>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
// disabled - a sequence of regexps.
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
ManifestT ReadManifest() {
ManifestT manifest;
absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath);
if (path.empty()) {
return manifest;
}
// Note: parens are required to disambiguate vs function decl.
std::ifstream file_stream((std::string(path)));
std::string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
for (std::string& line : lines) {
auto comment = line.find("//");
if (comment != std::string::npos) {
line = line.substr(0, comment);
}
if (line.empty()) {
continue;
}
absl::StripTrailingAsciiWhitespace(&line);
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (size_t i = 1; i < pieces.size(); ++i) {
platforms.push_back(pieces[i]);
}
}
return manifest;
}
} // namespace
void ManifestCheckingTest::SetUp() {
const testing::TestInfo* test_info =
testing::UnitTest::GetInstance()->current_test_info();
absl::string_view test_case_name = test_info->test_suite_name();
absl::string_view test_name = test_info->name();
VLOG(1) << "test_case_name: " << test_case_name;
VLOG(1) << "test_name: " << test_name;
// Remove the type suffix from the test case name.
if (const char* type_param = test_info->type_param()) {
VLOG(1) << "type_param: " << type_param;
size_t last_slash = test_case_name.rfind('/');
test_case_name = test_case_name.substr(0, last_slash);
VLOG(1) << "test_case_name: " << test_case_name;
}
// Remove the test instantiation name if it is present.
auto first_slash = test_case_name.find('/');
if (first_slash != test_case_name.npos) {
test_case_name.remove_prefix(first_slash + 1);
VLOG(1) << "test_case_name: " << test_case_name;
}
ManifestT manifest = ReadManifest();
// If the test name ends with a slash followed by one or more characters,
// strip that off.
auto last_slash = test_name.rfind('/');
if (last_slash != test_name.npos) {
test_name = test_name.substr(0, last_slash);
VLOG(1) << "test_name: " << test_name;
}
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
return;
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<std::string>& disabled_platforms = it->second;
auto platform_string = kTestPlatform;
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
GTEST_SKIP();
return;
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.
}
} // namespace xla

View File

@ -0,0 +1,35 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
#include "tensorflow/core/platform/test.h"
namespace xla {
// This class allows us to intercept the test name and use an arbitrary
// heuristic to decide whether the test case should be disabled. We
// determine whether the test case should be disabled by resolving the (test
// case name, test name) in a manifest file.
class ManifestCheckingTest : public ::testing::Test {
protected:
// This method runs before each test runs.
void SetUp() override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_

View File

@ -15,93 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include <fstream>
#include <streambuf>
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace {
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
// disabled - a sequence of regexps.
using ManifestT = absl::flat_hash_map<string, std::vector<string>>;
ManifestT ReadManifest() {
ManifestT manifest;
string path = XLA_DISABLED_MANIFEST;
if (path.empty()) {
return manifest;
}
std::ifstream file_stream(path);
// Note: parens are required to disambiguate vs function decl.
string contents((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
std::vector<string> lines = absl::StrSplit(contents, '\n');
for (string& line : lines) {
auto comment = line.find("//");
if (comment != string::npos) {
line = line.substr(0, comment);
}
if (line.empty()) {
continue;
}
absl::StripTrailingAsciiWhitespace(&line);
std::vector<string> pieces = absl::StrSplit(line, ' ');
CHECK_GE(pieces.size(), 1);
auto& platforms = manifest[pieces[0]];
for (int64 i = 1; i < pieces.size(); ++i) {
platforms.push_back(pieces[i]);
}
}
return manifest;
static bool InitModule() {
kDisabledManifestPath = XLA_DISABLED_MANIFEST;
VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath;
kTestPlatform = XLA_PLATFORM;
VLOG(1) << "kTestPlatform: " << kTestPlatform;
return false;
}
} // namespace
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
absl::string_view test_name) {
ManifestT manifest = ReadManifest();
// If the test name ends with a slash followed by one or more digits, strip
// that off; this is just a shard number, and matching on this would be
// unstable even if someone wanted to do it.
static LazyRE2 shard_num_pattern = {R"(/\d+$)"};
absl::string_view suffix;
if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) {
test_name.remove_suffix(suffix.size());
}
// First try full match: test_case_name.test_name
// If that fails, try to find just the test_case_name; this would disable all
// tests in the test case.
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
if (it == manifest.end()) {
it = manifest.find(test_case_name);
if (it == manifest.end()) {
return std::string(test_name);
}
}
// Expect a full match vs. one of the platform regexps to disable the test.
const std::vector<string>& disabled_platforms = it->second;
string platform_string = XLA_PLATFORM;
for (const auto& s : disabled_platforms) {
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
return absl::StrCat("DISABLED_", test_name);
}
}
// We didn't hit in the disabled manifest entries, so don't disable it.
return std::string(test_name);
}
static bool module_initialized = InitModule();
} // namespace xla

View File

@ -28,12 +28,6 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/test.h"
#define DISABLED_ON_CPU(X) X
#define DISABLED_ON_GPU(X) X
#define DISABLED_ON_GPU_ROCM(X) X
@ -79,117 +73,15 @@ limitations under the License.
namespace xla {
// Reads a disabled manifest file to resolve whether test cases should be
// disabled on a particular platform. For a test that should be disabled,
// returns DISABLED_ prepended to its name; otherwise returns the test name
// unmodified.
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
absl::string_view test_name);
inline const char *kDisabledManifestPath = nullptr;
inline const char *kTestPlatform = nullptr;
} // namespace xla
// This is the internal "gtest" class instantiation -- it is identical to the
// GTEST_TEST_ macro, except that we intercept the test name for potential
// modification by PrependDisabledIfIndicated. That file can use an arbitrary
// heuristic to decide whether the test case should be disabled, and we
// determine whether the test case should be disabled by resolving the (test
// case name, test name) in a manifest file.
#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
: public parent_class { \
public: \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
\
private: \
virtual void TestBody(); \
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)); \
}; \
\
::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)::test_info_ = \
::testing::RegisterTest( \
#test_case_name, \
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
.c_str(), \
nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \
return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \
}); \
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
// This is identical to the TEST_F macro from "gtest", but it potentially
// disables the test based on an external manifest file, DISABLED_MANIFEST.
//
// Per usual, you can see what tests are available via --gunit_list_tests and
// choose to run tests that have been disabled via the manifest via
// --gunit_also_run_disabled_tests.
#define XLA_TEST_F(test_fixture, test_name) \
XLA_GTEST_TEST_(test_fixture, test_name, test_fixture)
#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
// Likewise, this is identical to the TEST_P macro from "gtest", but
// potentially disables the test based on the DISABLED_MANIFEST file.
//
// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
// be properly expanded before the stringification occurs.
#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
: public test_case_name { \
public: \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
virtual void TestBody(); \
\
private: \
static int AddToRegistry() { \
::testing::UnitTest::GetInstance() \
->parameterized_test_registry() \
.GetTestCasePatternHolder<test_case_name>( \
#test_case_name, \
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestPattern( \
#test_case_name, \
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
.c_str(), \
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
test_case_name, test_name)>()); \
return 0; \
} \
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)); \
}; \
int GTEST_TEST_CLASS_NAME_(test_case_name, \
test_name)::gtest_registering_dummy_ = \
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
#define XLA_TEST_P(test_case_name, test_name) \
XLA_TEST_P_IMPL_(test_case_name, test_name)
// This is identical to the TEST_F macro from "gtest", but it potentially
// disables the test based on an external manifest file, DISABLED_MANIFEST.
#define XLA_TYPED_TEST(CaseName, TestName) \
template <typename gtest_TypeParam_> \
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
: public CaseName<gtest_TypeParam_> { \
private: \
typedef CaseName<gtest_TypeParam_> TestFixture; \
typedef gtest_TypeParam_ TypeParam; \
virtual void TestBody(); \
}; \
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
::testing::internal::TypeParameterizedTest< \
CaseName, \
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
TestName)>, \
GTEST_TYPE_PARAMS_(CaseName)>:: \
Register( \
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
#CaseName, \
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
0); \
template <typename gtest_TypeParam_> \
void GTEST_TEST_CLASS_NAME_(CaseName, \
TestName)<gtest_TypeParam_>::TestBody()
#define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName)
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_

View File

@ -719,6 +719,7 @@ tf_cuda_library(
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/core/util:__pkg__",
"//tensorflow/security/fuzzing:__subpackages__",
],
deps = [
":allocation_description_proto_cc",

View File

@ -153,16 +153,9 @@ limitations under the License.
#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
// Defines for sets of types.
// TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES.
//
// The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and
// thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in
// TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the
// TF binary size and performance.
#define TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_INTEGRAL_TYPES(m) \
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \
TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_FLOAT_TYPES(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
@ -174,10 +167,10 @@ limitations under the License.
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \
TF_CALL_int8(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)

View File

@ -238,11 +238,6 @@ int DataTypeSize(DataType dt) {
TF_CALL_qint16(CASE);
TF_CALL_quint16(CASE);
// uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
// don't want to define kernels for them at this stage to avoid binary
// bloat.
TF_CALL_uint32(CASE);
TF_CALL_uint64(CASE);
default:
return 0;
}

View File

@ -837,6 +837,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
"RandomPoissonV2",
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
// but it can't generate any observable side-effect.
@ -850,12 +851,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
// the same device_ordinal on the same host.
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
"EnqueueTPUEmbeddingSparseTensorBatch",
"EnqueueTPUEmbeddingRaggedTensorBatch",
// SaveV2 and RestoreV2 should be allowed to operate in parallel on
// multiple hosts.
"SaveV2", "RestoreV2"});
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
"EnqueueTPUEmbeddingRaggedTensorBatch"});
return exemption->contains(op);
}

View File

@ -4168,6 +4168,25 @@ tf_kernel_library(
]),
)
tf_cuda_cc_test(
name = "mlir_generated_op_gpu_tanh_test",
size = "small",
srcs = if_mlir_generated_gpu_kernels_enabled(["mlir_generated_op_gpu_tanh_test.cc"]),
tags = tf_cuda_tests_tags() + ["no_rocm"],
deps = [
":cwise_op",
":ops_testutil",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/common_runtime:device",
"//tensorflow/core/common_runtime:device_factory",
],
)
tf_kernel_library(
name = "nextafter_op",
prefix = "nextafter_op",
@ -4900,7 +4919,9 @@ tf_kernel_library(
"topk_op_gpu_double.cu.cc",
"topk_op_gpu_float.cu.cc",
"topk_op_gpu_half.cu.cc",
"topk_op_gpu_uint64.cu.cc",
"topk_op_gpu_int64.cu.cc",
"topk_op_gpu_uint32.cu.cc",
"topk_op_gpu_int32.cu.cc",
"topk_op_gpu_int16.cu.cc",
"topk_op_gpu_uint16.cu.cc",
@ -6802,7 +6823,8 @@ filegroup(
"cwise_op_minimum.cc",
"cwise_op_mul_1.cc",
"cwise_op_mul_2.cc",
"cwise_op_neg.cc",
"cwise_op_neg_1.cc",
"cwise_op_neg_2.cc",
"cwise_op_pow.cc",
"cwise_op_real.cc",
"cwise_op_reciprocal.cc",
@ -8780,7 +8802,8 @@ exports_files([
"cwise_op_mod.cc",
"cwise_op_mul_1.cc",
"cwise_op_mul_2.cc",
"cwise_op_neg.cc",
"cwise_op_neg_1.cc",
"cwise_op_neg_2.cc",
"cwise_op_not_equal_to_1.cc",
"cwise_op_not_equal_to_2.cc",
"cwise_op_round.cc",

View File

@ -116,8 +116,6 @@ REGISTER(qint8)
REGISTER(quint16)
REGISTER(qint16)
REGISTER(qint32)
REGISTER(uint32)
REGISTER(uint64)
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
!defined(__ANDROID_TYPES_FULL__)

View File

@ -208,8 +208,6 @@ REGISTER_CONCAT(qint8);
REGISTER_CONCAT(quint16);
REGISTER_CONCAT(qint16);
REGISTER_CONCAT(qint32);
REGISTER_CONCAT(uint32);
REGISTER_CONCAT(uint64);
#undef REGISTER_CONCAT

View File

@ -211,7 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
// the conversion from uint8 to quint8.
REGISTER_KERNEL(CPU, quint8);
REGISTER_KERNEL(CPU, quint16);
REGISTER_KERNEL(CPU, uint32);
#undef REGISTER_CPU_KERNEL
#ifdef TENSORFLOW_USE_SYCL

View File

@ -101,27 +101,21 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
REGISTER_CPU_SWITCH(uint64);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_SWITCH(uint64);
TF_CALL_variant(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_SWITCH);
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
TF_CALL_bool(REGISTER_GPU_SWITCH);
TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
#undef REGISTER_CPU_SWITCH
#undef REGISTER_CPU_REF_SWITCH
#undef REGISTER_GPU_SWITCH
#undef REGISTER_GPU_REF_SWITCH
// Special GPU kernels for int32, string & resource handles. Requiring all
// inputs and outputs to be in host memory.
// TODO(b/25387198): Also enable int32 in device memory.
// Special GPU kernels for int32 and string.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
#define REGISTER_GPU_HOST_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_GPU) \
@ -151,6 +145,8 @@ TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_REF_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(bool);
REGISTER_GPU_HOST_REF_KERNEL(bool);
REGISTER_GPU_HOST_KERNEL(tstring);
REGISTER_GPU_HOST_REF_KERNEL(tstring);
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
@ -311,7 +307,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
REGISTER_GPU_KERNEL(bool);
REGISTER_GPU_REF_KERNEL(bool);
REGISTER_GPU_KERNEL(uint64);
TF_CALL_variant(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL

View File

@ -19,8 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_UNARY7(neg, Eigen::half, float, double, int32, int64, complex64,
complex128);
DEFINE_UNARY4(neg, int8, int16, int32, int64);
DEFINE_UNARY6(neg, Eigen::half, float, double, bfloat16, complex64, complex128);
} // namespace functor
} // namespace tensorflow

View File

@ -16,8 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER8(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
complex64, int64, complex128, bfloat16);
REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
#ifdef TENSORFLOW_USE_SYCL
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
@ -30,8 +29,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg")
#endif // TENSORFLOW_USE_SYCL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64,
complex64, complex128);
REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
REGISTER6(UnaryOp, CPU, "Neg", functor::neg, Eigen::half, float, double,
bfloat16, complex64, complex128);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
bfloat16, complex64, complex128);
#endif
} // namespace tensorflow

View File

@ -220,8 +220,6 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
break;
TF_CALL_NUMBER_TYPES(CASE);
TF_CALL_tstring(CASE);
TF_CALL_uint32(CASE);
TF_CALL_uint64(CASE);
// TODO(feihugis): figure out how to support variant tensors.
#undef CASE
default:

View File

@ -98,7 +98,6 @@ typedef Eigen::SyclDevice SYCLDevice;
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
// uint32 not included in ALL_TYPES
TF_CALL_uint32(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
// quint16 not included in QUANTIZIED_TYPES
TF_CALL_quint16(REGISTER_KERNELS);

View File

@ -164,8 +164,6 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared {
DynamicPartitionOp<T>)
TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
// For partitioning fingerprints.
TF_CALL_uint64(REGISTER_DYNAMIC_PARTITION);
#undef REGISTER_DYNAMIC_PARTITION
} // namespace tensorflow

View File

@ -45,6 +45,8 @@ DEFINE_SETZERO_CPU(Eigen::half);
DEFINE_SETZERO_CPU(bfloat16);
DEFINE_SETZERO_CPU(float);
DEFINE_SETZERO_CPU(double);
DEFINE_SETZERO_CPU(uint32);
DEFINE_SETZERO_CPU(uint64);
DEFINE_SETZERO_CPU(uint8);
DEFINE_SETZERO_CPU(int8);
DEFINE_SETZERO_CPU(uint16);
@ -96,6 +98,8 @@ DEFINE_SETONE_CPU(Eigen::half);
DEFINE_SETONE_CPU(bfloat16);
DEFINE_SETONE_CPU(float);
DEFINE_SETONE_CPU(double);
DEFINE_SETONE_CPU(uint32);
DEFINE_SETONE_CPU(uint64);
DEFINE_SETONE_CPU(uint8);
DEFINE_SETONE_CPU(int8);
DEFINE_SETONE_CPU(uint16);
@ -137,7 +141,6 @@ struct FillFunctor<Eigen::ThreadPoolDevice, T> {
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
DEFINE_FILL_CPU(quint8);
DEFINE_FILL_CPU(quint16);
DEFINE_FILL_CPU(uint32);
#undef DEFINE_FILL_CPU
#ifdef TENSORFLOW_USE_SYCL

View File

@ -211,8 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
TF_CALL_quint16(REGISTER_GATHER_CPU);
TF_CALL_qint16(REGISTER_GATHER_CPU);
TF_CALL_uint32(REGISTER_GATHER_CPU);
TF_CALL_uint64(REGISTER_GATHER_CPU);
#undef REGISTER_GATHER_CPU

View File

@ -122,7 +122,6 @@ REGISTER_SYCL_HOST_KERNEL(bool);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(Variant);
TF_CALL_uint32(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL

View File

@ -178,6 +178,9 @@ class MklAddNOp : public OpKernel {
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
}
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine));
// Create memory descriptor for MKL-DNN.
// If all input in Tensorflow format, create block memory descriptor,
// else convert TF format to MKL memory descriptor
@ -215,6 +218,7 @@ class MklAddNOp : public OpKernel {
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
#endif
src.SetUsrMem(md, &src_tensor);
src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
inputs.push_back(src.GetOpMem());
}
@ -240,11 +244,10 @@ class MklAddNOp : public OpKernel {
}
AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
output_mkl_shape);
dst.SetUsrMemDataHandle(dst_tensor);
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
// Create Sum op, and submit net for execution.
std::vector<primitive> net;
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
#ifdef ENABLE_MKLDNN_V1
mkldnn::sum sum_op(sum_pd);
std::unordered_map<int, memory> net_args = {

View File

@ -281,11 +281,19 @@ class MklConcatFwdPrimitive : public MklPrimitive {
std::shared_ptr<stream> fwd_stream) {
DCHECK_EQ(in_data.size(), context_.data_mem.size());
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.data_mem_shdptr[i]->set_data_handle(
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
}
context_.dst_mem->set_data_handle(
static_cast<void*>(dst_data.get_data_handle()), *fwd_stream);
#else
context_.data_mem_shdptr[i]->set_data_handle(
static_cast<void*>(in_data[i].get_data_handle()));
}
context_.dst_mem->set_data_handle(
static_cast<void*>(dst_data.get_data_handle()));
#endif // ENABLE_MKLDNN_THREADPOOL
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
context_.data_mem[i] = *context_.data_mem_shdptr[i];
@ -788,11 +796,13 @@ class MklConcatOp : public OpKernel {
dnn_shape_dst);
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
if (dnn_shape_dst.IsMklTensor())
dst_md = dnn_shape_dst.GetMklLayout();
dst.SetUsrMem(dst_md, dst_tensor);
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
#ifdef ENABLE_MKLDNN_V1
auto concat_op = concat(concat_pd);
std::unordered_map<int, memory> net_args = {
@ -830,9 +840,10 @@ class MklConcatOp : public OpKernel {
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
: dst_md;
dst.SetUsrMem(dst_md, dst_tensor);
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
dst.SetUsrMem(dst_md, dst_tensor);
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
// Execute concat
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
fwd_cpu_stream);

View File

@ -75,6 +75,9 @@ class MklDequantizeOp : public OpKernel {
MklDnnData<T> src(&cpu_engine);
MklDnnData<float> dst(&cpu_engine);
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine));
// If input is in MKL layout, then simply grab input layout; otherwise,
// construct input TF layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
@ -85,6 +88,7 @@ class MklDequantizeOp : public OpKernel {
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
src.SetUsrMem(src_md, &src_tensor);
src.SetUsrMemDataHandle(&src_tensor, reorder_stream);
Tensor* output_tensor = nullptr;
MklDnnShape output_mkl_shape;
@ -129,6 +133,7 @@ class MklDequantizeOp : public OpKernel {
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
output_mkl_shape);
dst.SetUsrMem(dst_md, output_tensor);
dst.SetUsrMemDataHandle(output_tensor, reorder_stream);
// The quantization logic here for mode SCALED is similar to the logic
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
@ -155,8 +160,6 @@ class MklDequantizeOp : public OpKernel {
// Also it does not define round_nearest (enum).
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
#endif // !ENABLE_MKLDNN_V1
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine));
std::vector<primitive> net;
// Create reorder primitive and then execute.

View File

@ -137,6 +137,7 @@ class MklLRNOp : public OpKernel {
// that input is in NHWC layout with Channel being the last dimension.
src_dnn_data.SetUsrMem(src_md, &src_tensor);
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
src_dnn_data.SetUsrMemDataHandle(&src_tensor, fwd_stream_);
// dst_dnn_data has the same shape as input.
dst_dnn_data.SetUsrMem(src_md);
@ -157,7 +158,7 @@ class MklLRNOp : public OpKernel {
&output_tensor);
OP_REQUIRES_OK(context, context->status());
DCHECK(output_tensor != nullptr);
dst_dnn_data.SetUsrMemDataHandle(output_tensor);
dst_dnn_data.SetUsrMemDataHandle(output_tensor, fwd_stream_);
// Handle workspace required for MKL-DNN.
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
@ -393,6 +394,7 @@ class MklLRNGradOp : public OpKernel {
orig_input_dnn_shape.GetSizesAsMklDnnDims();
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
orig_input_dnn_data.SetUsrMemDataHandle(&orig_input_tensor, bwd_stream_);
// output_dnn_data has the same shape as original input
output_dnn_data.SetUsrMem(orig_input_md);
@ -421,7 +423,7 @@ class MklLRNGradOp : public OpKernel {
orig_input_format, &output_tensor);
OP_REQUIRES_OK(context, context->status());
DCHECK(output_tensor != nullptr);
output_dnn_data.SetUsrMemDataHandle(output_tensor);
output_dnn_data.SetUsrMemDataHandle(output_tensor, bwd_stream_);
// Create LRN primitive and add it to the net
// At this point, workspace is enabled, so we don't need

View File

@ -137,6 +137,7 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
memory::dims out_strides =
ReorderStrides(CalculateTFStrides(out_dims), perm);
std::shared_ptr<stream> transpose_stream;
in.SetUsrMem(in_dims, in_strides, &in_tensor);
// Output dimensions are same as input dimensions. We adjust the layout
// using strides.
@ -144,16 +145,16 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
std::shared_ptr<stream> transpose_stream;
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
out.SetUsrMemDataHandle(out_tensor, transpose_stream);
net.push_back(*(prim->GetPrimitive()));
std::vector<MemoryArgsMap> net_args;
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
execute_primitives(net, transpose_stream, net_args);
#else
std::shared_ptr<stream> transpose_stream;
transpose_stream.reset(new CPU_STREAM(cpu_engine));
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
transpose_stream->submit(net).wait();

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class MlirGeneratedOpGpuTanhTest : public OpsTestBase {
protected:
void SetUp() override {
std::unique_ptr<tensorflow::Device> device_gpu(
tensorflow::DeviceFactory::NewDevice("GPU", {},
"/job:a/replica:0/task:0"));
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
}
template <typename T, typename RT = T>
void RunTanhOp(std::initializer_list<T> input) {
TensorShape shape({2, 7});
TF_ASSERT_OK(NodeDefBuilder("tanh_op", "Tanh")
.Input(FakeInput(DataTypeToEnum<T>::v()))
.Attr("T", DataTypeToEnum<T>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<T>(shape, input);
TF_ASSERT_OK(RunOpKernel());
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
std::vector<T> expected;
expected.reserve(input.size());
for (const T& inp : input) {
expected.push_back(static_cast<T>(std::tanh(static_cast<RT>(inp))));
}
test::FillValues<T>(&expected_tensor, expected);
test::ExpectClose(expected_tensor, *GetOutput(0));
}
};
TEST_F(MlirGeneratedOpGpuTanhTest, TanhFloat) {
RunTanhOp<float>({-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f,
0.5f, 0.7f, 0.9f, 9.0f, 18.0f});
}
TEST_F(MlirGeneratedOpGpuTanhTest, TanhDouble) {
RunTanhOp<double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
0.7, 0.9, 9.0, 18.0});
}
TEST_F(MlirGeneratedOpGpuTanhTest, TanhHalf) {
RunTanhOp<Eigen::half, float>(
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)});
}
} // namespace
} // end namespace tensorflow

View File

@ -296,8 +296,6 @@ TF_CALL_tstring(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_quint16(REGISTER_CPU_KERNEL);
TF_CALL_qint16(REGISTER_CPU_KERNEL);
TF_CALL_uint32(REGISTER_CPU_KERNEL);
TF_CALL_uint64(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE

View File

@ -308,8 +308,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
TF_CALL_quint16(REGISTER_KERNELS);
TF_CALL_qint16(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS);
TF_CALL_uint64(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
} // namespace tensorflow

View File

@ -561,8 +561,6 @@ TF_CALL_string(REGISTER_CPU_KERNEL);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
TF_CALL_quint16(REGISTER_CPU_KERNEL);
TF_CALL_qint16(REGISTER_CPU_KERNEL);
TF_CALL_uint32(REGISTER_CPU_KERNEL);
TF_CALL_uint64(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL

View File

@ -213,8 +213,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
TF_CALL_quint16(REGISTER_KERNELS);
TF_CALL_qint16(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS);
TF_CALL_uint64(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
} // namespace tensorflow

View File

@ -35,6 +35,7 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
static constexpr int VectorSizeElements = 8;
namespace functor {
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
@ -93,6 +94,66 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
}
}
__global__ void ReluGradHalfKernelVector(
const Eigen::half* __restrict__ gradient,
const Eigen::half* __restrict__ feature, Eigen::half* __restrict__ backprop,
int32 count) {
int32 half8_count = count / VectorSizeElements;
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < half8_count) {
// Cast to xx_h8 for vector load and store.
float4 gradient_h8 = reinterpret_cast<const float4*>(gradient)[index];
float4 feature_h8 = reinterpret_cast<const float4*>(feature)[index];
float4* p_backprop_h8 = reinterpret_cast<float4*>(backprop) + index;
half2* gradient_h2 = reinterpret_cast<half2*>(&gradient_h8);
half2* feature_h2 = reinterpret_cast<half2*>(&feature_h8);
float4 backprop_h8;
half2* p_backprop_h2 = reinterpret_cast<half2*>(&backprop_h8);
// Fast path, when half2 primitives are available.
#if __CUDA_ARCH__ >= 530
const half2 kZeroH2 = __float2half2_rn(0.f);
#endif
for (int i = 0; i < VectorSizeElements / 2; i++) {
#if __CUDA_ARCH__ >= 530
// mask = (feature > 0)
half2 mask_h2 = __hgt2(feature_h2[i], kZeroH2);
// backprop = mask * gradient
half2 backprop_h2 = __hmul2(mask_h2, gradient_h2[i]);
#else
// Fall back: convert half2 to float2 for processing.
float2 feature_f2 = __half22float2(feature_h2[i]);
float2 gradient_f2 = __half22float2(gradient_h2[i]);
float2 backprop_f2 =
make_float2((feature_f2.x > 0.0f) ? float(gradient_f2.x) : 0.0f,
(feature_f2.y > 0.0f) ? float(gradient_f2.y) : 0.0f);
// Convert back to half2.
half2 backprop_h2 = __float22half2_rn(backprop_f2);
#endif
p_backprop_h2[i] = backprop_h2;
}
// Write back the result.
*p_backprop_h8 = backprop_h8;
}
int remaining_count = (count % VectorSizeElements);
if (index < remaining_count) {
// Use first threads to process the remaining elements.
Eigen::half grad_h = gradient[half8_count * VectorSizeElements + index];
Eigen::half feature_h = feature[half8_count * VectorSizeElements + index];
float grad_f = static_cast<float>(grad_h);
float feature_f = static_cast<float>(feature_h);
float backprop_f = (feature_f > 0) ? grad_f : 0;
Eigen::half backprop_h(backprop_f);
backprop[half8_count * VectorSizeElements + index] = backprop_h;
}
}
template <typename Device>
struct ReluGrad<Device, Eigen::half> {
// Computes ReluGrad backprop.
@ -108,15 +169,28 @@ struct ReluGrad<Device, Eigen::half> {
// NOTE: When the activation is exactly zero, we do not propagate the
// associated gradient value. This allows the output of the Relu to be used,
// as well as its input.
auto gradient_ptr = reinterpret_cast<uintptr_t>(gradient.data());
auto feature_ptr = reinterpret_cast<uintptr_t>(feature.data());
auto backprop_ptr = reinterpret_cast<uintptr_t>(backprop.data());
bool aligned = gradient_ptr % 16 == 0 && feature_ptr % 16 == 0 &&
backprop_ptr % 16 == 0;
int32 count = gradient.size();
if (count == 0) return;
int32 half2_count = Eigen::divup(count, 2);
constexpr int32 kThreadInBlock = 512;
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
if (count == 0) return;
if (aligned) {
int32 half8_count = Eigen::divup(count, VectorSizeElements);
int32 kBlock = Eigen::divup(half8_count, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernelVector, kBlock, kThreadInBlock, 0, d.stream(),
gradient.data(), feature.data(), backprop.data(), count));
} else {
int32 half2_count = Eigen::divup(count, 2);
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
TF_CHECK_OK(GpuLaunchKernel(
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
}
}
};

View File

@ -512,7 +512,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -43,7 +43,6 @@ void Split<Eigen::ThreadPoolDevice, T, NDims>::operator()(
TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
DEFINE_CPU_KERNELS(quint8)
DEFINE_CPU_KERNELS(uint64)
#ifdef TENSORFLOW_USE_SYCL
template <typename T, int NDims>

View File

@ -404,7 +404,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
TF_CALL_ALL_TYPES(REGISTER_SPLIT);
REGISTER_SPLIT(quint8);
REGISTER_SPLIT(uint64);
#undef REGISTER_SPLIT

View File

@ -440,8 +440,6 @@ class StridedSliceAssignOp : public OpKernel {
StridedSliceAssignOp<CPUDevice, type, true>)
TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
TF_CALL_uint32(REGISTER_STRIDED_SLICE);
TF_CALL_uint64(REGISTER_STRIDED_SLICE);
#undef REGISTER_STRIDED_SLICE

View File

@ -287,8 +287,6 @@ TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
TF_CALL_uint32(DECLARE_FOR_N_CPU);
TF_CALL_uint64(DECLARE_FOR_N_CPU);
#ifdef TENSORFLOW_USE_SYCL
#define PREVENT_FOR_N_SYCL(T) \

View File

@ -52,7 +52,8 @@ class SummaryScalarOp : public OpKernel {
Summary s;
for (int i = 0; i < Ttags.size(); i++) {
Summary::Value* v = s.add_value();
v->set_tag(string(Ttags(i))); // NOLINT
const tstring& Ttags_i = Ttags(i);
v->set_tag(Ttags_i.data(), Ttags_i.size());
v->set_simple_value(float(Tvalues(i)));
}
@ -102,7 +103,8 @@ class SummaryHistoOp : public OpKernel {
Summary s;
Summary::Value* v = s.add_value();
v->set_tag(string(tags.scalar<tstring>()())); // NOLINT
const tstring& tags0 = tags.scalar<tstring>()();
v->set_tag(tags0.data(), tags0.size());
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
Tensor* summary_tensor = nullptr;

View File

@ -258,7 +258,6 @@ namespace functor {
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
TF_CALL_uint32(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
@ -276,7 +275,6 @@ TF_CALL_uint32(DECLARE_GPU_SPEC);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
TF_CALL_uint32(REGISTER_KERNELS)
#undef REGISTER_KERNELS
#endif // end GOOGLE_CUDA

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
#include "tensorflow/core/kernels/topk_op_gpu.h"
namespace tensorflow {
using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, uint32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
#include "tensorflow/core/kernels/topk_op_gpu.h"
namespace tensorflow {
using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, uint64>;
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -252,7 +252,8 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_uint32(REGISTER_GPU_KERNELS);
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -78,3 +78,32 @@ op {
}
}
}
op {
name: "Acos"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

View File

@ -78,3 +78,32 @@ op {
}
}
}
op {
name: "Asin"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

View File

@ -78,3 +78,32 @@ op {
}
}
}
op {
name: "Atan"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

View File

@ -168,3 +168,32 @@ op {
}
}
}
op {
name: "Inv"
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "y"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_BFLOAT16
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
}

Some files were not shown because too many files have changed in this diff Show More