FLImaging 6.8.21.2
ComputationalGraphInvolution.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalGraph.h"
6#include "Tensor.h"
7#include "BackendConv2D.h"
8#include "BackendFold.h"
9#include "BackendAdaptiveAvgPool2D.h"
10#include "BackendBatchNorm2D.h"
11#include "BackendReLU.h"
12
13#include <vector>
14
15namespace FLImaging
16{
17 namespace AI
18 {
19 template <typename T>
20 class FL_EXPORT CComputationalGraphInvolution : public CComputationalGraph<T>
21 {
22 private:
23 CComputationalGraphInvolution();
24
25 protected:
26 CComputationalGraphInvolution(const CComputationalGraphInvolution<T>& cg);
27
28 public:
29 // Operand1
30 // 4th dim : Batch size
31 // 3rd dim : Channels
32 // 2nd dim : Height
33 // 1st dim : Width
34 //
35 // Operand2 (Channel Reduce Kernel)
36 // 4th dim : Channels / Reduction ratio
37 // 3rd dim : Channels
38 // 2nd dim : 1
39 // 1st dim : 1
40 //
41 // Operand3 (Channel to spatial Kernel)
42 // 4th dim : KernelX * KernelY
43 // 3rd dim : Channels / Reduction ratio
44 // 2nd dim : 1
45 // 1st dim : 1
46
47 CComputationalGraphInvolution(const CComputationalBase<T>& cbOperand1, const CTensor<T>& tsrKernel1, const CTensor<T>& tsrKernel2, int64_t i64KernelX = 3, int64_t i64KernelY = 3, int64_t i64StrideX = 1, int64_t i64StrideY = 1, int64_t i64GroupCount = 1, T tEpsilon = (T)FL_EPSILON_FLOAT, T tMomentum = (T).1, bool bAffine = true);
48 virtual ~CComputationalGraphInvolution();
49
50 virtual CTensor<T>& Forward() override;
51 virtual CTensor<T>* Backward() override;
52 virtual CComputationalBase<T>* Clone() const override;
53
54 virtual const std::vector<int64_t>& GetEstimatedShape(bool bRecursive = true) const override;
55 virtual const CResult PrintNodeParamInfo() const override;
56
57 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
58 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
59
60 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
61 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
62
63 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
64 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
65
66 DeclareGetClassType();
67 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphInvolution, *this);
68
69 protected:
70 const CResult CheckShape();
71 const CResult GetCommonValues(std::vector<int64_t>& vctPooledShape, std::vector<int64_t>& vctReduceOutputShape, std::vector<int64_t>& vctSpatialOutputShape, bool& bPoolSkip) const;
72 const CResult InvolutionMainForward(const CTensor<T>* pTsrOperand, const CTensor<T>* pTsrKernel, CTensor<T>* pTsrResult, const std::vector<int64_t> vctResultShape);
73 const CResult InvolutionMainBackwardImage(const CTensor<T>* pTsrDy, const CTensor<T>* pTsrW, CTensor<T>* pTsrDx, const std::vector<int64_t> vctXShape, bool bAddGradient, CTensor<T>* pTsrDerivativeTemp);
74 const CResult InvolutionMainBackwardKernel(const CTensor<T>* pTsrDy, const CTensor<T>* pTsrX, CTensor<T>* pTsrDw, std::vector<int64_t> vctWShape);
75
76 int64_t m_i64KernelX;
77 int64_t m_i64KernelY;
78 int64_t m_i64StrideX;
79 int64_t m_i64StrideY;
80 int64_t m_i64GroupCount;
81 T m_tEpsilon;
82 T m_tMomentum;
83 bool m_bAffine;
84
85 CTensor<T> m_tsrPoolVal;
86 CTensor<T> m_tsrKernelBuffer;
87 CTensor<T> m_tsrReduceVal;
88 CTensor<T> m_tsrBNVarienceBuffer;
89 CTensor<T> m_tsrBNMeanBuffer;
90 CTensor<T> m_tsrBNSaveWeightGradient;
91 CTensor<T> m_tsrBNSaveBiasGradient;
92 CTensor<T> m_tsrReLUVal;
93 CTensor<T> m_tsrMulKernel;
94
95 CBackendAdaptiveAvgPool2D<T> m_backendAdaptiveAvgPool2D;
96 CBackendConv2D<T> m_backendConv2DReduce;
97 CBackendConv2D<T> m_backendConv2DSpatial;
98 CBackendBatchNorm2D<T> m_backendBatchNorm2D;
99 CBackendReLU<T> m_backendReLU;
100
101 public:
102 DeclareGetSignletonObject(CComputationalGraphInvolution);
103 };
104
105 #define CCGFInvolution(...) (*(new CComputationalGraphInvolution<float>(__VA_ARGS__)))
106 #define CCGDInvolution(...) (*(new CComputationalGraphInvolution<double>(__VA_ARGS__)))
107
108 #define CCGTInvolution(T, ...) (*(new CComputationalGraphInvolution<T>(__VA_ARGS__)))
109 }
110}
111
112#endif