FLImaging 6.12.9.2
ComputationalGraphCoordConv2D.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalGraphConv2D.h"
6
7#include <vector>
8
9namespace FLImaging
10{
11 namespace AI
12 {
13 template <typename T>
14 class FL_EXPORT CComputationalGraphCoordConv2D : public CComputationalGraphConv2D<T>
15 {
16 private:
17 CComputationalGraphCoordConv2D();
18
19 protected:
20 CComputationalGraphCoordConv2D(const CComputationalGraphCoordConv2D<T>& cg);
21
22 public:
23 // Operand1
24 // 4th dim : Batch size
25 // 3rd dim : Channels
26 // 2nd dim : Height
27 // 1st dim : Width
28
29 // Kernel
30 // 4th dim : Num of kernels
31 // 3rd dim : Channels
32 // 2nd dim : Height
33 // 1st dim : Width
34
35 CComputationalGraphCoordConv2D(const CComputationalBase<T>& cbOperand1, const CTensor<T>& tsrKernel, int64_t i64StrideX = 1, int64_t i64StrideY = 1, int64_t i64PaddingX = 0, int64_t i64PaddingY = 0, int64_t i64DilationX = 1, int64_t i64DilationY = 1, int64_t i64GroupCount = 1, bool bEnableRadius = false);
36 virtual ~CComputationalGraphCoordConv2D();
37
38 virtual const CResult MakeCoordTensor();
39 virtual const CResult ExtractImageDer(CTensor<T>& tsrDerCoordPlus, bool bAddGradient);
40
41 virtual CTensor<T>& Forward() override;
42 virtual CTensor<T>* Backward() override;
43 virtual CComputationalBase<T>* Clone() const override;
44
45 virtual const std::vector<int64_t>& GetEstimatedShape(bool bRecursive = true) const override;
46 virtual const CResult PrintNodeParamInfo() const override;
47
48 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
49 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
50
51 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
52 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
53
54 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
55
56 DeclareGetClassType();
57 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphCoordConv2D, *this);
58
59 protected:
60 bool m_bEnableRadius;
61
62 CTensor<T> m_tsrConcatenateBuffer;
63 CTensor<T> m_tsrConcatenateDerivativeTemp;
64
65 public:
66 DeclareGetSignletonObject(CComputationalGraphCoordConv2D);
67 };
68
69 #define CCGFCoordConv2D(...) (*(new CComputationalGraphCoordConv2D<float>(__VA_ARGS__)))
70 #define CCGDCoordConv2D(...) (*(new CComputationalGraphCoordConv2D<double>(__VA_ARGS__)))
71
72 #define CCGTCoordConv2D(T, ...) (*(new CComputationalGraphCoordConv2D<T>(__VA_ARGS__)))
73 }
74}
75#endif
Definition AlgorithmAIBase.h:18