PyTorch
Loading...
Searching...
No Matches
Descriptors.h
Go to the documentation of this file.
1#pragma once
2
3#include <string>
4
6#include <ATen/cuda/Exceptions.h>
7
8#include <ATen/cudnn/cudnn-wrapper.h>
9#include <ATen/cudnn/Utils.h>
10#include <ATen/core/Tensor.h>
11#include <ATen/TensorUtils.h>
12#include <ATen/cuda/ATenCUDAGeneral.h>
13#include <cuda.h>
14
15#ifndef AT_PER_OPERATOR_HEADERS
16#include <ATen/Functions.h>
17#else
18#include <ATen/ops/empty.h>
19#endif
20
21namespace at { namespace native {
22
23std::string cudnnTypeToString(cudnnDataType_t dtype);
24
25// TODO: Add constructors for all of the descriptors
26
27inline int dataSize(cudnnDataType_t dataType)
28{
29 switch (dataType) {
30#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
31 case CUDNN_DATA_BFLOAT16:
32#endif
33 case CUDNN_DATA_HALF: return 2;
34 case CUDNN_DATA_FLOAT: return 4;
35 default: return 8;
36 }
37}
38
39// The stride for a size-1 dimensions is not uniquely determined; in
40// fact, it can be anything you want, because the fact that the
41// tensor is size 1 at this dimension means that you will never actually
42// try advancing your pointer by this stride.
43//
44// However, CuDNN has a much more stringent requirement on strides:
45// if you are passing a contiguous input, it better be the case
46// that the stride for dim i is the product of the sizes of dims
47// i+1 to the end. This stride is indeed uniquely determined. This
48// function modifies 'stride' in place so this invariant holds.
49template <typename T>
50static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
51 int64_t z = 1;
52 int index = 0;
53 std::vector<int> permutation(dim);
54
55 if (nhwc) {
56 permutation[index++] = 1;
57 }
58 for (int d = dim-1; d > 1; d--) {
59 permutation[index++] = d;
60 }
61 if (!nhwc) {
62 permutation[index++] = 1;
63 }
64 permutation[index++] = 0;
65 for (int d : permutation) {
66 if (size[d] == 1) {
67 stride[d] = z;
68 } else {
69 z *= size[d];
70 }
71 }
72}
73
74template <typename T, cudnnStatus_t (*dtor)(T*)>
76 void operator()(T* x) {
77 if (x != nullptr) {
78 AT_CUDNN_CHECK(dtor(x));
79 }
80 }
81};
82
83// A generic class for wrapping cuDNN descriptor types. All you need
84// is to give the underlying type the Descriptor_t points to (usually,
85// if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct),
86// the constructor and the destructor. Subclasses are responsible
87// for defining a set() function to actually set the descriptor.
88//
89// Descriptors default construct to a nullptr, and have a descriptor
90// initialized the first time you call set() or any other initializing
91// function.
92template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
93class TORCH_CUDA_CPP_API Descriptor {
94 public:
95 // TODO: Figure out why const-correctness doesn't work here
96
97 // Use desc() to access the underlying descriptor pointer in
98 // a read-only fashion. Most client code should use this.
99 // If the descriptor was never initialized, this will return
100 // nullptr.
101 T* desc() const { return desc_.get(); }
102 T* desc() { return desc_.get(); }
103
104 // Use mut_desc() to access the underlying descriptor pointer
105 // if you intend to modify what it points to (e.g., using
106 // cudnnSetFooDescriptor). This will ensure that the descriptor
107 // is initialized. Code in this file will use this function.
108 T* mut_desc() { init(); return desc_.get(); }
109protected:
110 void init() {
111 if (desc_ == nullptr) {
112 T* raw_desc;
113 AT_CUDNN_CHECK(ctor(&raw_desc));
114 desc_.reset(raw_desc);
115 }
116 }
117private:
118 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
119};
120
121class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
122 cudnnTensorStruct,
123 &cudnnCreateTensorDescriptor,
124 &cudnnDestroyTensorDescriptor> {
125 public:
126 TensorDescriptor() = default;
127 explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
128 set(t, pad);
129 }
130
131 // Note [CuDNN broadcast padding]
132 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133 // pad specifies the minimum dimensionality of the tensor descriptor
134 // we produce (it doesn't have anything to do with, e.g., convolution
135 // padding). If 't' is lower-dimensional than 'pad', the remaining
136 // dimensions (on the right) are padded with ones. This doesn't
137 // affect the underlying data layout. This is particularly useful for
138 // dealing with a pecularity of the CuDNN API, which is that broadcasting in CuDNN is
139 // done in two steps: first, the client code is expected to pad out
140 // (the dimensions) input tensors to be the same dimension as the
141 // target broadcast, and then second, CuDNN takes of actually
142 // broadcasting size 1 dimensions.
143
144 void set(const at::Tensor &t, size_t pad = 0);
145 void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
146 void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
147
148 void print();
149
150private:
151 void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
152
153 void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
154 fixSizeOneDimStride<int>(dim, size, stride, nhwc);
155 AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride));
156 }
157};
158
159std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
160
161class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
162 cudnnFilterStruct,
163 &cudnnCreateFilterDescriptor,
164 &cudnnDestroyFilterDescriptor> {
165 public:
166 void set(const at::Tensor &t, int64_t pad = 0) {
167 set(t, at::MemoryFormat::Contiguous, pad);
168 }
169
170 void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
171
172 void print();
173private:
174 void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
175 AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
176 }
177};
178
179std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
180
181struct TORCH_CUDA_CPP_API ConvolutionDescriptor
182 : public Descriptor<
183 cudnnConvolutionStruct,
184 &cudnnCreateConvolutionDescriptor,
185 &cudnnDestroyConvolutionDescriptor> {
186 void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) {
187 cudnnDataType_t mathType = dataType;
188 if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
189 AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
190 CUDNN_CROSS_CORRELATION, mathType));
191 AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
192 // See Note [behavior of cudnnFind and cudnnGet]
193 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
194 if(dataType == CUDNN_DATA_HALF) {
195 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
196 } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
197#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
198 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
199#endif
200 }
201 }
202};
203
204struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor
205 : public Descriptor<
206 cudnnSpatialTransformerStruct,
207 &cudnnCreateSpatialTransformerDescriptor,
208 &cudnnDestroySpatialTransformerDescriptor> {
209 void set(cudnnDataType_t dataType, int dim, int* size) {
210 AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
211 }
212};
213
214struct TORCH_CUDA_CPP_API DropoutDescriptor
215 : public Descriptor<
216 cudnnDropoutStruct,
217 &cudnnCreateDropoutDescriptor,
218 &cudnnDestroyDropoutDescriptor> {
220
221 // Initialize a dropout descriptor's RNG state.
222 // WARNING: This function is very expensive, avoid calling this function!
223 void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) {
224 TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
225 size_t state_size;
226 AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
227 AT_ASSERT(options.device().type() == kCUDA);
228 AT_ASSERT(options.dtype() == kByte);
229 state = at::empty({static_cast<int64_t>(state_size)}, options);
230 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
231 }
232
233 // Restore a dropout descriptor given a dropout probability and existing RNG state.
234 void set(cudnnHandle_t handle, float dropout, at::Tensor state_) {
235 TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
236 state = state_;
237 void *state_ptr = state.data_ptr();
238 size_t state_size = state.size(0);
239 // NB: The seed doesn't actually matter, so we give a dummy value
240 AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
241 }
242
243 // Restore a dropout descriptor corresponding to no dropout
244 void set_no_dropout(cudnnHandle_t handle) {
245 // NB: seed doesn't matter when dropout = 0, because no random number
246 // initialization actually takes place when there is no dropout.
247 // NB: Empirically, cudnnSetDropoutDescriptor is cheap when
248 // dropoot == 0
249 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */));
250 }
251};
252
253struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
254 cudnnRNNStruct,
255 &cudnnCreateRNNDescriptor,
256 &cudnnDestroyRNNDescriptor> {
258 void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
259 cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
260 cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
261 dropout_desc_ = std::move(dropout_desc);
262
263 AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
264 handle,
265 mut_desc(),
266 hidden_size,
267 num_layers,
268 dropout_desc_.desc(),
269 input_mode,
270 bidirectional,
271 mode,
272 algo,
273 datatype));
274 if (proj_size != 0) {
275 AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
276 handle,
277 /*rnnDesc=*/mut_desc(),
278 /*recProjSize=*/proj_size,
279 /*outProjSize=*/0));
280 }
281 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
282 if (prop->major >= 7) {
283 if (input_type == CUDNN_DATA_HALF) {
284 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
285 }
286#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
287 else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
288 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
289 }
290#endif
291 else {
292 // Technically, as the default it's not necessary to explicitly
293 // set this.
294 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
295 }
296 }
297 }
298};
299
300struct TORCH_CUDA_CPP_API CTCLossDescriptor
301 : public Descriptor<
302 cudnnCTCLossStruct,
303 &cudnnCreateCTCLossDescriptor,
304 &cudnnDestroyCTCLossDescriptor> {
305 void set(cudnnDataType_t datatype) {
306 AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
307 }
308#if CUDNN_VERSION >= 7600
309 void setEx(
310 cudnnDataType_t datatype,
311 cudnnLossNormalizationMode_t normMode,
312 cudnnNanPropagation_t gradMode) {
313 AT_CUDNN_CHECK(
314 cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
315 }
316#endif
317};
318
319struct TORCH_CUDA_CPP_API ActivationDescriptor
320 : public Descriptor<
321 cudnnActivationStruct,
322 &cudnnCreateActivationDescriptor,
323 &cudnnDestroyActivationDescriptor> {
324 void set(cudnnActivationMode_t mode) {
325 AT_ASSERT(
326 mode == CUDNN_ACTIVATION_RELU,
327 "TODO: support more cuDNN activation modes");
328 AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
329 mut_desc(),
330 mode,
331 cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
332 std::numeric_limits<double>::max()));
333 }
334};
335
337{
338 float f;
339 double d;
340 Constant(cudnnDataType_t dataType, double value) {
341 if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
342 f = static_cast<float>(value);
343 } else {
344 d = value;
345 }
346 }
347};
348
349}} // namespace
#define TORCH_INTERNAL_ASSERT(cond,...)
Definition: Exception.h:377
#define AT_ASSERT(...)
Definition: Exception.h:654
void * data_ptr() const
Definition: TensorBase.h:543
Definition: TensorBody.h:90
int64_t size(at::Dimname dim) const
Definition: TensorBody.h:3350
Definition: Descriptors.h:93
T * desc()
Definition: Descriptors.h:102
T * mut_desc()
Definition: Descriptors.h:108
void init()
Definition: Descriptors.h:110
T * desc() const
Definition: Descriptors.h:101
Definition: Descriptors.h:164
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad=0)
void set(const at::Tensor &t, int64_t pad=0)
Definition: Descriptors.h:166
Definition: Descriptors.h:124
void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad=0)
void set(const at::Tensor &t, size_t pad=0)
void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad=0)
TensorDescriptor(const at::Tensor &t, size_t pad=0)
Definition: Descriptors.h:127
cudaDeviceProp * getCurrentDeviceProperties()
std::ostream & operator<<(std::ostream &out, const TensorDescriptor &d)
static void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc)
Definition: Descriptors.h:50
int dataSize(cudnnDataType_t dataType)
Definition: Descriptors.h:27
Definition: TensorBase.h:34
at::Tensor pad(const at::Tensor &self, at::IntArrayRef pad, c10::string_view mode="constant", c10::optional< double > value=c10::nullopt)
Definition: Functions.h:14477
at::Tensor dropout(const at::Tensor &input, double p, bool train)
Definition: Functions.h:313
inline ::std::tuple< at::Tensor, at::Tensor > mode(const at::Tensor &self, int64_t dim=-1, bool keepdim=false)
Definition: Functions.h:4848
int64_t stride(const at::Tensor &self, at::Dimname dim)
Definition: Functions.h:7467
at::Tensor set(const at::Tensor &self, at::Storage source)
Definition: Functions.h:23868
int64_t size(const at::Tensor &self, at::Dimname dim)
Definition: Functions.h:7031
at::Tensor empty(at::IntArrayRef size, c10::optional< at::DimnameList > names, at::TensorOptions options={}, c10::optional< at::MemoryFormat > memory_format=c10::nullopt)
Definition: Functions.h:2592
at::Tensor index(const at::Tensor &self, const c10::List< c10::optional< at::Tensor > > &indices)
Definition: Functions.h:3622
at::Tensor t(const at::Tensor &self)
Definition: Functions.h:7681
Definition: Descriptors.h:323
void set(cudnnActivationMode_t mode)
Definition: Descriptors.h:324
Definition: Descriptors.h:304
void set(cudnnDataType_t datatype)
Definition: Descriptors.h:305
Definition: Descriptors.h:185
void set(cudnnDataType_t dataType, int dim, int *pad, int *stride, int *upscale, int groups, bool allow_tf32)
Definition: Descriptors.h:186
Definition: Descriptors.h:75
void operator()(T *x)
Definition: Descriptors.h:76
Definition: Descriptors.h:218
at::Tensor state
Definition: Descriptors.h:219
void set_no_dropout(cudnnHandle_t handle)
Definition: Descriptors.h:244
void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions &options)
Definition: Descriptors.h:223
void set(cudnnHandle_t handle, float dropout, at::Tensor state_)
Definition: Descriptors.h:234
Definition: Descriptors.h:256
DropoutDescriptor dropout_desc_
Definition: Descriptors.h:257
void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor &&dropout_desc, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32)
Definition: Descriptors.h:258
Definition: Descriptors.h:208
void set(cudnnDataType_t dataType, int dim, int *size)
Definition: Descriptors.h:209
Definition: Descriptors.h:337
double d
Definition: Descriptors.h:339
float f
Definition: Descriptors.h:338
Constant(cudnnDataType_t dataType, double value)
Definition: Descriptors.h:340