FLImaging 6.5.16.1
ComputationalGraphTransAtrousConv2D.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalGraph.h"
6#include "Tensor.h"
7
8#include <vector>
9
10namespace FLImaging
11{
12 namespace AI
13 {
14 #ifdef CUDNN_MODE
15 template<typename T>
16 class CCuda_ComputationalGraphTransConv2D_Cudnn;
17 #endif
18
19 template <typename T>
20 class FL_EXPORT CComputationalGraphTransAtrousConv2D : public CComputationalGraph<T>
21 {
22 private:
23 CComputationalGraphTransAtrousConv2D();
24
25 protected:
26 CComputationalGraphTransAtrousConv2D(const CComputationalGraphTransAtrousConv2D<T>& cg);
27
28 public:
29 // Operand1
30 // 4th dim : Batch size
31 // 3rd dim : Channels
32 // 2nd dim : Height
33 // 1st dim : Width
34
35 // Kernel
36 // 4th dim : Num of kernels
37 // 3rd dim : Channels
38 // 2nd dim : Height
39 // 1st dim : Width
40
41 // Bias
42 // 1st dim : Num of kernels
43 CComputationalGraphTransAtrousConv2D(const CComputationalBase<T>& cbOperand1, const CTensor<T>& tsrKernel, int64_t i64DilationX = 1, int64_t i64DilationY = 1, int64_t i64StrideX = 1, int64_t i64StrideY = 1, int64_t i64PaddingX = 0, int64_t i64PaddingY = 0);
44 virtual ~CComputationalGraphTransAtrousConv2D();
45
46 virtual CTensor<T>& Forward() override;
47 virtual CTensor<T>* Backward() override;
48 virtual CComputationalBase<T>* Clone() const override;
49 virtual const CResult PrintNodeParamInfo() const override;
50
51 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
52 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
53
54 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
55 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
56
57 virtual const std::vector<int64_t>& GetEstimatedShape(bool bRecursive = true) const override;
58
59 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
60 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
61
62
63 DeclareGetClassType();
64 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphTransAtrousConv2D, *this);
65
66 protected:
67 virtual const CResult TransAtrousConvolve2D(CTensor<T>* pTsrOperand, CTensor<T>* pTsrKernel, CTensor<T>* pTsrResult);
68 virtual const CResult TransAtrousConvolve2D_CUDNN(CTensor<T>* pTsrOperand, CTensor<T>* pTsrKernel, CTensor<T>* pTsrResult);
69 virtual const CResult DerivativeImage();
70 virtual const CResult DerivativeImage_CUDNN(bool bAddGradient);
71 virtual const CResult DerivativeKernel();
72 virtual const CResult DerivativeKernel_CUDNN(bool bAddGradient);
73
74 int64_t m_i64StrideX;
75 int64_t m_i64StrideY;
76 int64_t m_i64PaddingX;
77 int64_t m_i64PaddingY;
78 int64_t m_i64DilationX;
79 int64_t m_i64DilationY;
80
81 CTensor<T> m_tsrInputTranspose;
82 CTensor<T> m_tsrKernelTranspose;
83 CTensor<T> m_tsrDerivativeTranspose;
84
85 #ifdef CUDNN_MODE
86 CCuda_ComputationalGraphTransConv2D_Cudnn<T>* m_pCudnn;
87 #endif
88
89 public:
90 DeclareGetSignletonObject(CComputationalGraphTransAtrousConv2D);
91 };
92
93 #define CCGFTransAtrousConv2D(...) (*(new CComputationalGraphTransAtrousConv2D<float>(__VA_ARGS__)))
94 #define CCGDTransAtrousConv2D(...) (*(new CComputationalGraphTransAtrousConv2D<double>(__VA_ARGS__)))
95
96 #define CCGTTransAtrousConv2D(T, ...) (*(new CComputationalGraphTransAtrousConv2D<T>(__VA_ARGS__)))
97 }
98}
99
100#endif