PyTorch
Loading...
Searching...
No Matches
Descriptors.h
Go to the documentation of this file.
1#pragma once
2
3#include <ATen/mkl/Exceptions.h>
4#include <mkl_dfti.h>
5#include <ATen/Tensor.h>
6
7namespace at { namespace native {
8
10 void operator()(DFTI_DESCRIPTOR* desc) {
11 if (desc != nullptr) {
12 MKL_DFTI_CHECK(DftiFreeDescriptor(&desc));
13 }
14 }
15};
16
18public:
19 void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) {
20 if (desc_ != nullptr) {
21 throw std::runtime_error("DFTI DESCRIPTOR can only be initialized once");
22 }
23 DFTI_DESCRIPTOR *raw_desc;
24 if (signal_ndim == 1) {
25 MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
26 } else {
27 MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, signal_ndim, sizes));
28 }
29 desc_.reset(raw_desc);
30 }
31
32 DFTI_DESCRIPTOR *get() const {
33 if (desc_ == nullptr) {
34 throw std::runtime_error("DFTI DESCRIPTOR has not been initialized");
35 }
36 return desc_.get();
37 }
38
39private:
40 std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
41};
42
43
44}} // at::native
Definition: Descriptors.h:17
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG *sizes)
Definition: Descriptors.h:19
DFTI_DESCRIPTOR * get() const
Definition: Descriptors.h:32
Definition: TensorBase.h:34
Definition: Descriptors.h:9
void operator()(DFTI_DESCRIPTOR *desc)
Definition: Descriptors.h:10