Introduce Xla{Dot,Conv}V2

These support mixed operand precision (bf16 @ f32 -> f32), (int8 @ int8 -> int32) and thus diverge significantly from the original Xla{Dot,Conv}.

PiperOrigin-RevId: 358860920
Change-Id: I316f1b43268b268ce50528f7aa307182bce304f6
This commit is contained in:
David Majnemer 2021-02-22 11:15:54 -08:00 committed by TensorFlower Gardener
parent 85c597aaa2
commit 4856f23a49
11 changed files with 341 additions and 122 deletions

View File

@ -694,8 +694,10 @@ tensorflow::MemoryTypeVector GetOutputMemoryTypes(
static auto const ops_triggering_xla_compilation =
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
"XlaConv",
"XlaConvV2",
"XlaDequantize",
"XlaDot",
"XlaDotV2",
"XlaDynamicSlice",
"XlaDynamicUpdateSlice",
"XlaEinsum",

View File

@ -2057,8 +2057,10 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"While",
"XlaBroadcastHelper",
"XlaConv",
"XlaConvV2",
"XlaDequantize",
"XlaDot",
"XlaDotV2",
"XlaDynamicSlice",
"XlaDynamicUpdateSlice",
"XlaEinsum",

View File

@ -29,6 +29,14 @@ namespace tensorflow {
.HostMemory("feature_group_count") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaConvV2") \
.HostMemory("window_strides") \
.HostMemory("padding") \
.HostMemory("lhs_dilation") \
.HostMemory("rhs_dilation") \
.HostMemory("feature_group_count") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER( \
Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \
XlaCompileOnDemandOp); \
@ -38,6 +46,8 @@ namespace tensorflow {
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDotV2").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER( \
Name("XlaDynamicSlice").HostMemory("size_indices").Device(DEVICE), \
XlaCompileOnDemandOp); \

View File

@ -17682,6 +17682,37 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaConvV2Op : TF_Op<"XlaConvV2", [NoSideEffect]> {
let summary = "Wraps the XLA ConvGeneralDilated operator, documented at";
let description = [{
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
.
}];
let arguments = (ins
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the input tensor}]>:$lhs,
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the kernel tensor}]>:$rhs,
Arg<TF_I32OrI64Tensor, [{the inter-window strides}]>:$window_strides,
Arg<TF_I32OrI64Tensor, [{the padding to apply at the start and end of each input dimensions}]>:$padding,
Arg<TF_I32OrI64Tensor, [{dilation to apply between input elements}]>:$lhs_dilation,
Arg<TF_I32OrI64Tensor, [{dilation to apply between kernel elements}]>:$rhs_dilation,
Arg<TF_I32OrI64Tensor, [{number of feature groups for grouped convolution.}]>:$feature_group_count,
StrAttr:$dimension_numbers,
StrAttr:$precision_config
);
let results = (outs
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>;
TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>;
}
def TF_XlaDotOp : TF_Op<"XlaDot", [NoSideEffect]> {
let summary = "Wraps the XLA DotGeneral operator, documented at";
@ -17705,6 +17736,31 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaDotV2Op : TF_Op<"XlaDotV2", [NoSideEffect]> {
let summary = "Wraps the XLA DotGeneral operator, documented at";
let description = [{
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
.
}];
let arguments = (ins
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the LHS tensor}]>:$lhs,
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the RHS tensor}]>:$rhs,
StrAttr:$dimension_numbers,
StrAttr:$precision_config
);
let results = (outs
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>;
TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>;
}
def TF_XlaDynamicSliceOp : TF_Op<"XlaDynamicSlice", [NoSideEffect]> {
let summary = "Wraps the XLA DynamicSlice operator, documented at";

View File

@ -264,7 +264,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::UpperBoundOp>(),
TypeID::get<TF::XlaBroadcastHelperOp>(),
TypeID::get<TF::XlaConvOp>(),
TypeID::get<TF::XlaConvV2Op>(),
TypeID::get<TF::XlaDotOp>(),
TypeID::get<TF::XlaDotV2Op>(),
TypeID::get<TF::XlaDynamicSliceOp>(),
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
TypeID::get<TF::XlaEinsumOp>(),

View File

@ -205,6 +205,35 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
],
dtype=dtype))
def testDotGeneralInt8xInt8ToInt32(self):
def dot_fn(lhs, rhs):
dnums = xla_data_pb2.DotDimensionNumbers()
dnums.lhs_contracting_dimensions.append(2)
dnums.rhs_contracting_dimensions.append(1)
dnums.lhs_batch_dimensions.append(0)
dnums.rhs_batch_dimensions.append(0)
return xla.dot_general(
lhs, rhs, dimension_numbers=dnums, preferred_element_type=np.int32)
lhs = np.array([
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
], dtype=np.int8)
rhs = np.array([
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
],
dtype=np.int8)
self._assertOpOutputMatchesExpected(
dot_fn,
args=(lhs, rhs),
expected=np.array([
[[9, 12, 15], [19, 26, 33]],
[[95, 106, 117], [129, 144, 159]],
],
dtype=np.int32))
def testNeg(self):
for dtype in self.numeric_types - {np.uint8, np.int8}:
self._assertOpOutputMatchesExpected(

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -38,6 +39,7 @@ class XlaConvOp : public XlaOpKernel {
OP_REQUIRES(context,
precision_config_.ParsePartialFromString(precision_config_attr),
errors::InvalidArgument("Error parsing precision config."));
preferred_element_type_ = absl::nullopt;
}
void Compile(XlaOpKernelContext* context) override {
@ -77,10 +79,13 @@ class XlaConvOp : public XlaOpKernel {
xla::XlaOp output = xla::ConvGeneralDilated(
context->Input(0), context->Input(1), window_strides, padding,
lhs_dilation, rhs_dilation, dnums_, feature_group_count,
/*batch_group_count=*/1, &precision_config_);
/*batch_group_count=*/1, &precision_config_, preferred_element_type_);
context->SetOutput(0, output);
}
protected:
absl::optional<xla::PrimitiveType> preferred_element_type_;
private:
xla::ConvolutionDimensionNumbers dnums_;
xla::PrecisionConfig precision_config_;
@ -96,5 +101,29 @@ REGISTER_XLA_OP(Name("XlaConv")
.CompileTimeConstantInput("padding"),
XlaConvOp);
class XlaConvV2Op : public XlaConvOp {
public:
explicit XlaConvV2Op(OpKernelConstruction* context) : XlaConvOp(context) {
DataType preferred_element_dtype;
OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type",
&preferred_element_dtype));
xla::PrimitiveType preferred_element_type;
OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype,
&preferred_element_type));
preferred_element_type_ = preferred_element_type;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(XlaConvV2Op);
};
REGISTER_XLA_OP(Name("XlaConvV2")
.CompileTimeConstantInput("window_strides")
.CompileTimeConstantInput("lhs_dilation")
.CompileTimeConstantInput("rhs_dilation")
.CompileTimeConstantInput("feature_group_count")
.CompileTimeConstantInput("padding"),
XlaConvOp);
} // namespace
} // namespace tensorflow

View File

@ -14,12 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
@ -39,6 +41,7 @@ class XlaDotOp : public XlaOpKernel {
context,
precision_config_.ParsePartialFromString(precision_config_attr),
errors::InvalidArgument("Error parsing convolution dimension numbers"));
preferred_element_type_ = absl::nullopt;
}
void Compile(XlaOpKernelContext* context) override {
@ -47,19 +50,40 @@ class XlaDotOp : public XlaOpKernel {
// We do only minimal checking, relying on XLA to check the shape
// invariants.
xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1),
dnums_, &precision_config_);
xla::XlaOp output =
xla::DotGeneral(context->Input(0), context->Input(1), dnums_,
&precision_config_, preferred_element_type_);
context->SetOutput(0, output);
}
protected:
absl::optional<xla::PrimitiveType> preferred_element_type_;
private:
xla::DotDimensionNumbers dnums_;
xla::PrecisionConfig precision_config_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
};
REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp);
class XlaDotV2Op : public XlaDotOp {
public:
explicit XlaDotV2Op(OpKernelConstruction* context) : XlaDotOp(context) {
DataType preferred_element_dtype;
OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type",
&preferred_element_dtype));
xla::PrimitiveType preferred_element_type;
OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype,
&preferred_element_type));
preferred_element_type_ = preferred_element_type;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(XlaDotV2Op);
};
REGISTER_XLA_OP(Name("XlaDotV2"), XlaDotV2Op);
} // namespace
} // namespace tensorflow

View File

@ -161,6 +161,147 @@ dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaConvV2")
.Input("lhs: LhsT")
.Input("rhs: RhsT")
.Input("window_strides: Tindices")
.Input("padding: Tindices")
.Input("lhs_dilation: Tindices")
.Input("rhs_dilation: Tindices")
.Input("feature_group_count: Tindices")
.Attr("LhsT: numbertype")
.Attr("RhsT: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Attr("preferred_element_type: numbertype")
.Output("output: preferred_element_type")
.SetShapeFn(UnchangedRank)
.Doc(R"doc(
Wraps the XLA ConvGeneralDilated operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
.
lhs: the input tensor
rhs: the kernel tensor
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
lhs_dilation: dilation to apply between input elements
rhs_dilation: dilation to apply between kernel elements
feature_group_count: number of feature groups for grouped convolution.
dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
preferred_element_type: The type of the tensor.
)doc");
static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
if (!c->FullyDefined(lhs_shape_handle) ||
!c->FullyDefined(rhs_shape_handle)) {
return shape_inference::UnknownShape(c);
}
string dimension_numbers_string;
TF_RETURN_IF_ERROR(
c->GetAttr("dimension_numbers", &dimension_numbers_string));
xla::DotDimensionNumbers dimension_numbers;
dimension_numbers.ParseFromString(dimension_numbers_string);
// Check that number of contracting dimensions match.
if (dimension_numbers.lhs_contracting_dimensions_size() !=
dimension_numbers.rhs_contracting_dimensions_size())
return errors::InvalidArgument(
"Must specify the same number of contracting dimensions for lhs "
"and rhs. Got: ",
dimension_numbers.lhs_contracting_dimensions_size(), " and ",
dimension_numbers.rhs_contracting_dimensions_size());
// Check that contracting dimension sizes match.
for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
++i) {
const int64 lhs_contracting_dimension =
dimension_numbers.lhs_contracting_dimensions(i);
const int64 rhs_contracting_dimension =
dimension_numbers.rhs_contracting_dimensions(i);
shape_inference::DimensionOrConstant lhs_contracting_dimension_or_constant(
c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
shape_inference::DimensionOrConstant rhs_contracting_dimension_or_constant(
c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
const int64 lhs_contracting_dimension_size =
c->Value(lhs_contracting_dimension_or_constant);
const int64 rhs_contracting_dimension_size =
c->Value(rhs_contracting_dimension_or_constant);
if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
return errors::InvalidArgument(
"Contracting dimension sizes do not match. Got: ",
lhs_contracting_dimension_size, " and ",
rhs_contracting_dimension_size);
}
}
// Check that number of batch dimensions match.
if (dimension_numbers.lhs_batch_dimensions_size() !=
dimension_numbers.rhs_batch_dimensions_size())
return errors::InvalidArgument(
"Must specify the same number of batch dimensions for lhs "
"and rhs. Got: ",
dimension_numbers.lhs_batch_dimensions_size(), " and ",
dimension_numbers.rhs_batch_dimensions_size());
// Check that batch dimension sizes match.
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
const int64 lhs_batch_dimension = dimension_numbers.lhs_batch_dimensions(i);
const int64 rhs_batch_dimension = dimension_numbers.rhs_batch_dimensions(i);
shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
const int64 lhs_batch_dimension_size =
c->Value(lhs_batch_dimension_or_constant);
const int64 rhs_batch_dimension_size =
c->Value(rhs_batch_dimension_or_constant);
if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
return errors::InvalidArgument(
"Batch dimension sizes do not match. Got: ", lhs_batch_dimension_size,
" and ", rhs_batch_dimension_size);
}
}
// The ranks of lhs and rhs are decremented by 1 respectively due to the
// contraction, and added for the rank of the result. When an input tensor
// is a scalar, its contribution to the rank of the result is 0. Generate
// the result dimensions in order, rhs dimensions followed by lhs
// dimensions except the contracted and batch dimensions.
std::vector<shape_inference::DimensionHandle> output_dims;
for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
}
const int32 lhs_rank = c->Rank(lhs_shape_handle);
for (int64 i = 0; i < lhs_rank; ++i) {
if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
i) ||
absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
continue;
}
output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
}
const int32 rhs_rank = c->Rank(rhs_shape_handle);
for (int64 i = 0; i < rhs_rank; ++i) {
if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
i) ||
absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
continue;
}
output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
}
c->set_output(0, c->MakeShape(output_dims));
return Status::OK();
}
REGISTER_OP("XlaDot")
.Input("lhs: T")
.Input("rhs: T")
@ -168,120 +309,7 @@ REGISTER_OP("XlaDot")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Output("output: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
if (!c->FullyDefined(lhs_shape_handle) ||
!c->FullyDefined(rhs_shape_handle)) {
return shape_inference::UnknownShape(c);
}
string dimension_numbers_string;
TF_RETURN_IF_ERROR(
c->GetAttr("dimension_numbers", &dimension_numbers_string));
xla::DotDimensionNumbers dimension_numbers;
dimension_numbers.ParseFromString(dimension_numbers_string);
// Check that number of contracting dimensions match.
if (dimension_numbers.lhs_contracting_dimensions_size() !=
dimension_numbers.rhs_contracting_dimensions_size())
return errors::InvalidArgument(
"Must specify the same number of contracting dimensions for lhs "
"and rhs. Got: ",
dimension_numbers.lhs_contracting_dimensions_size(), " and ",
dimension_numbers.rhs_contracting_dimensions_size());
// Check that contracting dimension sizes match.
for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
++i) {
const int64 lhs_contracting_dimension =
dimension_numbers.lhs_contracting_dimensions(i);
const int64 rhs_contracting_dimension =
dimension_numbers.rhs_contracting_dimensions(i);
shape_inference::DimensionOrConstant
lhs_contracting_dimension_or_constant(
c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
shape_inference::DimensionOrConstant
rhs_contracting_dimension_or_constant(
c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
const int64 lhs_contracting_dimension_size =
c->Value(lhs_contracting_dimension_or_constant);
const int64 rhs_contracting_dimension_size =
c->Value(rhs_contracting_dimension_or_constant);
if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
return errors::InvalidArgument(
"Contracting dimension sizes do not match. Got: ",
lhs_contracting_dimension_size, " and ",
rhs_contracting_dimension_size);
}
}
// Check that number of batch dimensions match.
if (dimension_numbers.lhs_batch_dimensions_size() !=
dimension_numbers.rhs_batch_dimensions_size())
return errors::InvalidArgument(
"Must specify the same number of batch dimensions for lhs "
"and rhs. Got: ",
dimension_numbers.lhs_batch_dimensions_size(), " and ",
dimension_numbers.rhs_batch_dimensions_size());
// Check that batch dimension sizes match.
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size();
++i) {
const int64 lhs_batch_dimension =
dimension_numbers.lhs_batch_dimensions(i);
const int64 rhs_batch_dimension =
dimension_numbers.rhs_batch_dimensions(i);
shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
const int64 lhs_batch_dimension_size =
c->Value(lhs_batch_dimension_or_constant);
const int64 rhs_batch_dimension_size =
c->Value(rhs_batch_dimension_or_constant);
if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
return errors::InvalidArgument(
"Batch dimension sizes do not match. Got: ",
lhs_batch_dimension_size, " and ", rhs_batch_dimension_size);
}
}
// The ranks of lhs and rhs are decremented by 1 respectively due to the
// contraction, and added for the rank of the result. When an input tensor
// is a scalar, its contribution to the rank of the result is 0. Generate
// the result dimensions in order, rhs dimensions followed by lhs
// dimensions except the contracted and batch dimensions.
std::vector<shape_inference::DimensionHandle> output_dims;
for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
}
const int32 lhs_rank = c->Rank(lhs_shape_handle);
for (int64 i = 0; i < lhs_rank; ++i) {
if (absl::c_linear_search(
dimension_numbers.lhs_contracting_dimensions(), i) ||
absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
i)) {
continue;
}
output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
}
const int32 rhs_rank = c->Rank(rhs_shape_handle);
for (int64 i = 0; i < rhs_rank; ++i) {
if (absl::c_linear_search(
dimension_numbers.rhs_contracting_dimensions(), i) ||
absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
i)) {
continue;
}
output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
}
c->set_output(0, c->MakeShape(output_dims));
return Status::OK();
})
.SetShapeFn(XlaDotShapeFunction)
.Doc(R"doc(
Wraps the XLA DotGeneral operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
@ -293,6 +321,28 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
REGISTER_OP("XlaDotV2")
.Input("lhs: LhsT")
.Input("rhs: RhsT")
.Attr("LhsT: numbertype")
.Attr("RhsT: numbertype")
.Attr("dimension_numbers: string")
.Attr("precision_config: string")
.Attr("preferred_element_type: numbertype")
.Output("output: preferred_element_type")
.SetShapeFn(XlaDotShapeFunction)
.Doc(R"doc(
Wraps the XLA DotGeneral operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
.
lhs: the LHS tensor
rhs: the RHS tensor
dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
preferred_element_type: The type of the tensor.
)doc");
REGISTER_OP("XlaSetBound")
.Input("input: int32")
.Input("bound: int32")

View File

@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.numpy_ops import np_utils
# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
# ops include:
@ -249,6 +250,7 @@ def conv(lhs,
dimension_numbers,
feature_group_count=1,
precision_config=None,
preferred_element_type=None,
name=None):
"""Wraps the XLA ConvGeneralDilated operator.
@ -266,6 +268,7 @@ def conv(lhs,
dimension_numbers: a `ConvolutionDimensionNumbers` proto.
feature_group_count: number of feature groups for grouped convolution.
precision_config: a `xla.PrecisionConfig` proto.
preferred_element_type: the result `dtype`.
name: an optional name for the operator
Returns:
@ -274,7 +277,9 @@ def conv(lhs,
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
return gen_xla_ops.xla_conv(
if preferred_element_type is None:
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
return gen_xla_ops.xla_conv_v2(
lhs,
rhs,
window_strides=window_strides,
@ -284,6 +289,7 @@ def conv(lhs,
feature_group_count=feature_group_count,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name)
@ -294,15 +300,23 @@ def dot(lhs, rhs, name=None):
return math_ops.tensordot(lhs, rhs, axes=1, name=name)
def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
def dot_general(lhs,
rhs,
dimension_numbers,
precision_config=None,
preferred_element_type=None,
name=None):
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
return gen_xla_ops.xla_dot(
if preferred_element_type is None:
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
return gen_xla_ops.xla_dot_v2(
lhs,
rhs,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name)

View File

@ -143,7 +143,8 @@ bool IsOpTensorCoreEligible(absl::string_view tf_op_name) {
|| absl::StrContains(tf_op_name, "CudnnRNNForward")
|| absl::StrContains(tf_op_name, "CudnnRNNBackprop")
// Special cases.
|| absl::EndsWith(tf_op_name, "XlaDot");
|| absl::EndsWith(tf_op_name, "XlaDot")
|| absl::EndsWith(tf_op_name, "XlaDotV2");
// clang-format on
}