3#include <ATen/mkl/Exceptions.h>
5#include <ATen/Tensor.h>
7namespace at {
namespace native {
11 if (desc !=
nullptr) {
12 MKL_DFTI_CHECK(DftiFreeDescriptor(&desc));
19 void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) {
20 if (desc_ !=
nullptr) {
21 throw std::runtime_error(
"DFTI DESCRIPTOR can only be initialized once");
23 DFTI_DESCRIPTOR *raw_desc;
24 if (signal_ndim == 1) {
25 MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
27 MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, signal_ndim, sizes));
29 desc_.reset(raw_desc);
32 DFTI_DESCRIPTOR *
get()
const {
33 if (desc_ ==
nullptr) {
34 throw std::runtime_error(
"DFTI DESCRIPTOR has not been initialized");
40 std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
Definition: Descriptors.h:17
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG *sizes)
Definition: Descriptors.h:19
DFTI_DESCRIPTOR * get() const
Definition: Descriptors.h:32
Definition: TensorBase.h:34
Definition: Descriptors.h:9
void operator()(DFTI_DESCRIPTOR *desc)
Definition: Descriptors.h:10