A refactor to add helper functions {Split, Concat}, which do Tensor dtype deduction automatically.

PiperOrigin-RevId: 316703807
Change-Id: I234bea9fce3cf3b2cb352be12246ee9f4e8c405a
This commit is contained in:
Mingming Liu 2020-06-16 10:09:39 -07:00 committed by TensorFlower Gardener
parent 9ee6864c2c
commit f6b8e93a5f

View File

@ -100,9 +100,9 @@ typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
// Concatenates 'inputs' into a single tensor along the zeroth dimension.
// Requires that all elements of 'inputs' have element type T. Writes to the
// op's output at position 'output_index', using 'context' for the allocation to
// ensure proper device placement.
// Requires that all elements of 'inputs' have element type T. Writes to
// 'output' using 'context' for the allocation to ensure proper device
// placement.
template <typename T>
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
Tensor* output) {
@ -157,6 +157,25 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
return Status::OK();
}
// Same as 'Concat' above, but handles Tensor dtype deduction automatically.
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
Tensor* output) {
const DataType type = inputs[0].dtype();
Status concat_status;
switch (type) {
#define CASE(type) \
case DataTypeToEnum<type>::value: \
concat_status = Concat<type>(context, inputs, output); \
break;
TF_CALL_ALL_TYPES(CASE);
#undef CASE
default:
concat_status = errors::InvalidArgument("Unsupported data type: ", type);
break;
}
return concat_status;
}
// The Split*() functions split 'input' with element type T into 'sizes.size()'
// tensors along the zeroth dimension, with the ith split having zeroth-
// dimension size 'sizes[i]'. They allocate the output tensors using 'context',
@ -268,6 +287,25 @@ Status Split(OpKernelContext* context, const Tensor& input,
return SplitCPU<T>(context, input, sizes, outputs);
}
// Same as 'Split' above, but handles Tensor dtype automatically.
Status Split(OpKernelContext* context, const Tensor& input,
const gtl::ArraySlice<int64> sizes, std::vector<Tensor>* outputs) {
const DataType type = input.dtype();
Status split_status;
switch (type) {
#define CASE(type) \
case DataTypeToEnum<type>::value: \
split_status = Split<type>(context, input, sizes, outputs); \
break;
TF_CALL_ALL_TYPES(CASE);
#undef CASE
default:
split_status = errors::InvalidArgument("Unsupported data type: ", type);
break;
}
return split_status;
}
// A class encapsulating the state and logic for batching tensors.
class BatchResource : public ResourceBase {
public:
@ -449,22 +487,9 @@ class BatchResource : public ResourceBase {
}
}
const DataType type = to_concatenate[0].dtype();
Status concat_status;
Tensor concatenated_tensor;
switch (type) {
#define CASE(type) \
case DataTypeToEnum<type>::value: \
concat_status = \
Concat<type>(context, to_concatenate, &concatenated_tensor); \
break;
TF_CALL_ALL_TYPES(CASE);
#undef CASE
default:
concat_status =
errors::InvalidArgument("Unsupported data type: ", type);
break;
}
Status concat_status =
Concat(context, to_concatenate, &concatenated_tensor);
TF_RETURN_IF_ERROR(concat_status);
concatenated_tensors->push_back(concatenated_tensor);
}
@ -1001,17 +1026,7 @@ class UnbatchResource : public ResourceBase {
batch_keys.push_back(batch_indices(i, 0));
}
const DataType type = data_t.dtype();
switch (type) {
#define CASE(type) \
case DataTypeToEnum<type>::value: \
TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
break;
TF_CALL_ALL_TYPES(CASE);
#undef CASE
default:
return errors::InvalidArgument("Unsupported data type: ", type);
}
TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
}
// Critical section.