6#include <ATen/cuda/Exceptions.h>
8#include <ATen/cudnn/cudnn-wrapper.h>
11#include <ATen/TensorUtils.h>
12#include <ATen/cuda/ATenCUDAGeneral.h>
15#ifndef AT_PER_OPERATOR_HEADERS
18#include <ATen/ops/empty.h>
21namespace at {
namespace native {
23std::string cudnnTypeToString(cudnnDataType_t dtype);
30#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
31 case CUDNN_DATA_BFLOAT16:
33 case CUDNN_DATA_HALF:
return 2;
34 case CUDNN_DATA_FLOAT:
return 4;
53 std::vector<int> permutation(dim);
56 permutation[
index++] = 1;
58 for (
int d = dim-1; d > 1; d--) {
59 permutation[
index++] = d;
62 permutation[
index++] = 1;
64 permutation[
index++] = 0;
65 for (
int d : permutation) {
74template <
typename T, cudnnStatus_t (*dtor)(T*)>
78 AT_CUDNN_CHECK(dtor(x));
92template <
typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
101 T*
desc()
const {
return desc_.get(); }
102 T*
desc() {
return desc_.get(); }
111 if (desc_ ==
nullptr) {
113 AT_CUDNN_CHECK(ctor(&raw_desc));
114 desc_.reset(raw_desc);
118 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
123 &cudnnCreateTensorDescriptor,
124 &cudnnDestroyTensorDescriptor> {
145 void set(
const at::Tensor &t, at::MemoryFormat memory_format,
size_t pad = 0);
146 void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides,
size_t pad = 0);
151 void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides,
size_t pad,
bool nhwc);
153 void set(cudnnDataType_t dataType,
int dim,
int* size,
int* stride,
bool nhwc) {
154 fixSizeOneDimStride<int>(dim, size, stride, nhwc);
155 AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride));
163 &cudnnCreateFilterDescriptor,
164 &cudnnDestroyFilterDescriptor> {
167 set(
t, at::MemoryFormat::Contiguous,
pad);
170 void set(
const at::Tensor &t,
const at::MemoryFormat memory_format, int64_t pad = 0);
174 void set(cudnnDataType_t dataType,
int dim,
int* size, cudnnTensorFormat_t filter_format) {
175 AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
183 cudnnConvolutionStruct,
184 &cudnnCreateConvolutionDescriptor,
185 &cudnnDestroyConvolutionDescriptor> {
186 void set(cudnnDataType_t dataType,
int dim,
int* pad,
int* stride,
int * upscale ,
int groups,
bool allow_tf32) {
187 cudnnDataType_t mathType = dataType;
188 if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
189 AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim,
pad,
stride, upscale,
190 CUDNN_CROSS_CORRELATION, mathType));
191 AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
193 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
194 if(dataType == CUDNN_DATA_HALF) {
195 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
196 }
else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
197#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
198 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
206 cudnnSpatialTransformerStruct,
207 &cudnnCreateSpatialTransformerDescriptor,
208 &cudnnDestroySpatialTransformerDescriptor> {
209 void set(cudnnDataType_t dataType,
int dim,
int* size) {
210 AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
217 &cudnnCreateDropoutDescriptor,
218 &cudnnDestroyDropoutDescriptor> {
223 void initialize_rng(cudnnHandle_t handle,
float dropout,
long long int seed,
const TensorOptions& options) {
226 AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
227 AT_ASSERT(options.device().type() == kCUDA);
229 state =
at::empty({
static_cast<int64_t
>(state_size)}, options);
230 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle,
dropout, state.
data_ptr(), state_size, seed));
238 size_t state_size = state.
size(0);
240 AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle,
dropout, state_ptr, state_size, 0 ));
249 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 ,
nullptr, 0 , 0 ));
255 &cudnnCreateRNNDescriptor,
256 &cudnnDestroyRNNDescriptor> {
258 void set(cudnnHandle_t handle,
int hidden_size,
int proj_size,
int num_layers,
DropoutDescriptor&& dropout_desc,
259 cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
260 cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo,
bool allow_tf32) {
261 dropout_desc_ = std::move(dropout_desc);
263 AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
268 dropout_desc_.desc(),
274 if (proj_size != 0) {
275 AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
282 if (prop->major >= 7) {
283 if (input_type == CUDNN_DATA_HALF) {
284 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
286#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
287 else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
288 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
294 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
303 &cudnnCreateCTCLossDescriptor,
304 &cudnnDestroyCTCLossDescriptor> {
305 void set(cudnnDataType_t datatype) {
306 AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
308#if CUDNN_VERSION >= 7600
310 cudnnDataType_t datatype,
311 cudnnLossNormalizationMode_t normMode,
312 cudnnNanPropagation_t gradMode) {
314 cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
321 cudnnActivationStruct,
322 &cudnnCreateActivationDescriptor,
323 &cudnnDestroyActivationDescriptor> {
324 void set(cudnnActivationMode_t mode) {
326 mode == CUDNN_ACTIVATION_RELU,
327 "TODO: support more cuDNN activation modes");
328 AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
331 cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
332 std::numeric_limits<double>::max()));
341 if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
342 f =
static_cast<float>(value);
#define TORCH_INTERNAL_ASSERT(cond,...)
Definition: Exception.h:377
#define AT_ASSERT(...)
Definition: Exception.h:654
void * data_ptr() const
Definition: TensorBase.h:543
Definition: TensorBody.h:90
int64_t size(at::Dimname dim) const
Definition: TensorBody.h:3350
Definition: Descriptors.h:93
T * desc()
Definition: Descriptors.h:102
T * mut_desc()
Definition: Descriptors.h:108
void init()
Definition: Descriptors.h:110
T * desc() const
Definition: Descriptors.h:101
Definition: Descriptors.h:164
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad=0)
void set(const at::Tensor &t, int64_t pad=0)
Definition: Descriptors.h:166
Definition: Descriptors.h:124
void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad=0)
TensorDescriptor()=default
void set(const at::Tensor &t, size_t pad=0)
void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad=0)
TensorDescriptor(const at::Tensor &t, size_t pad=0)
Definition: Descriptors.h:127
cudaDeviceProp * getCurrentDeviceProperties()
std::ostream & operator<<(std::ostream &out, const TensorDescriptor &d)
static void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc)
Definition: Descriptors.h:50
int dataSize(cudnnDataType_t dataType)
Definition: Descriptors.h:27
Definition: TensorBase.h:34
at::Tensor pad(const at::Tensor &self, at::IntArrayRef pad, c10::string_view mode="constant", c10::optional< double > value=c10::nullopt)
Definition: Functions.h:14477
at::Tensor dropout(const at::Tensor &input, double p, bool train)
Definition: Functions.h:313
inline ::std::tuple< at::Tensor, at::Tensor > mode(const at::Tensor &self, int64_t dim=-1, bool keepdim=false)
Definition: Functions.h:4848
int64_t stride(const at::Tensor &self, at::Dimname dim)
Definition: Functions.h:7467
at::Tensor set(const at::Tensor &self, at::Storage source)
Definition: Functions.h:23868
int64_t size(const at::Tensor &self, at::Dimname dim)
Definition: Functions.h:7031
at::Tensor empty(at::IntArrayRef size, c10::optional< at::DimnameList > names, at::TensorOptions options={}, c10::optional< at::MemoryFormat > memory_format=c10::nullopt)
Definition: Functions.h:2592
at::Tensor index(const at::Tensor &self, const c10::List< c10::optional< at::Tensor > > &indices)
Definition: Functions.h:3622
at::Tensor t(const at::Tensor &self)
Definition: Functions.h:7681
Definition: Descriptors.h:323
void set(cudnnActivationMode_t mode)
Definition: Descriptors.h:324
Definition: Descriptors.h:304
void set(cudnnDataType_t datatype)
Definition: Descriptors.h:305
Definition: Descriptors.h:185
void set(cudnnDataType_t dataType, int dim, int *pad, int *stride, int *upscale, int groups, bool allow_tf32)
Definition: Descriptors.h:186
Definition: Descriptors.h:75
void operator()(T *x)
Definition: Descriptors.h:76
Definition: Descriptors.h:218
at::Tensor state
Definition: Descriptors.h:219
void set_no_dropout(cudnnHandle_t handle)
Definition: Descriptors.h:244
void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions &options)
Definition: Descriptors.h:223
void set(cudnnHandle_t handle, float dropout, at::Tensor state_)
Definition: Descriptors.h:234
Definition: Descriptors.h:256
DropoutDescriptor dropout_desc_
Definition: Descriptors.h:257
void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor &&dropout_desc, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32)
Definition: Descriptors.h:258
Definition: Descriptors.h:337
double d
Definition: Descriptors.h:339
float f
Definition: Descriptors.h:338
Constant(cudnnDataType_t dataType, double value)
Definition: Descriptors.h:340