FLImaging 6.6.27.1
ComputationalBase.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "DefinitionsAI.h"
6#include "ComputationalGraphUtilities.h"
7
8namespace FLImaging
9{
10 namespace AI
11 {
12 template <typename T>
13 class CTensor;
14
15 template <typename T>
16 class CComputationalGraph;
17
18 template <typename T>
19 class FL_EXPORT CComputationalBase : public CAlgorithmAIBase
20 {
21
22 public:
23 CComputationalBase();
24 virtual ~CComputationalBase();
25
26 virtual bool IsInitialized() const;
27 virtual const CResult SetInitialized(bool bSet);
28
29 virtual const CResult Load(const wchar_t* pWcsFileName);
30 virtual const CResult Save(const wchar_t* pWcsFileName);
31
32 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const;
33 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const;
34
35 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr);
36 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr);
37
38 virtual const CResult Clear();
39
40 const int64_t GetObjectNumber() const;
41
42 CComputationalBase<T>& SetID(const wchar_t* pWcsName);
43 CComputationalBase<T>& ID(const wchar_t* pWcsName);
44 const wchar_t* GetID() const;
45
46 virtual void Throw(const CResult& res, const wchar_t* pWcsExtraMessage = nullptr) const override;
47
48 CComputationalBase<T>& EnableReference(bool bReference);
49 bool IsReference() const;
50
51 const CResult SetValueAttribute(EValueAttribute eType);
52 EValueAttribute GetValueAttribute() const;
53
54 ENodeType GetNodeType() const;
55 EDataType GetDataType() const;
56 ENodeOperator GetNodeOperator() const;
57 virtual const std::vector<int64_t>& GetEstimatedShape(bool bRecursive = true) const;
58 virtual const std::vector<int64_t>& GetShape() const;
59 virtual const std::vector<int64_t>& GetShapeAsc() const;
60 virtual const CResult ClearShape(bool bRecursive = false);
61 virtual int64_t GetNextBatchSize(int64_t i64BatchSize) const;
62
63 virtual CTensor<T>& Evaluate() = 0;
64 virtual CTensor<T>& Forward() = 0;
65 virtual CTensor<T>* Backward() = 0;
66
67 virtual CTensor<T>& GetValue() const;
68 virtual CTensor<T>* GetDerivative() const;
69 virtual void ClearDerivativesRecursive();
70 virtual void ResetDerivativesRecursive();
71
72 virtual CComputationalBase<T>* Clone() const = 0;
73
74 virtual const CResult Swap(CComputationalBase<T>& cbSwap);
75
76 virtual const CComputationalBase<T>* GetAt(const wchar_t* pWcsName) const = 0;
77 virtual const CResult FindByValueAttribute(EValueAttribute eValueAttribute, std::vector<const CComputationalBase<T>*>& vctResult) const = 0;
78
79 virtual CComputationalBase<T>* GetOperand(int64_t i64Index) const;
80 virtual bool IsIntrinsicOperand(int64_t i64Index) const;
81 virtual int64_t GetOperandCount() const;
82
83 virtual int64_t GetGeneration() const;
84
85 virtual bool IsTrainingModeEnabled() const;
86 virtual const CResult EnableTrainingMode(bool bMode, bool bRecursively = true);
87
88 virtual const CResult PrintGraphInfo(bool bRecursively = true, bool bShapeOrderAsc = false, bool bIncludeTensors = false) const;
89 virtual const CResult PrintNodeParamInfo() const;
90
91 virtual bool IsTensorCoreAvailable() const;
92
93 virtual int64_t GetAddGradientCount() const;
94 virtual const CResult SetAddGradientCount(int64_t i64AddGradientCount);
95
96 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const = 0;
97 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const = 0;
98
99 virtual const CResult EnableRetainedValue(bool bRetained);
100 virtual bool IsRetainedEnabled() const;
101
102 virtual bool IsInplace() const;
103
104 virtual int64_t GetParentNodeCount() const;
105 virtual CComputationalBase<T>* GetParentNode(int64_t i64Index) const;
106
107 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex);
108 virtual int32_t GetDeviceIndex() const;
109
110 virtual const CResult CreateValue();
111 virtual const CResult ClearValue();
112
113 virtual const CResult CreateDerivative();
114 virtual const CResult ClearDerivative();
115
116 const CResult SetValuePtr(CTensor<T>* pTsrValue);
117 const CResult SetDerivativePtr(CTensor<T>* pTsrDerivative);
118
119 virtual CTensor<T>* GetValuePtr();
120
121 static CComputationalBase<T>* GetSingletonObject();
122
123 virtual const CResult SetProcessingUnit(const Base::CProcessingUnitBase& puBase) override;
124 virtual const CResult SetProcessingUnit(const Base::CProcessingUnitBase* pPuBase) override;
125
126 SupportToDuplicateAbstractObject(CComputationalBase);
127 protected:
128 const CResult InternalAssign(const CComputationalBase<T>& cb);
129
130 virtual const CResult BeginForwardPerformanceCheck();
131 virtual const CResult EndForwardPerformanceCheck();
132
133 virtual const CResult BeginBackwardPerformanceCheck();
134 virtual const CResult EndBackwardPerformanceCheck();
135
136 bool m_bInitialized;
137 bool m_bReference;
138 ENodeType m_eNodeType;
139 EDataType m_eDataType;
140 EValueAttribute m_eValueAttribute;
141 ENodeOperator m_eNodeOperator;
142 wchar_t* m_pWcsName;
143 bool m_bTrainingMode;
144 int64_t m_i64ObjectNumber;
145 bool m_bRetained;
146 bool m_bInplace;
147 int64_t m_i64AddGradientCount;
148 int64_t m_i64Generation;
149 int32_t m_i32DeviceIndex;
150
151 CTensor<T>* m_pTsrValue;
152 CTensor<T>* m_pTsrDerivative;
153
154 Base::CPerformanceCounter m_perfForward;
155 Base::CPerformanceCounter m_perfBackward;
156
157 std::vector<int64_t>& m_vctShape;
158 std::vector<int64_t>& m_vctShapeAsc;
159
160 std::vector<CComputationalBase<T>*>& m_vctParentNodes;
161
162 private:
163 friend class CComputationalGraph<T>;
164 friend class CComputationalGraphUtilities<T>;
165 };
166
167 typedef CComputationalBase<float> CComputationalBaseF;
168 typedef CComputationalBase<double> CComputationalBaseD;
169
170 typedef CComputationalBase<float> CCBF;
171 typedef CComputationalBase<double> CCBD;
172
173 template <typename T>
174 using CCB = CComputationalBase<T>;
175 }
176}
177
178#endif
알고리즘에서 필요한 프로세싱 유닛 AI 클래스
Definition AlgorithmAIBase.h:25
Definition AlgorithmAIBase.h:16
@ Clear
도형 정리 메뉴
Definition DefinitionsGUI.h:2058
@ Save
저장 메뉴
Definition DefinitionsGUI.h:303
@ Load
불러오기
Definition DefinitionsGUI.h:50