diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index ca2152d6c10..3e463cac0e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -12,16 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/lib/util.h" - +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { namespace { @@ -112,23 +117,85 @@ class BitcastOp : public XlaOpKernel { if (src_dtype_ == dst_dtype_) { output = input; - } else { - // Error out if the bitcast has a complex source or destination type and - // the bitcast is not trivial. - OP_REQUIRES(ctx, - !xla::primitive_util::IsComplexType(src_type_) && - !xla::primitive_util::IsComplexType(dst_type_), - errors::Unimplemented("Complex types not supported.")); - // XLA bitcast requires that the bit-width of the source and destination - // matches, and currently only the simple lowering is performed. - OP_REQUIRES(ctx, - xla::primitive_util::BitWidth(src_type_) == - xla::primitive_util::BitWidth(dst_type_), - errors::Unimplemented( - "Only bitcasts between equally sized types supported.")); - output = xla::BitcastConvertType(input, dst_type_); + ctx->SetOutput(0, output); + return; + } + // Error out if the bitcast has a complex source or destination type and + // the bitcast is not trivial. + OP_REQUIRES(ctx, + !xla::primitive_util::IsComplexType(src_type_) && + !xla::primitive_util::IsComplexType(dst_type_), + errors::Unimplemented("Complex types not supported.")); + auto input_bit_width = xla::primitive_util::BitWidth(src_type_); + auto output_bit_width = xla::primitive_util::BitWidth(dst_type_); + + auto input_logical_type = + xla::primitive_util::UnsignedIntegralTypeForBitWidth(input_bit_width); + auto output_logical_type = + xla::primitive_util::UnsignedIntegralTypeForBitWidth(output_bit_width); + + OP_REQUIRES(ctx, + + output_bit_width % input_bit_width == 0 || + input_bit_width % output_bit_width == 0, + errors::InvalidArgument( + "Neither bit width is a multiple of the other.")); + + // Modify the input as needed so we only need to bitcast to create the + // output. + if (input_bit_width > output_bit_width) { + // Casting to a smaller bit width results in a new inner dimension. + auto broadcasted_input_shape = ctx->InputShape(0); + auto reshaped_input_shape = ctx->InputShape(0); + broadcasted_input_shape.AddDim(input_bit_width / output_bit_width); + reshaped_input_shape.AddDim(1); + auto output_bit_width_mask = (1 << output_bit_width) - 1; + + auto status_or_input = + BroadcastTo(xla::Reshape(input, reshaped_input_shape.dim_sizes()), + broadcasted_input_shape.dim_sizes()); + OP_REQUIRES_OK(ctx, status_or_input.status()); + input = xla::BitcastConvertType(status_or_input.ConsumeValueOrDie(), + input_logical_type); + auto xla_input_shape_status = ctx->builder()->GetShape(input); + OP_REQUIRES_OK(ctx, xla_input_shape_status.status()); + auto xla_input_shape = xla_input_shape_status.ConsumeValueOrDie(); + + auto iota = xla::Iota(ctx->builder(), xla_input_shape, + xla_input_shape.dimensions_size() - 1); + xla::XlaOp iota_m = + xla::Mul(xla::ScalarLike(input, output_bit_width), iota); + input = xla::And(xla::ShiftRightLogical(input, iota_m), + xla::ScalarLike(input, output_bit_width_mask)); + input = xla::ConvertElementType(input, output_logical_type); + } else if (input_bit_width < output_bit_width) { + // Casting to a larger bit width results in removing the innermost + // dimension. + auto input_shape = ctx->InputShape(0); + xla::Shape input_xla_shape = + TensorShapeToXLAShape(dst_type_, input_shape); + OP_REQUIRES( + ctx, + input_shape.dim_size(input_shape.dims() - 1) == + output_bit_width / input_bit_width, + errors::InvalidArgument( + "Inner dimension of operand should be removed after cast.")); + + auto zero = XlaHelpers::Zero(ctx->builder(), dst_dtype_); + input = xla::ConvertElementType(input, dst_type_); + + // Shift bits and OR them together to reduce the inner dimension. + xla::XlaOp iota_m = + xla::Mul(xla::ScalarLike(input, input_bit_width), + xla::Iota(ctx->builder(), input_xla_shape, + input_xla_shape.dimensions_size() - 1)); + input = xla::ShiftLeft(input, iota_m); + input = xla::Reduce(input, zero, + CreateScalarOrComputation(dst_type_, ctx->builder()), + {input_xla_shape.dimensions_size() - 1}); } + output = xla::BitcastConvertType(input, dst_type_); ctx->SetOutput(0, output); } diff --git a/tensorflow/python/kernel_tests/bitcast_op_test.py b/tensorflow/python/kernel_tests/bitcast_op_test.py index 91709f81fa9..b4f9a21a899 100644 --- a/tensorflow/python/kernel_tests/bitcast_op_test.py +++ b/tensorflow/python/kernel_tests/bitcast_op_test.py @@ -38,14 +38,12 @@ class BitcastTest(test.TestCase): self.assertEqual(tf_ans.get_shape(), shape) self.assertEqual(tf_ans.dtype, datatype) - @test_util.disable_xla("Different bitwidths not supported") def testSmaller(self): x = np.random.rand(3, 2) datatype = dtypes.int8 shape = [3, 2, 8] self._testBitcast(x, datatype, shape) - @test_util.disable_xla("Different bitwidths not supported") def testLarger(self): x = np.arange(16, dtype=np.int8).reshape([4, 4]) datatype = dtypes.int32 @@ -69,7 +67,6 @@ class BitcastTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "Cannot bitcast due to shape"): array_ops.bitcast(x, datatype, None) - @test_util.disable_xla("Different bitwidths not supported") def testEmpty(self): x = np.ones([], np.int32) datatype = dtypes.int8