A refactor to add helper functions {Split, Concat}, which do Tensor dtype deduction automatically.
PiperOrigin-RevId: 316703807 Change-Id: I234bea9fce3cf3b2cb352be12246ee9f4e8c405a
This commit is contained in:
parent
9ee6864c2c
commit
f6b8e93a5f
@ -100,9 +100,9 @@ typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
// Concatenates 'inputs' into a single tensor along the zeroth dimension.
|
||||
// Requires that all elements of 'inputs' have element type T. Writes to the
|
||||
// op's output at position 'output_index', using 'context' for the allocation to
|
||||
// ensure proper device placement.
|
||||
// Requires that all elements of 'inputs' have element type T. Writes to
|
||||
// 'output' using 'context' for the allocation to ensure proper device
|
||||
// placement.
|
||||
template <typename T>
|
||||
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
||||
Tensor* output) {
|
||||
@ -157,6 +157,25 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Same as 'Concat' above, but handles Tensor dtype deduction automatically.
|
||||
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
||||
Tensor* output) {
|
||||
const DataType type = inputs[0].dtype();
|
||||
Status concat_status;
|
||||
switch (type) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: \
|
||||
concat_status = Concat<type>(context, inputs, output); \
|
||||
break;
|
||||
TF_CALL_ALL_TYPES(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
concat_status = errors::InvalidArgument("Unsupported data type: ", type);
|
||||
break;
|
||||
}
|
||||
return concat_status;
|
||||
}
|
||||
|
||||
// The Split*() functions split 'input' with element type T into 'sizes.size()'
|
||||
// tensors along the zeroth dimension, with the ith split having zeroth-
|
||||
// dimension size 'sizes[i]'. They allocate the output tensors using 'context',
|
||||
@ -268,6 +287,25 @@ Status Split(OpKernelContext* context, const Tensor& input,
|
||||
return SplitCPU<T>(context, input, sizes, outputs);
|
||||
}
|
||||
|
||||
// Same as 'Split' above, but handles Tensor dtype automatically.
|
||||
Status Split(OpKernelContext* context, const Tensor& input,
|
||||
const gtl::ArraySlice<int64> sizes, std::vector<Tensor>* outputs) {
|
||||
const DataType type = input.dtype();
|
||||
Status split_status;
|
||||
switch (type) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: \
|
||||
split_status = Split<type>(context, input, sizes, outputs); \
|
||||
break;
|
||||
TF_CALL_ALL_TYPES(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
split_status = errors::InvalidArgument("Unsupported data type: ", type);
|
||||
break;
|
||||
}
|
||||
return split_status;
|
||||
}
|
||||
|
||||
// A class encapsulating the state and logic for batching tensors.
|
||||
class BatchResource : public ResourceBase {
|
||||
public:
|
||||
@ -449,22 +487,9 @@ class BatchResource : public ResourceBase {
|
||||
}
|
||||
}
|
||||
|
||||
const DataType type = to_concatenate[0].dtype();
|
||||
Status concat_status;
|
||||
Tensor concatenated_tensor;
|
||||
switch (type) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: \
|
||||
concat_status = \
|
||||
Concat<type>(context, to_concatenate, &concatenated_tensor); \
|
||||
break;
|
||||
TF_CALL_ALL_TYPES(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
concat_status =
|
||||
errors::InvalidArgument("Unsupported data type: ", type);
|
||||
break;
|
||||
}
|
||||
Status concat_status =
|
||||
Concat(context, to_concatenate, &concatenated_tensor);
|
||||
TF_RETURN_IF_ERROR(concat_status);
|
||||
concatenated_tensors->push_back(concatenated_tensor);
|
||||
}
|
||||
@ -1001,17 +1026,7 @@ class UnbatchResource : public ResourceBase {
|
||||
batch_keys.push_back(batch_indices(i, 0));
|
||||
}
|
||||
|
||||
const DataType type = data_t.dtype();
|
||||
switch (type) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: \
|
||||
TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
|
||||
break;
|
||||
TF_CALL_ALL_TYPES(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
return errors::InvalidArgument("Unsupported data type: ", type);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
|
||||
}
|
||||
|
||||
// Critical section.
|
||||
|
Loading…
Reference in New Issue
Block a user