diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index e94aef641e3..6449a399573 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -100,9 +100,9 @@ typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL // Concatenates 'inputs' into a single tensor along the zeroth dimension. -// Requires that all elements of 'inputs' have element type T. Writes to the -// op's output at position 'output_index', using 'context' for the allocation to -// ensure proper device placement. +// Requires that all elements of 'inputs' have element type T. Writes to +// 'output' using 'context' for the allocation to ensure proper device +// placement. template <typename T> Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs, Tensor* output) { @@ -157,6 +157,25 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs, return Status::OK(); } +// Same as 'Concat' above, but handles Tensor dtype deduction automatically. +Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs, + Tensor* output) { + const DataType type = inputs[0].dtype(); + Status concat_status; + switch (type) { +#define CASE(type) \ + case DataTypeToEnum<type>::value: \ + concat_status = Concat<type>(context, inputs, output); \ + break; + TF_CALL_ALL_TYPES(CASE); +#undef CASE + default: + concat_status = errors::InvalidArgument("Unsupported data type: ", type); + break; + } + return concat_status; +} + // The Split*() functions split 'input' with element type T into 'sizes.size()' // tensors along the zeroth dimension, with the ith split having zeroth- // dimension size 'sizes[i]'. They allocate the output tensors using 'context', @@ -268,6 +287,25 @@ Status Split(OpKernelContext* context, const Tensor& input, return SplitCPU<T>(context, input, sizes, outputs); } +// Same as 'Split' above, but handles Tensor dtype automatically. +Status Split(OpKernelContext* context, const Tensor& input, + const gtl::ArraySlice<int64> sizes, std::vector<Tensor>* outputs) { + const DataType type = input.dtype(); + Status split_status; + switch (type) { +#define CASE(type) \ + case DataTypeToEnum<type>::value: \ + split_status = Split<type>(context, input, sizes, outputs); \ + break; + TF_CALL_ALL_TYPES(CASE); +#undef CASE + default: + split_status = errors::InvalidArgument("Unsupported data type: ", type); + break; + } + return split_status; +} + // A class encapsulating the state and logic for batching tensors. class BatchResource : public ResourceBase { public: @@ -449,22 +487,9 @@ class BatchResource : public ResourceBase { } } - const DataType type = to_concatenate[0].dtype(); - Status concat_status; Tensor concatenated_tensor; - switch (type) { -#define CASE(type) \ - case DataTypeToEnum<type>::value: \ - concat_status = \ - Concat<type>(context, to_concatenate, &concatenated_tensor); \ - break; - TF_CALL_ALL_TYPES(CASE); -#undef CASE - default: - concat_status = - errors::InvalidArgument("Unsupported data type: ", type); - break; - } + Status concat_status = + Concat(context, to_concatenate, &concatenated_tensor); TF_RETURN_IF_ERROR(concat_status); concatenated_tensors->push_back(concatenated_tensor); } @@ -1001,17 +1026,7 @@ class UnbatchResource : public ResourceBase { batch_keys.push_back(batch_indices(i, 0)); } - const DataType type = data_t.dtype(); - switch (type) { -#define CASE(type) \ - case DataTypeToEnum<type>::value: \ - TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \ - break; - TF_CALL_ALL_TYPES(CASE); -#undef CASE - default: - return errors::InvalidArgument("Unsupported data type: ", type); - } + TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs)); } // Critical section.