FLImaging 6.12.9.2
ComputationalGraph.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalBase.h"
6#include "Tensor.h"
7
8namespace FLImaging
9{
10 namespace AI
11 {
12 template <typename T>
13 class CTensor;
14
15 template <typename T>
16 class CBackendBase;
17
18 template <typename T>
19 class COptimizer;
20
21 template <typename T>
22 class FL_EXPORT CComputationalGraph : public CComputationalBase<T>
23 {
24
25 public:
26 CComputationalGraph();
27 protected:
28 CComputationalGraph(const CComputationalGraph<T>& cg);
29
30 public:
31 virtual ~CComputationalGraph();
32
33 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
34 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
35
36 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
37 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
38
39 const CResult Assign(const CComputationalGraph<T>& cg);
40 virtual const CResult Clear() override;
41
42 virtual CTensor<T>& Evaluate() override;
43 virtual CTensor<T>& Forward() = 0;
44 virtual CTensor<T>* Backward() = 0;
45
46 virtual void ClearDerivativesRecursive() override;
47 virtual void ResetDerivativesRecursive() override;
48
49 virtual CComputationalBase<T>* Clone() const = 0;
50
51 virtual const CResult Swap(CComputationalBase<T>& cbSwap) override;
52
53 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex) override;
54
55 virtual const CComputationalBase<T>* GetAt(const wchar_t* pWcsName) const override;
56 virtual const CResult FindByValueAttribute(EValueAttribute eValueAttribute, std::vector<const CComputationalBase<T>*>& vctResult) const override;
57 virtual const CResult FindByNodeOperator(ENodeOperator eNodeOperator, std::vector<const CComputationalBase<T>*>& vctResult) const;
58
59 virtual const CResult SetOperand(int64_t i64Index, const CComputationalBase<T>* pCbOperand, bool bSetSourceReferenceType = false);
60 virtual const CResult SetOperand(const std::vector<const CComputationalBase<T>*>& cbOperands, bool bSetSourceReferenceType = false);
61 virtual CComputationalBase<T>* GetOperand(int64_t i64Index) const;
62 virtual bool IsIntrinsicOperand(int64_t i64Index) const;
63 virtual int64_t GetOperandCount() const override;
64
65 virtual CTensor<T>* GetMemberTensor(int64_t i64Index) const;
66 virtual int64_t GetMemberTensorCount() const;
67
68 virtual CBackendBase<T>* GetMemberBackend(int64_t i64Index) const;
69 virtual int64_t GetMemberBackendCount() const;
70
71 virtual int64_t GetNextBatchSize(int64_t i64BatchSize) const override;
72 virtual const CResult EnableTrainingMode(bool bMode, bool bRecursively = true) override;
73
74 virtual const CResult PrintGraphInfo(bool bRecursively = true, bool bShapeOrderAsc = false, bool bIncludeTensors = false) const override;
75
76 virtual const CResult SetParentOptimizer(COptimizer<T>* pParentOptimizer, bool bRecursive = true);
77 virtual const COptimizer<T>* GetParentOptimizer() const;
78
79 virtual int64_t GetClassType() const = 0;
80
81 virtual int64_t GetOperandIndexForNextBatchSize() const;
82 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
83 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
84
85 virtual const CTensor<T>& GetDerivativeTemp() const;
86
87 virtual int64_t GetSharedMemoryCount() const;
88 virtual const CResult EnableTensorCore(bool bTensorCore = true) override;
89
90 SupportToDuplicateAbstractObject(CComputationalGraph);
91 protected:
92 virtual const CResult InternalAssign(const CComputationalGraph<T>& cg);
93
94 virtual const CResult InternalConnectNode(const CComputationalBase<T>& cbOperand1);
95 virtual const CResult InternalConnectNode(const CComputationalBase<T>& cbOperand1, const CComputationalBase<T>& cbOperand2);
96 virtual const CResult InternalConnectNode(const CComputationalBase<T>& cbOperand1, const CComputationalBase<T>& cbOperand2, const CComputationalBase<T>& cbOperand3);
97 virtual const CResult InternalConnectNode(const std::initializer_list<const CComputationalBase<T>*>& ilOperands);
98 virtual const CResult InternalConnectNode(const std::vector<const CComputationalBase<T>*>& vctOperands);
99
100 virtual const CResult InternalCreateOperandsNode(int64_t i64OperandsCount);
101
102 virtual const CResult ClearParentNode();
103 virtual const CResult RemoveCurrentParentNode();
104 virtual const CResult UpdateParentNode();
105
106 virtual const CResult ClearOperand();
107 virtual const CResult UpdateGeneration(bool bRecursive = false);
108
109 enum EMaxCount
110 {
111 EMaxCount_Operands = 3,
112 EMaxCount_Derivatives = 3,
113 };
114
115 virtual bool AreOperatorsValid() const;
116 virtual bool AreOperatorsAndTheirValuesValid();
117
118 virtual bool InitializeDerivativeVariable();
119
120 CBackendBase<T>** m_ppMemberBackend;
121 CTensor<T>** m_ppTsrMemberTensor;
122 int64_t m_i64MemberTensorCount;
123 int64_t m_i64MemberBackendCount;
124
125 int64_t m_i64SharedMemoryCount;
126
127 CComputationalBase<T>** m_pCbOperands;
128 bool* m_pArrBIntrinsicOperand;
129 int64_t m_i64OperandsCount;
130 CTensor<T> m_tsrDerivativeTemp;
131 COptimizer<T>* m_pParentOptimizer;
132 private:
133 friend class CComputationalGraphUtilities<T>;
134 friend class COptimizer<T>;
135 };
136
137 typedef CComputationalGraph<float> CComputationalGraphF;
138 typedef CComputationalGraph<double> CComputationalGraphD;
139
140 typedef CComputationalGraph<float> CCGF;
141 typedef CComputationalGraph<double> CCGD;
142
143 template <typename T>
144 using CCG = CComputationalGraph<T>;
145 }
146}
147
148#endif
Definition AlgorithmAIBase.h:18
@ Assign
CGUIPropertyItemView3DFigure 의 값을 해당 도형으로 설정하는 함수
Definition DefinitionsGUIView3D.h:2930
@ Clear
도형 정리 메뉴
Definition DefinitionsGUI.h:2110