6#include <ATen/cuda/Exceptions.h> 
    8#include <ATen/cudnn/cudnn-wrapper.h> 
   11#include <ATen/TensorUtils.h> 
   12#include <ATen/cuda/ATenCUDAGeneral.h> 
   15#ifndef AT_PER_OPERATOR_HEADERS 
   18#include <ATen/ops/empty.h> 
   21namespace at { 
namespace native {
 
   23std::string cudnnTypeToString(cudnnDataType_t dtype);
 
   30#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200 
   31    case CUDNN_DATA_BFLOAT16:
 
   33    case CUDNN_DATA_HALF: 
return 2;
 
   34    case CUDNN_DATA_FLOAT: 
return 4;
 
   53  std::vector<int> permutation(dim);
 
   56    permutation[
index++] = 1;
 
   58  for (
int d = dim-1; d > 1; d--) {
 
   59    permutation[
index++] = d;
 
   62    permutation[
index++] = 1;
 
   64  permutation[
index++] = 0;
 
   65  for (
int d : permutation) {
 
   74template <
typename T, cudnnStatus_t (*dtor)(T*)>
 
   78      AT_CUDNN_CHECK(dtor(x));
 
   92template <
typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
 
  101  T* 
desc()
 const { 
return desc_.get(); }
 
  102  T* 
desc() { 
return desc_.get(); }
 
  111    if (desc_ == 
nullptr) {
 
  113      AT_CUDNN_CHECK(ctor(&raw_desc));
 
  114      desc_.reset(raw_desc);
 
  118  std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
 
  123                                               &cudnnCreateTensorDescriptor,
 
  124                                               &cudnnDestroyTensorDescriptor> {
 
  145  void set(
const at::Tensor &t, at::MemoryFormat memory_format, 
size_t pad = 0);
 
  146  void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, 
size_t pad = 0);
 
  151  void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, 
size_t pad, 
bool nhwc);
 
  153  void set(cudnnDataType_t dataType, 
int dim, 
int* size, 
int* stride, 
bool nhwc) {
 
  154    fixSizeOneDimStride<int>(dim, size, stride, nhwc);
 
  155    AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride));
 
  163                                               &cudnnCreateFilterDescriptor,
 
  164                                               &cudnnDestroyFilterDescriptor> {
 
  167    set(
t, at::MemoryFormat::Contiguous, 
pad);
 
  170  void set(
const at::Tensor &t, 
const at::MemoryFormat memory_format, int64_t pad = 0);
 
  174  void set(cudnnDataType_t dataType, 
int dim, 
int* size, cudnnTensorFormat_t filter_format) {
 
  175    AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
 
  183          cudnnConvolutionStruct,
 
  184          &cudnnCreateConvolutionDescriptor,
 
  185          &cudnnDestroyConvolutionDescriptor> {
 
  186  void set(cudnnDataType_t dataType, 
int dim, 
int* pad, 
int* stride, 
int * upscale , 
int groups, 
bool allow_tf32) {
 
  187    cudnnDataType_t mathType = dataType;
 
  188    if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
 
  189    AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, 
pad, 
stride, upscale,
 
  190                                          CUDNN_CROSS_CORRELATION, mathType));
 
  191    AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
 
  193    AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
 
  194    if(dataType == CUDNN_DATA_HALF) {
 
  195      AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
 
  196    } 
else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
 
  197#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 
  198      AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
 
  206          cudnnSpatialTransformerStruct,
 
  207          &cudnnCreateSpatialTransformerDescriptor,
 
  208          &cudnnDestroySpatialTransformerDescriptor> {
 
  209  void set(cudnnDataType_t dataType, 
int dim, 
int* size) {
 
  210    AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
 
  217          &cudnnCreateDropoutDescriptor,
 
  218          &cudnnDestroyDropoutDescriptor> {
 
  223  void initialize_rng(cudnnHandle_t handle, 
float dropout, 
long long int seed, 
const TensorOptions& options) {
 
  226    AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
 
  227    AT_ASSERT(options.device().type() == kCUDA);
 
  229    state = 
at::empty({
static_cast<int64_t
>(state_size)}, options);
 
  230    AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 
dropout, state.
data_ptr(), state_size, seed));
 
  238    size_t state_size = state.
size(0);
 
  240    AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, 
dropout, state_ptr, state_size, 0 ));
 
  249    AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 , 
nullptr, 0 , 0 ));
 
  255                                             &cudnnCreateRNNDescriptor,
 
  256                                             &cudnnDestroyRNNDescriptor> {
 
  258  void set(cudnnHandle_t handle, 
int hidden_size, 
int proj_size, 
int num_layers, 
DropoutDescriptor&& dropout_desc,
 
  259           cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
 
  260           cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, 
bool allow_tf32) {
 
  261    dropout_desc_ = std::move(dropout_desc);
 
  263    AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
 
  268          dropout_desc_.desc(),
 
  274    if (proj_size != 0) {
 
  275      AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
 
  282    if (prop->major >= 7) {
 
  283      if (input_type == CUDNN_DATA_HALF) {
 
  284        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
 
  286#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 
  287      else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
 
  288        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
 
  294        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
 
  303          &cudnnCreateCTCLossDescriptor,
 
  304          &cudnnDestroyCTCLossDescriptor> {
 
  305  void set(cudnnDataType_t datatype) {
 
  306    AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
 
  308#if CUDNN_VERSION >= 7600 
  310      cudnnDataType_t datatype,
 
  311      cudnnLossNormalizationMode_t normMode,
 
  312      cudnnNanPropagation_t gradMode) {
 
  314        cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
 
  321          cudnnActivationStruct,
 
  322          &cudnnCreateActivationDescriptor,
 
  323          &cudnnDestroyActivationDescriptor> {
 
  324  void set(cudnnActivationMode_t mode) {
 
  326        mode == CUDNN_ACTIVATION_RELU,
 
  327        "TODO: support more cuDNN activation modes");
 
  328    AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
 
  331        cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
 
  332        std::numeric_limits<double>::max()));
 
  341    if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
 
  342      f = 
static_cast<float>(value);
 
#define TORCH_INTERNAL_ASSERT(cond,...)
Definition: Exception.h:377
 
#define AT_ASSERT(...)
Definition: Exception.h:654
 
void * data_ptr() const
Definition: TensorBase.h:543
 
Definition: TensorBody.h:90
 
int64_t size(at::Dimname dim) const
Definition: TensorBody.h:3350
 
Definition: Descriptors.h:93
 
T * desc()
Definition: Descriptors.h:102
 
T * mut_desc()
Definition: Descriptors.h:108
 
void init()
Definition: Descriptors.h:110
 
T * desc() const
Definition: Descriptors.h:101
 
Definition: Descriptors.h:164
 
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad=0)
 
void set(const at::Tensor &t, int64_t pad=0)
Definition: Descriptors.h:166
 
Definition: Descriptors.h:124
 
void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad=0)
 
TensorDescriptor()=default
 
void set(const at::Tensor &t, size_t pad=0)
 
void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad=0)
 
TensorDescriptor(const at::Tensor &t, size_t pad=0)
Definition: Descriptors.h:127
 
cudaDeviceProp * getCurrentDeviceProperties()
 
std::ostream & operator<<(std::ostream &out, const TensorDescriptor &d)
 
static void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc)
Definition: Descriptors.h:50
 
int dataSize(cudnnDataType_t dataType)
Definition: Descriptors.h:27
 
Definition: TensorBase.h:34
 
at::Tensor pad(const at::Tensor &self, at::IntArrayRef pad, c10::string_view mode="constant", c10::optional< double > value=c10::nullopt)
Definition: Functions.h:14477
 
at::Tensor dropout(const at::Tensor &input, double p, bool train)
Definition: Functions.h:313
 
inline ::std::tuple< at::Tensor, at::Tensor > mode(const at::Tensor &self, int64_t dim=-1, bool keepdim=false)
Definition: Functions.h:4848
 
int64_t stride(const at::Tensor &self, at::Dimname dim)
Definition: Functions.h:7467
 
at::Tensor set(const at::Tensor &self, at::Storage source)
Definition: Functions.h:23868
 
int64_t size(const at::Tensor &self, at::Dimname dim)
Definition: Functions.h:7031
 
at::Tensor empty(at::IntArrayRef size, c10::optional< at::DimnameList > names, at::TensorOptions options={}, c10::optional< at::MemoryFormat > memory_format=c10::nullopt)
Definition: Functions.h:2592
 
at::Tensor index(const at::Tensor &self, const c10::List< c10::optional< at::Tensor > > &indices)
Definition: Functions.h:3622
 
at::Tensor t(const at::Tensor &self)
Definition: Functions.h:7681
 
Definition: Descriptors.h:323
 
void set(cudnnActivationMode_t mode)
Definition: Descriptors.h:324
 
Definition: Descriptors.h:304
 
void set(cudnnDataType_t datatype)
Definition: Descriptors.h:305
 
Definition: Descriptors.h:185
 
void set(cudnnDataType_t dataType, int dim, int *pad, int *stride, int *upscale, int groups, bool allow_tf32)
Definition: Descriptors.h:186
 
Definition: Descriptors.h:75
 
void operator()(T *x)
Definition: Descriptors.h:76
 
Definition: Descriptors.h:218
 
at::Tensor state
Definition: Descriptors.h:219
 
void set_no_dropout(cudnnHandle_t handle)
Definition: Descriptors.h:244
 
void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions &options)
Definition: Descriptors.h:223
 
void set(cudnnHandle_t handle, float dropout, at::Tensor state_)
Definition: Descriptors.h:234
 
Definition: Descriptors.h:256
 
DropoutDescriptor dropout_desc_
Definition: Descriptors.h:257
 
void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor &&dropout_desc, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32)
Definition: Descriptors.h:258
 
Definition: Descriptors.h:337
 
double d
Definition: Descriptors.h:339
 
float f
Definition: Descriptors.h:338
 
Constant(cudnnDataType_t dataType, double value)
Definition: Descriptors.h:340