Add the tf2xla_supported_ops tool, which dumps ops supported by tf2xla.
Also fix a TODO in XlaOpRegistry to filter by the types allowed by the OpDef. Also see #14798 PiperOrigin-RevId: 177986664
This commit is contained in:
parent
e72ecbdb7a
commit
4b0a236848
@ -1,6 +1,6 @@
|
|||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
@ -25,6 +25,30 @@ package(
|
|||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf2xla_supported_ops_lib",
|
||||||
|
srcs = ["tf2xla_supported_ops.cc"],
|
||||||
|
hdrs = ["tf2xla_supported_ops.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||||
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:ops",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_binary(
|
||||||
|
name = "tf2xla_supported_ops",
|
||||||
|
srcs = ["tf2xla_supported_ops_main.cc"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [":tf2xla_supported_ops_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
xla_proto_library(
|
xla_proto_library(
|
||||||
name = "tf2xla_proto",
|
name = "tf2xla_proto",
|
||||||
srcs = ["tf2xla.proto"],
|
srcs = ["tf2xla.proto"],
|
||||||
|
242
tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
Normal file
242
tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
**Supported operators for device: XLA_CPU_JIT**
|
||||||
|
|
||||||
|
Operator | Type Constraint
|
||||||
|
------------------------------------- | ---------------
|
||||||
|
`Abs` | `T={double,float,int32,int64}`
|
||||||
|
`Acosh` | `T={complex64,double,float}`
|
||||||
|
`Add` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`All` | `Tidx={int32,int64}`
|
||||||
|
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`Any` | `Tidx={int32,int64}`
|
||||||
|
`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ArgMax` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={float}`
|
||||||
|
`ArgMin` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Asinh` | `T={complex64,double,float}`
|
||||||
|
`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Atan2` | `T={double,float}`
|
||||||
|
`Atanh` | `T={complex64,double,float}`
|
||||||
|
`AvgPool` | `T={double,float}`
|
||||||
|
`AvgPool3D` | `T={double,float}`
|
||||||
|
`AvgPool3DGrad` | `T={double,float}`
|
||||||
|
`AvgPoolGrad` | `T={double,float}`
|
||||||
|
`BatchMatMul` | `T={complex64,double,float,int32}`
|
||||||
|
`BatchToSpace` | `Tidx={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BatchToSpaceND` | `Tcrops={int32,int64}`<br>`Tblock_shape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BitwiseAnd` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`BitwiseOr` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`BroadcastArgs` | `T={int32,int64}`
|
||||||
|
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||||
|
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Ceil` | `T={double,float}`
|
||||||
|
`Cholesky` | `T={complex64,double,float}`
|
||||||
|
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||||
|
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ConcatOffset` |
|
||||||
|
`ConcatV2` | `Tidx={int32}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Conj` | `T={complex64}`
|
||||||
|
`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ControlTrigger` |
|
||||||
|
`Conv2D` | `T={float}`
|
||||||
|
`Conv2DBackpropFilter` | `T={float}`
|
||||||
|
`Conv2DBackpropInput` | `T={float}`
|
||||||
|
`Conv3D` | `T={double,float}`
|
||||||
|
`Conv3DBackpropFilterV2` | `T={double,float}`
|
||||||
|
`Conv3DBackpropInputV2` | `T={double,float}`
|
||||||
|
`Cos` | `T={complex64,double,float}`
|
||||||
|
`Cosh` | `T={complex64,double,float}`
|
||||||
|
`Cross` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Cumprod` | `Tidx={int32,int64}`<br>`T={float}`
|
||||||
|
`Cumsum` | `Tidx={int32,int64}`<br>`T={float}`
|
||||||
|
`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`DepthwiseConv2dNative` | `T={double,float}`
|
||||||
|
`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}`
|
||||||
|
`DepthwiseConv2dNativeBackpropInput` | `T={double,float}`
|
||||||
|
`Diag` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`DiagPart` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Div` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Elu` | `T={double,float}`
|
||||||
|
`EluGrad` | `T={double,float}`
|
||||||
|
`Equal` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`Exp` | `T={complex64,double,float}`
|
||||||
|
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Expm1` | `T={complex64,double,float}`
|
||||||
|
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Floor` | `T={double,float}`
|
||||||
|
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`FloorMod` | `T={double,float,int32,int64}`
|
||||||
|
`FusedBatchNorm` | `T={float}`
|
||||||
|
`FusedBatchNormGrad` | `T={float}`
|
||||||
|
`FusedBatchNormGradV2` | `U={float}`<br>`T={float}`
|
||||||
|
`FusedBatchNormV2` | `U={float}`<br>`T={float}`
|
||||||
|
`Gather` | `Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`Inv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Invert` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`InvertPermutation` | `T={int32}`
|
||||||
|
`IsFinite` | `T={double,float}`
|
||||||
|
`IsInf` | `T={double,float}`
|
||||||
|
`IsNan` | `T={double,float}`
|
||||||
|
`L2Loss` | `T={double,float}`
|
||||||
|
`LRN` | `T={float}`
|
||||||
|
`LRNGrad` | `T={float}`
|
||||||
|
`LeftShift` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`Less` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`LessEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`LinSpace` | `Tidx={int32,int64}`<br>`T={double,float}`
|
||||||
|
`Log` | `T={complex64,double,float}`
|
||||||
|
`Log1p` | `T={complex64,double,float}`
|
||||||
|
`LogSoftmax` | `T={double,float}`
|
||||||
|
`LogicalAnd` |
|
||||||
|
`LogicalNot` |
|
||||||
|
`LogicalOr` |
|
||||||
|
`MatMul` | `T={complex64,double,float}`
|
||||||
|
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MaxPool` | `T={double,float,int32,int64}`
|
||||||
|
`MaxPool3D` | `T={float}`
|
||||||
|
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||||
|
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Maximum` | `T={double,float,int32,int64}`
|
||||||
|
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Minimum` | `T={double,float,int32,int64}`
|
||||||
|
`MirrorPad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Mod` | `T={double,float,int32,int64}`
|
||||||
|
`Mul` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Multinomial` | `output_dtype={int32,int64}`<br>`T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Neg` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`NoOp` |
|
||||||
|
`NotEqual` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`OneHot` | `TI={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`OnesLike` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Pad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`PadV2` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Pow` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||||
|
`RandomStandardNormal` | `dtype={float}`
|
||||||
|
`RandomUniform` | `T={int32,int64}`<br>`dtype={double,float}`
|
||||||
|
`RandomUniformInt` | `T={int32,int64}`<br>`Tout={int32,int64}`
|
||||||
|
`Range` | `Tidx={double,float,int32,int64}`
|
||||||
|
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Real` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`RealDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Reciprocal` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`ReciprocalGrad` | `T={complex64,double,float}`
|
||||||
|
`Relu` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Relu6` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ResourceApplyAdagrad` | `T={double,float}`
|
||||||
|
`ResourceApplyAdam` | `T={double,float}`
|
||||||
|
`ResourceApplyFtrl` | `T={double,float}`
|
||||||
|
`ResourceApplyFtrlV2` | `T={double,float}`
|
||||||
|
`ResourceApplyGradientDescent` | `T={double,float}`
|
||||||
|
`ResourceApplyMomentum` | `T={double,float}`
|
||||||
|
`ResourceApplyRMSProp` | `T={double,float}`
|
||||||
|
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||||
|
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`Rint` | `T={double,float}`
|
||||||
|
`Round` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Rsqrt` | `T={complex64,double,float}`
|
||||||
|
`RsqrtGrad` | `T={complex64,double,float}`
|
||||||
|
`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Selu` | `T={double,float}`
|
||||||
|
`SeluGrad` | `T={double,float}`
|
||||||
|
`Shape` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ShapeN` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Sigmoid` | `T={complex64,double,float}`
|
||||||
|
`SigmoidGrad` | `T={complex64,double,float}`
|
||||||
|
`Sign` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Sin` | `T={complex64,double,float}`
|
||||||
|
`Sinh` | `T={complex64,double,float}`
|
||||||
|
`Size` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Slice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Softmax` | `T={double,float}`
|
||||||
|
`SoftmaxCrossEntropyWithLogits` | `T={double,float}`
|
||||||
|
`Softplus` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Softsign` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SpaceToBatch` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SpaceToBatchND` | `Tblock_shape={int32,int64}`<br>`Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SparseMatMul` | `Tb={float}`<br>`Ta={float}`
|
||||||
|
`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`<br>`T={double,float}`
|
||||||
|
`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SplitV` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Sqrt` | `T={complex64,double,float}`
|
||||||
|
`SqrtGrad` | `T={complex64,double,float}`
|
||||||
|
`Square` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`SquaredDifference` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StackCloseV2` |
|
||||||
|
`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StatelessRandomNormal` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||||
|
`StatelessRandomUniform` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||||
|
`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StridedSlice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StridedSliceGrad` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Sub` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Sum` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Tan` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Tanh` | `T={complex64,double,float}`
|
||||||
|
`TanhGrad` | `T={complex64,double,float}`
|
||||||
|
`TensorArrayCloseV3` |
|
||||||
|
`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayGradV3` |
|
||||||
|
`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArraySizeV3` |
|
||||||
|
`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Tile` | `Tmultiples={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Transpose` | `Tperm={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TruncateDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`TruncateMod` | `T={double,float,int32,int64}`
|
||||||
|
`TruncatedNormal` | `T={int32,int64}`<br>`dtype={double,float}`
|
||||||
|
`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`<br>`Tindices={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`VarIsInitializedOp` |
|
||||||
|
`VariableShape` | `out_type={int32,int64}`
|
||||||
|
`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||||
|
`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||||
|
`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
|
||||||
|
To regenerate this table, run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_CPU_JIT
|
||||||
|
```
|
238
tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
Normal file
238
tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
**Supported operators for device: XLA_GPU_JIT**
|
||||||
|
|
||||||
|
Operator | Type Constraint
|
||||||
|
------------------------------------- | ---------------
|
||||||
|
`Abs` | `T={double,float,int32,int64}`
|
||||||
|
`Acosh` | `T={complex64,double,float}`
|
||||||
|
`Add` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`All` | `Tidx={int32,int64}`
|
||||||
|
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`Any` | `Tidx={int32,int64}`
|
||||||
|
`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ArgMax` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ArgMin` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Asinh` | `T={complex64,double,float}`
|
||||||
|
`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Atan2` | `T={double,float}`
|
||||||
|
`Atanh` | `T={complex64,double,float}`
|
||||||
|
`AvgPool` | `T={double,float}`
|
||||||
|
`AvgPool3D` | `T={double,float}`
|
||||||
|
`AvgPool3DGrad` | `T={double,float}`
|
||||||
|
`AvgPoolGrad` | `T={double,float}`
|
||||||
|
`BatchMatMul` | `T={complex64,double,float,int32}`
|
||||||
|
`BatchToSpace` | `Tidx={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BatchToSpaceND` | `Tcrops={int32,int64}`<br>`Tblock_shape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`BitwiseAnd` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`BitwiseOr` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`BroadcastArgs` | `T={int32,int64}`
|
||||||
|
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||||
|
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Ceil` | `T={double,float}`
|
||||||
|
`Cholesky` | `T={complex64,double,float}`
|
||||||
|
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||||
|
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ConcatOffset` |
|
||||||
|
`ConcatV2` | `Tidx={int32}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Conj` | `T={complex64}`
|
||||||
|
`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ControlTrigger` |
|
||||||
|
`Conv2D` | `T={float}`
|
||||||
|
`Conv2DBackpropFilter` | `T={float}`
|
||||||
|
`Conv2DBackpropInput` | `T={float}`
|
||||||
|
`Conv3D` | `T={double,float}`
|
||||||
|
`Conv3DBackpropFilterV2` | `T={double,float}`
|
||||||
|
`Conv3DBackpropInputV2` | `T={double,float}`
|
||||||
|
`Cos` | `T={complex64,double,float}`
|
||||||
|
`Cosh` | `T={complex64,double,float}`
|
||||||
|
`Cross` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Cumprod` | `Tidx={int32,int64}`<br>`T={float}`
|
||||||
|
`Cumsum` | `Tidx={int32,int64}`<br>`T={float}`
|
||||||
|
`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`DepthwiseConv2dNative` | `T={double,float}`
|
||||||
|
`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}`
|
||||||
|
`DepthwiseConv2dNativeBackpropInput` | `T={double,float}`
|
||||||
|
`Diag` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`DiagPart` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Div` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Elu` | `T={double,float}`
|
||||||
|
`EluGrad` | `T={double,float}`
|
||||||
|
`Equal` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`Exp` | `T={complex64,double,float}`
|
||||||
|
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Expm1` | `T={complex64,double,float}`
|
||||||
|
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Floor` | `T={double,float}`
|
||||||
|
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`FloorMod` | `T={double,float,int32,int64}`
|
||||||
|
`FusedBatchNorm` | `T={float}`
|
||||||
|
`FusedBatchNormGrad` | `T={float}`
|
||||||
|
`FusedBatchNormGradV2` | `U={float}`<br>`T={float}`
|
||||||
|
`FusedBatchNormV2` | `U={float}`<br>`T={float}`
|
||||||
|
`Gather` | `Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`Inv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Invert` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`InvertPermutation` | `T={int32}`
|
||||||
|
`IsFinite` | `T={double,float}`
|
||||||
|
`IsInf` | `T={double,float}`
|
||||||
|
`IsNan` | `T={double,float}`
|
||||||
|
`L2Loss` | `T={double,float}`
|
||||||
|
`LRN` | `T={float}`
|
||||||
|
`LRNGrad` | `T={float}`
|
||||||
|
`LeftShift` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`Less` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`LessEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`LinSpace` | `Tidx={int32,int64}`<br>`T={double,float}`
|
||||||
|
`Log` | `T={complex64,double,float}`
|
||||||
|
`Log1p` | `T={complex64,double,float}`
|
||||||
|
`LogSoftmax` | `T={double,float}`
|
||||||
|
`LogicalAnd` |
|
||||||
|
`LogicalNot` |
|
||||||
|
`LogicalOr` |
|
||||||
|
`MatMul` | `T={complex64,double,float}`
|
||||||
|
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`MaxPool` | `T={double,float,int32,int64}`
|
||||||
|
`MaxPool3D` | `T={float}`
|
||||||
|
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||||
|
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Maximum` | `T={double,float,int32,int64}`
|
||||||
|
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Minimum` | `T={double,float,int32,int64}`
|
||||||
|
`MirrorPad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Mod` | `T={double,float,int32,int64}`
|
||||||
|
`Mul` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Multinomial` | `output_dtype={int32,int64}`<br>`T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Neg` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`NoOp` |
|
||||||
|
`NotEqual` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`OneHot` | `TI={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`OnesLike` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Pad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`PadV2` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Pow` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||||
|
`Range` | `Tidx={double,float,int32,int64}`
|
||||||
|
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Real` | `Tout={double,float}`<br>`T={complex64}`
|
||||||
|
`RealDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Reciprocal` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`ReciprocalGrad` | `T={complex64,double,float}`
|
||||||
|
`Relu` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Relu6` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ResourceApplyAdagrad` | `T={double,float}`
|
||||||
|
`ResourceApplyAdam` | `T={double,float}`
|
||||||
|
`ResourceApplyFtrl` | `T={double,float}`
|
||||||
|
`ResourceApplyFtrlV2` | `T={double,float}`
|
||||||
|
`ResourceApplyGradientDescent` | `T={double,float}`
|
||||||
|
`ResourceApplyMomentum` | `T={double,float}`
|
||||||
|
`ResourceApplyRMSProp` | `T={double,float}`
|
||||||
|
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||||
|
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||||
|
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||||
|
`Rint` | `T={double,float}`
|
||||||
|
`Round` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Rsqrt` | `T={complex64,double,float}`
|
||||||
|
`RsqrtGrad` | `T={complex64,double,float}`
|
||||||
|
`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Selu` | `T={double,float}`
|
||||||
|
`SeluGrad` | `T={double,float}`
|
||||||
|
`Shape` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`ShapeN` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Sigmoid` | `T={complex64,double,float}`
|
||||||
|
`SigmoidGrad` | `T={complex64,double,float}`
|
||||||
|
`Sign` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Sin` | `T={complex64,double,float}`
|
||||||
|
`Sinh` | `T={complex64,double,float}`
|
||||||
|
`Size` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Slice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Softmax` | `T={double,float}`
|
||||||
|
`SoftmaxCrossEntropyWithLogits` | `T={double,float}`
|
||||||
|
`Softplus` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Softsign` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SpaceToBatch` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SpaceToBatchND` | `Tblock_shape={int32,int64}`<br>`Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SparseMatMul` | `Tb={float}`<br>`Ta={float}`
|
||||||
|
`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`<br>`T={double,float}`
|
||||||
|
`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SplitV` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Sqrt` | `T={complex64,double,float}`
|
||||||
|
`SqrtGrad` | `T={complex64,double,float}`
|
||||||
|
`Square` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`SquaredDifference` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StackCloseV2` |
|
||||||
|
`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StatelessRandomNormal` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||||
|
`StatelessRandomUniform` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||||
|
`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StridedSlice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`StridedSliceGrad` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Sub` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Sum` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Tan` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`Tanh` | `T={complex64,double,float}`
|
||||||
|
`TanhGrad` | `T={complex64,double,float}`
|
||||||
|
`TensorArrayCloseV3` |
|
||||||
|
`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayGradV3` |
|
||||||
|
`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArraySizeV3` |
|
||||||
|
`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Tile` | `Tmultiples={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`Transpose` | `Tperm={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`TruncateDiv` | `T={complex64,double,float,int32,int64}`
|
||||||
|
`TruncateMod` | `T={double,float,int32,int64}`
|
||||||
|
`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`<br>`Tindices={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`VarIsInitializedOp` |
|
||||||
|
`VariableShape` | `out_type={int32,int64}`
|
||||||
|
`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||||
|
`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||||
|
`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||||
|
|
||||||
|
To regenerate this table, run:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_GPU_JIT
|
||||||
|
```
|
97
tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
Normal file
97
tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tf2xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void PrintSupportedOps(const string& device, const string& regen_run) {
|
||||||
|
XlaOpRegistry::RegisterCompilationKernels();
|
||||||
|
|
||||||
|
std::vector<const KernelDef*> kdefs =
|
||||||
|
XlaOpRegistry::DeviceKernels(device,
|
||||||
|
/*include_compilation_only_kernels=*/true);
|
||||||
|
std::sort(
|
||||||
|
kdefs.begin(), kdefs.end(),
|
||||||
|
[](const KernelDef* a, const KernelDef* b) { return a->op() < b->op(); });
|
||||||
|
|
||||||
|
std::cout << "**Supported operators for device: " << device << "**\n\n"
|
||||||
|
<< "Operator | Type Constraint\n"
|
||||||
|
<< "-------- | ---------------" << std::endl;
|
||||||
|
for (const KernelDef* kdef : kdefs) {
|
||||||
|
std::vector<string> constraints;
|
||||||
|
for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) {
|
||||||
|
std::vector<string> types;
|
||||||
|
for (int type : constraint.allowed_values().list().type()) {
|
||||||
|
types.push_back(DataTypeString(static_cast<DataType>(type)));
|
||||||
|
}
|
||||||
|
std::sort(types.begin(), types.end());
|
||||||
|
constraints.push_back("`" + constraint.name() + "={" +
|
||||||
|
str_util::Join(types, ",") + "}`");
|
||||||
|
}
|
||||||
|
std::cout << "`" << kdef->op() << "` | "
|
||||||
|
<< str_util::Join(constraints, "<br>") << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "\nTo regenerate this table, run:\n\n```shell\n"
|
||||||
|
<< regen_run << " --device=" << device << "\n```" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void SupportedOpsMain(int argc, char** argv, const char* regen_run) {
|
||||||
|
std::vector<string> device_names = XlaOpRegistry::BackendNames();
|
||||||
|
std::sort(device_names.begin(), device_names.end());
|
||||||
|
|
||||||
|
// Set up and parse flags.
|
||||||
|
string device;
|
||||||
|
std::vector<Flag> flag_list = {
|
||||||
|
{"device", &device,
|
||||||
|
"Name of the compilation device for which to print supported ops, "
|
||||||
|
"one of: " +
|
||||||
|
str_util::Join(device_names, ",")},
|
||||||
|
};
|
||||||
|
string usage = Flags::Usage(argv[0], flag_list);
|
||||||
|
bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list);
|
||||||
|
QCHECK(parsed_flags_ok) << "\n" << usage;
|
||||||
|
QCHECK(XlaOpRegistry::IsBackendRegistered(device))
|
||||||
|
<< "\nUnknown device: " << device << "\n"
|
||||||
|
<< usage;
|
||||||
|
|
||||||
|
// Run the program.
|
||||||
|
port::InitMain(usage.c_str(), &argc, &argv);
|
||||||
|
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
|
||||||
|
"other than flags\n\n"
|
||||||
|
<< usage;
|
||||||
|
PrintSupportedOps(device, regen_run);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tf2xla
|
||||||
|
} // namespace tensorflow
|
33
tensorflow/compiler/tf2xla/tf2xla_supported_ops.h
Normal file
33
tensorflow/compiler/tf2xla/tf2xla_supported_ops.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tf2xla {
|
||||||
|
|
||||||
|
// The implementation of a main function for a binary that prints a table of
|
||||||
|
// supported tf2xla operators for a given device, along with their type
|
||||||
|
// constraints, to stdout.
|
||||||
|
//
|
||||||
|
// Pass the argc and argv from main, unmodified. Use regen_run to specify the
|
||||||
|
// command used to regenerate the table.
|
||||||
|
void SupportedOpsMain(int argc, char** argv, const char* regen_run);
|
||||||
|
|
||||||
|
} // namespace tf2xla
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_
|
22
tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc
Normal file
22
tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h"
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
const char* regen_run =
|
||||||
|
"bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops";
|
||||||
|
tensorflow::tf2xla::SupportedOpsMain(argc, argv, regen_run);
|
||||||
|
}
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/device_base.h"
|
#include "tensorflow/core/framework/device_base.h"
|
||||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_def_util.h"
|
||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
|
||||||
@ -187,22 +188,39 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
|
|
||||||
// Constrain each type attribute to the intersection of:
|
// Constrain each type attribute to the intersection of:
|
||||||
// a) the types supported by the backend, and
|
// a) the types supported by the backend, and
|
||||||
// b) the attribute's type constraints.
|
// b) the types allowed by the OpDef, and
|
||||||
// TODO(phawkins): it may be necessary to also take the intersection with
|
// c) the type constraints.
|
||||||
// the set of types supported by the OpDef.
|
|
||||||
for (const string& type_attr : type_attrs) {
|
for (const string& type_attr : type_attrs) {
|
||||||
KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
|
KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
|
||||||
attr_constraint->set_name(type_attr);
|
attr_constraint->set_name(type_attr);
|
||||||
auto* allowed_values =
|
auto* allowed_values =
|
||||||
attr_constraint->mutable_allowed_values()->mutable_list();
|
attr_constraint->mutable_allowed_values()->mutable_list();
|
||||||
|
|
||||||
auto it = op_registration->type_constraints.find(type_attr);
|
const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
|
||||||
|
const auto* op_def_allowed_types =
|
||||||
|
op_def_attr.has_allowed_values()
|
||||||
|
? &op_def_attr.allowed_values().list().type()
|
||||||
|
: nullptr;
|
||||||
|
auto constraint_it = op_registration->type_constraints.find(type_attr);
|
||||||
|
const std::set<DataType>* type_constraints =
|
||||||
|
constraint_it != op_registration->type_constraints.end()
|
||||||
|
? &constraint_it->second
|
||||||
|
: nullptr;
|
||||||
for (DataType dtype : backend.second.supported_types) {
|
for (DataType dtype : backend.second.supported_types) {
|
||||||
if (it == op_registration->type_constraints.end() ||
|
// Filter out types that aren't allowed by the OpDef.
|
||||||
(it != op_registration->type_constraints.end() &&
|
if (op_def_allowed_types != nullptr &&
|
||||||
it->second.find(dtype) != it->second.end())) {
|
std::find(op_def_allowed_types->begin(),
|
||||||
allowed_values->add_type(dtype);
|
op_def_allowed_types->end(),
|
||||||
|
dtype) == op_def_allowed_types->end()) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
// Filter out types based on the type constraints.
|
||||||
|
if (type_constraints != nullptr &&
|
||||||
|
type_constraints->find(dtype) == type_constraints->end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Passed all the filters, this type is allowed.
|
||||||
|
allowed_values->add_type(dtype);
|
||||||
}
|
}
|
||||||
if (op_registration->allow_resource_types) {
|
if (op_registration->allow_resource_types) {
|
||||||
allowed_values->add_type(DT_RESOURCE);
|
allowed_values->add_type(DT_RESOURCE);
|
||||||
@ -245,6 +263,22 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
|||||||
return kernels;
|
return kernels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<string> XlaOpRegistry::BackendNames() {
|
||||||
|
std::vector<string> names;
|
||||||
|
XlaOpRegistry& registry = Instance();
|
||||||
|
mutex_lock lock(registry.mutex_);
|
||||||
|
for (const auto& backend_pair : registry.backends_) {
|
||||||
|
names.push_back(backend_pair.first);
|
||||||
|
}
|
||||||
|
return names;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool XlaOpRegistry::IsBackendRegistered(const string& name) {
|
||||||
|
XlaOpRegistry& registry = Instance();
|
||||||
|
mutex_lock lock(registry.mutex_);
|
||||||
|
return registry.backends_.find(name) != registry.backends_.end();
|
||||||
|
}
|
||||||
|
|
||||||
XlaOpRegistry& XlaOpRegistry::Instance() {
|
XlaOpRegistry& XlaOpRegistry::Instance() {
|
||||||
static XlaOpRegistry* r = new XlaOpRegistry;
|
static XlaOpRegistry* r = new XlaOpRegistry;
|
||||||
return *r;
|
return *r;
|
||||||
|
@ -97,6 +97,12 @@ class XlaOpRegistry {
|
|||||||
gtl::ArraySlice<DataType> supported_types,
|
gtl::ArraySlice<DataType> supported_types,
|
||||||
BackendOpFilter op_filter);
|
BackendOpFilter op_filter);
|
||||||
|
|
||||||
|
// Returns the names of the registered backends.
|
||||||
|
static std::vector<string> BackendNames();
|
||||||
|
|
||||||
|
// Returns true iff a backend with the given name is registered.
|
||||||
|
static bool IsBackendRegistered(const string& name);
|
||||||
|
|
||||||
// Registers `device_name` for XLA compilation, using information from
|
// Registers `device_name` for XLA compilation, using information from
|
||||||
// `registration`.
|
// `registration`.
|
||||||
static void RegisterCompilationDevice(const string& device_name,
|
static void RegisterCompilationDevice(const string& device_name,
|
||||||
@ -116,8 +122,8 @@ class XlaOpRegistry {
|
|||||||
static void RegisterCompilationKernels();
|
static void RegisterCompilationKernels();
|
||||||
|
|
||||||
// Returns KernelDefs for compilation ops registered on
|
// Returns KernelDefs for compilation ops registered on
|
||||||
// 'compilation_device_name'.
|
// 'compilation_device_name'. Does not include kernels registered as
|
||||||
// Does not include kernels registered as CompilationOnly.
|
// CompilationOnly, iff include_compilation_only_kernels=false.
|
||||||
static std::vector<const KernelDef*> DeviceKernels(
|
static std::vector<const KernelDef*> DeviceKernels(
|
||||||
const string& compilation_device_name,
|
const string& compilation_device_name,
|
||||||
bool include_compilation_only_kernels);
|
bool include_compilation_only_kernels);
|
||||||
|
Loading…
Reference in New Issue
Block a user