Merge branch 'master' of github.com:ashahba/tensorflow into ashahba/fix-horovod
This commit is contained in:
commit
c21a9d30f7
@ -198,6 +198,9 @@ LogicalResult MarkUncompilableOps(
|
||||
int outside_compiled_cluster_counter = 0;
|
||||
block->walk([&](Operation* op) {
|
||||
if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
|
||||
VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
|
||||
<< " isn't compilable, adding outside_compilation attr. "
|
||||
"This op will automatically be placed on CPU.";
|
||||
op->setAttr(
|
||||
kXlaOutsideCompilationAttr,
|
||||
StringAttr::get(
|
||||
|
@ -250,6 +250,7 @@ cc_library(
|
||||
],
|
||||
hdrs = ["mlir_hlo_to_hlo.h"],
|
||||
deps = [
|
||||
":attribute_exporter",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/mlir:name_utils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
@ -338,6 +339,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attribute_exporter",
|
||||
srcs = ["attribute_exporter.cc"],
|
||||
hdrs = ["attribute_exporter.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:dnn",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "translate_cl_options",
|
||||
srcs = ["xla_mlir_translate_cl.cc"],
|
||||
|
87
tensorflow/compiler/mlir/xla/attribute_exporter.cc
Normal file
87
tensorflow/compiler/mlir/xla/attribute_exporter.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input) {
|
||||
ConvolutionDimensionNumbers output;
|
||||
|
||||
output.set_input_batch_dimension(
|
||||
input.input_batch_dimension().getValue().getSExtValue());
|
||||
output.set_input_feature_dimension(
|
||||
input.input_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.input_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_input_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_kernel_input_feature_dimension(
|
||||
input.kernel_input_feature_dimension().getValue().getSExtValue());
|
||||
output.set_kernel_output_feature_dimension(
|
||||
input.kernel_output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.kernel_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_kernel_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_output_batch_dimension(
|
||||
input.output_batch_dimension().getValue().getSExtValue());
|
||||
output.set_output_feature_dimension(
|
||||
input.output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (auto v : input.output_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_output_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
|
||||
llvm::StringRef input) {
|
||||
llvm::Optional<mlir::lmhlo_gpu::Activation> activation =
|
||||
mlir::lmhlo_gpu::symbolizeActivation(input);
|
||||
if (!activation) {
|
||||
return InternalError("Unexpected activation");
|
||||
}
|
||||
|
||||
switch (activation.getValue()) {
|
||||
case mlir::lmhlo_gpu::Activation::None:
|
||||
return stream_executor::dnn::kNone;
|
||||
case mlir::lmhlo_gpu::Activation::Sigmoid:
|
||||
return stream_executor::dnn::kSigmoid;
|
||||
case mlir::lmhlo_gpu::Activation::Tanh:
|
||||
return stream_executor::dnn::kTanh;
|
||||
case mlir::lmhlo_gpu::Activation::Relu:
|
||||
return stream_executor::dnn::kRelu;
|
||||
case mlir::lmhlo_gpu::Activation::Relu6:
|
||||
return stream_executor::dnn::kRelu6;
|
||||
case mlir::lmhlo_gpu::Activation::ReluX:
|
||||
return stream_executor::dnn::kReluX;
|
||||
case mlir::lmhlo_gpu::Activation::BandPass:
|
||||
return stream_executor::dnn::kBandPass;
|
||||
default:
|
||||
return InternalError("Unexpected activation");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
37
tensorflow/compiler/mlir/xla/attribute_exporter.h
Normal file
37
tensorflow/compiler/mlir/xla/attribute_exporter.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Converts the conv dimensions attribute to XLA HLO.
|
||||
ConvolutionDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input);
|
||||
|
||||
StatusOr<stream_executor::dnn::ActivationMode> ConvertConvActivationMode(
|
||||
llvm::StringRef input);
|
||||
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
@ -297,36 +298,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
|
||||
|
||||
static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
|
||||
mlir::mhlo::ConvDimensionNumbers input) {
|
||||
xla::ConvolutionDimensionNumbers output;
|
||||
|
||||
output.set_input_batch_dimension(
|
||||
input.input_batch_dimension().getValue().getSExtValue());
|
||||
output.set_input_feature_dimension(
|
||||
input.input_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.input_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_input_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_kernel_input_feature_dimension(
|
||||
input.kernel_input_feature_dimension().getValue().getSExtValue());
|
||||
output.set_kernel_output_feature_dimension(
|
||||
input.kernel_output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.kernel_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_kernel_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
output.set_output_batch_dimension(
|
||||
input.output_batch_dimension().getValue().getSExtValue());
|
||||
output.set_output_feature_dimension(
|
||||
input.output_feature_dimension().getValue().getSExtValue());
|
||||
|
||||
for (int64 v : input.output_spatial_dimensions().getValues<int64>()) {
|
||||
output.add_output_spatial_dimensions(v);
|
||||
}
|
||||
|
||||
return output;
|
||||
return xla::ConvertConvDimensionNumbers(input);
|
||||
}
|
||||
|
||||
xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
|
||||
|
@ -179,7 +179,7 @@ Status ConvBackpropComputeDimensionsV2XlaShapes(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
absl::Span<const DataType> GetXlaConvTypes() {
|
||||
std::vector<DataType> GetXlaConvTypes() {
|
||||
return {DT_FLOAT, DT_BFLOAT16, DT_HALF, DT_DOUBLE};
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ namespace tensorflow {
|
||||
|
||||
// We don't support integers for convolutions, so we list the supported types
|
||||
// here.
|
||||
absl::Span<const DataType> GetXlaConvTypes();
|
||||
std::vector<DataType> GetXlaConvTypes();
|
||||
|
||||
// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
|
||||
// convolution.
|
||||
|
@ -270,6 +270,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:attribute_exporter",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_module_importer",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_utils",
|
||||
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
@ -255,15 +256,17 @@ Status RunGpuConvImpl(const GpuConvParams& params,
|
||||
} // anonymous namespace
|
||||
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call) {
|
||||
const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
|
||||
GpuConvConfig config;
|
||||
|
||||
config.input_type = cudnn_call->operand(0)->shape().element_type();
|
||||
config.output_type = cudnn_call->shape().tuple_shapes(0).element_type();
|
||||
const Shape& operand0_shape = desc.operand0_shape;
|
||||
const Shape& operand1_shape = desc.operand1_shape;
|
||||
const Shape& result_shape = desc.result_shape;
|
||||
const CudnnConvBackendConfig& backend_config = desc.backend_config;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
|
||||
cudnn_call->backend_config<CudnnConvBackendConfig>());
|
||||
TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call));
|
||||
config.input_type = operand0_shape.element_type();
|
||||
config.output_type = result_shape.element_type();
|
||||
config.kind = desc.kind;
|
||||
|
||||
// The third field is scratch size stored from conv_algorithm_picker
|
||||
// The operand is added to the shape field of the conv instruction
|
||||
@ -271,13 +274,9 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
config.algorithm = se::dnn::AlgorithmConfig(
|
||||
se::dnn::AlgorithmDesc(backend_config.algorithm(),
|
||||
backend_config.tensor_ops_enabled()),
|
||||
cudnn_call->shape().tuple_shapes(1).dimensions(0));
|
||||
desc.scratch_size);
|
||||
config.conv_result_scale = backend_config.conv_result_scale();
|
||||
|
||||
Shape operand0_shape = cudnn_call->operand(0)->shape();
|
||||
Shape operand1_shape = cudnn_call->operand(1)->shape();
|
||||
Shape result_shape = cudnn_call->shape().tuple_shapes(0);
|
||||
|
||||
switch (config.kind) {
|
||||
case CudnnConvKind::kForward:
|
||||
case CudnnConvKind::kForwardActivation:
|
||||
@ -311,9 +310,8 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
fusion.side_input_scale = backend_config.side_input_scale();
|
||||
}
|
||||
|
||||
const Window& window = cudnn_call->window();
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
cudnn_call->convolution_dimension_numbers();
|
||||
const Window& window = desc.window;
|
||||
const ConvolutionDimensionNumbers& dnums = desc.dnums;
|
||||
|
||||
VLOG(3) << "Convolution Algorithm: "
|
||||
<< config.algorithm.algorithm()->algo_id();
|
||||
@ -330,7 +328,7 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
|
||||
|
||||
const int num_dimensions = window.dimensions_size();
|
||||
CHECK_LE(num_dimensions, 3) << cudnn_call->ToString();
|
||||
CHECK_LE(num_dimensions, 3) << inst_as_string;
|
||||
|
||||
// cuDNN does not support 1D convolutions. We therefore express 1D
|
||||
// convolutions as 2D convolutions where the first spatial dimension is 1.
|
||||
@ -344,18 +342,18 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
|
||||
|
||||
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
for (const WindowDimension& dim : window.dimensions()) {
|
||||
CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString();
|
||||
CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString();
|
||||
CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
|
||||
CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
|
||||
CHECK_EQ(dim.base_dilation(), 1)
|
||||
<< "cudnn does not support base dilation; it "
|
||||
"must be made explicit with a kPad: "
|
||||
<< cudnn_call->ToString();
|
||||
<< inst_as_string;
|
||||
}
|
||||
|
||||
// cuDNN's convolution APIs support the BDYX layout for activations/output and
|
||||
@ -364,43 +362,43 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
FilterLayout filter_dl;
|
||||
DataLayout output_dl;
|
||||
|
||||
const Shape* input_shape = &config.input_shape;
|
||||
const Shape* filter_shape = &config.filter_shape;
|
||||
const Shape* output_shape = &config.output_shape;
|
||||
const Shape& input_shape = config.input_shape;
|
||||
const Shape& filter_shape = config.filter_shape;
|
||||
const Shape& output_shape = config.output_shape;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
|
||||
XlaConvLayoutsToStreamExecutorLayouts(
|
||||
dnums, input_shape->layout(), filter_shape->layout(),
|
||||
output_shape->layout()));
|
||||
dnums, input_shape.layout(), filter_shape.layout(),
|
||||
output_shape.layout()));
|
||||
|
||||
BatchDescriptor& input_descriptor = config.input_descriptor;
|
||||
input_descriptor = BatchDescriptor(effective_num_dimensions);
|
||||
input_descriptor.set_layout(input_dl)
|
||||
.set_feature_map_count(
|
||||
input_shape->dimensions(dnums.input_feature_dimension()))
|
||||
.set_count(input_shape->dimensions(dnums.input_batch_dimension()));
|
||||
input_shape.dimensions(dnums.input_feature_dimension()))
|
||||
.set_count(input_shape.dimensions(dnums.input_batch_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
// Note that the dimensions are reversed. The same holds below.
|
||||
input_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
input_shape->dimensions(dnums.input_spatial_dimensions(dim)));
|
||||
input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
FilterDescriptor& filter_descriptor = config.filter_descriptor;
|
||||
filter_descriptor = FilterDescriptor(effective_num_dimensions);
|
||||
filter_descriptor.set_layout(filter_dl)
|
||||
.set_input_feature_map_count(
|
||||
filter_shape->dimensions(dnums.kernel_input_feature_dimension()))
|
||||
filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
|
||||
.set_output_feature_map_count(
|
||||
filter_shape->dimensions(dnums.kernel_output_feature_dimension()));
|
||||
filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
filter_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim)));
|
||||
filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
|
||||
config.conv_desc.set_group_count(cudnn_call->feature_group_count());
|
||||
config.conv_desc.set_group_count(desc.feature_group_count);
|
||||
config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
config.conv_desc
|
||||
@ -419,12 +417,12 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
output_descriptor = BatchDescriptor(effective_num_dimensions);
|
||||
output_descriptor.set_layout(output_dl)
|
||||
.set_feature_map_count(
|
||||
output_shape->dimensions(dnums.output_feature_dimension()))
|
||||
.set_count(output_shape->dimensions(dnums.output_batch_dimension()));
|
||||
output_shape.dimensions(dnums.output_feature_dimension()))
|
||||
.set_count(output_shape.dimensions(dnums.output_batch_dimension()));
|
||||
for (int dim = 0; dim < num_dimensions; ++dim) {
|
||||
output_descriptor.set_spatial_dim(
|
||||
static_cast<DimIndex>(effective_num_dimensions - dim - 1),
|
||||
output_shape->dimensions(dnums.output_spatial_dimensions(dim)));
|
||||
output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
|
||||
}
|
||||
|
||||
// Add a singleton dimension in the 1D convolution case.
|
||||
@ -439,6 +437,23 @@ StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
return config;
|
||||
}
|
||||
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call) {
|
||||
GpuConvDescriptor descriptor;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
|
||||
TF_ASSIGN_OR_RETURN(descriptor.backend_config,
|
||||
cudnn_call->backend_config<CudnnConvBackendConfig>());
|
||||
descriptor.operand0_shape = cudnn_call->operand(0)->shape();
|
||||
descriptor.operand1_shape = cudnn_call->operand(1)->shape();
|
||||
descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
|
||||
descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
|
||||
descriptor.window = cudnn_call->window();
|
||||
descriptor.dnums = cudnn_call->convolution_dimension_numbers();
|
||||
descriptor.feature_group_count = cudnn_call->feature_group_count();
|
||||
return GetGpuConvConfig(descriptor, cudnn_call->ToString());
|
||||
}
|
||||
|
||||
StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
const GpuConvConfig& config,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
|
@ -119,9 +119,31 @@ Status RunGpuConv(const GpuConvConfig& conv_config,
|
||||
se::ScratchAllocator* scratch_allocator, se::Stream* stream,
|
||||
RunConvOptions = {});
|
||||
|
||||
// Struct to describe properties of a convolution without being tied to specific
|
||||
// IR. Will be used to help build Convolution thunks from either XLA HLO or
|
||||
// LHLO GPU dialect in MLIR.
|
||||
struct GpuConvDescriptor {
|
||||
CudnnConvKind kind;
|
||||
CudnnConvBackendConfig backend_config;
|
||||
Shape operand0_shape;
|
||||
Shape operand1_shape;
|
||||
Shape result_shape;
|
||||
size_t scratch_size;
|
||||
Window window;
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
int64 feature_group_count;
|
||||
};
|
||||
|
||||
// Returns the convolution configuration given a XLA HLO instruction.
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(
|
||||
const HloCustomCallInstruction* cudnn_call);
|
||||
|
||||
// Returns the convolution configuration given a convolution descriptor `desc`
|
||||
// and a string representation of the convolution instruction `inst_as_string`
|
||||
// (for error reporting).
|
||||
StatusOr<GpuConvConfig> GetGpuConvConfig(const GpuConvDescriptor& desc,
|
||||
absl::string_view inst_as_string);
|
||||
|
||||
// Implementation details exposed for debugging and log analysis.
|
||||
StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
const GpuConvConfig& conv_config,
|
||||
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
@ -62,6 +63,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
|
||||
@ -1046,8 +1048,12 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return EmitGemmThunkFromMlir(input);
|
||||
}
|
||||
|
||||
if (IsCustomCallToDnnConvolution(*custom_call)) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
|
||||
mlir::lmhlo_gpu::ConvForwardFusedOp,
|
||||
mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
|
||||
mlir::lmhlo_gpu::ConvBackwardFilterOp,
|
||||
mlir::lmhlo_gpu::ConvBackwardInputOp>(input.op)) {
|
||||
return EmitConvolutionThunkFromMlir(input);
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -1060,6 +1066,118 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
custom_call->custom_call_target());
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
|
||||
using mlir::dyn_cast;
|
||||
using mlir::lmhlo_gpu::Activation;
|
||||
using mlir::lmhlo_gpu::ConvBackwardFilterOp;
|
||||
using mlir::lmhlo_gpu::ConvBackwardInputOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardFusedOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
|
||||
using mlir::lmhlo_gpu::ConvForwardOp;
|
||||
|
||||
// Last 2 operands of the convolution operation are the result and scratch.
|
||||
std::vector<BufferAllocation::Slice> operand_slices;
|
||||
int64 num_operands = input.op->getNumOperands();
|
||||
operand_slices.reserve(num_operands - 2);
|
||||
for (mlir::Value operand : input.op->getOperands().drop_back(2)) {
|
||||
TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
|
||||
operand_slices.push_back(slice);
|
||||
}
|
||||
|
||||
mlir::Value conv_result = input.op->getOperand(num_operands - 2);
|
||||
mlir::Value scratch_result = input.op->getOperand(num_operands - 1);
|
||||
TF_ASSIGN_OR_RETURN(auto conv_result_slice,
|
||||
GetAllocationSliceForMlir(conv_result));
|
||||
TF_ASSIGN_OR_RETURN(auto scratch_slice,
|
||||
GetAllocationSliceForMlir(scratch_result));
|
||||
|
||||
auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
|
||||
mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
|
||||
llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
|
||||
return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
|
||||
}));
|
||||
return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
|
||||
shape.dimensions(), minor_to_major);
|
||||
};
|
||||
|
||||
GpuConvDescriptor descriptor;
|
||||
|
||||
auto fill_conv_descriptor = [&](auto op) {
|
||||
descriptor.operand0_shape =
|
||||
apply_layout(TypeToShape(input.op->getOperand(0).getType()),
|
||||
op.backend_config().operand_0_layout());
|
||||
descriptor.operand1_shape =
|
||||
apply_layout(TypeToShape(input.op->getOperand(1).getType()),
|
||||
op.backend_config().operand_1_layout());
|
||||
descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
|
||||
op.backend_config().result_layout());
|
||||
descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
|
||||
mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
|
||||
mlir::DenseIntElementsAttr padding = op.padding().getValue();
|
||||
mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
|
||||
mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
|
||||
mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
|
||||
for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
|
||||
WindowDimension* dim = descriptor.window.add_dimensions();
|
||||
// Window size for a convolution is the same as the kernel size.
|
||||
// Kernel size of the convolution is operand1_shape. We need to look at
|
||||
// the convolution dimension numbers kernel spatial dimensions to get
|
||||
// the window size.
|
||||
int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
|
||||
dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
|
||||
dim->set_stride(window_strides.getValue<int64>(index));
|
||||
dim->set_padding_low(padding.getValue<int64>(index));
|
||||
dim->set_padding_high(padding.getValue<int64>(index));
|
||||
dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
|
||||
dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
|
||||
dim->set_window_reversal(window_reversal.getValue<bool>(index));
|
||||
}
|
||||
descriptor.feature_group_count = op.feature_group_count();
|
||||
descriptor.backend_config.set_algorithm(
|
||||
op.backend_config().algorithm().getInt());
|
||||
descriptor.backend_config.set_tensor_ops_enabled(
|
||||
op.backend_config().tensor_ops_enabled().getValue());
|
||||
descriptor.backend_config.set_conv_result_scale(
|
||||
op.result_scale().convertToDouble());
|
||||
};
|
||||
|
||||
auto set_activation_mode = [&](auto op) -> Status {
|
||||
TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
|
||||
ConvertConvActivationMode(op.activation_mode()));
|
||||
descriptor.backend_config.set_activation_mode(
|
||||
static_cast<int64>(activation_mode));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
if (auto op = dyn_cast<ConvForwardOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForward;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvBackwardInputOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kBackwardInput;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvBackwardFilterOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kBackwardFilter;
|
||||
fill_conv_descriptor(op);
|
||||
} else if (auto op = dyn_cast<ConvForwardFusedOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForwardActivation;
|
||||
fill_conv_descriptor(op);
|
||||
TF_RETURN_IF_ERROR(set_activation_mode(op));
|
||||
} else if (auto op = dyn_cast<ConvForwardFusedSideInputOp>(input.op)) {
|
||||
descriptor.kind = CudnnConvKind::kForwardActivation;
|
||||
fill_conv_descriptor(op);
|
||||
TF_RETURN_IF_ERROR(set_activation_mode(op));
|
||||
descriptor.backend_config.set_side_input_scale(
|
||||
op.side_input_scale().convertToDouble());
|
||||
} else {
|
||||
return InternalError("Unexpected operation");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
|
||||
AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
|
||||
input.thunk_info, std::move(config), std::move(operand_slices),
|
||||
conv_result_slice, scratch_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
|
||||
auto build_gemm_config = [](auto op) {
|
||||
GpuGemmConfig config;
|
||||
|
@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleConditional(HloInstruction* conditional) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
|
||||
Status EmitGemmThunkFromMlir(MlirEmitterInput input);
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);
|
||||
|
@ -183,6 +183,12 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
return GetProperty(key, per_second_rates_);
|
||||
}
|
||||
|
||||
// Return the key that is used to index into Properties for the specified
|
||||
// input/output at the shape index.
|
||||
static std::string GetOperandBytesAccessedKey(int64 operand_num,
|
||||
ShapeIndex index = {});
|
||||
static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});
|
||||
|
||||
protected:
|
||||
typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;
|
||||
|
||||
@ -229,12 +235,6 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
void SetOutputBytesAccessed(float value);
|
||||
void SetOutputBytesAccessed(ShapeIndex index, float value);
|
||||
|
||||
// Return the key that is used to index into Properties for the specified
|
||||
// input/output at the shape index.
|
||||
static std::string GetOperandBytesAccessedKey(int64 operand_num,
|
||||
ShapeIndex index = {});
|
||||
static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});
|
||||
|
||||
// Function which computes the size of the top-level of a given shape (not
|
||||
// including nested elements, if any). If null then bytes_accessed methods
|
||||
// return an error.
|
||||
|
@ -417,6 +417,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
try_reshard = ReshardPartialReplicateWithAllToAll(target);
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
}
|
||||
|
||||
if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
|
||||
@ -424,6 +428,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
try_reshard = ReshardPartialReplicateWithAllToAll(target);
|
||||
if (try_reshard.has_value()) {
|
||||
return try_reshard.value();
|
||||
}
|
||||
}
|
||||
|
||||
// If not replicated yet, first replicate and then reshard to use one of the
|
||||
@ -1216,6 +1224,92 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||
.ReshardWithAllToAll(target, remaining_source_target_dims);
|
||||
}
|
||||
|
||||
absl::optional<PartitionedHlo>
|
||||
PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
|
||||
bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
|
||||
const auto& partial_replicate_sharding =
|
||||
source_is_partial_replicate ? sharding() : target;
|
||||
// If neither the source nor the target is partial replicate, return null.
|
||||
if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
|
||||
// If both source and target are partial replicate, should be supported in
|
||||
// Reshard with AllToAll already.
|
||||
if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
|
||||
// to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
|
||||
// the last tile dim will be replicate first before all-to-all.
|
||||
// Or resharding from
|
||||
// sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||
// to sharding={devices=[2,3]0,1,2,3,4,5}, where
|
||||
// the last tile dim will be sharded after all-to-all.
|
||||
const int num_replicas =
|
||||
partial_replicate_sharding.tile_assignment().dimensions().back();
|
||||
if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
|
||||
partial_replicate_sharding.tile_assignment().num_dimensions()) ||
|
||||
(partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
int to_replicate_dim = -1;
|
||||
for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
|
||||
--i) {
|
||||
if (tile_sharding.tile_assignment().dim(i) > 1 &&
|
||||
(to_replicate_dim == -1)) {
|
||||
if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
to_replicate_dim = i;
|
||||
}
|
||||
|
||||
if (tile_sharding.tile_assignment().dim(i) !=
|
||||
partial_replicate_sharding.tile_assignment().dim(i + 1)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
if (to_replicate_dim == -1) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Check if core assignments for source and the target are the same.
|
||||
auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
|
||||
reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
|
||||
if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
auto tmp_tile_assignment = tile_sharding.tile_assignment();
|
||||
auto tmp_tile_assignment_dimensions =
|
||||
tile_sharding.tile_assignment().dimensions();
|
||||
tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
|
||||
tmp_tile_assignment_dimensions.push_back(num_replicas);
|
||||
tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
|
||||
auto tmp_partial_replicate_sharding =
|
||||
HloSharding::PartialTile(tmp_tile_assignment);
|
||||
|
||||
if (source_is_partial_replicate) {
|
||||
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
|
||||
sharding(), tmp_partial_replicate_sharding)) {
|
||||
auto partitioned_hlo =
|
||||
ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
|
||||
return partitioned_hlo.Reshard(target);
|
||||
}
|
||||
} else {
|
||||
auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
|
||||
|
||||
if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
|
||||
partitioned_hlo.sharding(), target)) {
|
||||
return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
|
||||
}
|
||||
}
|
||||
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
||||
const HloSharding& target) const {
|
||||
CHECK(CanReshardWithCollectivePermute(sharding(), target))
|
||||
|
@ -338,6 +338,10 @@ class PartitionedHlo {
|
||||
absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
|
||||
const HloSharding& target);
|
||||
|
||||
// Helper function to reshard from partial replicate using AllToAll.
|
||||
absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
|
||||
const HloSharding& target);
|
||||
|
||||
// SPMD instruction.
|
||||
HloInstruction* hlo_;
|
||||
|
||||
|
@ -5542,6 +5542,55 @@ ENTRY entry {
|
||||
EXPECT_THAT(root, partially_replicated);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8] parameter(0),
|
||||
sharding={devices=[2,3]0,1,2,3,4,5}
|
||||
ROOT %copy0 = f32[8,8] copy(%param0),
|
||||
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/6));
|
||||
VLOG(1) << module->ToString();
|
||||
auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
|
||||
auto partially_replicated = AllOf(
|
||||
op::Shape("f32[8,4]"),
|
||||
op::Copy(op::Reshape(
|
||||
op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
|
||||
op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, partially_replicated);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%param0 = f32[8,8] parameter(0),
|
||||
sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
|
||||
ROOT %copy0 = f32[8,8] copy(%param0),
|
||||
sharding={devices=[2,3]0,1,2,3,4,5}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/6));
|
||||
VLOG(1) << module->ToString();
|
||||
auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
|
||||
auto tiled = AllOf(
|
||||
op::Shape("f32[4,3]"),
|
||||
op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
|
||||
op::Reshape(partial_replicated)))),
|
||||
_),
|
||||
_, _)));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, tiled);
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "CollectiveBcastRecvV2"
|
||||
summary: "Receives a tensor value broadcast from another device."
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "CollectiveBcastSendV2"
|
||||
summary: "Broadcasts a tensor value to one or more other devices."
|
||||
visibility: HIDDEN
|
||||
}
|
@ -114,6 +114,7 @@ filegroup(
|
||||
"gpu_managed_allocator.h",
|
||||
"gpu_process_state.h",
|
||||
"gpu_util.h",
|
||||
"gpu_virtual_mem_allocator.h",
|
||||
"//tensorflow/core/common_runtime:gpu_runtime_headers",
|
||||
"//tensorflow/core/common_runtime/device:device_runtime_headers",
|
||||
],
|
||||
@ -137,6 +138,7 @@ tf_cuda_library(
|
||||
cuda_deps = [
|
||||
"@local_config_cuda//cuda:cudnn_header",
|
||||
"//tensorflow/stream_executor/cuda:cuda_platform",
|
||||
":gpu_virtual_mem_allocator",
|
||||
],
|
||||
deps = [
|
||||
":gpu_bfc_allocator",
|
||||
@ -187,6 +189,7 @@ tf_cuda_library(
|
||||
features = ["parse_headers"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gpu_virtual_mem_allocator",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -195,6 +198,31 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gpu_virtual_mem_allocator",
|
||||
srcs = [
|
||||
"gpu_virtual_mem_allocator.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gpu_virtual_mem_allocator.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
features = ["parse_headers"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gpu_id",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/framework:allocator",
|
||||
"//tensorflow/core/platform:stream_executor",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gpu_init",
|
||||
hdrs = [
|
||||
@ -403,3 +431,21 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/stream_executor:platform",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_virtual_mem_allocator_test",
|
||||
size = "small",
|
||||
srcs = ["gpu_virtual_mem_allocator_test.cc"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_init",
|
||||
":gpu_virtual_mem_allocator",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/framework:allocator",
|
||||
"//tensorflow/core/platform:stream_executor",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
186
tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc
Normal file
186
tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.cc
Normal file
@ -0,0 +1,186 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
|
||||
#if CUDA_VERSION >= 10020
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ::stream_executor::gpu::GpuContext;
|
||||
using ::stream_executor::gpu::GpuDevicePtr;
|
||||
using ::stream_executor::gpu::GpuDriver;
|
||||
|
||||
// Rounds value up to the specified power of two alignment.
|
||||
size_t AlignUp(size_t value, size_t alignment) {
|
||||
DCHECK_EQ(alignment & (alignment - 1), 0)
|
||||
<< "Alignment must be a power of two; alignment=" << alignment;
|
||||
return (value + alignment - 1) & ~(alignment - 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ stream_executor::port::StatusOr<
|
||||
std::unique_ptr<GpuVirtualMemAllocator>>
|
||||
GpuVirtualMemAllocator::Create(const std::vector<Visitor>& alloc_visitors,
|
||||
const std::vector<Visitor>& free_visitors,
|
||||
GpuContext& gpu_context, PlatformGpuId gpu_id,
|
||||
size_t virtual_address_space_size,
|
||||
const std::vector<PlatformGpuId>& peer_gpu_ids) {
|
||||
std::vector<int> access_gpu_ordinals;
|
||||
access_gpu_ordinals.reserve(peer_gpu_ids.size() + 1);
|
||||
access_gpu_ordinals.push_back(gpu_id.value());
|
||||
for (const auto& peer_id : peer_gpu_ids) {
|
||||
access_gpu_ordinals.push_back(peer_id.value());
|
||||
}
|
||||
|
||||
// Find the min granularity for all devices that have access to this memory;
|
||||
// that is, the maximum min granularity among all devices.
|
||||
size_t max_granularity = 1;
|
||||
for (const int device_ordinal : access_gpu_ordinals) {
|
||||
TF_ASSIGN_OR_RETURN(size_t granularity,
|
||||
GpuDriver::GetMinAllocationGranularity(device_ordinal));
|
||||
max_granularity = std::max(max_granularity, granularity);
|
||||
}
|
||||
|
||||
// Create the virtual memory reservation. Must be aligned to system page size,
|
||||
// and larger than the CUDA min granularity. Empirically, the granularity
|
||||
// check is sufficient as the granularity is some multiple of the page size.
|
||||
// TODO(imintz): Create OS agnostic page size utility for completeness.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
GpuDriver::VmemSpan vmem,
|
||||
GpuDriver::ReserveVirtualMemory(
|
||||
&gpu_context, AlignUp(virtual_address_space_size, max_granularity)));
|
||||
VLOG(1) << "Reserved GPU virtual memory at " << vmem.base << " of size "
|
||||
<< strings::HumanReadableNumBytes(vmem.size_bytes) << " bytes";
|
||||
|
||||
return std::unique_ptr<GpuVirtualMemAllocator>(new GpuVirtualMemAllocator(
|
||||
alloc_visitors, free_visitors, gpu_context, gpu_id,
|
||||
std::move(access_gpu_ordinals), vmem, max_granularity));
|
||||
}
|
||||
|
||||
GpuVirtualMemAllocator::GpuVirtualMemAllocator(
|
||||
const std::vector<Visitor>& alloc_visitors,
|
||||
const std::vector<Visitor>& free_visitors, GpuContext& gpu_context,
|
||||
PlatformGpuId gpu_id, const std::vector<int> access_gpu_ordinals,
|
||||
GpuDriver::VmemSpan vmem, size_t granularity)
|
||||
: SubAllocator(alloc_visitors, free_visitors),
|
||||
gpu_context_(gpu_context),
|
||||
gpu_id_(gpu_id),
|
||||
access_gpu_ordinals_(access_gpu_ordinals),
|
||||
vmem_(vmem),
|
||||
granularity_(granularity) {}
|
||||
|
||||
GpuVirtualMemAllocator::~GpuVirtualMemAllocator() {
|
||||
for (const auto mapping : mappings_) {
|
||||
GpuDriver::UnmapMemory(&gpu_context_, mapping.va, mapping.physical.bytes);
|
||||
GpuDriver::ReleaseMemoryHandle(&gpu_context_, std::move(mapping.physical));
|
||||
}
|
||||
GpuDriver::FreeVirtualMemory(&gpu_context_, vmem_);
|
||||
}
|
||||
|
||||
void* GpuVirtualMemAllocator::Alloc(size_t alignment, size_t num_bytes,
|
||||
size_t* bytes_received) {
|
||||
if (num_bytes == 0) return nullptr;
|
||||
size_t padded_bytes = (num_bytes + granularity_ - 1) & ~(granularity_ - 1);
|
||||
|
||||
GpuDevicePtr next_va = vmem_.base + next_alloc_offset_;
|
||||
|
||||
// TODO(imintz): Attempt to extend the vmem allocation by reserving additional
|
||||
// virtual memory at the specific address at the end of the initial vmem
|
||||
// reservation.
|
||||
if (next_va + padded_bytes > vmem_.base + vmem_.size_bytes) {
|
||||
LOG(ERROR) << "OOM in GPU virtual memory allocator when attempting to "
|
||||
"allocate {request: "
|
||||
<< strings::HumanReadableNumBytes(num_bytes)
|
||||
<< ", aligned: " << padded_bytes << "} bytes.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create physical memory backing allocation.
|
||||
auto maybe_handle =
|
||||
GpuDriver::CreateMemoryHandle(&gpu_context_, padded_bytes);
|
||||
if (!maybe_handle.ok()) {
|
||||
LOG(ERROR) << maybe_handle.status();
|
||||
return nullptr;
|
||||
}
|
||||
GpuDriver::GenericMemoryHandle handle = std::move(maybe_handle).ValueOrDie();
|
||||
|
||||
// Map VAs for this physical memory.
|
||||
auto status = GpuDriver::MapMemory(&gpu_context_, next_va, handle,
|
||||
access_gpu_ordinals_);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
GpuDriver::ReleaseMemoryHandle(&gpu_context_, std::move(handle));
|
||||
return nullptr;
|
||||
}
|
||||
next_alloc_offset_ += handle.bytes;
|
||||
mappings_.push_back({next_va, std::move(handle)});
|
||||
VisitAlloc(reinterpret_cast<void*>(next_va), gpu_id_.value(), padded_bytes);
|
||||
*bytes_received = padded_bytes;
|
||||
return reinterpret_cast<void*>(next_va);
|
||||
}
|
||||
|
||||
void GpuVirtualMemAllocator::Free(void* ptr, size_t num_bytes) {
|
||||
if (ptr == nullptr) return;
|
||||
|
||||
auto mapping_it =
|
||||
std::lower_bound(mappings_.begin(), mappings_.end(), ptr,
|
||||
[](const Mapping& mapping, const void* ptr) {
|
||||
return reinterpret_cast<const void*>(mapping.va) < ptr;
|
||||
});
|
||||
if (mapping_it == mappings_.end() ||
|
||||
(reinterpret_cast<void*>(mapping_it->va) != ptr)) {
|
||||
LOG(ERROR) << "Could not find GPU vmem mapping for address at "
|
||||
<< reinterpret_cast<uintptr_t>(ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
int num_mappings_to_free = 0;
|
||||
int total_bytes = 0;
|
||||
for (auto it = mapping_it; it != mappings_.end() && total_bytes < num_bytes;
|
||||
++it) {
|
||||
++num_mappings_to_free;
|
||||
total_bytes += it->physical.bytes;
|
||||
}
|
||||
if (total_bytes != num_bytes) {
|
||||
LOG(ERROR) << "Invalid size requested for freeing GPU vmem mapping. Got "
|
||||
<< strings::HumanReadableNumBytes(num_bytes) << " but expected "
|
||||
<< strings::HumanReadableNumBytes(mapping_it->physical.bytes);
|
||||
return;
|
||||
}
|
||||
|
||||
VLOG(1) << "Freeing " << num_mappings_to_free << " mappings for a total of "
|
||||
<< total_bytes << " bytes";
|
||||
for (auto it = mapping_it; it < mapping_it + num_mappings_to_free; ++it) {
|
||||
GpuDriver::UnmapMemory(&gpu_context_, it->va, it->physical.bytes);
|
||||
GpuDriver::ReleaseMemoryHandle(&gpu_context_, std::move(it->physical));
|
||||
}
|
||||
|
||||
// Move back the next_alloc_offset_ if this free was at the end.
|
||||
if (mapping_it + num_mappings_to_free == mappings_.end()) {
|
||||
next_alloc_offset_ = mapping_it->va - vmem_.base;
|
||||
}
|
||||
|
||||
mappings_.erase(mapping_it, mapping_it + num_mappings_to_free);
|
||||
VisitFree(ptr, gpu_id_.value(), num_bytes);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
113
tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h
Normal file
113
tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h
Normal file
@ -0,0 +1,113 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// CUDA virtual memory API is only available in CUDA versions greater than 10.2.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VMEM_ALLOCATOR_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VMEM_ALLOCATOR_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_types.h"
|
||||
#endif
|
||||
|
||||
#if CUDA_VERSION >= 10020
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// GpuVirtualMemAllocator is a SubAllocator for use with BFCAllocator which
|
||||
// provides contiguous allocations with each call to Alloc. This is done by
|
||||
// reserving a large chunk of virtual addresses at construction and then mapping
|
||||
// physical memory pages to this virtual address range as requested.
|
||||
//
|
||||
// This class is not thread-safe.
|
||||
class GpuVirtualMemAllocator : public SubAllocator {
|
||||
public:
|
||||
static stream_executor::port::StatusOr<
|
||||
std::unique_ptr<GpuVirtualMemAllocator>>
|
||||
Create(const std::vector<Visitor>& alloc_visitors,
|
||||
const std::vector<Visitor>& free_visitors,
|
||||
stream_executor::gpu::GpuContext& gpu_context, PlatformGpuId gpu_id,
|
||||
size_t virtual_address_space_size,
|
||||
const std::vector<PlatformGpuId>& peer_gpu_ids);
|
||||
~GpuVirtualMemAllocator() override;
|
||||
|
||||
// Allocates memory at least as large as requested by num_bytes. Will be
|
||||
// aligned to the min allocation granularity (typically 2MiB).
|
||||
// alignment is ignored by this allocator.
|
||||
void* Alloc(size_t alignment, size_t num_bytes,
|
||||
size_t* bytes_received) override;
|
||||
|
||||
// Frees should only happen at the end of the contiguous memory allocations or
|
||||
// else we introduce pointless fragmentation...But, this is supported. If the
|
||||
// allocation happens at the end, then the next_alloc_offset_ is moved back,
|
||||
// otherwise a hole is created.
|
||||
//
|
||||
// Holes are not re-used, all allocations continue to come at the end of the
|
||||
// next_alloc_offset_. To accommodate this, the virtual_address_space_size
|
||||
// should be much larger than the max physical size of the allocator.
|
||||
//
|
||||
// In practice, since the BFC allocator coalesces adjacent AllocationRegions,
|
||||
// this free function should never be invoked.
|
||||
void Free(void* ptr, size_t num_bytes) override;
|
||||
|
||||
private:
|
||||
GpuVirtualMemAllocator(const std::vector<Visitor>& alloc_visitors,
|
||||
const std::vector<Visitor>& free_visitors,
|
||||
::stream_executor::gpu::GpuContext& gpu_context,
|
||||
PlatformGpuId gpu_id,
|
||||
std::vector<int> access_gpu_ordinals,
|
||||
stream_executor::gpu::GpuDriver::VmemSpan vmem,
|
||||
size_t granularity);
|
||||
|
||||
stream_executor::gpu::GpuContext& gpu_context_;
|
||||
PlatformGpuId gpu_id_;
|
||||
|
||||
// Peer access is configured at mmap time so the allocator must be aware of
|
||||
// all gpus that may want to read the memory. This list also includes the
|
||||
// above gpu_id_ to facilitate the invocation of the GpuDriver::MapMemory
|
||||
// function.
|
||||
const std::vector<int> access_gpu_ordinals_;
|
||||
|
||||
// The virtual memory span held by this allocator.
|
||||
stream_executor::gpu::GpuDriver::VmemSpan vmem_;
|
||||
// The next offset from the vmem base address that will be allocated. This
|
||||
// corresponds to the size of physically pinned memory if holes haven't been
|
||||
// created with "free".
|
||||
size_t next_alloc_offset_ = 0;
|
||||
|
||||
// Smallest allocation as determined by CUDA.
|
||||
const size_t granularity_;
|
||||
|
||||
struct Mapping {
|
||||
stream_executor::gpu::GpuDevicePtr va;
|
||||
stream_executor::gpu::GpuDriver::GenericMemoryHandle physical;
|
||||
};
|
||||
// List of mappings, sorted by va.
|
||||
std::vector<Mapping> mappings_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GpuVirtualMemAllocator);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // CUDA_VERSION >= 10200
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_VMEM_ALLOCATOR_H_
|
@ -0,0 +1,185 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h"
|
||||
|
||||
#if CUDA_VERSION >= 10020
|
||||
|
||||
#include "tensorflow/core/common_runtime/device/device_id_utils.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ::stream_executor::gpu::GpuContext;
|
||||
using ::stream_executor::gpu::GpuDevicePtr;
|
||||
using ::stream_executor::gpu::GpuDriver;
|
||||
|
||||
// Empirically the min allocation granularity.
|
||||
constexpr size_t k2MiB{2 << 20};
|
||||
|
||||
// Creates an allocator with 8 MiB of virtual address space.
|
||||
std::unique_ptr<GpuVirtualMemAllocator> CreateAllocator() {
|
||||
PlatformGpuId gpu_id(0);
|
||||
auto executor =
|
||||
DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(), gpu_id)
|
||||
.ValueOrDie();
|
||||
GpuContext* gpu_context = reinterpret_cast<GpuContext*>(
|
||||
executor->implementation()->GpuContextHack());
|
||||
return GpuVirtualMemAllocator::Create(
|
||||
{}, {}, *gpu_context, gpu_id,
|
||||
/*virtual_address_space_size=*/4 * k2MiB, {})
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
TEST(GpuVirtualMemAllocatorTest, SimpleAlloc) {
|
||||
PlatformGpuId gpu_id(0);
|
||||
auto executor =
|
||||
DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(), gpu_id)
|
||||
.ValueOrDie();
|
||||
GpuContext* gpu_context = reinterpret_cast<GpuContext*>(
|
||||
executor->implementation()->GpuContextHack());
|
||||
auto allocator = GpuVirtualMemAllocator::Create(
|
||||
{}, {}, *gpu_context, gpu_id,
|
||||
/*virtual_address_space_size=*/4 * k2MiB, {})
|
||||
.ValueOrDie();
|
||||
size_t bytes_received; // Ignored in this test.
|
||||
void* gpu_block =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(gpu_block, nullptr);
|
||||
|
||||
constexpr size_t kBufSize{256};
|
||||
void* host_mem[2] = {GpuDriver::HostAllocate(gpu_context, kBufSize),
|
||||
GpuDriver::HostAllocate(gpu_context, kBufSize)};
|
||||
std::memset(host_mem[0], 'z', kBufSize);
|
||||
std::memset(host_mem[1], 0, kBufSize);
|
||||
|
||||
GpuDevicePtr gpu_buf = reinterpret_cast<GpuDevicePtr>(gpu_block) + 2048;
|
||||
ASSERT_TRUE(GpuDriver::SynchronousMemcpyH2D(gpu_context, gpu_buf, host_mem[0],
|
||||
kBufSize)
|
||||
.ok());
|
||||
ASSERT_TRUE(GpuDriver::SynchronousMemcpyD2H(gpu_context, host_mem[1], gpu_buf,
|
||||
kBufSize)
|
||||
.ok());
|
||||
for (int i = 0; i < kBufSize; ++i) {
|
||||
ASSERT_EQ('z', reinterpret_cast<const char*>(host_mem[1])[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GpuVirtualMemAllocatorTest, AllocPaddedUp) {
|
||||
auto allocator = CreateAllocator();
|
||||
size_t bytes_received;
|
||||
void* gpu_block =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/256, &bytes_received);
|
||||
ASSERT_NE(gpu_block, nullptr);
|
||||
ASSERT_EQ(bytes_received, k2MiB);
|
||||
}
|
||||
|
||||
TEST(GpuVirtualMemAllocatorTest, AllocsContiguous) {
|
||||
auto allocator = CreateAllocator();
|
||||
size_t bytes_received; // Ignored in this test.
|
||||
void* first_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(first_alloc, nullptr);
|
||||
void* second_alloc = allocator->Alloc(
|
||||
/*alignment=*/0, /*num_bytes=*/2 * k2MiB, &bytes_received);
|
||||
ASSERT_NE(second_alloc, nullptr);
|
||||
|
||||
ASSERT_EQ(second_alloc, reinterpret_cast<const char*>(first_alloc) + k2MiB);
|
||||
|
||||
void* third_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(third_alloc, nullptr);
|
||||
|
||||
ASSERT_EQ(third_alloc,
|
||||
reinterpret_cast<const char*>(second_alloc) + 2 * k2MiB);
|
||||
}
|
||||
|
||||
TEST(GpuVirtualMemAllocator, OverAllocate) {
|
||||
auto allocator = CreateAllocator();
|
||||
size_t bytes_received; // Ignored in this test.
|
||||
void* first_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(first_alloc, nullptr);
|
||||
void* over_alloc = allocator->Alloc(/*alignment=*/0, /*num_bytes=*/4 * k2MiB,
|
||||
&bytes_received);
|
||||
ASSERT_EQ(over_alloc, nullptr);
|
||||
}
|
||||
|
||||
TEST(GpuVirtualMemAllocatorTest, FreeAtEnd) {
|
||||
auto allocator = CreateAllocator();
|
||||
size_t bytes_received; // Ignored in this test.
|
||||
void* first_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(first_alloc, nullptr);
|
||||
void* second_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(second_alloc, nullptr);
|
||||
|
||||
allocator->Free(second_alloc, k2MiB);
|
||||
|
||||
void* re_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_EQ(re_alloc, second_alloc);
|
||||
}
|
||||
|
||||
TEST(GpuVirtualMemAllocatorTest, FreeHole) {
|
||||
auto allocator = CreateAllocator();
|
||||
size_t bytes_received; // Ignored in this test.
|
||||
void* first_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(first_alloc, nullptr);
|
||||
void* second_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(second_alloc, nullptr);
|
||||
|
||||
allocator->Free(first_alloc, k2MiB);
|
||||
|
||||
void* third_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(third_alloc, nullptr);
|
||||
|
||||
// Expect that allocation still happens at the end.
|
||||
ASSERT_EQ(third_alloc, reinterpret_cast<const char*>(second_alloc) + k2MiB);
|
||||
}
|
||||
|
||||
TEST(GpuVirtualMemAllocatorTest, FreeRange) {
|
||||
auto allocator = CreateAllocator();
|
||||
size_t bytes_received; // Ignored in this test.
|
||||
void* first_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(first_alloc, nullptr);
|
||||
void* second_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(second_alloc, nullptr);
|
||||
void* third_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(third_alloc, nullptr);
|
||||
|
||||
allocator->Free(first_alloc, 3 * k2MiB);
|
||||
|
||||
void* re_alloc =
|
||||
allocator->Alloc(/*alignment=*/0, /*num_bytes=*/k2MiB, &bytes_received);
|
||||
ASSERT_NE(re_alloc, nullptr);
|
||||
ASSERT_EQ(re_alloc, first_alloc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
@ -830,7 +830,8 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
||||
// to run asynchronously to avoid deadlock.
|
||||
"CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
|
||||
"CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
|
||||
"NcclAllReduce", "Send", "Recv",
|
||||
"CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
|
||||
"Send", "Recv",
|
||||
|
||||
// Legacy random ops.
|
||||
// See details in tensorflow/python/framework/auto_control_deps.py.
|
||||
|
@ -20,7 +20,7 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
const NodeScopeAndName ParseNodeScopeAndName(const string& node_name) {
|
||||
auto pos = node_name.find_last_of("/");
|
||||
auto pos = node_name.find_last_of('/');
|
||||
if (pos == string::npos) {
|
||||
return {"", node_name};
|
||||
} else {
|
||||
|
@ -48,13 +48,13 @@ const char kScopedAllocatorAttrName[] = "_scoped_allocator";
|
||||
// matches op_name, i.e. it looks from the name like this node is
|
||||
// of that op type.
|
||||
bool HasOpName(const string& node_name, const string& op_name) {
|
||||
size_t begin = node_name.rfind("/");
|
||||
size_t begin = node_name.rfind('/');
|
||||
if (begin == string::npos) {
|
||||
begin = 0;
|
||||
} else {
|
||||
++begin;
|
||||
}
|
||||
size_t end = node_name.rfind("_");
|
||||
size_t end = node_name.rfind('_');
|
||||
if (end != string::npos) {
|
||||
size_t p = end + 1;
|
||||
while (p < node_name.size()) {
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/collective.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -742,5 +743,261 @@ REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2")
|
||||
.HostMemory("instance_key"),
|
||||
CollectiveGatherV2OpKernel);
|
||||
|
||||
class CollectiveBcastSendV2OpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
const bool is_source = true;
|
||||
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
const Tensor& input = c->input(0);
|
||||
const Tensor& group_size = c->input(1);
|
||||
const Tensor& group_key = c->input(2);
|
||||
const Tensor& instance_key = c->input(3);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_size.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_size"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_key"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, instance_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input instance_key"), done);
|
||||
|
||||
auto col_params = new CollectiveParams();
|
||||
col_params->name = name_;
|
||||
col_params->group.device_type = device_type_;
|
||||
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
|
||||
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.type = BROADCAST_COLLECTIVE;
|
||||
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.data_type = data_type_;
|
||||
col_params->instance.impl_details.communication_hint = communication_hint_;
|
||||
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
|
||||
col_params->is_source = true;
|
||||
// Add a default value for subdiv offsets, which is the same as the default
|
||||
// value in the V1 op's attribute.
|
||||
col_params->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
VLOG(1) << "CollectiveBcastSendV2 group_size "
|
||||
<< col_params->group.group_size << " group_key "
|
||||
<< col_params->group.group_key << " instance_key "
|
||||
<< col_params->instance.instance_key;
|
||||
|
||||
auto done_with_cleanup = [col_params, done = std::move(done)]() {
|
||||
delete col_params;
|
||||
done();
|
||||
};
|
||||
|
||||
// Allocate the output tensor, trying to reuse the input.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
|
||||
done_with_cleanup);
|
||||
col_params->instance.shape = input.shape();
|
||||
|
||||
// Resolve the collective params.
|
||||
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
|
||||
// blocking work because it's not guaranteed that this call cannot block.
|
||||
c->collective_executor()->RunClosure([c,
|
||||
done = std::move(done_with_cleanup),
|
||||
col_params, col_exec]() {
|
||||
VLOG(1) << "CollectiveBcastSendV2 CompleteParams for collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key << " instance "
|
||||
<< col_params->instance.instance_key;
|
||||
col_exec->CompleteParamsAsync(
|
||||
c->device()->attributes(), col_params, c->cancellation_manager(),
|
||||
[c, done = std::move(done), col_params, col_exec](const Status& s) {
|
||||
if (s.ok()) {
|
||||
auto actual_done = [c, group_key = col_params->group.group_key,
|
||||
instance_key =
|
||||
col_params->instance.instance_key,
|
||||
done = std::move(done)](const Status& s) {
|
||||
VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync done for "
|
||||
"collective "
|
||||
<< c->op_kernel().name() << " device "
|
||||
<< c->device()->name() << " group " << group_key
|
||||
<< " instance " << instance_key << " status " << s;
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
VLOG(1) << "CollectiveBcastSendV2 ExecuteAsync start for "
|
||||
"collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key
|
||||
<< " instance " << col_params->instance.instance_key;
|
||||
col_exec->ExecuteAsync(
|
||||
c, *col_params,
|
||||
CollectiveKey(c, col_params->group.group_key,
|
||||
col_params->instance.instance_key),
|
||||
actual_done);
|
||||
} else {
|
||||
c->SetStatus(s);
|
||||
done();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
DeviceType device_type_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_ = 0;
|
||||
string name_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
|
||||
CollectiveBcastSendV2OpKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("group_size")
|
||||
.HostMemory("group_key")
|
||||
.HostMemory("instance_key"),
|
||||
CollectiveBcastSendV2OpKernel);
|
||||
|
||||
class CollectiveBcastRecvV2OpKernel : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
|
||||
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
|
||||
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
|
||||
const bool is_source = false;
|
||||
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
|
||||
}
|
||||
|
||||
protected:
|
||||
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
|
||||
CollectiveExecutor* col_exec = c->collective_executor();
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, col_exec,
|
||||
errors::Internal(
|
||||
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
|
||||
name_),
|
||||
done);
|
||||
const Tensor& group_size = c->input(0);
|
||||
const Tensor& group_key = c->input(1);
|
||||
const Tensor& instance_key = c->input(2);
|
||||
const Tensor& shape_tensor = c->input(3);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_size.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_size"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, group_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input group_key"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
c, instance_key.dims() == 0,
|
||||
errors::Internal("Unexpected dimensions on input instance_key"), done);
|
||||
|
||||
auto col_params = new CollectiveParams();
|
||||
auto done_with_cleanup = [col_params, done = std::move(done)]() {
|
||||
delete col_params;
|
||||
done();
|
||||
};
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
c, tensor::MakeShape(shape_tensor, &col_params->instance.shape),
|
||||
done_with_cleanup);
|
||||
col_params->name = name_;
|
||||
col_params->group.device_type = device_type_;
|
||||
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
|
||||
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.type = BROADCAST_COLLECTIVE;
|
||||
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
|
||||
col_params->instance.data_type = data_type_;
|
||||
col_params->instance.impl_details.communication_hint = communication_hint_;
|
||||
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
|
||||
col_params->is_source = false;
|
||||
// Add a default value for subdiv offsets, which is the same as the default
|
||||
// value in the V1 op's attribute.
|
||||
col_params->instance.impl_details.subdiv_offsets.push_back(0);
|
||||
VLOG(1) << "CollectiveBcastRecvV2 group_size "
|
||||
<< col_params->group.group_size << " group_key "
|
||||
<< col_params->group.group_key << " instance_key "
|
||||
<< col_params->instance.instance_key;
|
||||
|
||||
// Allocate the output tensor.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(c,
|
||||
c->forward_input_or_allocate_output(
|
||||
{0}, 0, col_params->instance.shape, &output),
|
||||
done_with_cleanup);
|
||||
|
||||
// Resolve the collective params.
|
||||
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
|
||||
// blocking work because it's not guaranteed that this call cannot block.
|
||||
c->collective_executor()->RunClosure([c,
|
||||
done = std::move(done_with_cleanup),
|
||||
col_params, col_exec]() {
|
||||
VLOG(1) << "CollectiveBcastRecvV2 CompleteParams for collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key << " instance "
|
||||
<< col_params->instance.instance_key;
|
||||
col_exec->CompleteParamsAsync(
|
||||
c->device()->attributes(), col_params, c->cancellation_manager(),
|
||||
[c, done = std::move(done), col_params, col_exec](const Status& s) {
|
||||
if (s.ok()) {
|
||||
auto actual_done = [c, group_key = col_params->group.group_key,
|
||||
instance_key =
|
||||
col_params->instance.instance_key,
|
||||
done = std::move(done)](const Status& s) {
|
||||
VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync done for "
|
||||
"collective "
|
||||
<< c->op_kernel().name() << " device "
|
||||
<< c->device()->name() << " group " << group_key
|
||||
<< " instance " << instance_key << " status " << s;
|
||||
OP_REQUIRES_OK_ASYNC(c, s, done);
|
||||
done();
|
||||
};
|
||||
VLOG(1) << "CollectiveBcastRecvV2 ExecuteAsync start for "
|
||||
"collective "
|
||||
<< col_params->name << " device " << c->device()->name()
|
||||
<< " group " << col_params->group.group_key
|
||||
<< " instance " << col_params->instance.instance_key;
|
||||
col_exec->ExecuteAsync(
|
||||
c, *col_params,
|
||||
CollectiveKey(c, col_params->group.group_key,
|
||||
col_params->instance.instance_key),
|
||||
actual_done);
|
||||
} else {
|
||||
c->SetStatus(s);
|
||||
done();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
DeviceType device_type_;
|
||||
DataType data_type_ = DT_INVALID;
|
||||
string communication_hint_;
|
||||
float timeout_seconds_ = 0;
|
||||
string name_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
|
||||
CollectiveBcastRecvV2OpKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("group_size")
|
||||
.HostMemory("group_key")
|
||||
.HostMemory("instance_key")
|
||||
.HostMemory("shape"),
|
||||
CollectiveBcastRecvV2OpKernel);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -946,6 +946,17 @@ AsyncOpKernel* IteratorGetNextOp::AsAsync() {
|
||||
return type_string() == "IteratorGetNextSync" ? nullptr : this;
|
||||
}
|
||||
|
||||
void RecordElementSize(const std::vector<Tensor> element,
|
||||
profiler::TraceMe* traceme) {
|
||||
traceme->AppendMetadata([&]() {
|
||||
int64 element_size = 0;
|
||||
for (const auto& component : element) {
|
||||
element_size += component.TotalBytes();
|
||||
}
|
||||
return profiler::TraceMeEncode({{"element_size", element_size}});
|
||||
});
|
||||
}
|
||||
|
||||
Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
profiler::TraceMe traceme(
|
||||
[&] {
|
||||
@ -968,6 +979,7 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
|
||||
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
|
||||
RecordElementSize(components, &traceme);
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
@ -995,6 +1007,7 @@ Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
|
||||
if (end_of_sequence) {
|
||||
return WriteOptionalNoneToOutput(ctx, 0);
|
||||
} else {
|
||||
RecordElementSize(components, &traceme);
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
if (components[i].dtype() != output_types_[i]) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -299,14 +299,22 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
data::TraceMeMetadata GetTraceMeMetadata() const override {
|
||||
int64 limit = -1, size = -1;
|
||||
data::TraceMeMetadata result;
|
||||
// NOTE: We only set the parallelism value if the lock can be acquired
|
||||
// right away to avoid introducing tracing overhead.
|
||||
if (mu_->try_lock()) {
|
||||
limit = buffer_limit();
|
||||
size = buffer_.size();
|
||||
if (!buffer_.empty()) {
|
||||
std::vector<std::string> shapes(buffer_.front().value.size());
|
||||
for (const auto& component : buffer_.front().value) {
|
||||
shapes.push_back(component.shape().DebugString());
|
||||
}
|
||||
result.push_back(std::make_pair("next_element_shapes",
|
||||
absl::StrJoin(shapes, ",")));
|
||||
}
|
||||
mu_->unlock();
|
||||
}
|
||||
data::TraceMeMetadata result;
|
||||
result.push_back(std::make_pair(
|
||||
"buffer_limit",
|
||||
strings::Printf("%lld", static_cast<long long>(limit))));
|
||||
|
@ -197,6 +197,7 @@ tf_cuda_cc_test(
|
||||
"no_cuda_asan", # TODO(b/171341759): re-enable.
|
||||
],
|
||||
deps = [
|
||||
":gpu_ops_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
|
@ -48,10 +48,12 @@ absl::InlinedVector<T, 10> RepeatInputToMatchShape(
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Helper functions to get default input values.
|
||||
/// Helper functions to get default input shapes.
|
||||
|
||||
TensorShape DefaultInputShape();
|
||||
|
||||
/// Helper functions to get default input data.
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
@ -72,17 +74,10 @@ T DefaultScalarInput() {
|
||||
return static_cast<T>(true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::InlinedVector<T, 10> InfZeroInput() {
|
||||
return InputAsVector<T, double>({-std::numeric_limits<double>::infinity(),
|
||||
-0.1, -0.0, 0.0, 0.1,
|
||||
std::numeric_limits<float>::infinity()});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
|
||||
// Only generate values less than the bitwidth of the data type.
|
||||
if (op_name == "LeftShift" || op_name == "RightShift") {
|
||||
auto max_shift = sizeof(T) * 8 - 1;
|
||||
@ -96,16 +91,65 @@ absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
|
||||
return InputAsVector<T, double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, bool>::value, bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name) {
|
||||
absl::InlinedVector<T, 10> DefaultInput(absl::string_view op_name = "") {
|
||||
return InputAsVector<T, bool>({true, false, true, true, false});
|
||||
}
|
||||
|
||||
/// Helper functions to get more specific input data.
|
||||
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<std::complex<T>, 10> DefaultComplexInput() {
|
||||
auto input = test::DefaultInput<T>();
|
||||
absl::InlinedVector<std::complex<T>, 10> complex_input;
|
||||
for (T value : input) {
|
||||
complex_input.emplace_back(value, -value);
|
||||
}
|
||||
return complex_input;
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> NearZeroAndExtremeInput() {
|
||||
return InputAsVector<T, double>({-std::numeric_limits<double>::infinity(),
|
||||
-0.1, -0.0, 0.0, 0.1,
|
||||
std::numeric_limits<float>::infinity()});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
std::enable_if_t<llvm::is_one_of<T, int8, int16, int32, int64>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> NearZeroAndExtremeInput() {
|
||||
return InputAsVector<T, T>({std::numeric_limits<T>::min(),
|
||||
std::numeric_limits<T>::min() + 1, -1, 0, 1,
|
||||
std::numeric_limits<T>::max()});
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInputGreaterThanZero() {
|
||||
return test::InputAsVector<T, double>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<
|
||||
llvm::is_one_of<T, Eigen::half, float, double>::value,
|
||||
bool> = true>
|
||||
absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
|
||||
return test::InputAsVector<T, double>({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1,
|
||||
0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/gpu_ops_test_util.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -44,20 +45,10 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||
}
|
||||
|
||||
// 'T' is the input type, 'RT' is the input type for the callback function,
|
||||
// 'OutT' is the output type, 'ROutT' is the output type for the callback
|
||||
// function. In most cases it is enough to just provide the input type,
|
||||
// because all the types are the same.
|
||||
template <typename T, typename RT = T, typename OutT = T, typename ROutT = RT>
|
||||
void Run(std::vector<int64> input_shape, absl::InlinedVector<T, 10> input,
|
||||
const std::string op_name, ROutT (*expected_callback)(RT),
|
||||
bool expect_equal = true, bool add_tout = false,
|
||||
bool expect_buffer_reuse = true, bool add_t = true) {
|
||||
assert(std::accumulate(input_shape.begin(), input_shape.end(), 1,
|
||||
std::multiplies<int64>()) == input.size() &&
|
||||
"Expected input length to equal to shape's number of elements.");
|
||||
|
||||
TensorShape shape(input_shape);
|
||||
template <typename T, typename OutT>
|
||||
void SetOpKernel(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input, bool add_t,
|
||||
bool add_tout) {
|
||||
NodeDefBuilder builder("some_name", op_name);
|
||||
builder.Input(FakeInput(DataTypeToEnum<T>::v()));
|
||||
if (add_t) {
|
||||
@ -70,6 +61,15 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
}
|
||||
|
||||
template <typename T, typename OutT>
|
||||
void RunAndExpectResult(const std::string& op_name, const TensorShape& shape,
|
||||
const absl::InlinedVector<T, 10>& input,
|
||||
const absl::InlinedVector<OutT, 10>& expected_output,
|
||||
bool add_t, bool add_tout, bool expect_buffer_reuse,
|
||||
bool expect_equal) {
|
||||
SetOpKernel<T, OutT>(op_name, shape, input, add_t, add_tout);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Assert buffer reuse if expected.
|
||||
@ -81,13 +81,7 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
|
||||
// Assert expected results.
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<OutT>::value, shape);
|
||||
absl::InlinedVector<OutT, 14> expected;
|
||||
expected.reserve(input.size());
|
||||
for (const T& inp : input) {
|
||||
expected.push_back(
|
||||
static_cast<OutT>(expected_callback(static_cast<RT>(inp))));
|
||||
}
|
||||
test::FillValues<OutT>(&expected_tensor, expected);
|
||||
test::FillValues<OutT>(&expected_tensor, expected_output);
|
||||
if (expect_equal) {
|
||||
test::ExpectEqual(expected_tensor, *GetOutput(0));
|
||||
} else {
|
||||
@ -95,241 +89,225 @@ class GpuUnaryOpTest : public OpsTestBase {
|
||||
}
|
||||
}
|
||||
|
||||
// Some helper functions to get default input values.
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
typename BaselineOutT>
|
||||
void Test(const std::string op_name, const TensorShape& shape,
|
||||
absl::InlinedVector<T, 10> input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT),
|
||||
bool expect_equal = true, bool add_tout = false,
|
||||
bool expect_buffer_reuse = true, bool add_t = true) {
|
||||
// Prepare inputs and compute expected results.
|
||||
auto repeated_input =
|
||||
test::RepeatInputToMatchShape(input, shape.num_elements());
|
||||
absl::InlinedVector<OutT, 10> expected_output =
|
||||
ComputeExpectedOutput<T, BaselineT, OutT, BaselineOutT>(
|
||||
repeated_input, baseline_callback);
|
||||
|
||||
std::vector<int64> DefaultInputShape() { return std::vector<int64>{2, 7}; }
|
||||
|
||||
template <typename T>
|
||||
absl::InlinedVector<T, 10> DefaultInput() {
|
||||
return InputAsVector<T>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3,
|
||||
0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::InlinedVector<std::complex<T>, 10> DefaultComplexInput() {
|
||||
auto input = DefaultInput<T>();
|
||||
absl::InlinedVector<std::complex<T>, 10> complex_input;
|
||||
for (T value : input) {
|
||||
complex_input.emplace_back(value, -value);
|
||||
}
|
||||
return complex_input;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::InlinedVector<T, 10> DefaultInputGreaterThanZero() {
|
||||
return InputAsVector<T>({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
|
||||
0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::InlinedVector<T, 10> DefaultInputGreaterOrEqualToZero() {
|
||||
return InputAsVector<T>({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1, 0.2, 0.3,
|
||||
0.5, 0.7, 0.9, 9.0, 18.0});
|
||||
RunAndExpectResult<T, OutT>(op_name, shape, repeated_input, expected_output,
|
||||
add_t, add_tout, expect_buffer_reuse,
|
||||
expect_equal);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
absl::InlinedVector<T, 10> InputAsVector(
|
||||
std::initializer_list<double> input) {
|
||||
absl::InlinedVector<T, 10> result;
|
||||
result.reserve(input.size());
|
||||
for (const auto& value : input) {
|
||||
result.push_back(static_cast<T>(value));
|
||||
template <typename T, typename BaselineT, typename OutT,
|
||||
typename BaselineOutT>
|
||||
absl::InlinedVector<OutT, 10> ComputeExpectedOutput(
|
||||
absl::InlinedVector<T, 10> input,
|
||||
BaselineOutT (*baseline_callback)(BaselineT)) {
|
||||
absl::InlinedVector<OutT, 10> expected_output;
|
||||
for (int i = 0; i < input.size(); i++) {
|
||||
auto arg = static_cast<BaselineT>(input[i]);
|
||||
auto result = static_cast<OutT>(baseline_callback(arg));
|
||||
expected_output.push_back(result);
|
||||
}
|
||||
return result;
|
||||
return expected_output;
|
||||
}
|
||||
};
|
||||
|
||||
/// Test `tf.Abs`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsFloat) {
|
||||
Run<float>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{-std::numeric_limits<float>::infinity(), -0.1f, -0.0f, 0.0f, 0.1f,
|
||||
std::numeric_limits<float>::infinity()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
Test<float, float, float, float>(
|
||||
/*op_name=*/"Abs", test::DefaultInputShape(),
|
||||
test::NearZeroAndExtremeInput<float>(),
|
||||
/*baseline_callback=*/std::abs,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsDouble) {
|
||||
Run<double>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{-std::numeric_limits<double>::infinity(), -0.1, -0.0, 0.0, 0.1,
|
||||
std::numeric_limits<double>::infinity()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
Test<double, double, double, double>(
|
||||
/*op_name=*/"Abs", test::DefaultInputShape(),
|
||||
test::NearZeroAndExtremeInput<double>(),
|
||||
/*baseline_callback=*/std::abs,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsHalf) {
|
||||
Run<Eigen::half, float>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{static_cast<Eigen::half>(-std::numeric_limits<double>::infinity()),
|
||||
static_cast<Eigen::half>(-0.1), static_cast<Eigen::half>(-0.0),
|
||||
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(0.1),
|
||||
static_cast<Eigen::half>(std::numeric_limits<double>::infinity())},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Abs", test::DefaultInputShape(),
|
||||
test::NearZeroAndExtremeInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::abs,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsInt32) {
|
||||
Run<int32>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{std::numeric_limits<int32>::min(), std::numeric_limits<int32>::min() + 1,
|
||||
-1, 0, 1, std::numeric_limits<int32>::max()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
Test<int32, int32, int32, int32>(
|
||||
/*op_name=*/"Abs", test::DefaultInputShape(),
|
||||
test::NearZeroAndExtremeInput<int32>(),
|
||||
/*baseline_callback=*/std::abs,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, AbsInt64) {
|
||||
Run<int64>(
|
||||
/*input_shape=*/{2, 3},
|
||||
/*input=*/
|
||||
{std::numeric_limits<int64>::min(), std::numeric_limits<int64>::min() + 1,
|
||||
-1, 0, 1, std::numeric_limits<int64>::max()},
|
||||
/*op_name=*/"Abs",
|
||||
/*expected_callback=*/std::abs,
|
||||
Test<int64, int64, int64, int64>(
|
||||
/*op_name=*/"Abs", test::DefaultInputShape(),
|
||||
test::NearZeroAndExtremeInput<int64>(),
|
||||
/*baseline_callback=*/std::abs,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Ceil`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, CeilFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Ceil",
|
||||
/*expected_callback=*/std::ceil,
|
||||
/*expect_equal=*/true);
|
||||
Test<float, float, float, float>(
|
||||
/*op_name=*/"Ceil", test::DefaultInputShape(),
|
||||
test::DefaultInput<float>("Ceil"),
|
||||
/*baseline_callback=*/std::ceil,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, CeilDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Ceil",
|
||||
/*expected_callback=*/std::ceil,
|
||||
/*expect_equal=*/true);
|
||||
Test<double, double, double, double>(
|
||||
/*op_name=*/"Ceil", test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::ceil,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, CeilHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Ceil",
|
||||
/*expected_callback=*/std::ceil,
|
||||
/*expect_equal=*/true);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Ceil", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::ceil,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Conj`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ConjFloat) {
|
||||
Run<std::complex<float>, const std::complex<float>&, std::complex<float>,
|
||||
std::complex<float>>(DefaultInputShape(), DefaultComplexInput<float>(),
|
||||
/*op_name=*/"Conj",
|
||||
/*expected_callback=*/std::conj,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/false,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ConjDouble) {
|
||||
Run<std::complex<double>, const std::complex<double>&, std::complex<double>,
|
||||
std::complex<double>>(DefaultInputShape(), DefaultComplexInput<double>(),
|
||||
/*op_name=*/"Conj",
|
||||
/*expected_callback=*/std::conj,
|
||||
Test<std::complex<float>, const std::complex<float>&, std::complex<float>,
|
||||
std::complex<float>>(/*op_name=*/"Conj", test::DefaultInputShape(),
|
||||
test::DefaultComplexInput<float>(),
|
||||
/*baseline_callback=*/std::conj,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/false,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ConjDouble) {
|
||||
Test<std::complex<double>, const std::complex<double>&, std::complex<double>,
|
||||
std::complex<double>>(
|
||||
/*op_name=*/"Conj", test::DefaultInputShape(),
|
||||
test::DefaultComplexInput<double>(),
|
||||
/*baseline_callback=*/std::conj,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/false,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Cos`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, CosFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Cos",
|
||||
/*expected_callback=*/std::cos,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(
|
||||
/*op_name=*/"Cos", test::DefaultInputShape(), test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/std::cos,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, CosDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Cos",
|
||||
/*expected_callback=*/std::cos,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(/*op_name=*/"Cos",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::cos,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, CosHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Cos",
|
||||
/*expected_callback=*/std::cos,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Cos", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::cos,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Exp`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ExpFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Exp",
|
||||
/*expected_callback=*/std::exp,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(/*op_name=*/"Exp", test::DefaultInputShape(),
|
||||
test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/std::exp,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ExpDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Exp",
|
||||
/*expected_callback=*/std::exp,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(/*op_name=*/"Exp",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::exp,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ExpHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Exp",
|
||||
/*expected_callback=*/std::exp,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Exp", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::exp,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Floor`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, FloorFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Floor",
|
||||
/*expected_callback=*/std::floor,
|
||||
/*expect_equal=*/true);
|
||||
Test<float, float, float, float>(/*op_name=*/"Floor",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/std::floor,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, FloorDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Floor",
|
||||
/*expected_callback=*/std::floor,
|
||||
/*expect_equal=*/true);
|
||||
Test<double, double, double, double>(/*op_name=*/"Floor",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::floor,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, FloorHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Floor",
|
||||
/*expected_callback=*/std::floor,
|
||||
/*expect_equal=*/true);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Floor", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::floor,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Imag`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ImagFloat) {
|
||||
Run<std::complex<float>, const std::complex<float>&, float, float>(
|
||||
DefaultInputShape(), DefaultComplexInput<float>(),
|
||||
/*op_name=*/"Imag",
|
||||
/*expected_callback=*/std::imag,
|
||||
Test<std::complex<float>, const std::complex<float>&, float, float>(
|
||||
/*op_name=*/"Imag", test::DefaultInputShape(),
|
||||
test::DefaultComplexInput<float>(),
|
||||
/*baseline_callback=*/std::imag,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, ImagDouble) {
|
||||
Run<std::complex<double>, const std::complex<double>&, double, double>(
|
||||
DefaultInputShape(), DefaultComplexInput<double>(),
|
||||
/*op_name=*/"Imag",
|
||||
/*expected_callback=*/std::imag,
|
||||
Test<std::complex<double>, const std::complex<double>&, double, double>(
|
||||
/*op_name=*/"Imag", test::DefaultInputShape(),
|
||||
test::DefaultComplexInput<double>(),
|
||||
/*baseline_callback=*/std::imag,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
@ -338,64 +316,65 @@ TEST_F(GpuUnaryOpTest, ImagDouble) {
|
||||
/// Test `tf.IsInf`.
|
||||
|
||||
// TODO(b/162575339): The tests currently still fails with CUDA_ILLEGAL_ADDRESS
|
||||
// when run with unranked kernels.
|
||||
// when Test with unranked kernels.
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_IsInfFloat) {
|
||||
Run<float, float, bool, bool>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"IsInf",
|
||||
/*expected_callback=*/std::isinf,
|
||||
/*expect_equal=*/true);
|
||||
Test<float, float, bool, bool>(/*op_name=*/"IsInf", test::DefaultInputShape(),
|
||||
test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/std::isinf,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_IsInfDouble) {
|
||||
// Workaround for gcc bug, it would fail with "unresolved overloaded function
|
||||
// type" if passing std::isinf with type double. So we use type float for
|
||||
// comparing expected values.
|
||||
Run<double, float, bool, bool>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"IsInf",
|
||||
/*expected_callback=*/std::isinf,
|
||||
/*expect_equal=*/true);
|
||||
Test<double, float, bool, bool>(/*op_name=*/"IsInf",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::isinf,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, DISABLED_IsInfHalf) {
|
||||
Run<Eigen::half, float, bool, bool>(DefaultInputShape(),
|
||||
DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"IsInf",
|
||||
/*expected_callback=*/std::isinf,
|
||||
/*expect_equal=*/true);
|
||||
Test<Eigen::half, float, bool, bool>(/*op_name=*/"IsInf",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::isinf,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Log`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, LogFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInputGreaterThanZero<float>(),
|
||||
/*op_name=*/"Log",
|
||||
/*expected_callback=*/std::log,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(/*op_name=*/"Log", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterThanZero<float>(),
|
||||
/*baseline_callback=*/std::log,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, LogDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInputGreaterThanZero<double>(),
|
||||
/*op_name=*/"Log",
|
||||
/*expected_callback=*/std::log,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(
|
||||
/*op_name=*/"Log", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterThanZero<double>(),
|
||||
/*baseline_callback=*/std::log,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, LogHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(),
|
||||
/*input=*/
|
||||
DefaultInputGreaterThanZero<Eigen::half>(),
|
||||
/*op_name=*/"Log",
|
||||
/*expected_callback=*/std::log,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Log", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterThanZero<Eigen::half>(),
|
||||
/*baseline_callback=*/std::log,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.LogicalNot`
|
||||
|
||||
TEST_F(GpuUnaryOpTest, LogicalNot) {
|
||||
Run<bool, bool, bool, bool>(
|
||||
DefaultInputShape(), DefaultInput<bool>(),
|
||||
/*op_name=*/"LogicalNot",
|
||||
/*expected_callback=*/[](bool v) { return !v; },
|
||||
Test<bool, bool, bool, bool>(
|
||||
/*op_name=*/"LogicalNot", test::DefaultInputShape(),
|
||||
test::DefaultInput<bool>(),
|
||||
/*baseline_callback=*/[](bool v) { return !v; },
|
||||
/*expect_equal=*/true,
|
||||
/*add_tout=*/false,
|
||||
/*expect_buffer_reuse=*/true,
|
||||
@ -406,69 +385,71 @@ TEST_F(GpuUnaryOpTest, LogicalNot) {
|
||||
|
||||
/// Reference implementation.
|
||||
template <typename T>
|
||||
T expected_neg(T x) {
|
||||
T baseline_neg(T x) {
|
||||
return -x;
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(
|
||||
/*op_name=*/"Neg", test::DefaultInputShape(), test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/baseline_neg,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(
|
||||
/*op_name=*/"Neg", test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/baseline_neg,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Neg", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/baseline_neg,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegInt8) {
|
||||
Run<int8>(DefaultInputShape(), DefaultInput<int8>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/true);
|
||||
Test<int8, int8, int8, int8>(
|
||||
/*op_name=*/"Neg", test::DefaultInputShape(), test::DefaultInput<int8>(),
|
||||
/*baseline_callback=*/baseline_neg,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegInt16) {
|
||||
Run<int16>(DefaultInputShape(), DefaultInput<int16>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/true);
|
||||
Test<int16, int16, int16, int16>(/*op_name=*/"Neg", test::DefaultInputShape(),
|
||||
test::DefaultInput<int16>(),
|
||||
/*baseline_callback=*/baseline_neg,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, NegInt64) {
|
||||
Run<int64>(DefaultInputShape(), DefaultInput<int64>(),
|
||||
/*op_name=*/"Neg",
|
||||
/*expected_callback=*/expected_neg,
|
||||
/*expect_equal=*/true);
|
||||
Test<int64, int64, int64, int64>(/*op_name=*/"Neg", test::DefaultInputShape(),
|
||||
test::DefaultInput<int64>(),
|
||||
/*baseline_callback=*/baseline_neg,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Real`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, RealFloat) {
|
||||
Run<std::complex<float>, const std::complex<float>&, float, float>(
|
||||
DefaultInputShape(), DefaultComplexInput<float>(),
|
||||
/*op_name=*/"Real",
|
||||
/*expected_callback=*/std::real,
|
||||
Test<std::complex<float>, const std::complex<float>&, float, float>(
|
||||
/*op_name=*/"Real", test::DefaultInputShape(),
|
||||
test::DefaultComplexInput<float>(),
|
||||
/*baseline_callback=*/std::real,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, RealDouble) {
|
||||
Run<std::complex<double>, const std::complex<double>&, double, double>(
|
||||
DefaultInputShape(), DefaultComplexInput<double>(),
|
||||
/*op_name=*/"Real",
|
||||
/*expected_callback=*/std::real,
|
||||
Test<std::complex<double>, const std::complex<double>&, double, double>(
|
||||
/*op_name=*/"Real", test::DefaultInputShape(),
|
||||
test::DefaultComplexInput<double>(),
|
||||
/*baseline_callback=*/std::real,
|
||||
/*expect_equal=*/false,
|
||||
/*add_tout=*/true,
|
||||
/*expect_buffer_reuse=*/false);
|
||||
@ -478,141 +459,153 @@ TEST_F(GpuUnaryOpTest, RealDouble) {
|
||||
|
||||
/// Reference implementation.
|
||||
template <typename T>
|
||||
T expected_rsqrt(T x) {
|
||||
T baseline_rsqrt(T x) {
|
||||
return 1.0 / std::sqrt(x);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, RsqrtFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInputGreaterThanZero<float>(),
|
||||
/*op_name=*/"Rsqrt",
|
||||
/*expected_callback=*/expected_rsqrt,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(/*op_name=*/"Rsqrt",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterThanZero<float>(),
|
||||
/*baseline_callback=*/baseline_rsqrt,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, RsqrtDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInputGreaterThanZero<double>(),
|
||||
/*op_name=*/"Rsqrt",
|
||||
/*expected_callback=*/expected_rsqrt,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(
|
||||
/*op_name=*/"Rsqrt", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterThanZero<double>(),
|
||||
/*baseline_callback=*/baseline_rsqrt,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, RsqrtHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(),
|
||||
/*input=*/
|
||||
DefaultInputGreaterThanZero<Eigen::half>(),
|
||||
/*op_name=*/"Rsqrt",
|
||||
/*expected_callback=*/expected_rsqrt,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Rsqrt", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterThanZero<Eigen::half>(),
|
||||
/*baseline_callback=*/baseline_rsqrt,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Sign`.
|
||||
|
||||
// Reference implementation
|
||||
template <typename T>
|
||||
T expected_sign(T x) {
|
||||
T baseline_sign(T x) {
|
||||
if (x == 0) return 0;
|
||||
if (x < 0) return -1;
|
||||
return 1;
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
/*expect_equal=*/true);
|
||||
Test<float, float, float, float>(/*op_name=*/"Sign",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/baseline_sign,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
/*expect_equal=*/true);
|
||||
Test<double, double, double, double>(/*op_name=*/"Sign",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/baseline_sign,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
// TODO(b/162577610): We should actually use true
|
||||
// here. This requires returning 0.0 for input -0.0.
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Sign", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*expected_callback=*/baseline_sign,
|
||||
// TODO(b/162577610): We should actually use true
|
||||
// here. This requires returning 0.0 for input -0.0.
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SignInt64) {
|
||||
Run<int64>(DefaultInputShape(), DefaultInput<int64>(),
|
||||
/*op_name=*/"Sign",
|
||||
/*expected_callback=*/expected_sign,
|
||||
/*expect_equal=*/true);
|
||||
Test<int64, int64, int64, int64>(
|
||||
/*op_name=*/"Sign", test::DefaultInputShape(),
|
||||
test::DefaultInput<int64>(),
|
||||
/*expected_callback=*/baseline_sign,
|
||||
/*expect_equal=*/true);
|
||||
}
|
||||
|
||||
/// Test `tf.Sin`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SinFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Sin",
|
||||
/*expected_callback=*/std::sin,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(/*op_name=*/"Sin", test::DefaultInputShape(),
|
||||
test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/std::sin,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SinDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Sin",
|
||||
/*expected_callback=*/std::sin,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(/*op_name=*/"Sin",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::sin,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SinHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Sin",
|
||||
/*expected_callback=*/std::sin,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Sin", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::sin,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Sqrt`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SqrtFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInputGreaterOrEqualToZero<float>(),
|
||||
/*op_name=*/"Sqrt",
|
||||
/*expected_callback=*/std::sqrt,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(
|
||||
/*op_name=*/"Sqrt", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterOrEqualToZero<float>(),
|
||||
/*baseline_callback=*/std::sqrt,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SqrtDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInputGreaterOrEqualToZero<double>(),
|
||||
/*op_name=*/"Sqrt",
|
||||
/*expected_callback=*/std::sqrt,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(
|
||||
/*op_name=*/"Sqrt", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterOrEqualToZero<double>(),
|
||||
/*baseline_callback=*/std::sqrt,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, SqrtHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(),
|
||||
DefaultInputGreaterOrEqualToZero<Eigen::half>(),
|
||||
/*op_name=*/"Sqrt",
|
||||
/*expected_callback=*/std::sqrt,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Sqrt", test::DefaultInputShape(),
|
||||
test::DefaultInputGreaterOrEqualToZero<Eigen::half>(),
|
||||
/*baseline_callback=*/std::sqrt,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
/// Test `tf.Tanh`.
|
||||
|
||||
TEST_F(GpuUnaryOpTest, TanhFloat) {
|
||||
Run<float>(DefaultInputShape(), DefaultInput<float>(),
|
||||
/*op_name=*/"Tanh",
|
||||
/*expected_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
Test<float, float, float, float>(/*op_name=*/"Tanh",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<float>(),
|
||||
/*baseline_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, TanhDouble) {
|
||||
Run<double>(DefaultInputShape(), DefaultInput<double>(),
|
||||
/*op_name=*/"Tanh",
|
||||
/*expected_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
Test<double, double, double, double>(/*op_name=*/"Tanh",
|
||||
test::DefaultInputShape(),
|
||||
test::DefaultInput<double>(),
|
||||
/*baseline_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GpuUnaryOpTest, TanhHalf) {
|
||||
Run<Eigen::half, float>(DefaultInputShape(), DefaultInput<Eigen::half>(),
|
||||
/*op_name=*/"Tanh",
|
||||
/*expected_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
Test<Eigen::half, float, Eigen::half, float>(
|
||||
/*op_name=*/"Tanh", test::DefaultInputShape(),
|
||||
test::DefaultInput<Eigen::half>(),
|
||||
/*baseline_callback=*/std::tanh,
|
||||
/*expect_equal=*/false);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -145,4 +145,35 @@ REGISTER_OP("CollectiveGatherV2")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CollectiveBcastSendV2")
|
||||
.Input("input: T")
|
||||
.Output("data: T")
|
||||
.Attr("T: {bool, float, float16, float64, int32, int64}")
|
||||
.Input("group_size: int32")
|
||||
.Input("group_key: int32")
|
||||
.Input("instance_key: int32")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("CollectiveBcastRecvV2")
|
||||
.Output("data: T")
|
||||
.Attr("T: {bool, float, float16, float64, int32, int64}")
|
||||
.Input("group_size: int32")
|
||||
.Input("group_key: int32")
|
||||
.Input("instance_key: int32")
|
||||
.Input("shape: Tshape")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.Attr("communication_hint: string = 'auto'")
|
||||
.Attr("timeout_seconds: float = 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// The output shape is given by the `shape` input at index 3.
|
||||
shape_inference::ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(/*input_idx=*/3, &out));
|
||||
c->set_output(/*idx=*/0, out);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -0,0 +1,65 @@
|
||||
op {
|
||||
name: "CollectiveBcastRecvV2"
|
||||
input_arg {
|
||||
name: "group_size"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "group_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "shape"
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BOOL
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
op {
|
||||
name: "CollectiveBcastSendV2"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "group_size"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "group_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BOOL
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -7441,6 +7441,71 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveBcastRecvV2"
|
||||
input_arg {
|
||||
name: "group_size"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "group_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "shape"
|
||||
type_attr: "Tshape"
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BOOL
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tshape"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT32
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveBcastSend"
|
||||
input_arg {
|
||||
@ -7497,6 +7562,58 @@ op {
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveBcastSendV2"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "group_size"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "group_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "instance_key"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BOOL
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "communication_hint"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "auto"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "timeout_seconds"
|
||||
type: "float"
|
||||
default_value {
|
||||
f: 0
|
||||
}
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "CollectiveGather"
|
||||
input_arg {
|
||||
|
@ -54,6 +54,7 @@ namespace {
|
||||
const char* const kTPUReplicatedInput = "TPUReplicatedInput";
|
||||
const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
|
||||
const char* const kPivotForClusterAttr = "_pivot_for_cluster";
|
||||
const char* const kTPUPartitionedInput = "TPUPartitionedInput";
|
||||
|
||||
// Finds the `index` of an _Arg or _Retval node.
|
||||
Status GetIndexAttr(const Node& n, int num_args, int* index) {
|
||||
@ -1586,7 +1587,18 @@ void RemoveUnusedTPUReplicatedInputs(Graph* graph) {
|
||||
}
|
||||
}
|
||||
if (!has_output) {
|
||||
// Remove any TPUPartitionedInput node from the src nodes of the
|
||||
// to-be-removed TPUReplicatedInput node
|
||||
std::vector<Node*> to_be_removed_src_nodes;
|
||||
for (const auto& e_in : n->in_edges()) {
|
||||
if (!e_in->IsControlEdge() &&
|
||||
e_in->src()->type_string() == kTPUPartitionedInput)
|
||||
to_be_removed_src_nodes.push_back(e_in->src());
|
||||
}
|
||||
graph->RemoveNode(n);
|
||||
for (Node* node : to_be_removed_src_nodes) {
|
||||
graph->RemoveNode(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4620,6 +4620,45 @@ func KmeansPlusPlusInitialization(scope *Scope, points tf.Output, num_to_sample
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// CollectiveBcastRecvV2Attr is an optional argument to CollectiveBcastRecvV2.
|
||||
type CollectiveBcastRecvV2Attr func(optionalAttr)
|
||||
|
||||
// CollectiveBcastRecvV2CommunicationHint sets the optional communication_hint attribute to value.
|
||||
// If not specified, defaults to "auto"
|
||||
func CollectiveBcastRecvV2CommunicationHint(value string) CollectiveBcastRecvV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["communication_hint"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// CollectiveBcastRecvV2TimeoutSeconds sets the optional timeout_seconds attribute to value.
|
||||
// If not specified, defaults to 0
|
||||
func CollectiveBcastRecvV2TimeoutSeconds(value float32) CollectiveBcastRecvV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["timeout_seconds"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Receives a tensor value broadcast from another device.
|
||||
func CollectiveBcastRecvV2(scope *Scope, group_size tf.Output, group_key tf.Output, instance_key tf.Output, shape tf.Output, T tf.DataType, optional ...CollectiveBcastRecvV2Attr) (data tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{"T": T}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "CollectiveBcastRecvV2",
|
||||
Input: []tf.Input{
|
||||
group_size, group_key, instance_key, shape,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// AbortAttr is an optional argument to Abort.
|
||||
type AbortAttr func(optionalAttr)
|
||||
|
||||
@ -49360,6 +49399,45 @@ func LoadTPUEmbeddingRMSPropParameters(scope *Scope, parameters tf.Output, ms tf
|
||||
return scope.AddOperation(opspec)
|
||||
}
|
||||
|
||||
// CollectiveBcastSendV2Attr is an optional argument to CollectiveBcastSendV2.
|
||||
type CollectiveBcastSendV2Attr func(optionalAttr)
|
||||
|
||||
// CollectiveBcastSendV2CommunicationHint sets the optional communication_hint attribute to value.
|
||||
// If not specified, defaults to "auto"
|
||||
func CollectiveBcastSendV2CommunicationHint(value string) CollectiveBcastSendV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["communication_hint"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// CollectiveBcastSendV2TimeoutSeconds sets the optional timeout_seconds attribute to value.
|
||||
// If not specified, defaults to 0
|
||||
func CollectiveBcastSendV2TimeoutSeconds(value float32) CollectiveBcastSendV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["timeout_seconds"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcasts a tensor value to one or more other devices.
|
||||
func CollectiveBcastSendV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, optional ...CollectiveBcastSendV2Attr) (data tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "CollectiveBcastSendV2",
|
||||
Input: []tf.Input{
|
||||
input, group_size, group_key, instance_key,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// InfeedEnqueueTupleAttr is an optional argument to InfeedEnqueueTuple.
|
||||
type InfeedEnqueueTupleAttr func(optionalAttr)
|
||||
|
||||
|
@ -60,6 +60,7 @@ cc_library(
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -187,7 +187,9 @@ class DelegatedInterpreter {
|
||||
|
||||
class InterpreterFp16 : public DelegatedInterpreter {
|
||||
public:
|
||||
explicit InterpreterFp16(TfLiteBuiltinOperator op) : DelegatedInterpreter(3) {
|
||||
explicit InterpreterFp16(TfLiteBuiltinOperator op,
|
||||
bool const_dequantize_inputs = true)
|
||||
: DelegatedInterpreter(3) {
|
||||
void* builtin_data = malloc(sizeof(int));
|
||||
EXPECT_EQ(interpreter_.AddTensors(5), kTfLiteOk);
|
||||
EXPECT_EQ(interpreter_.SetInputs({0, 1}), kTfLiteOk);
|
||||
@ -243,6 +245,15 @@ class InterpreterFp16 : public DelegatedInterpreter {
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
if (const_dequantize_inputs) {
|
||||
// This simulates the dequantize inputs being constants in the graph.
|
||||
// If this is not true, FP16GraphPartitionHelper should not consider the
|
||||
// corresponding DEQUANTIZE ops.
|
||||
auto* tensor0 = interpreter_.tensor(0);
|
||||
auto* tensor2 = interpreter_.tensor(2);
|
||||
tensor0->allocation_type = kTfLiteMmapRo;
|
||||
tensor2->allocation_type = kTfLiteMmapRo;
|
||||
}
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
1, TfLiteType::kTfLiteFloat32, "t1", dims, quantization, false),
|
||||
@ -337,6 +348,64 @@ TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
}
|
||||
|
||||
InterpreterFp16* interpreter_fp16_non_constant =
|
||||
new InterpreterFp16(kTfLiteBuiltinAdd, /*const_dequantize_inputs=*/false);
|
||||
|
||||
// Same as GetOpsToReplaceAcceptsFp16DequantizeNodes, but the DEQUANTIZE inputs
|
||||
// are not constant. As a result, we don't allow the delegate to accept them.
|
||||
TEST(ModelBuilderTest, GetOpsToReplaceRejectsNonConstantFp16DequantizeNodes) {
|
||||
TfLiteContext* context = interpreter_fp16_non_constant->context();
|
||||
|
||||
// These functions are meant to be called inside delegates. Swap out
|
||||
// for similar functions to permit direct calling of GetOpsToReplace.
|
||||
context->GetExecutionPlan = [](struct TfLiteContext* context,
|
||||
TfLiteIntArray** execution_plan) {
|
||||
*execution_plan = interpreter_fp16_non_constant->exec_plan();
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
*node = interpreter_fp16_non_constant->node(node_index);
|
||||
*registration = interpreter_fp16_non_constant->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->PreviewDelegatePartitioning =
|
||||
[](struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||
TfLiteDelegateParams** partition_params_array, int* num_partitions) {
|
||||
// The partitioner should accept only the Add op initially.
|
||||
EXPECT_EQ(nodes_to_replace->size, 1);
|
||||
// Single partition output.
|
||||
auto params = interpreter_fp16_non_constant->add_delegate_params();
|
||||
params->nodes_to_replace = TfLiteIntArrayCreate(1);
|
||||
params->nodes_to_replace->data[0] = 2;
|
||||
params->input_tensors = TfLiteIntArrayCreate(2);
|
||||
params->input_tensors->data[0] = 1;
|
||||
params->input_tensors->data[1] = 3;
|
||||
params->output_tensors = TfLiteIntArrayCreate(1);
|
||||
params->output_tensors->data[0] = 4;
|
||||
|
||||
*partition_params_array =
|
||||
interpreter_fp16_non_constant->delegate_params();
|
||||
*num_partitions = interpreter_fp16_non_constant->num_delegate_params();
|
||||
return kTfLiteOk;
|
||||
};
|
||||
|
||||
TfLiteIntArray* ops_to_replace = GetOpsToReplace(context);
|
||||
|
||||
// Only ADD is delegated, with FP32 (dequantized) inputs.
|
||||
EXPECT_EQ(ops_to_replace->size, 1);
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
context->GetNodeAndRegistration(context, ops_to_replace->data[0], &node,
|
||||
®istration);
|
||||
EXPECT_EQ(context->tensors[node->inputs->data[0]].type,
|
||||
TfLiteType::kTfLiteFloat32);
|
||||
EXPECT_EQ(context->tensors[node->inputs->data[1]].type,
|
||||
TfLiteType::kTfLiteFloat32);
|
||||
TfLiteIntArrayFree(ops_to_replace);
|
||||
}
|
||||
|
||||
InterpreterFp16* interpreter_fp16_gt_op =
|
||||
new InterpreterFp16(kTfLiteBuiltinGreater);
|
||||
|
||||
@ -800,6 +869,13 @@ class InterpreterMultiNode : public DelegatedInterpreter {
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
2, TfLiteType::kTfLiteFloat16, "t2", dims, quantization, false),
|
||||
kTfLiteOk);
|
||||
// Simulate DEQUANTIZE inputs being constants.
|
||||
auto* tensor0 = interpreter_.tensor(0);
|
||||
auto* tensor1 = interpreter_.tensor(1);
|
||||
auto* tensor2 = interpreter_.tensor(2);
|
||||
tensor0->allocation_type = kTfLiteMmapRo;
|
||||
tensor1->allocation_type = kTfLiteMmapRo;
|
||||
tensor2->allocation_type = kTfLiteMmapRo;
|
||||
EXPECT_EQ(
|
||||
interpreter_.SetTensorParametersReadWrite(
|
||||
3, TfLiteType::kTfLiteFloat32, "t3", dims, quantization, false),
|
||||
|
@ -135,6 +135,7 @@ cc_test(
|
||||
":nnapi_delegate_mock_test",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:deprecated_backends",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
"//tensorflow/lite/nnapi:nnapi_lib",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
@ -183,7 +184,8 @@ FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
||||
// its value in delegated_dequant_consumers.
|
||||
for (int j = 0; j < node->inputs->size; ++j) {
|
||||
const int input_tid = node->inputs->data[j];
|
||||
if (dequant_consumers_.find(input_tid) != dequant_consumers_.end()) {
|
||||
if (constant_dequant_consumers_.find(input_tid) !=
|
||||
constant_dequant_consumers_.end()) {
|
||||
delegated_dequant_consumers[input_tid] += 1;
|
||||
}
|
||||
}
|
||||
@ -192,9 +194,10 @@ FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
||||
// If the number of delegated consumers is same as total number of consumers,
|
||||
// add the corresponding DEQUANTIZE op to the delegated nodes.
|
||||
for (auto tensor_and_consumers : delegated_dequant_consumers) {
|
||||
if (dequant_consumers_[tensor_and_consumers.first] ==
|
||||
if (constant_dequant_consumers_[tensor_and_consumers.first] ==
|
||||
tensor_and_consumers.second) {
|
||||
ops_to_replace.emplace_back(dequant_nodes_[tensor_and_consumers.first]);
|
||||
ops_to_replace.emplace_back(
|
||||
constant_dequant_nodes_[tensor_and_consumers.first]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,16 +219,21 @@ FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitionsImpl(
|
||||
bool FP16GraphPartitionHelper::IsNodeSupported(
|
||||
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
|
||||
int node_id, std::string* unsupported_details) {
|
||||
if (registration->builtin_code == kTfLiteBuiltinDequantize &&
|
||||
context_->tensors[node->inputs->data[0]].type ==
|
||||
TfLiteType::kTfLiteFloat16) {
|
||||
// Update mappings if this node is a fp16 DEQUANTIZE node.
|
||||
dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
|
||||
dequant_nodes_[node->outputs->data[0]] = node_id;
|
||||
// We do not accept these ops right now.
|
||||
// This is done to support use-cases where a DEQUANTIZE output might be
|
||||
// consumed by a CPU op.
|
||||
return false;
|
||||
if (registration->builtin_code == kTfLiteBuiltinDequantize) {
|
||||
auto& dequantize_input = context_->tensors[node->inputs->data[0]];
|
||||
if (dequantize_input.type == kTfLiteFloat16 &&
|
||||
IsConstantTensor(&dequantize_input)) {
|
||||
// Update mappings if this node is a fp16 DEQUANTIZE node that
|
||||
// works on a **constant** input tensor.
|
||||
// If the input is not a constant, the remapping that we do here will
|
||||
// cause bugs due to preceding ops such as DENSIFY.
|
||||
constant_dequant_map_[node->outputs->data[0]] = node->inputs->data[0];
|
||||
constant_dequant_nodes_[node->outputs->data[0]] = node_id;
|
||||
// We do not accept these ops right now.
|
||||
// This is done to support use-cases where a DEQUANTIZE output might be
|
||||
// consumed by a CPU op.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// To check if a (possibly) FP16 node is supported, we temporarily point the
|
||||
@ -234,7 +242,7 @@ bool FP16GraphPartitionHelper::IsNodeSupported(
|
||||
// we remap the original node inputs, so that the TFLite graph remains the
|
||||
// same.
|
||||
std::vector<int> orig_inputs;
|
||||
if (!dequant_nodes_.empty()) {
|
||||
if (!constant_dequant_nodes_.empty()) {
|
||||
RemapFp16InputTensors(node, &orig_inputs);
|
||||
}
|
||||
|
||||
@ -245,10 +253,11 @@ bool FP16GraphPartitionHelper::IsNodeSupported(
|
||||
// Remapping happened. Restore original inputs.
|
||||
for (int j = 0; j < node->inputs->size; ++j) {
|
||||
node->inputs->data[j] = orig_inputs[j];
|
||||
if (dequant_nodes_.find(orig_inputs[j]) != dequant_nodes_.end()) {
|
||||
if (constant_dequant_nodes_.find(orig_inputs[j]) !=
|
||||
constant_dequant_nodes_.end()) {
|
||||
// If its a fp16 tensor, increment number of consumers of the
|
||||
// corresponding DEQUANTIZE.
|
||||
dequant_consumers_[orig_inputs[j]] += 1;
|
||||
constant_dequant_consumers_[orig_inputs[j]] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -289,8 +298,8 @@ void FP16GraphPartitionHelper::RemapFp16InputTensors(
|
||||
bool is_remapped = false;
|
||||
for (int j = 0; j < inputs->size; ++j) {
|
||||
const int input_tid = inputs->data[j];
|
||||
const auto it = dequant_map_.find(input_tid);
|
||||
if (it != dequant_map_.end()) {
|
||||
const auto it = constant_dequant_map_.find(input_tid);
|
||||
if (it != constant_dequant_map_.end()) {
|
||||
inputs->data[j] = it->second;
|
||||
is_remapped = true;
|
||||
}
|
||||
|
@ -131,8 +131,8 @@ class GraphPartitionHelper {
|
||||
// Specialized partitioner for graphs that possibly contain fp16 tensors.
|
||||
//
|
||||
// From nodes that accept fp16 inputs, this delegates the following:
|
||||
// 1. All nodes (except DEQUANTIZE) that are supported with fp16 inputs by the
|
||||
// delegate (in the TFLite graph, these nodes take in dequantized FP32
|
||||
// 1. All nodes (except DEQUANTIZE) that are supported with constant fp16 inputs
|
||||
// by the delegate (in the TFLite graph, these nodes take in dequantized FP32
|
||||
// outputs).
|
||||
// 2. All fp16 DEQUANTIZE nodes that have *all* their consumers in the *first*
|
||||
// delegated partition. This is because TFLite's partitioning algorithm
|
||||
@ -168,11 +168,12 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
||||
|
||||
// ('dequantize' here refers to fp16 DEQUANTIZE)
|
||||
// Mapping of dequantize nodes' output tensor-id to its node id.
|
||||
std::unordered_map<int, int> dequant_nodes_;
|
||||
// TODO(b/156707497): Use absl hash_maps here.
|
||||
std::unordered_map<int, int> constant_dequant_nodes_;
|
||||
// Mapping of DEQUANTIZE node's output (fp32) to its input (fp16).
|
||||
std::unordered_map<int, int> dequant_map_;
|
||||
std::unordered_map<int, int> constant_dequant_map_;
|
||||
// mapping of DEQUANTIZE output tensor-id to its number of consumers.
|
||||
std::unordered_map<int, int> dequant_consumers_;
|
||||
std::unordered_map<int, int> constant_dequant_consumers_;
|
||||
};
|
||||
|
||||
} // namespace delegates
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -65,11 +65,12 @@ import org.checkerframework.checker.nullness.qual.NonNull;
|
||||
* model with Toco, as are the default shapes of the inputs.
|
||||
*
|
||||
* <p>When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
|
||||
* be implicitly resized according to that array's shape. When inputs are provided as {@link Buffer}
|
||||
* types, no implicit resizing is done; the caller must ensure that the {@link Buffer} byte size
|
||||
* either matches that of the corresponding tensor, or that they first resize the tensor via {@link
|
||||
* #resizeInput()}. Tensor shape and type information can be obtained via the {@link Tensor} class,
|
||||
* available via {@link #getInputTensor(int)} and {@link #getOutputTensor(int)}.
|
||||
* be implicitly resized according to that array's shape. When inputs are provided as {@link
|
||||
* java.nio.Buffer} types, no implicit resizing is done; the caller must ensure that the {@link
|
||||
* java.nio.Buffer} byte size either matches that of the corresponding tensor, or that they first
|
||||
* resize the tensor via {@link #resizeInput(int, int[])}. Tensor shape and type information can be
|
||||
* obtained via the {@link Tensor} class, available via {@link #getInputTensor(int)} and {@link
|
||||
* #getOutputTensor(int)}.
|
||||
*
|
||||
* <p><b>WARNING:</b>Instances of a {@code Interpreter} is <b>not</b> thread-safe. A {@code
|
||||
* Interpreter} owns resources that <b>must</b> be explicitly freed by invoking {@link #close()}
|
||||
@ -269,7 +270,7 @@ public final class Interpreter implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
|
||||
* {@link #Options}.
|
||||
* {@link Interpreter.Options}.
|
||||
*
|
||||
* <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
|
||||
* {@code ByteBuffer} can be either a {@link MappedByteBuffer} that memory-maps a model file, or a
|
||||
@ -285,33 +286,35 @@ public final class Interpreter implements AutoCloseable {
|
||||
/**
|
||||
* Runs model inference if the model takes only one input, and provides only one output.
|
||||
*
|
||||
* <p>Warning: The API is more efficient if a {@link Buffer} (preferably direct, but not required)
|
||||
* is used as the input/output data type. Please consider using {@link Buffer} to feed and fetch
|
||||
* primitive data for better performance. The following concrete {@link Buffer} types are
|
||||
* supported:
|
||||
* <p>Warning: The API is more efficient if a {@link java.nio.Buffer} (preferably direct, but not
|
||||
* required) is used as the input/output data type. Please consider using {@link java.nio.Buffer}
|
||||
* to feed and fetch primitive data for better performance. The following concrete {@link
|
||||
* java.nio.Buffer} types are supported:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
|
||||
* <li>{@link FloatBuffer} - compatible with float Tensors.
|
||||
* <li>{@link IntBuffer} - compatible with int32 Tensors.
|
||||
* <li>{@link LongBuffer} - compatible with int64 Tensors.
|
||||
* <li>{@link java.nio.FloatBuffer} - compatible with float Tensors.
|
||||
* <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors.
|
||||
* <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors.
|
||||
* </ul>
|
||||
*
|
||||
* Note that boolean types are only supported as arrays, not {@link Buffer}s, or as scalar inputs.
|
||||
* Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as
|
||||
* scalar inputs.
|
||||
*
|
||||
* @param input an array or multidimensional array, or a {@link Buffer} of primitive types
|
||||
* including int, float, long, and byte. {@link Buffer} is the preferred way to pass large
|
||||
* input data for primitive types, whereas string types require using the (multi-dimensional)
|
||||
* array input path. When a {@link Buffer} is used, its content should remain unchanged until
|
||||
* model inference is done, and the caller must ensure that the {@link Buffer} is at the
|
||||
* appropriate read position. A {@code null} value is allowed only if the caller is using a
|
||||
* {@link Delegate} that allows buffer handle interop, and such a buffer has been bound to the
|
||||
* input {@link Tensor}.
|
||||
* @param output a multidimensional array of output data, or a {@link Buffer} of primitive types
|
||||
* including int, float, long, and byte. When a {@link Buffer} is used, the caller must ensure
|
||||
* that it is set the appropriate write position. A null value is allowed only if the caller
|
||||
* is using a {@link Delegate} that allows buffer handle interop, and such a buffer has been
|
||||
* bound to the output {@link Tensor}. See {@link Options#setAllowBufferHandleOutput()}.
|
||||
* @param input an array or multidimensional array, or a {@link java.nio.Buffer} of primitive
|
||||
* types including int, float, long, and byte. {@link java.nio.Buffer} is the preferred way to
|
||||
* pass large input data for primitive types, whereas string types require using the
|
||||
* (multi-dimensional) array input path. When a {@link java.nio.Buffer} is used, its content
|
||||
* should remain unchanged until model inference is done, and the caller must ensure that the
|
||||
* {@link java.nio.Buffer} is at the appropriate read position. A {@code null} value is
|
||||
* allowed only if the caller is using a {@link Delegate} that allows buffer handle interop,
|
||||
* and such a buffer has been bound to the input {@link Tensor}.
|
||||
* @param output a multidimensional array of output data, or a {@link java.nio.Buffer} of
|
||||
* primitive types including int, float, long, and byte. When a {@link java.nio.Buffer} is
|
||||
* used, the caller must ensure that it is set the appropriate write position. A null value is
|
||||
* allowed only if the caller is using a {@link Delegate} that allows buffer handle interop,
|
||||
* and such a buffer has been bound to the output {@link Tensor}. See {@link
|
||||
* Interpreter.Options#setAllowBufferHandleOutput(boolean)}.
|
||||
* @throws IllegalArgumentException if {@code input} or {@code output} is null or empty, or if
|
||||
* error occurs when running the inference.
|
||||
* @throws IllegalArgumentException (EXPERIMENTAL, subject to change) if the inference is
|
||||
@ -327,35 +330,36 @@ public final class Interpreter implements AutoCloseable {
|
||||
/**
|
||||
* Runs model inference if the model takes multiple inputs, or returns multiple outputs.
|
||||
*
|
||||
* <p>Warning: The API is more efficient if {@link Buffer}s (preferably direct, but not required)
|
||||
* are used as the input/output data types. Please consider using {@link Buffer} to feed and fetch
|
||||
* primitive data for better performance. The following concrete {@link Buffer} types are
|
||||
* supported:
|
||||
* <p>Warning: The API is more efficient if {@link java.nio.Buffer}s (preferably direct, but not
|
||||
* required) are used as the input/output data types. Please consider using {@link
|
||||
* java.nio.Buffer} to feed and fetch primitive data for better performance. The following
|
||||
* concrete {@link java.nio.Buffer} types are supported:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
|
||||
* <li>{@link FloatBuffer} - compatible with float Tensors.
|
||||
* <li>{@link IntBuffer} - compatible with int32 Tensors.
|
||||
* <li>{@link LongBuffer} - compatible with int64 Tensors.
|
||||
* <li>{@link java.nio.FloatBuffer} - compatible with float Tensors.
|
||||
* <li>{@link java.nio.IntBuffer} - compatible with int32 Tensors.
|
||||
* <li>{@link java.nio.LongBuffer} - compatible with int64 Tensors.
|
||||
* </ul>
|
||||
*
|
||||
* Note that boolean types are only supported as arrays, not {@link Buffer}s, or as scalar inputs.
|
||||
* Note that boolean types are only supported as arrays, not {@link java.nio.Buffer}s, or as
|
||||
* scalar inputs.
|
||||
*
|
||||
* <p>Note: {@code null} values for invididual elements of {@code inputs} and {@code outputs} is
|
||||
* allowed only if the caller is using a {@link Delegate} that allows buffer handle interop, and
|
||||
* such a buffer has been bound to the corresponding input or output {@link Tensor}(s).
|
||||
*
|
||||
* @param inputs an array of input data. The inputs should be in the same order as inputs of the
|
||||
* model. Each input can be an array or multidimensional array, or a {@link Buffer} of
|
||||
* primitive types including int, float, long, and byte. {@link Buffer} is the preferred way
|
||||
* to pass large input data, whereas string types require using the (multi-dimensional) array
|
||||
* input path. When {@link Buffer} is used, its content should remain unchanged until model
|
||||
* inference is done, and the caller must ensure that the {@link Buffer} is at the appropriate
|
||||
* read position.
|
||||
* model. Each input can be an array or multidimensional array, or a {@link java.nio.Buffer}
|
||||
* of primitive types including int, float, long, and byte. {@link java.nio.Buffer} is the
|
||||
* preferred way to pass large input data, whereas string types require using the
|
||||
* (multi-dimensional) array input path. When {@link java.nio.Buffer} is used, its content
|
||||
* should remain unchanged until model inference is done, and the caller must ensure that the
|
||||
* {@link java.nio.Buffer} is at the appropriate read position.
|
||||
* @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
|
||||
* Buffer}s of primitive types including int, float, long, and byte. It only needs to keep
|
||||
* entries for the outputs to be used. When a {@link Buffer} is used, the caller must ensure
|
||||
* that it is set the appropriate write position.
|
||||
* java.nio.Buffer}s of primitive types including int, float, long, and byte. It only needs to
|
||||
* keep entries for the outputs to be used. When a {@link java.nio.Buffer} is used, the caller
|
||||
* must ensure that it is set the appropriate write position.
|
||||
* @throws IllegalArgumentException if {@code inputs} or {@code outputs} is null or empty, or if
|
||||
* error occurs when running the inference.
|
||||
*/
|
||||
@ -494,8 +498,8 @@ public final class Interpreter implements AutoCloseable {
|
||||
/**
|
||||
* Sets the number of threads to be used for ops that support multi-threading.
|
||||
*
|
||||
* @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread
|
||||
* multi-threading. This method will be removed in a future release.
|
||||
* @deprecated Prefer using {@link Interpreter.Options#setNumThreads(int)} directly for
|
||||
* controlling thread multi-threading. This method will be removed in a future release.
|
||||
*/
|
||||
@Deprecated
|
||||
public void setNumThreads(int numThreads) {
|
||||
@ -507,8 +511,8 @@ public final class Interpreter implements AutoCloseable {
|
||||
* Advanced: Modifies the graph with the provided {@link Delegate}.
|
||||
*
|
||||
* @throws IllegalArgumentException if error occurs when modifying graph with {@code delegate}.
|
||||
* @deprecated Prefer using {@link Options#addDelegate} to provide delegates at creation time.
|
||||
* This method will be removed in a future release.
|
||||
* @deprecated Prefer using {@link Interpreter.Options#addDelegate} to provide delegates at
|
||||
* creation time. This method will be removed in a future release.
|
||||
*/
|
||||
@Deprecated
|
||||
public void modifyGraphWithDelegate(Delegate delegate) {
|
||||
|
@ -85,7 +85,7 @@ public final class TensorFlowLite {
|
||||
}
|
||||
}
|
||||
|
||||
public static native String nativeRuntimeVersion();
|
||||
private static native String nativeRuntimeVersion();
|
||||
|
||||
public static native String nativeSchemaVersion();
|
||||
private static native String nativeSchemaVersion();
|
||||
}
|
||||
|
@ -349,7 +349,9 @@ cc_library(
|
||||
"//conditions:default": ["-DTFLITE_HAVE_CPUINFO"],
|
||||
}),
|
||||
deps = [
|
||||
":deprecated_backends", # TODO(b/168923364): Move to dependent targets.
|
||||
# TODO(b/168923364): Remove deprecated_backends after it is added to all
|
||||
# necessary targets.
|
||||
":deprecated_backends",
|
||||
":tflite_with_ruy",
|
||||
":op_macros",
|
||||
# For now this unconditionally depends on both ruy and gemmlowp.
|
||||
|
@ -21,13 +21,5 @@ machine LoadPlatformDescription @platforms/cpus/stm32f103.repl
|
||||
# These lines are needed to show the results of DebugLog calls in the output.
|
||||
machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu"
|
||||
showAnalyzer cpu.uartSemihosting Antmicro.Renode.Analyzers.LoggingUartAnalyzer
|
||||
|
||||
logFile $logfile
|
||||
|
||||
macro reset
|
||||
"""
|
||||
sysbus LoadELF $bin
|
||||
"""
|
||||
|
||||
runMacro $reset
|
||||
cpu.uartSemihosting CreateFileBackend $logfile true
|
||||
|
||||
|
26
tensorflow/lite/micro/testing/robot.resource.txt
Normal file
26
tensorflow/lite/micro/testing/robot.resource.txt
Normal file
@ -0,0 +1,26 @@
|
||||
*** Variables ***
|
||||
${UART} sysbus.cpu.uartSemihosting
|
||||
|
||||
*** Keywords ***
|
||||
Teardown With Custom Message
|
||||
Test Teardown
|
||||
[Documentation] Replace robot fail message with whole UART output
|
||||
${UART_LOGS} Get File ${UART_LOG}
|
||||
Set Test Message UART OUTPUT:\n\n${UART_LOGS}
|
||||
Remove File ${UART_LOG}
|
||||
|
||||
Create Platform
|
||||
Execute Command $logfile=@${UART_LOG}
|
||||
Execute Script ${RESC}
|
||||
Provides ready-platform
|
||||
|
||||
Test Binary
|
||||
[Arguments] ${BIN}
|
||||
Requires ready-platform
|
||||
Execute Command sysbus LoadELF ${BIN}
|
||||
|
||||
Create Terminal Tester ${UART} timeout=2
|
||||
Start Emulation
|
||||
|
||||
Wait For Line On Uart ${UART_LINE_ON_SUCCESS}
|
||||
|
@ -13,12 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
#
|
||||
#
|
||||
# Parameters:
|
||||
# ${1} - path to a binary to test or directory (all *_test will be run).
|
||||
|
||||
set -e
|
||||
|
||||
TARGET=bluepill
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
TFLM_ROOT_DIR=${SCRIPT_DIR}/..
|
||||
|
||||
# The renode script for the board being emulated.
|
||||
RESC_PATH=${TFLM_ROOT_DIR}/testing/bluepill.resc
|
||||
RESC_PATH=${TFLM_ROOT_DIR}/testing/${TARGET}.resc
|
||||
|
||||
# Robot file with definition of custom keywords used in test suite.
|
||||
ROBOT_RESOURCE=${TFLM_ROOT_DIR}/testing/robot.resource.txt
|
||||
|
||||
# Renode's entrypoint for using the Robot Framework.
|
||||
RENODE_TEST_SCRIPT=${TFLM_ROOT_DIR}/tools/make/downloads/renode/test.sh
|
||||
@ -38,39 +49,59 @@ then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit_code=0
|
||||
|
||||
# The logs from this script will go in the RESULTS_DIRECTORY. These include:
|
||||
# 1. RENODE_LOG: Output log from the renode process.
|
||||
# Files generated by this script will go in the RESULTS_DIRECTORY. These include:
|
||||
# 1. UART_LOG: Output log from the renode uart.
|
||||
# 2. html and xml files generated by the Robot Framework.
|
||||
# 3. ROBOT_SCRIPT: Generated test suite.
|
||||
#
|
||||
# Note that with the current approach (in bluepill.robot), multiple test
|
||||
# binaries are run in a loop and RENODE_LOG only has logs from the last test
|
||||
# binary since it is deleted prior to running each test binary.
|
||||
RESULTS_DIRECTORY=/tmp/renode_bluepill_logs
|
||||
# Note that with the current approach (in generated ROBOT_SCRIPT), multiple test
|
||||
# binaries are run in a the same test suite and UART_LOG only has logs from the last test
|
||||
# binary since it is deleted prior to running each test binary. If some test fails
|
||||
# the UART_LOG will be printed to console log before being deleted.
|
||||
RESULTS_DIRECTORY=/tmp/renode_${TARGET}_logs
|
||||
mkdir -p ${RESULTS_DIRECTORY}
|
||||
RENODE_LOG=${RESULTS_DIRECTORY}/renode_log.txt
|
||||
|
||||
ROBOT_COMMAND="${RENODE_TEST_SCRIPT} ${TFLM_ROOT_DIR}/testing/bluepill.robot \
|
||||
UART_LOG=${RESULTS_DIRECTORY}/uart_log.txt
|
||||
|
||||
ROBOT_SCRIPT=${RESULTS_DIRECTORY}/${TARGET}.robot
|
||||
|
||||
echo -e "*** Settings ***\n" \
|
||||
"Suite Setup Setup\n" \
|
||||
"Suite Teardown Teardown\n" \
|
||||
"Test Setup Reset Emulation\n" \
|
||||
"Test Teardown Teardown With Custom Message\n" \
|
||||
"Resource \${RENODEKEYWORDS}\n" \
|
||||
"Resource ${ROBOT_RESOURCE}\n" \
|
||||
"Default Tags tensorflow\n" \
|
||||
"\n" \
|
||||
"*** Variables ***\n" \
|
||||
"\${RESC} undefined_RESC\n" \
|
||||
"\${UART_LOG} /tmp/uart.log\n" \
|
||||
"\${UART_LINE_ON_SUCCESS} ~~~ALL TESTS PASSED~~~\n" \
|
||||
"\${CREATE_SNAPSHOT_ON_FAIL} False\n" \
|
||||
"\n" \
|
||||
"*** Test Cases ***\n" \
|
||||
"Should Create Platform\n" \
|
||||
" Create Platform\n" > $ROBOT_SCRIPT
|
||||
|
||||
declare -a FILES
|
||||
if [[ -d ${1} ]]; then
|
||||
FILES=`ls -1 ${1}/*_test`
|
||||
else
|
||||
FILES=${1}
|
||||
fi
|
||||
|
||||
for binary in ${FILES}
|
||||
do
|
||||
echo -e "Should Run $(basename ${binary})\n"\
|
||||
" Test Binary @$(realpath ${binary})\n" >> ${ROBOT_SCRIPT}
|
||||
done
|
||||
|
||||
ROBOT_COMMAND="${RENODE_TEST_SCRIPT} ${ROBOT_SCRIPT} \
|
||||
-r ${RESULTS_DIRECTORY} \
|
||||
--variable RESC:${RESC_PATH} \
|
||||
--variable RENODE_LOG:${RENODE_LOG} \
|
||||
--variable DIR_WITH_TESTS:${1}"
|
||||
--variable UART_LOG:${UART_LOG}"
|
||||
|
||||
echo "${ROBOT_COMMAND}"
|
||||
|
||||
if ! ${ROBOT_COMMAND}
|
||||
then
|
||||
exit_code=1
|
||||
fi
|
||||
|
||||
if [ $exit_code -eq 0 ]
|
||||
then
|
||||
echo "PASS"
|
||||
else
|
||||
echo "UART LOGS:"
|
||||
# Extract output from renode log
|
||||
cat ${RENODE_LOG} |grep 'uartSemihosting' |sed 's/^.*from start] *//g'
|
||||
fi
|
||||
|
||||
exit $exit_code
|
||||
echo ""
|
||||
${ROBOT_COMMAND}
|
||||
|
@ -39,9 +39,4 @@ readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARG
|
||||
# Next, build w/o release so that we can run the tests and get additional
|
||||
# debugging info on failures.
|
||||
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
|
||||
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} build
|
||||
|
||||
# TODO(b/172939049): Using renode to run the tests is not currently integrated
|
||||
# with the Makefile. So, we manually run the test script with the correct path
|
||||
# to the bluepill generated files.
|
||||
tensorflow/lite/micro/testing/test_bluepill_binary.sh tensorflow/lite/micro/tools/make/gen/bluepill_cortex-m3/bin/
|
||||
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} test
|
||||
|
@ -525,6 +525,10 @@ $(HOST_OS) \
|
||||
arduino \
|
||||
chre
|
||||
|
||||
# ${TARGET}_makefile.inc can set this to true to allow it to defined a custom
|
||||
# implementation for `make test`. See bluepill_makefile as an example.
|
||||
TARGET_SPECIFIC_MAKE_TEST:=0
|
||||
|
||||
ifeq ($(findstring $(TARGET),$(TARGETS_WITHOUT_MAKEFILES)),)
|
||||
include $(MAKEFILE_DIR)/targets/$(TARGET)_makefile.inc
|
||||
endif
|
||||
@ -644,7 +648,9 @@ $(eval $(call microlite_test,$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET)
|
||||
$(foreach TEST_TARGET,$(filter tensorflow/lite/micro/kernels/%,$(MICROLITE_TEST_SRCS)),\
|
||||
$(eval $(call microlite_test,kernel_$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET))))
|
||||
|
||||
ifeq ($(TARGET_SPECIFIC_MAKE_TEST),0)
|
||||
test: $(MICROLITE_TEST_TARGETS)
|
||||
endif
|
||||
|
||||
# Just build the test targets
|
||||
build: $(MICROLITE_BUILD_TARGETS)
|
||||
|
@ -59,3 +59,10 @@ MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE
|
||||
|
||||
TEST_SCRIPT := tensorflow/lite/micro/testing/test_bluepill_binary.sh
|
||||
|
||||
# We are setting this variable to non-zero to allow us to have a custom
|
||||
# implementation of `make test` for bluepill
|
||||
TARGET_SPECIFIC_MAKE_TEST := 1
|
||||
test: build
|
||||
$(TEST_SCRIPT) $(BINDIR)
|
||||
|
||||
|
||||
|
@ -45,7 +45,9 @@ ASYNC_STATEFUL_OPS = [
|
||||
"CollectiveReduce",
|
||||
"CollectiveReduceV2",
|
||||
"CollectiveBcastSend",
|
||||
"CollectiveBcastSendV2",
|
||||
"CollectiveBcastRecv",
|
||||
"CollectiveBcastRecvV2",
|
||||
"NcclAllReduce",
|
||||
# We do not add "Send" here since we want it to be added as a control output
|
||||
# in order to avoid being pruned.
|
||||
|
@ -232,6 +232,7 @@ py_test(
|
||||
":benchmark_util",
|
||||
":profiler_lib",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -32,6 +34,26 @@ _PRINT_EVAL_STEP_EVERY_SEC = 60.0
|
||||
_ITERATIONS_UNINITIALIZED = -1
|
||||
|
||||
|
||||
def list_checkpoint_attributes(ckpt_dir_or_file):
|
||||
"""Lists all the attributes in a checkpoint.
|
||||
|
||||
Checkpoint keys are paths in a checkpoint graph, and attribute is the first
|
||||
element in the path. e.g. with a checkpoint key
|
||||
"optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE", optimizer is the attribute. The
|
||||
attribute is also used to save/restore a variable in a checkpoint,
|
||||
e.g. tf.train.Checkpoint(optimizer=optimizer, model=model).
|
||||
|
||||
Args:
|
||||
ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
|
||||
|
||||
Returns:
|
||||
Set of attributes in a checkpoint.
|
||||
"""
|
||||
reader = checkpoint_utils.load_checkpoint(ckpt_dir_or_file)
|
||||
variable_map = reader.get_variable_to_shape_map()
|
||||
return {name.split('/')[0] for name in variable_map.keys()}
|
||||
|
||||
|
||||
class SidecarEvaluator(object):
|
||||
"""A class designed for a dedicated evaluator task.
|
||||
|
||||
@ -148,6 +170,21 @@ class SidecarEvaluator(object):
|
||||
# `expect_partial` because the checkpoint can have other `Trackable`s
|
||||
# such as `optimizer`.
|
||||
checkpoint.restore(latest_checkpoint).expect_partial()
|
||||
checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint)
|
||||
# The checkpoint should contain model and optimizer for SidecarEvaluator
|
||||
# to work. But the model weights saved by ModelCheckpoint callback does
|
||||
# not contain model as an attribute. To make SidecarEvaluator compatibly
|
||||
# work in this case, if model attribute is not found but
|
||||
# layer_with_weights attribute is found, use model.load_weights to load
|
||||
# the model's weights, while self._iterations is still restored by
|
||||
# checkpoint variable.
|
||||
if 'model' not in checkpoint_attributes:
|
||||
for attribute in checkpoint_attributes:
|
||||
# check whether the checkpoint has the required attributes for
|
||||
# model.load_weights to work.
|
||||
if re.match(r'^layer_with_weights-[\d+]', attribute) is not None:
|
||||
self.model.load_weights(latest_checkpoint)
|
||||
break
|
||||
except (errors_impl.OpError,) as e:
|
||||
# A couple errors can happen here with the coordinator racing to write
|
||||
# checkpoint:
|
||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
@ -36,6 +35,8 @@ from tensorflow.python.summary import summary_iterator
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training.tracking import util as tracking_util
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
|
||||
|
||||
class SidecarEvaluatorTest(test.TestCase):
|
||||
|
||||
@ -130,7 +131,6 @@ class SidecarEvaluatorTest(test.TestCase):
|
||||
|
||||
self.assertSummaryEventsWritten(log_dir)
|
||||
|
||||
@unittest.skip('b/172976255')
|
||||
def testSidecarEvaluatorOutputsSummarySavedWithCallback(self):
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints')
|
||||
log_dir = os.path.join(self.get_temp_dir(), 'summary')
|
||||
@ -139,7 +139,7 @@ class SidecarEvaluatorTest(test.TestCase):
|
||||
data = np.random.random((1000, 32))
|
||||
labels = np.random.random((1000, 10))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
|
||||
dataset = dataset.batch(32)
|
||||
dataset = dataset.batch(_BATCH_SIZE)
|
||||
save_callback = keras.callbacks.ModelCheckpoint(
|
||||
filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
|
||||
save_weights_only=True)
|
||||
@ -152,17 +152,22 @@ class SidecarEvaluatorTest(test.TestCase):
|
||||
# Create a new model used for evaluation.
|
||||
eval_model = self.createTestModel(compile_model=True)
|
||||
# Have an sidecar_evaluator evaluate once.
|
||||
sidecar_evaluator_lib.SidecarEvaluator(
|
||||
sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
|
||||
eval_model,
|
||||
data=dataset,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
max_evaluations=1).start()
|
||||
max_evaluations=1)
|
||||
sidecar_evaluator.start()
|
||||
|
||||
# Eval model has been restored to the same state as the original model, so
|
||||
# their weights should match. If not, restoration of the model didn't
|
||||
# work.
|
||||
self.assertModelsSameVariables(model, eval_model)
|
||||
|
||||
# check the iterations is restored.
|
||||
self.assertEqual(sidecar_evaluator._iterations.numpy(), _BATCH_SIZE)
|
||||
|
||||
self.assertSummaryEventsWritten(log_dir)
|
||||
|
||||
|
||||
|
@ -41,6 +41,7 @@ py_library(
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras/engine:input_spec",
|
||||
"//tensorflow/python/keras/legacy_tf_layers:layers_base",
|
||||
"//tensorflow/python/keras/saving",
|
||||
"//tensorflow/python/keras/utils:tf_utils",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
],
|
||||
|
@ -346,6 +346,12 @@ class RNNCell(base_layer.Layer):
|
||||
def get_config(self): # pylint: disable=useless-super-delegation
|
||||
return super(RNNCell, self).get_config()
|
||||
|
||||
@property
|
||||
def _use_input_spec_as_call_signature(self):
|
||||
# We do not store the shape information for the state argument in the call
|
||||
# function for legacy RNN cells, so do not generate an input signature.
|
||||
return False
|
||||
|
||||
|
||||
class LayerRNNCell(RNNCell):
|
||||
"""Subclass of RNNCells that act like proper `tf.Layer` objects.
|
||||
|
@ -3128,6 +3128,10 @@ cuda_py_test(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/saved_model:load",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/training/tracking",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -43,6 +43,8 @@ from tensorflow.python.platform import test
|
||||
class CollectiveOpsV1(object):
|
||||
all_reduce = _collective_ops.all_reduce
|
||||
all_gather = _collective_ops.all_gather
|
||||
broadcast_send = _collective_ops.broadcast_send
|
||||
broadcast_recv = _collective_ops.broadcast_recv
|
||||
|
||||
|
||||
class CollectiveOpsV2(object):
|
||||
@ -63,6 +65,25 @@ class CollectiveOpsV2(object):
|
||||
return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
|
||||
*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
||||
*args, **kwargs):
|
||||
group_size = array_ops.identity(group_size)
|
||||
group_key = array_ops.identity(group_key)
|
||||
instance_key = array_ops.identity(instance_key)
|
||||
return _collective_ops.broadcast_send_v2(t, group_size, group_key,
|
||||
instance_key, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
|
||||
**kwargs):
|
||||
group_size = array_ops.identity(group_size)
|
||||
group_key = array_ops.identity(group_key)
|
||||
instance_key = array_ops.identity(instance_key)
|
||||
shape = array_ops.identity(shape)
|
||||
return _collective_ops.broadcast_recv_v2(
|
||||
shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
|
||||
|
||||
|
||||
device_combination = (
|
||||
combinations.combine(device='CPU', communication='RING', required_gpus=0) +
|
||||
@ -191,6 +212,42 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
for result in run_all_gather_2devices():
|
||||
self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testBroadcast(self, collective_ops, device, communication):
|
||||
dev0 = '/device:%s:0' % device
|
||||
dev1 = '/device:%s:1' % device
|
||||
|
||||
@def_function.function
|
||||
def run_broadcast_2devices():
|
||||
shape = [3]
|
||||
in_value = constant_op.constant([1., 2., 3.], shape=shape)
|
||||
group_size = 2
|
||||
group_key = 2
|
||||
instance_key = 2
|
||||
collectives = []
|
||||
with ops.device(dev0):
|
||||
collectives.append(
|
||||
collective_ops.broadcast_send(
|
||||
in_value,
|
||||
shape,
|
||||
in_value.dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication))
|
||||
with ops.device(dev1):
|
||||
collectives.append(
|
||||
collective_ops.broadcast_recv(
|
||||
shape,
|
||||
in_value.dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint=communication))
|
||||
return collectives
|
||||
|
||||
for result in run_broadcast_2devices():
|
||||
self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
|
||||
communication):
|
||||
if device == 'GPU' and context.num_gpus() < 4:
|
||||
|
@ -26,14 +26,16 @@ import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -47,6 +49,9 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -3060,6 +3065,29 @@ class RNNCellTest(test.TestCase, parameterized.TestCase):
|
||||
reconstructed_wrapper = wrapper_cls.from_config(config_copy)
|
||||
self.assertFalse(reconstructed_wrapper._dropout_state_filter(None))
|
||||
|
||||
def testSavedModel(self):
|
||||
if test_util.is_gpu_available():
|
||||
self.skipTest("b/175887901")
|
||||
|
||||
with self.cached_session():
|
||||
root = tracking.AutoTrackable()
|
||||
root.cell = rnn_cell_impl.LSTMCell(8)
|
||||
@def_function.function(input_signature=[tensor_spec.TensorSpec([3, 8])])
|
||||
def call(x):
|
||||
state = root.cell.zero_state(3, dtype=x.dtype)
|
||||
y, _ = root.cell(x, state)
|
||||
return y
|
||||
root.call = call
|
||||
expected = root.call(array_ops.zeros((3, 8)))
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(root, save_dir)
|
||||
loaded = load.load(save_dir)
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
self.assertAllClose(
|
||||
expected, loaded.call(array_ops.zeros((3, 8))))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@test_util.run_all_without_tensor_float_32(
|
||||
|
@ -261,6 +261,40 @@ def broadcast_send(t,
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_send_v2(t,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Broadcasts one tensor to a group of others, across devices.
|
||||
|
||||
Args:
|
||||
t: the tensor to be sent.
|
||||
group_size: an int32 tensor. One plus the number of receiving tensors, i.e.
|
||||
the total number of devices participating. Each tensor must reside on a
|
||||
different device.
|
||||
group_key: an int32 tensor identifying the group of devices.
|
||||
instance_key: an int32 tensor identifying the participating group of Ops.
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the distributed broadcast send.
|
||||
"""
|
||||
return gen_collective_ops.collective_bcast_send_v2(
|
||||
t,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_recv(shape,
|
||||
dtype,
|
||||
group_size,
|
||||
@ -302,3 +336,41 @@ def broadcast_recv(shape,
|
||||
instance_key=instance_key,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
||||
def broadcast_recv_v2(shape,
|
||||
dtype,
|
||||
group_size,
|
||||
group_key,
|
||||
instance_key,
|
||||
communication_hint='auto',
|
||||
timeout=0):
|
||||
"""Receives a broadcasts tensor, across devices.
|
||||
|
||||
Args:
|
||||
shape: an int tensor. Shape of the tensor to be received.
|
||||
dtype: Type of the tensor to be received.
|
||||
group_size: an int32 tensor. One plus the number of receiving tensors, i.e.
|
||||
the total number of devices participating. Each tensor must reside on a
|
||||
different device.
|
||||
group_key: an int32 tensor identifying the group of devices.
|
||||
instance_key: an int32 tensor identifying the participating group of Ops.
|
||||
communication_hint: preferred collective communication. The implementation
|
||||
may fall back to another mechanism. Options include `auto`, `ring`, and
|
||||
`nccl`.
|
||||
timeout: If set to a non zero, set a completion timeout to detect staleness.
|
||||
If the timer goes off, a DeadlineExceededError is raised.
|
||||
The timeout value in seconds. This feature is experimental.
|
||||
|
||||
Returns:
|
||||
An Op implementing the broadcast receive.
|
||||
"""
|
||||
return gen_collective_ops.collective_bcast_recv_v2(
|
||||
T=dtype,
|
||||
group_size=group_size,
|
||||
group_key=group_key,
|
||||
instance_key=instance_key,
|
||||
shape=shape,
|
||||
communication_hint=communication_hint.lower(),
|
||||
timeout_seconds=timeout)
|
||||
|
||||
|
@ -1015,7 +1015,7 @@ def tf_gen_op_wrapper_py(
|
||||
native.py_library(
|
||||
name = generated_target_name,
|
||||
srcs = [out],
|
||||
srcs_version = "PY2AND3",
|
||||
srcs_version = "PY3",
|
||||
visibility = visibility,
|
||||
deps = [
|
||||
clean_dep("//tensorflow/python:framework_for_generated_wrappers_v2"),
|
||||
@ -1828,7 +1828,7 @@ def tf_custom_op_py_library(
|
||||
srcs = [],
|
||||
dso = [],
|
||||
kernels = [],
|
||||
srcs_version = "PY2AND3",
|
||||
srcs_version = "PY3",
|
||||
visibility = None,
|
||||
deps = [],
|
||||
**kwargs):
|
||||
@ -2019,7 +2019,7 @@ def pywrap_tensorflow_macro(
|
||||
native.py_library(
|
||||
name = name,
|
||||
srcs = [":" + name + ".py"],
|
||||
srcs_version = "PY2AND3",
|
||||
srcs_version = "PY3",
|
||||
data = select({
|
||||
clean_dep("//tensorflow:windows"): [":" + cc_library_pyd_name],
|
||||
"//conditions:default": [":" + cc_library_name],
|
||||
@ -2139,7 +2139,7 @@ def tf_py_test(
|
||||
deps.append(clean_dep(to_add))
|
||||
|
||||
# Python version placeholder
|
||||
kwargs.setdefault("srcs_version", "PY2AND3")
|
||||
kwargs.setdefault("srcs_version", "PY3")
|
||||
py_test(
|
||||
name = name,
|
||||
size = size,
|
||||
@ -2500,7 +2500,7 @@ def pybind_extension(
|
||||
module_name,
|
||||
hdrs = [],
|
||||
features = [],
|
||||
srcs_version = "PY2AND3",
|
||||
srcs_version = "PY3",
|
||||
data = [],
|
||||
copts = [],
|
||||
linkopts = [],
|
||||
|
@ -752,10 +752,18 @@ tf_module {
|
||||
name: "CollectiveBcastRecv"
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastRecvV2"
|
||||
argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSend"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSendV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
|
@ -752,10 +752,18 @@ tf_module {
|
||||
name: "CollectiveBcastRecv"
|
||||
argspec: "args=[\'T\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastRecvV2"
|
||||
argspec: "args=[\'group_size\', \'group_key\', \'instance_key\', \'shape\', \'T\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSend"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveBcastSendV2"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CollectiveGather"
|
||||
argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'shape\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], "
|
||||
|
Loading…
x
Reference in New Issue
Block a user