[TF:XLA] Handle bitcasts between different bitwidths

PiperOrigin-RevId: 251810931
This commit is contained in:
A. Unique TensorFlower 2019-06-06 02:09:32 -07:00 committed by TensorFlower Gardener
parent e3b76130b4
commit 3ed65f64e6
2 changed files with 83 additions and 19 deletions

View File

@ -12,16 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -112,23 +117,85 @@ class BitcastOp : public XlaOpKernel {
if (src_dtype_ == dst_dtype_) { if (src_dtype_ == dst_dtype_) {
output = input; output = input;
} else { ctx->SetOutput(0, output);
return;
}
// Error out if the bitcast has a complex source or destination type and // Error out if the bitcast has a complex source or destination type and
// the bitcast is not trivial. // the bitcast is not trivial.
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
!xla::primitive_util::IsComplexType(src_type_) && !xla::primitive_util::IsComplexType(src_type_) &&
!xla::primitive_util::IsComplexType(dst_type_), !xla::primitive_util::IsComplexType(dst_type_),
errors::Unimplemented("Complex types not supported.")); errors::Unimplemented("Complex types not supported."));
// XLA bitcast requires that the bit-width of the source and destination auto input_bit_width = xla::primitive_util::BitWidth(src_type_);
// matches, and currently only the simple lowering is performed. auto output_bit_width = xla::primitive_util::BitWidth(dst_type_);
auto input_logical_type =
xla::primitive_util::UnsignedIntegralTypeForBitWidth(input_bit_width);
auto output_logical_type =
xla::primitive_util::UnsignedIntegralTypeForBitWidth(output_bit_width);
OP_REQUIRES(ctx, OP_REQUIRES(ctx,
xla::primitive_util::BitWidth(src_type_) ==
xla::primitive_util::BitWidth(dst_type_), output_bit_width % input_bit_width == 0 ||
errors::Unimplemented( input_bit_width % output_bit_width == 0,
"Only bitcasts between equally sized types supported.")); errors::InvalidArgument(
output = xla::BitcastConvertType(input, dst_type_); "Neither bit width is a multiple of the other."));
// Modify the input as needed so we only need to bitcast to create the
// output.
if (input_bit_width > output_bit_width) {
// Casting to a smaller bit width results in a new inner dimension.
auto broadcasted_input_shape = ctx->InputShape(0);
auto reshaped_input_shape = ctx->InputShape(0);
broadcasted_input_shape.AddDim(input_bit_width / output_bit_width);
reshaped_input_shape.AddDim(1);
auto output_bit_width_mask = (1 << output_bit_width) - 1;
auto status_or_input =
BroadcastTo(xla::Reshape(input, reshaped_input_shape.dim_sizes()),
broadcasted_input_shape.dim_sizes());
OP_REQUIRES_OK(ctx, status_or_input.status());
input = xla::BitcastConvertType(status_or_input.ConsumeValueOrDie(),
input_logical_type);
auto xla_input_shape_status = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, xla_input_shape_status.status());
auto xla_input_shape = xla_input_shape_status.ConsumeValueOrDie();
auto iota = xla::Iota(ctx->builder(), xla_input_shape,
xla_input_shape.dimensions_size() - 1);
xla::XlaOp iota_m =
xla::Mul(xla::ScalarLike(input, output_bit_width), iota);
input = xla::And(xla::ShiftRightLogical(input, iota_m),
xla::ScalarLike(input, output_bit_width_mask));
input = xla::ConvertElementType(input, output_logical_type);
} else if (input_bit_width < output_bit_width) {
// Casting to a larger bit width results in removing the innermost
// dimension.
auto input_shape = ctx->InputShape(0);
xla::Shape input_xla_shape =
TensorShapeToXLAShape(dst_type_, input_shape);
OP_REQUIRES(
ctx,
input_shape.dim_size(input_shape.dims() - 1) ==
output_bit_width / input_bit_width,
errors::InvalidArgument(
"Inner dimension of operand should be removed after cast."));
auto zero = XlaHelpers::Zero(ctx->builder(), dst_dtype_);
input = xla::ConvertElementType(input, dst_type_);
// Shift bits and OR them together to reduce the inner dimension.
xla::XlaOp iota_m =
xla::Mul(xla::ScalarLike(input, input_bit_width),
xla::Iota(ctx->builder(), input_xla_shape,
input_xla_shape.dimensions_size() - 1));
input = xla::ShiftLeft(input, iota_m);
input = xla::Reduce(input, zero,
CreateScalarOrComputation(dst_type_, ctx->builder()),
{input_xla_shape.dimensions_size() - 1});
} }
output = xla::BitcastConvertType(input, dst_type_);
ctx->SetOutput(0, output); ctx->SetOutput(0, output);
} }

View File

@ -38,14 +38,12 @@ class BitcastTest(test.TestCase):
self.assertEqual(tf_ans.get_shape(), shape) self.assertEqual(tf_ans.get_shape(), shape)
self.assertEqual(tf_ans.dtype, datatype) self.assertEqual(tf_ans.dtype, datatype)
@test_util.disable_xla("Different bitwidths not supported")
def testSmaller(self): def testSmaller(self):
x = np.random.rand(3, 2) x = np.random.rand(3, 2)
datatype = dtypes.int8 datatype = dtypes.int8
shape = [3, 2, 8] shape = [3, 2, 8]
self._testBitcast(x, datatype, shape) self._testBitcast(x, datatype, shape)
@test_util.disable_xla("Different bitwidths not supported")
def testLarger(self): def testLarger(self):
x = np.arange(16, dtype=np.int8).reshape([4, 4]) x = np.arange(16, dtype=np.int8).reshape([4, 4])
datatype = dtypes.int32 datatype = dtypes.int32
@ -69,7 +67,6 @@ class BitcastTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "Cannot bitcast due to shape"): with self.assertRaisesRegexp(ValueError, "Cannot bitcast due to shape"):
array_ops.bitcast(x, datatype, None) array_ops.bitcast(x, datatype, None)
@test_util.disable_xla("Different bitwidths not supported")
def testEmpty(self): def testEmpty(self):
x = np.ones([], np.int32) x = np.ones([], np.int32)
datatype = dtypes.int8 datatype = dtypes.int8