FLImaging 6.5.16.1
ComputationalGraphTransConv2D.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalGraph.h"
6#include "Tensor.h"
7#include "BackendConv2D.h"
8
9#include <vector>
10
11namespace FLImaging
12{
13 namespace AI
14 {
15 #ifdef CUDNN_MODE
16 template<typename T>
17 class CCuda_ComputationalGraphTransConv2D_Cudnn;
18 #endif
19
20 template <typename T>
21 class FL_EXPORT CComputationalGraphTransConv2D : public CComputationalGraph<T>
22 {
23 private:
24 CComputationalGraphTransConv2D();
25
26 protected:
27 CComputationalGraphTransConv2D(const CComputationalGraphTransConv2D<T>& cg);
28
29 public:
30 // Operand1
31 // 4th dim : Batch size
32 // 3rd dim : Channels
33 // 2nd dim : Height
34 // 1st dim : Width
35
36 // Kernel
37 // 4th dim : Num of kernels
38 // 3rd dim : Channels
39 // 2nd dim : Height
40 // 1st dim : Width
41
42 // Bias
43 // 1st dim : Num of kernels
44 CComputationalGraphTransConv2D(const CComputationalBase<T>& cbOperand1, const CTensor<T>& tsrKernel, int64_t i64StrideX = 1, int64_t i64StrideY = 1, int64_t i64PaddingX = 0, int64_t i64PaddingY = 0, int64_t i64OutputPaddingX = 0, int64_t i64OutputPaddingY = 0, int64_t i64DilationX = 1, int64_t i64DilationY = 1);
45 virtual ~CComputationalGraphTransConv2D();
46
47 virtual CTensor<T>& Forward() override;
48 virtual CTensor<T>* Backward() override;
49 virtual CComputationalBase<T>* Clone() const override;
50 virtual const CResult PrintNodeParamInfo() const override;
51
52 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
53 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
54
55 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
56 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
57
58 virtual const std::vector<int64_t>& GetEstimatedShape(bool bRecursive = true) const override;
59
60 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
61 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
62
63
64 DeclareGetClassType();
65 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphTransConv2D, *this);
66
67 protected:
68 /*
69 virtual const CResult DerivativeKernel(bool bAddGradient);
70 virtual const CResult DerivativeImage(bool bAddGradient);
71 */
72 int64_t m_i64StrideX;
73 int64_t m_i64StrideY;
74 int64_t m_i64PaddingX;
75 int64_t m_i64PaddingY;
76 int64_t m_i64OutputPaddingX;
77 int64_t m_i64OutputPaddingY;
78 int64_t m_i64DilationX;
79 int64_t m_i64DilationY;
80
81 CTensor<T> m_tsrPaddingX;
82 CTensor<T> m_tsrPaddingY;
83 CTensor<T> m_tsrInputTranspose;
84 CTensor<T> m_tsrKernelTranspose;
85 CTensor<T> m_tsrDerivativeTranspose;
86
87 CBackendConv2D<T> m_backendConv2D;
88
89 #ifdef CUDNN_MODE
90 #ifdef CUDNN_MODE
91 CCuda_ComputationalGraphTransConv2D_Cudnn<T>* m_pCudnn;
92 #endif
93 #endif
94
95 public:
96 DeclareGetSignletonObject(CComputationalGraphTransConv2D);
97 };
98
99 #define CCGFTransConv2D(...) (*(new CComputationalGraphTransConv2D<float>(__VA_ARGS__)))
100 #define CCGDTransConv2D(...) (*(new CComputationalGraphTransConv2D<double>(__VA_ARGS__)))
101
102 #define CCGTTransConv2D(T, ...) (*(new CComputationalGraphTransConv2D<T>(__VA_ARGS__)))
103 }
104}
105
106#endif