FLImaging 6.5.16.1
ComputationalGraphAtrousConv2D.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalGraph.h"
6
7#include <vector>
8
9namespace FLImaging
10{
11 namespace AI
12 {
13 #ifdef CUDNN_MODE
14 template <typename T>
15 class CCuda_Conv2D_Cudnn;
16 #endif
17
18 template <typename T>
19 class FL_EXPORT CComputationalGraphAtrousConv2D : public CComputationalGraph<T>
20 {
21 private:
22 CComputationalGraphAtrousConv2D();
23
24 protected:
25 CComputationalGraphAtrousConv2D(const CComputationalGraphAtrousConv2D<T>& cg);
26
27 public:
28 // Operand1
29 // 4th dim : Batch size
30 // 3rd dim : Channels
31 // 2nd dim : Height
32 // 1st dim : Width
33
34 // Kernel
35 // 4th dim : Num of kernels
36 // 3rd dim : Channels
37 // 2nd dim : Height
38 // 1st dim : Width
39
40 // Bias
41 // 1st dim : Num of kernels
42 CComputationalGraphAtrousConv2D(const CComputationalBase<T>& cbOperand1, const CTensor<T>& tsrKernel, int64_t i64DilationX = 1, int64_t i64DilationY = 1, int64_t i64StrideX = 1, int64_t i64StrideY = 1, int64_t i64PaddingX = 0, int64_t i64PaddingY = 0);
43 virtual ~CComputationalGraphAtrousConv2D();
44
45 virtual CTensor<T>& Forward() override;
46 virtual CTensor<T>* Backward() override;
47 virtual CComputationalBase<T>* Clone() const override;
48 virtual const CResult PrintNodeParamInfo() const override;
49
50 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
51 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
52
53 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
54 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
55
56 virtual const std::vector<int64_t>& GetEstimatedShape(bool bRecursive = true) const override;
57
58 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
59 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
60
61
62 DeclareGetClassType();
63 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphAtrousConv2D, *this);
64
65 protected:
66 virtual const CResult Convolve2D(CTensor<T>* pTsrOperand, CTensor<T>* pTsrKernel, CTensor<T>* pTsrResult);
67 virtual const CResult Convolve2D_CUDNN(CTensor<T>* pTsrOperand, CTensor<T>* pTsrKernel, CTensor<T>* pTsrResult);
68 virtual const CResult DerivativeImage();
69 virtual const CResult DerivativeImage_CUDNN(bool bAddGradient);
70 virtual const CResult DerivativeKernel();
71 virtual const CResult DerivativeKernel_CUDNN(bool bAddGradient);
72
73 int64_t m_i64StrideX;
74 int64_t m_i64StrideY;
75 int64_t m_i64PaddingX;
76 int64_t m_i64PaddingY;
77 int64_t m_i64DilationX;
78 int64_t m_i64DilationY;
79 #ifdef CUDNN_MODE
80 CCuda_Conv2D_Cudnn<T>* m_pCudnn;
81 #endif
82
83 public:
84 DeclareGetSignletonObject(CComputationalGraphAtrousConv2D);
85 };
86
87 #define CCGFAtrousConv2D(...) (*(new CComputationalGraphAtrousConv2D<float>(__VA_ARGS__)))
88 #define CCGDAtrousConv2D(...) (*(new CComputationalGraphAtrousConv2D<double>(__VA_ARGS__)))
89
90 #define CCGTAtrousConv2D(T, ...) (*(new CComputationalGraphAtrousConv2D<T>(__VA_ARGS__)))
91 }
92}
93
94#endif