Back to Mnn

MNN: include/Tensor.hpp 源文件

doc/API/html/_tensor_8hpp_source.html

3.5.014.3 KB
Original Source

| MNN 1.0 |

Tensor.hpp

浏览该文件的文档.

1 //

2 // Tensor.hpp

3 // MNN

4 //

5 // Created by MNN on 2018/08/14.

6 // Copyright © 2018, Alibaba Group Holding Limited

7 //

8

9 #ifndef Tensor_hpp

10 #define Tensor_hpp

11

12 #include <vector>

13 #include "HalideRuntime.h"

14 #include "MNNDefine.h"

15

16 namespace MNN {

17

25 class MNN_PUBLIC Tensor {

26 public:

27struct InsideDescribe;

28

30enum DimensionType {

32TENSORFLOW,

34CAFFE,

36 CAFFE_C4

37 };

38

40enum HandleDataType {

42 HANDLE_NONE = 0,

44 HANDLE_STRING = 1

45 };

46

48enum DataReorderType {

50 NO_REORDER = 0,

52 REORDER_4 = 1,

54 REORDER_8

55 };

56

57 public:

63Tensor(int dimSize = 4, DimensionType type = CAFFE);

64

72Tensor(const Tensor* tensor, DimensionType type = CAFFE, bool allocMemory = true);

73

75 ~Tensor();

76

77 private:

78// remove all assignment operator

79Tensor(const Tensor& tensor) = delete;

80Tensor(const Tensor&& tensor) = delete;

81Tensor& operator=(const Tensor&) = delete;

82Tensor& operator=(const Tensor&&) = delete;

83

84 public:

93static Tensor* createDevice(const std::vector<int>& shape, [halide_type_t](structhalide type t.html) type, DimensionType dimType = TENSORFLOW);

94

102template <typename T>

103static Tensor* createDevice(const std::vector<int>& shape, DimensionType dimType = TENSORFLOW) {

104return createDevice(shape, halide_type_of<T>(), dimType);

105 }

106

115static Tensor* create(const std::vector<int>& shape, [halide_type_t](structhalide type t.html) type, void* data = NULL,

116 DimensionType dimType = TENSORFLOW);

117

125template <typename T>

126static Tensor* create(const std::vector<int>& shape, void* data = NULL, DimensionType dimType = TENSORFLOW) {

127return create(shape, halide_type_of<T>(), data, dimType);

128 }

129

130 public:

136bool copyFromHostTensor(const Tensor* hostTensor);

137

143bool copyToHostTensor(Tensor* hostTensor) const;

144

151static Tensor* createHostTensorFromDevice(const Tensor* deviceTensor, bool copyData = true);

152

153 public:

154const [halide_buffer_t](structhalide buffer t.html)& buffer() const {

155return mBuffer;

156 }

157[halide_buffer_t](structhalide buffer t.html)& buffer() {

158return mBuffer;

159 }

160

165 DimensionType getDimensionType() const;

166

171 HandleDataType getHandleDataType() const;

172

177void setType(int type);

178

183inline [halide_type_t](structhalide type t.html) getType() const {

184return mBuffer.type;

185 }

186

191template <typename T>

192 T* host() const {

193return (T*)mBuffer.host;

194 }

195

200 uint64_t deviceId() const {

201return mBuffer.device;

202 }

203

204 public:

205int dimensions() const {

206return mBuffer.dimensions;

207 }

208

213 std::vector<int> shape() const;

214

219int size() const;

220

225inline int elementSize() const {

226return size() / mBuffer.type.bytes();

227 }

228

229 public:

230// for CAFFE tensors only.

231inline int width() const {

232return mBuffer.dim[3].extent;

233 }

234inline int height() const {

235return mBuffer.dim[2].extent;

236 }

237inline int channel() const {

238return mBuffer.dim[1].extent;

239 }

240inline int batch() const {

241return mBuffer.dim[0].extent;

242 }

243

244// for TENSORFLOW tensors only.

245inline int tfWidth() const {

246return mBuffer.dim[2].extent;

247 }

248inline int tfHeight() const {

249return mBuffer.dim[1].extent;

250 }

251inline int tfChannel() const {

252return mBuffer.dim[3].extent;

253 }

254inline int tfBatch() const {

255return mBuffer.dim[0].extent;

256 }

257

258// visit dimension's extent & stride

259inline int stride(int index) const {

260return mBuffer.dim[index].stride;

261 }

262inline int length(int index) const {

263return mBuffer.dim[index].extent;

264 }

265inline void setStride(int index, int stride) {

266 mBuffer.dim[index].stride = stride;

267 }

268inline void setLength(int index, int length) {

269 mBuffer.dim[index].extent = length;

270 }

271

272 public:

276void print() const;

277

278 private:

279[halide_buffer_t](structhalide buffer t.html) mBuffer;

280struct InsideDescribe* mDescribe;

281

282 private:

283friend class TensorUtils;

284 };

285 } // namespace MNN

286

287 #endif /* Tensor_hpp */

MNN::Tensor::elementSize

int elementSize() const

calculate number of elements needed to store data taking reordering flag into account.

Definition: Tensor.hpp:225

MNN::Tensor::batch

int batch() const

Definition: Tensor.hpp:240

MNN::Tensor::createDevice

static Tensor * createDevice(const std::vector< int > &shape, DimensionType dimType=TENSORFLOW)

create tensor with shape and dimension type. data type is represented by T.

Definition: Tensor.hpp:103

MNN::Tensor::CAFFE

Definition: Tensor.hpp:34

MNN::Tensor::DataReorderType

DataReorderType

Definition: Tensor.hpp:48

MNN::Tensor::tfBatch

int tfBatch() const

Definition: Tensor.hpp:254

MNN::Tensor::TENSORFLOW

Definition: Tensor.hpp:32

MNN::Tensor::width

int width() const

Definition: Tensor.hpp:231

MNN::Tensor::buffer

halide_buffer_t & buffer()

Definition: Tensor.hpp:157

MNN::Tensor::length

int length(int index) const

Definition: Tensor.hpp:262

MNN::Tensor::setLength

void setLength(int index, int length)

Definition: Tensor.hpp:268

MNN::Tensor::height

int height() const

Definition: Tensor.hpp:234

MNN::Tensor::tfWidth

int tfWidth() const

Definition: Tensor.hpp:245

MNN::Tensor::stride

int stride(int index) const

Definition: Tensor.hpp:259

HalideRuntime.h

MNN::Tensor

Definition: Tensor.hpp:25

MNN::Tensor::host

T * host() const

visit host memory, data type is represented by T.

Definition: Tensor.hpp:192

MNN_PUBLIC

#define MNN_PUBLIC

Definition: MNNDefine.h:53

MNN::Tensor::buffer

const halide_buffer_t & buffer() const

Definition: Tensor.hpp:154

MNN::Tensor::DimensionType

DimensionType

Definition: Tensor.hpp:30

MNN::Tensor::tfChannel

int tfChannel() const

Definition: Tensor.hpp:251

[halide_type_t](structhalide type t.html)

Definition: HalideRuntime.h:82

MNN::Tensor::getType

halide_type_t getType() const

get data type.

Definition: Tensor.hpp:183

MNN

Definition: AutoTime.hpp:16

MNN::Tensor::setStride

void setStride(int index, int stride)

Definition: Tensor.hpp:265

MNN::Tensor::dimensions

int dimensions() const

Definition: Tensor.hpp:205

MNN::Tensor::channel

int channel() const

Definition: Tensor.hpp:237

MNN::Tensor::tfHeight

int tfHeight() const

Definition: Tensor.hpp:248

[halide_buffer_t](structhalide buffer t.html)

Definition: HalideRuntime.h:203

MNN::Tensor::deviceId

uint64_t deviceId() const

visit device memory.

Definition: Tensor.hpp:200

MNNDefine.h

MNN::Tensor::create

static Tensor * create(const std::vector< int > &shape, void *data=NULL, DimensionType dimType=TENSORFLOW)

create tensor with shape, data and dimension type. data type is represented by T.

Definition: Tensor.hpp:126

MNN::Tensor::HandleDataType

HandleDataType

Definition: Tensor.hpp:40


制作者 1.8.15