A refactor to add helper functions {Split, Concat}, which do Tensor dtype deduction automatically.
PiperOrigin-RevId: 316703807 Change-Id: I234bea9fce3cf3b2cb352be12246ee9f4e8c405a
This commit is contained in:
parent
9ee6864c2c
commit
f6b8e93a5f
@ -100,9 +100,9 @@ typedef Eigen::SyclDevice SYCLDevice;
|
|||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
// Concatenates 'inputs' into a single tensor along the zeroth dimension.
|
// Concatenates 'inputs' into a single tensor along the zeroth dimension.
|
||||||
// Requires that all elements of 'inputs' have element type T. Writes to the
|
// Requires that all elements of 'inputs' have element type T. Writes to
|
||||||
// op's output at position 'output_index', using 'context' for the allocation to
|
// 'output' using 'context' for the allocation to ensure proper device
|
||||||
// ensure proper device placement.
|
// placement.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
||||||
Tensor* output) {
|
Tensor* output) {
|
||||||
@ -157,6 +157,25 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as 'Concat' above, but handles Tensor dtype deduction automatically.
|
||||||
|
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor> inputs,
|
||||||
|
Tensor* output) {
|
||||||
|
const DataType type = inputs[0].dtype();
|
||||||
|
Status concat_status;
|
||||||
|
switch (type) {
|
||||||
|
#define CASE(type) \
|
||||||
|
case DataTypeToEnum<type>::value: \
|
||||||
|
concat_status = Concat<type>(context, inputs, output); \
|
||||||
|
break;
|
||||||
|
TF_CALL_ALL_TYPES(CASE);
|
||||||
|
#undef CASE
|
||||||
|
default:
|
||||||
|
concat_status = errors::InvalidArgument("Unsupported data type: ", type);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return concat_status;
|
||||||
|
}
|
||||||
|
|
||||||
// The Split*() functions split 'input' with element type T into 'sizes.size()'
|
// The Split*() functions split 'input' with element type T into 'sizes.size()'
|
||||||
// tensors along the zeroth dimension, with the ith split having zeroth-
|
// tensors along the zeroth dimension, with the ith split having zeroth-
|
||||||
// dimension size 'sizes[i]'. They allocate the output tensors using 'context',
|
// dimension size 'sizes[i]'. They allocate the output tensors using 'context',
|
||||||
@ -268,6 +287,25 @@ Status Split(OpKernelContext* context, const Tensor& input,
|
|||||||
return SplitCPU<T>(context, input, sizes, outputs);
|
return SplitCPU<T>(context, input, sizes, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as 'Split' above, but handles Tensor dtype automatically.
|
||||||
|
Status Split(OpKernelContext* context, const Tensor& input,
|
||||||
|
const gtl::ArraySlice<int64> sizes, std::vector<Tensor>* outputs) {
|
||||||
|
const DataType type = input.dtype();
|
||||||
|
Status split_status;
|
||||||
|
switch (type) {
|
||||||
|
#define CASE(type) \
|
||||||
|
case DataTypeToEnum<type>::value: \
|
||||||
|
split_status = Split<type>(context, input, sizes, outputs); \
|
||||||
|
break;
|
||||||
|
TF_CALL_ALL_TYPES(CASE);
|
||||||
|
#undef CASE
|
||||||
|
default:
|
||||||
|
split_status = errors::InvalidArgument("Unsupported data type: ", type);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return split_status;
|
||||||
|
}
|
||||||
|
|
||||||
// A class encapsulating the state and logic for batching tensors.
|
// A class encapsulating the state and logic for batching tensors.
|
||||||
class BatchResource : public ResourceBase {
|
class BatchResource : public ResourceBase {
|
||||||
public:
|
public:
|
||||||
@ -449,22 +487,9 @@ class BatchResource : public ResourceBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataType type = to_concatenate[0].dtype();
|
|
||||||
Status concat_status;
|
|
||||||
Tensor concatenated_tensor;
|
Tensor concatenated_tensor;
|
||||||
switch (type) {
|
Status concat_status =
|
||||||
#define CASE(type) \
|
Concat(context, to_concatenate, &concatenated_tensor);
|
||||||
case DataTypeToEnum<type>::value: \
|
|
||||||
concat_status = \
|
|
||||||
Concat<type>(context, to_concatenate, &concatenated_tensor); \
|
|
||||||
break;
|
|
||||||
TF_CALL_ALL_TYPES(CASE);
|
|
||||||
#undef CASE
|
|
||||||
default:
|
|
||||||
concat_status =
|
|
||||||
errors::InvalidArgument("Unsupported data type: ", type);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(concat_status);
|
TF_RETURN_IF_ERROR(concat_status);
|
||||||
concatenated_tensors->push_back(concatenated_tensor);
|
concatenated_tensors->push_back(concatenated_tensor);
|
||||||
}
|
}
|
||||||
@ -1001,17 +1026,7 @@ class UnbatchResource : public ResourceBase {
|
|||||||
batch_keys.push_back(batch_indices(i, 0));
|
batch_keys.push_back(batch_indices(i, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataType type = data_t.dtype();
|
TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
|
||||||
switch (type) {
|
|
||||||
#define CASE(type) \
|
|
||||||
case DataTypeToEnum<type>::value: \
|
|
||||||
TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
|
|
||||||
break;
|
|
||||||
TF_CALL_ALL_TYPES(CASE);
|
|
||||||
#undef CASE
|
|
||||||
default:
|
|
||||||
return errors::InvalidArgument("Unsupported data type: ", type);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Critical section.
|
// Critical section.
|
||||||
|
Loading…
Reference in New Issue
Block a user