Add the tf2xla_supported_ops tool, which dumps ops supported by tf2xla.
Also fix a TODO in XlaOpRegistry to filter by the types allowed by the OpDef. Also see #14798 PiperOrigin-RevId: 177986664
This commit is contained in:
parent
e72ecbdb7a
commit
4b0a236848
@ -1,6 +1,6 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
@ -25,6 +25,30 @@ package(
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
|
||||
|
||||
cc_library(
|
||||
name = "tf2xla_supported_ops_lib",
|
||||
srcs = ["tf2xla_supported_ops.cc"],
|
||||
hdrs = ["tf2xla_supported_ops.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf2xla_supported_ops",
|
||||
srcs = ["tf2xla_supported_ops_main.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":tf2xla_supported_ops_lib"],
|
||||
)
|
||||
|
||||
xla_proto_library(
|
||||
name = "tf2xla_proto",
|
||||
srcs = ["tf2xla.proto"],
|
||||
|
242
tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
Normal file
242
tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md
Normal file
@ -0,0 +1,242 @@
|
||||
**Supported operators for device: XLA_CPU_JIT**
|
||||
|
||||
Operator | Type Constraint
|
||||
------------------------------------- | ---------------
|
||||
`Abs` | `T={double,float,int32,int64}`
|
||||
`Acosh` | `T={complex64,double,float}`
|
||||
`Add` | `T={complex64,double,float,int32,int64}`
|
||||
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`All` | `Tidx={int32,int64}`
|
||||
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Any` | `Tidx={int32,int64}`
|
||||
`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ArgMax` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={float}`
|
||||
`ArgMin` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Asinh` | `T={complex64,double,float}`
|
||||
`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Atan2` | `T={double,float}`
|
||||
`Atanh` | `T={complex64,double,float}`
|
||||
`AvgPool` | `T={double,float}`
|
||||
`AvgPool3D` | `T={double,float}`
|
||||
`AvgPool3DGrad` | `T={double,float}`
|
||||
`AvgPoolGrad` | `T={double,float}`
|
||||
`BatchMatMul` | `T={complex64,double,float,int32}`
|
||||
`BatchToSpace` | `Tidx={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BatchToSpaceND` | `Tcrops={int32,int64}`<br>`Tblock_shape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BitwiseAnd` | `T={int32,int64,uint32,uint64}`
|
||||
`BitwiseOr` | `T={int32,int64,uint32,uint64}`
|
||||
`BroadcastArgs` | `T={int32,int64}`
|
||||
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Ceil` | `T={double,float}`
|
||||
`Cholesky` | `T={complex64,double,float}`
|
||||
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ConcatOffset` |
|
||||
`ConcatV2` | `Tidx={int32}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Conj` | `T={complex64}`
|
||||
`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ControlTrigger` |
|
||||
`Conv2D` | `T={float}`
|
||||
`Conv2DBackpropFilter` | `T={float}`
|
||||
`Conv2DBackpropInput` | `T={float}`
|
||||
`Conv3D` | `T={double,float}`
|
||||
`Conv3DBackpropFilterV2` | `T={double,float}`
|
||||
`Conv3DBackpropInputV2` | `T={double,float}`
|
||||
`Cos` | `T={complex64,double,float}`
|
||||
`Cosh` | `T={complex64,double,float}`
|
||||
`Cross` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Cumprod` | `Tidx={int32,int64}`<br>`T={float}`
|
||||
`Cumsum` | `Tidx={int32,int64}`<br>`T={float}`
|
||||
`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`DepthwiseConv2dNative` | `T={double,float}`
|
||||
`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}`
|
||||
`DepthwiseConv2dNativeBackpropInput` | `T={double,float}`
|
||||
`Diag` | `T={complex64,double,float,int32,int64}`
|
||||
`DiagPart` | `T={complex64,double,float,int32,int64}`
|
||||
`Div` | `T={complex64,double,float,int32,int64}`
|
||||
`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Elu` | `T={double,float}`
|
||||
`EluGrad` | `T={double,float}`
|
||||
`Equal` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`Exp` | `T={complex64,double,float}`
|
||||
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Expm1` | `T={complex64,double,float}`
|
||||
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Floor` | `T={double,float}`
|
||||
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`FloorMod` | `T={double,float,int32,int64}`
|
||||
`FusedBatchNorm` | `T={float}`
|
||||
`FusedBatchNormGrad` | `T={float}`
|
||||
`FusedBatchNormGradV2` | `U={float}`<br>`T={float}`
|
||||
`FusedBatchNormV2` | `U={float}`<br>`T={float}`
|
||||
`Gather` | `Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Inv` | `T={complex64,double,float,int32,int64}`
|
||||
`Invert` | `T={int32,int64,uint32,uint64}`
|
||||
`InvertPermutation` | `T={int32}`
|
||||
`IsFinite` | `T={double,float}`
|
||||
`IsInf` | `T={double,float}`
|
||||
`IsNan` | `T={double,float}`
|
||||
`L2Loss` | `T={double,float}`
|
||||
`LRN` | `T={float}`
|
||||
`LRNGrad` | `T={float}`
|
||||
`LeftShift` | `T={int32,int64,uint32,uint64}`
|
||||
`Less` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`LessEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`LinSpace` | `Tidx={int32,int64}`<br>`T={double,float}`
|
||||
`Log` | `T={complex64,double,float}`
|
||||
`Log1p` | `T={complex64,double,float}`
|
||||
`LogSoftmax` | `T={double,float}`
|
||||
`LogicalAnd` |
|
||||
`LogicalNot` |
|
||||
`LogicalOr` |
|
||||
`MatMul` | `T={complex64,double,float}`
|
||||
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPool` | `T={double,float,int32,int64}`
|
||||
`MaxPool3D` | `T={float}`
|
||||
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Maximum` | `T={double,float,int32,int64}`
|
||||
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Minimum` | `T={double,float,int32,int64}`
|
||||
`MirrorPad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Mod` | `T={double,float,int32,int64}`
|
||||
`Mul` | `T={complex64,double,float,int32,int64}`
|
||||
`Multinomial` | `output_dtype={int32,int64}`<br>`T={double,float,int32,int64,uint32,uint64}`
|
||||
`Neg` | `T={complex64,double,float,int32,int64}`
|
||||
`NoOp` |
|
||||
`NotEqual` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`OneHot` | `TI={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`OnesLike` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Pad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`PadV2` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Pow` | `T={complex64,double,float,int32,int64}`
|
||||
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||
`RandomStandardNormal` | `dtype={float}`
|
||||
`RandomUniform` | `T={int32,int64}`<br>`dtype={double,float}`
|
||||
`RandomUniformInt` | `T={int32,int64}`<br>`Tout={int32,int64}`
|
||||
`Range` | `Tidx={double,float,int32,int64}`
|
||||
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Real` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`RealDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`Reciprocal` | `T={complex64,double,float,int32,int64}`
|
||||
`ReciprocalGrad` | `T={complex64,double,float}`
|
||||
`Relu` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Relu6` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResourceApplyAdagrad` | `T={double,float}`
|
||||
`ResourceApplyAdam` | `T={double,float}`
|
||||
`ResourceApplyFtrl` | `T={double,float}`
|
||||
`ResourceApplyFtrlV2` | `T={double,float}`
|
||||
`ResourceApplyGradientDescent` | `T={double,float}`
|
||||
`ResourceApplyMomentum` | `T={double,float}`
|
||||
`ResourceApplyRMSProp` | `T={double,float}`
|
||||
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||
`Rint` | `T={double,float}`
|
||||
`Round` | `T={complex64,double,float,int32,int64}`
|
||||
`Rsqrt` | `T={complex64,double,float}`
|
||||
`RsqrtGrad` | `T={complex64,double,float}`
|
||||
`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Selu` | `T={double,float}`
|
||||
`SeluGrad` | `T={double,float}`
|
||||
`Shape` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ShapeN` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Sigmoid` | `T={complex64,double,float}`
|
||||
`SigmoidGrad` | `T={complex64,double,float}`
|
||||
`Sign` | `T={complex64,double,float,int32,int64}`
|
||||
`Sin` | `T={complex64,double,float}`
|
||||
`Sinh` | `T={complex64,double,float}`
|
||||
`Size` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Slice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Softmax` | `T={double,float}`
|
||||
`SoftmaxCrossEntropyWithLogits` | `T={double,float}`
|
||||
`Softplus` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Softsign` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`SpaceToBatch` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SpaceToBatchND` | `Tblock_shape={int32,int64}`<br>`Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SparseMatMul` | `Tb={float}`<br>`Ta={float}`
|
||||
`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`<br>`T={double,float}`
|
||||
`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SplitV` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Sqrt` | `T={complex64,double,float}`
|
||||
`SqrtGrad` | `T={complex64,double,float}`
|
||||
`Square` | `T={complex64,double,float,int32,int64}`
|
||||
`SquaredDifference` | `T={complex64,double,float,int32,int64}`
|
||||
`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StackCloseV2` |
|
||||
`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StatelessRandomNormal` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||
`StatelessRandomUniform` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||
`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StridedSlice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StridedSliceGrad` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Sub` | `T={complex64,double,float,int32,int64}`
|
||||
`Sum` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Tan` | `T={complex64,double,float,int32,int64}`
|
||||
`Tanh` | `T={complex64,double,float}`
|
||||
`TanhGrad` | `T={complex64,double,float}`
|
||||
`TensorArrayCloseV3` |
|
||||
`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayGradV3` |
|
||||
`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArraySizeV3` |
|
||||
`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Tile` | `Tmultiples={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Transpose` | `Tperm={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TruncateDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`TruncateMod` | `T={double,float,int32,int64}`
|
||||
`TruncatedNormal` | `T={int32,int64}`<br>`dtype={double,float}`
|
||||
`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`<br>`Tindices={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`VarIsInitializedOp` |
|
||||
`VariableShape` | `out_type={int32,int64}`
|
||||
`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||
`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||
`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
|
||||
To regenerate this table, run:
|
||||
|
||||
```shell
|
||||
bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_CPU_JIT
|
||||
```
|
238
tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
Normal file
238
tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md
Normal file
@ -0,0 +1,238 @@
|
||||
**Supported operators for device: XLA_GPU_JIT**
|
||||
|
||||
Operator | Type Constraint
|
||||
------------------------------------- | ---------------
|
||||
`Abs` | `T={double,float,int32,int64}`
|
||||
`Acosh` | `T={complex64,double,float}`
|
||||
`Add` | `T={complex64,double,float,int32,int64}`
|
||||
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`All` | `Tidx={int32,int64}`
|
||||
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Any` | `Tidx={int32,int64}`
|
||||
`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ArgMax` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ArgMin` | `Tidx={int32,int64}`<br>`output_type={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Asinh` | `T={complex64,double,float}`
|
||||
`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Atan2` | `T={double,float}`
|
||||
`Atanh` | `T={complex64,double,float}`
|
||||
`AvgPool` | `T={double,float}`
|
||||
`AvgPool3D` | `T={double,float}`
|
||||
`AvgPool3DGrad` | `T={double,float}`
|
||||
`AvgPoolGrad` | `T={double,float}`
|
||||
`BatchMatMul` | `T={complex64,double,float,int32}`
|
||||
`BatchToSpace` | `Tidx={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BatchToSpaceND` | `Tcrops={int32,int64}`<br>`Tblock_shape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`BitwiseAnd` | `T={int32,int64,uint32,uint64}`
|
||||
`BitwiseOr` | `T={int32,int64,uint32,uint64}`
|
||||
`BroadcastArgs` | `T={int32,int64}`
|
||||
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Ceil` | `T={double,float}`
|
||||
`Cholesky` | `T={complex64,double,float}`
|
||||
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ConcatOffset` |
|
||||
`ConcatV2` | `Tidx={int32}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Conj` | `T={complex64}`
|
||||
`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ControlTrigger` |
|
||||
`Conv2D` | `T={float}`
|
||||
`Conv2DBackpropFilter` | `T={float}`
|
||||
`Conv2DBackpropInput` | `T={float}`
|
||||
`Conv3D` | `T={double,float}`
|
||||
`Conv3DBackpropFilterV2` | `T={double,float}`
|
||||
`Conv3DBackpropInputV2` | `T={double,float}`
|
||||
`Cos` | `T={complex64,double,float}`
|
||||
`Cosh` | `T={complex64,double,float}`
|
||||
`Cross` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Cumprod` | `Tidx={int32,int64}`<br>`T={float}`
|
||||
`Cumsum` | `Tidx={int32,int64}`<br>`T={float}`
|
||||
`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`DepthwiseConv2dNative` | `T={double,float}`
|
||||
`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}`
|
||||
`DepthwiseConv2dNativeBackpropInput` | `T={double,float}`
|
||||
`Diag` | `T={complex64,double,float,int32,int64}`
|
||||
`DiagPart` | `T={complex64,double,float,int32,int64}`
|
||||
`Div` | `T={complex64,double,float,int32,int64}`
|
||||
`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Elu` | `T={double,float}`
|
||||
`EluGrad` | `T={double,float}`
|
||||
`Equal` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`Exp` | `T={complex64,double,float}`
|
||||
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Expm1` | `T={complex64,double,float}`
|
||||
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Floor` | `T={double,float}`
|
||||
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`FloorMod` | `T={double,float,int32,int64}`
|
||||
`FusedBatchNorm` | `T={float}`
|
||||
`FusedBatchNormGrad` | `T={float}`
|
||||
`FusedBatchNormGradV2` | `U={float}`<br>`T={float}`
|
||||
`FusedBatchNormV2` | `U={float}`<br>`T={float}`
|
||||
`Gather` | `Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Inv` | `T={complex64,double,float,int32,int64}`
|
||||
`Invert` | `T={int32,int64,uint32,uint64}`
|
||||
`InvertPermutation` | `T={int32}`
|
||||
`IsFinite` | `T={double,float}`
|
||||
`IsInf` | `T={double,float}`
|
||||
`IsNan` | `T={double,float}`
|
||||
`L2Loss` | `T={double,float}`
|
||||
`LRN` | `T={float}`
|
||||
`LRNGrad` | `T={float}`
|
||||
`LeftShift` | `T={int32,int64,uint32,uint64}`
|
||||
`Less` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`LessEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`LinSpace` | `Tidx={int32,int64}`<br>`T={double,float}`
|
||||
`Log` | `T={complex64,double,float}`
|
||||
`Log1p` | `T={complex64,double,float}`
|
||||
`LogSoftmax` | `T={double,float}`
|
||||
`LogicalAnd` |
|
||||
`LogicalNot` |
|
||||
`LogicalOr` |
|
||||
`MatMul` | `T={complex64,double,float}`
|
||||
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPool` | `T={double,float,int32,int64}`
|
||||
`MaxPool3D` | `T={float}`
|
||||
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Maximum` | `T={double,float,int32,int64}`
|
||||
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Minimum` | `T={double,float,int32,int64}`
|
||||
`MirrorPad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Mod` | `T={double,float,int32,int64}`
|
||||
`Mul` | `T={complex64,double,float,int32,int64}`
|
||||
`Multinomial` | `output_dtype={int32,int64}`<br>`T={double,float,int32,int64,uint32,uint64}`
|
||||
`Neg` | `T={complex64,double,float,int32,int64}`
|
||||
`NoOp` |
|
||||
`NotEqual` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`OneHot` | `TI={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`OnesLike` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Pad` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`PadV2` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Pow` | `T={complex64,double,float,int32,int64}`
|
||||
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||
`Range` | `Tidx={double,float,int32,int64}`
|
||||
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Real` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`RealDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`Reciprocal` | `T={complex64,double,float,int32,int64}`
|
||||
`ReciprocalGrad` | `T={complex64,double,float}`
|
||||
`Relu` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Relu6` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResourceApplyAdagrad` | `T={double,float}`
|
||||
`ResourceApplyAdam` | `T={double,float}`
|
||||
`ResourceApplyFtrl` | `T={double,float}`
|
||||
`ResourceApplyFtrlV2` | `T={double,float}`
|
||||
`ResourceApplyGradientDescent` | `T={double,float}`
|
||||
`ResourceApplyMomentum` | `T={double,float}`
|
||||
`ResourceApplyRMSProp` | `T={double,float}`
|
||||
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||
`Rint` | `T={double,float}`
|
||||
`Round` | `T={complex64,double,float,int32,int64}`
|
||||
`Rsqrt` | `T={complex64,double,float}`
|
||||
`RsqrtGrad` | `T={complex64,double,float}`
|
||||
`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Selu` | `T={double,float}`
|
||||
`SeluGrad` | `T={double,float}`
|
||||
`Shape` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ShapeN` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Sigmoid` | `T={complex64,double,float}`
|
||||
`SigmoidGrad` | `T={complex64,double,float}`
|
||||
`Sign` | `T={complex64,double,float,int32,int64}`
|
||||
`Sin` | `T={complex64,double,float}`
|
||||
`Sinh` | `T={complex64,double,float}`
|
||||
`Size` | `out_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Slice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Softmax` | `T={double,float}`
|
||||
`SoftmaxCrossEntropyWithLogits` | `T={double,float}`
|
||||
`Softplus` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Softsign` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`SpaceToBatch` | `Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SpaceToBatchND` | `Tblock_shape={int32,int64}`<br>`Tpaddings={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SparseMatMul` | `Tb={float}`<br>`Ta={float}`
|
||||
`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`<br>`T={double,float}`
|
||||
`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SplitV` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Sqrt` | `T={complex64,double,float}`
|
||||
`SqrtGrad` | `T={complex64,double,float}`
|
||||
`Square` | `T={complex64,double,float,int32,int64}`
|
||||
`SquaredDifference` | `T={complex64,double,float,int32,int64}`
|
||||
`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StackCloseV2` |
|
||||
`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StatelessRandomNormal` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||
`StatelessRandomUniform` | `Tseed={int32}`<br>`T={int32,int64}`<br>`dtype={float}`
|
||||
`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StridedSlice` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`StridedSliceGrad` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Sub` | `T={complex64,double,float,int32,int64}`
|
||||
`Sum` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Tan` | `T={complex64,double,float,int32,int64}`
|
||||
`Tanh` | `T={complex64,double,float}`
|
||||
`TanhGrad` | `T={complex64,double,float}`
|
||||
`TensorArrayCloseV3` |
|
||||
`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayGradV3` |
|
||||
`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArraySizeV3` |
|
||||
`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Tile` | `Tmultiples={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Transpose` | `Tperm={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`TruncateDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`TruncateMod` | `T={double,float,int32,int64}`
|
||||
`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`<br>`Tindices={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`VarIsInitializedOp` |
|
||||
`VariableShape` | `out_type={int32,int64}`
|
||||
`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||
`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}`
|
||||
`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
|
||||
To regenerate this table, run:
|
||||
|
||||
```shell
|
||||
bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_GPU_JIT
|
||||
```
|
97
tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
Normal file
97
tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc
Normal file
@ -0,0 +1,97 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tf2xla {
|
||||
namespace {
|
||||
|
||||
void PrintSupportedOps(const string& device, const string& regen_run) {
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
|
||||
std::vector<const KernelDef*> kdefs =
|
||||
XlaOpRegistry::DeviceKernels(device,
|
||||
/*include_compilation_only_kernels=*/true);
|
||||
std::sort(
|
||||
kdefs.begin(), kdefs.end(),
|
||||
[](const KernelDef* a, const KernelDef* b) { return a->op() < b->op(); });
|
||||
|
||||
std::cout << "**Supported operators for device: " << device << "**\n\n"
|
||||
<< "Operator | Type Constraint\n"
|
||||
<< "-------- | ---------------" << std::endl;
|
||||
for (const KernelDef* kdef : kdefs) {
|
||||
std::vector<string> constraints;
|
||||
for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) {
|
||||
std::vector<string> types;
|
||||
for (int type : constraint.allowed_values().list().type()) {
|
||||
types.push_back(DataTypeString(static_cast<DataType>(type)));
|
||||
}
|
||||
std::sort(types.begin(), types.end());
|
||||
constraints.push_back("`" + constraint.name() + "={" +
|
||||
str_util::Join(types, ",") + "}`");
|
||||
}
|
||||
std::cout << "`" << kdef->op() << "` | "
|
||||
<< str_util::Join(constraints, "<br>") << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "\nTo regenerate this table, run:\n\n```shell\n"
|
||||
<< regen_run << " --device=" << device << "\n```" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SupportedOpsMain(int argc, char** argv, const char* regen_run) {
|
||||
std::vector<string> device_names = XlaOpRegistry::BackendNames();
|
||||
std::sort(device_names.begin(), device_names.end());
|
||||
|
||||
// Set up and parse flags.
|
||||
string device;
|
||||
std::vector<Flag> flag_list = {
|
||||
{"device", &device,
|
||||
"Name of the compilation device for which to print supported ops, "
|
||||
"one of: " +
|
||||
str_util::Join(device_names, ",")},
|
||||
};
|
||||
string usage = Flags::Usage(argv[0], flag_list);
|
||||
bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list);
|
||||
QCHECK(parsed_flags_ok) << "\n" << usage;
|
||||
QCHECK(XlaOpRegistry::IsBackendRegistered(device))
|
||||
<< "\nUnknown device: " << device << "\n"
|
||||
<< usage;
|
||||
|
||||
// Run the program.
|
||||
port::InitMain(usage.c_str(), &argc, &argv);
|
||||
QCHECK(argc == 1) << "\nERROR: This command does not take any arguments "
|
||||
"other than flags\n\n"
|
||||
<< usage;
|
||||
PrintSupportedOps(device, regen_run);
|
||||
}
|
||||
|
||||
} // namespace tf2xla
|
||||
} // namespace tensorflow
|
33
tensorflow/compiler/tf2xla/tf2xla_supported_ops.h
Normal file
33
tensorflow/compiler/tf2xla/tf2xla_supported_ops.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tf2xla {
|
||||
|
||||
// The implementation of a main function for a binary that prints a table of
|
||||
// supported tf2xla operators for a given device, along with their type
|
||||
// constraints, to stdout.
|
||||
//
|
||||
// Pass the argc and argv from main, unmodified. Use regen_run to specify the
|
||||
// command used to regenerate the table.
|
||||
void SupportedOpsMain(int argc, char** argv, const char* regen_run);
|
||||
|
||||
} // namespace tf2xla
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_
|
22
tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc
Normal file
22
tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc
Normal file
@ -0,0 +1,22 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
const char* regen_run =
|
||||
"bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops";
|
||||
tensorflow::tf2xla::SupportedOpsMain(argc, argv, regen_run);
|
||||
}
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
@ -187,22 +188,39 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
|
||||
// Constrain each type attribute to the intersection of:
|
||||
// a) the types supported by the backend, and
|
||||
// b) the attribute's type constraints.
|
||||
// TODO(phawkins): it may be necessary to also take the intersection with
|
||||
// the set of types supported by the OpDef.
|
||||
// b) the types allowed by the OpDef, and
|
||||
// c) the type constraints.
|
||||
for (const string& type_attr : type_attrs) {
|
||||
KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
|
||||
attr_constraint->set_name(type_attr);
|
||||
auto* allowed_values =
|
||||
attr_constraint->mutable_allowed_values()->mutable_list();
|
||||
|
||||
auto it = op_registration->type_constraints.find(type_attr);
|
||||
const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
|
||||
const auto* op_def_allowed_types =
|
||||
op_def_attr.has_allowed_values()
|
||||
? &op_def_attr.allowed_values().list().type()
|
||||
: nullptr;
|
||||
auto constraint_it = op_registration->type_constraints.find(type_attr);
|
||||
const std::set<DataType>* type_constraints =
|
||||
constraint_it != op_registration->type_constraints.end()
|
||||
? &constraint_it->second
|
||||
: nullptr;
|
||||
for (DataType dtype : backend.second.supported_types) {
|
||||
if (it == op_registration->type_constraints.end() ||
|
||||
(it != op_registration->type_constraints.end() &&
|
||||
it->second.find(dtype) != it->second.end())) {
|
||||
allowed_values->add_type(dtype);
|
||||
// Filter out types that aren't allowed by the OpDef.
|
||||
if (op_def_allowed_types != nullptr &&
|
||||
std::find(op_def_allowed_types->begin(),
|
||||
op_def_allowed_types->end(),
|
||||
dtype) == op_def_allowed_types->end()) {
|
||||
continue;
|
||||
}
|
||||
// Filter out types based on the type constraints.
|
||||
if (type_constraints != nullptr &&
|
||||
type_constraints->find(dtype) == type_constraints->end()) {
|
||||
continue;
|
||||
}
|
||||
// Passed all the filters, this type is allowed.
|
||||
allowed_values->add_type(dtype);
|
||||
}
|
||||
if (op_registration->allow_resource_types) {
|
||||
allowed_values->add_type(DT_RESOURCE);
|
||||
@ -245,6 +263,22 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
|
||||
return kernels;
|
||||
}
|
||||
|
||||
std::vector<string> XlaOpRegistry::BackendNames() {
|
||||
std::vector<string> names;
|
||||
XlaOpRegistry& registry = Instance();
|
||||
mutex_lock lock(registry.mutex_);
|
||||
for (const auto& backend_pair : registry.backends_) {
|
||||
names.push_back(backend_pair.first);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
bool XlaOpRegistry::IsBackendRegistered(const string& name) {
|
||||
XlaOpRegistry& registry = Instance();
|
||||
mutex_lock lock(registry.mutex_);
|
||||
return registry.backends_.find(name) != registry.backends_.end();
|
||||
}
|
||||
|
||||
XlaOpRegistry& XlaOpRegistry::Instance() {
|
||||
static XlaOpRegistry* r = new XlaOpRegistry;
|
||||
return *r;
|
||||
|
@ -97,6 +97,12 @@ class XlaOpRegistry {
|
||||
gtl::ArraySlice<DataType> supported_types,
|
||||
BackendOpFilter op_filter);
|
||||
|
||||
// Returns the names of the registered backends.
|
||||
static std::vector<string> BackendNames();
|
||||
|
||||
// Returns true iff a backend with the given name is registered.
|
||||
static bool IsBackendRegistered(const string& name);
|
||||
|
||||
// Registers `device_name` for XLA compilation, using information from
|
||||
// `registration`.
|
||||
static void RegisterCompilationDevice(const string& device_name,
|
||||
@ -116,8 +122,8 @@ class XlaOpRegistry {
|
||||
static void RegisterCompilationKernels();
|
||||
|
||||
// Returns KernelDefs for compilation ops registered on
|
||||
// 'compilation_device_name'.
|
||||
// Does not include kernels registered as CompilationOnly.
|
||||
// 'compilation_device_name'. Does not include kernels registered as
|
||||
// CompilationOnly, iff include_compilation_only_kernels=false.
|
||||
static std::vector<const KernelDef*> DeviceKernels(
|
||||
const string& compilation_device_name,
|
||||
bool include_compilation_only_kernels);
|
||||
|
Loading…
Reference in New Issue
Block a user