PyTorch
Loading...
Searching...
No Matches
Utils.h
Go to the documentation of this file.
1#pragma once
2
3#include <ATen/core/Tensor.h>
4#include <ATen/cuda/Exceptions.h>
5#include <ATen/cudnn/cudnn-wrapper.h>
6#include <ATen/cudnn/Handle.h>
7
8namespace at { namespace native {
9
10// cuDNN has a buggy check for tensor being contiguous (that is, it does
11// not ignore stride for dimension that is equal to 0). This function
12// makes tensors which have zero stride contiguous, by setting the
13// strides to 1 as cuDNN likes.
15 for (auto s : t.strides()) {
16 if (s == 0) return t.contiguous();
17 }
18 return t;
19}
20
21}}
IntArrayRef strides() const
Definition: TensorBase.h:240
Definition: TensorBody.h:90
Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const
Definition: TensorBody.h:120
Tensor contiguousIfZeroInStrides(const Tensor &t)
Definition: Utils.h:14
Definition: TensorBase.h:34
at::Tensor t(const at::Tensor &self)
Definition: Functions.h:7681