FLImaging 7.3.20.1
InternalOptimizerBase.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "DefinitionsAI.h"
6#include "Tensor.h"
7#include <vector>
8
9namespace FLImaging
10{
11 namespace AI
12 {
13 template <typename T>
14 class CComputationalGraph;
15
16 template <typename T>
17 class CComputationalGraphPlaceholder;
18
19 template <typename T>
20 class CLearningRateSchedulerBase;
21
22 template <typename T>
23 class CClassEqualizerBase;
24
25 template <typename T>
26 class CComputationalGraphAugmentationBase;
27
28 template <typename T>
29 class CValidatorBase;
30
31 template <typename T>
32 class COptimizer;
33
34 template <typename T>
35 class FL_EXPORT CInternalOptimizerBase : public CAlgorithmAIBase
36 {
37 public:
38
39 CInternalOptimizerBase();
40 virtual ~CInternalOptimizerBase();
41
42 virtual const CResult Assign(const CInternalOptimizerBase<T>& ob);
43 virtual const CResult Assign(const CInternalOptimizerBase<T>* pOb);
44
45 virtual CInternalOptimizerBase<T>* Clone() const = 0;
46
47 virtual const CResult Initialize();
48 virtual const CResult InitializeBatch();
49 virtual bool IsInitialized() const;
50
51 virtual const CResult SetFunction(CComputationalGraph<T>& cgFunction);
52 virtual CComputationalGraph<T>& GetFunction();
53
54 virtual const CResult SetAugmentation(CComputationalGraph<T>* pCgAugmentation);
55 virtual CComputationalGraph<T>* GetAugmentation();
56
57 virtual const CResult Fit(T* pClippingThreshold= nullptr, int32_t* pI32IterationFeedbackFeedback = nullptr) = 0;
58 virtual const CResult BackPropagation();
59 virtual T Validate(int32_t* pI32IterationFeedback = nullptr);
60
61 virtual const CTensor<T>& GetResult() const;
62
63 virtual int64_t GetRequiredDedicatedMemory(int64_t i64BatchSize) const;
64
65 virtual const CResult SetValidator(CValidatorBase<T>& vali);
66 virtual const CResult SetValidator(CValidatorBase<T>* pVali);
67 virtual const CValidatorBase<T>* GetValidator() const;
68
69 virtual const std::vector<int64_t>& GetValidationIndices() const;
70
71 virtual const CResult ResetDerivatives();
72
73 virtual const CResult ClipGradient(T* pClippingThreshold = nullptr);
74 virtual const CResult UpdateWeights() = 0;
75
76 virtual const CResult EnableResetDerivatives(bool bResetDerivatives);
77 virtual bool IsResetDerivativesEnabled() const;
78
79 virtual int64_t GetTotalIteration() const;
80
81 virtual const CResult InitializeMiniBatch();
82
83 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex);
84 virtual int32_t GetDeviceIndex() const;
85
86 virtual const CResult SetOptimizerIndex(int32_t i32OptimizerIndex = 0);
87 virtual int32_t GetOptimizerIndex() const;
88
89 virtual const CResult SetMemoryLimitRatio(double f64MemoryLimitRatio = 1.);
90 virtual double GetMemoryLimitRatio() const;
91
92 virtual const COptimizer<T>* GetParentOptimizer() const;
93
94 protected:
95 virtual void Throw(const CResult& res, const wchar_t* pWcsExtraMessage = nullptr) const override;
96 virtual const CResult InitializeFit();
97 virtual const CResult TerminateFit();
98 virtual const CResult TerminateMiniBatch();
99
100 virtual const CResult MakeBatchTensor();
101
102 void ClearVariablesAndDerivatives();
103
104 virtual const CResult InitializeCurrentIteration(int64_t i64SubdivisionIndex = 0);
105 virtual const CResult LoadTensor(int64_t i64SubdivisionIndex = 0);
106 protected:
107 std::vector<CComputationalBase<T>*>& m_vctInputList;
108 std::vector<CComputationalGraphPlaceholder<T>*>& m_vctPlaceholders;
109 std::vector<CTensor<T>*>& m_vctOrgTensors;
110 std::vector<int64_t>& m_vctLearnOrders;
111 std::vector<int64_t>& m_vctValidationOrders;
112 std::vector<CTensor<T>*>& m_vctBatchTensors;
113
114 CComputationalGraph<T>* m_pCgFunction;
115 CComputationalGraph<T>* m_pCgAugmentation;
116 std::vector<std::pair<CTensor<T>*, CTensor<T>*>>* m_pVctVariablesAndDerivatives;
117 bool m_bInitialized;
118 bool m_bObjectAugmentationPrefeched;
119 bool m_bFitTerminated;
120 CResult m_resLoadTensorResult;
121
122 CTensor<T> m_tsrResult;
123 int64_t m_i64ResultCount;
124
125 bool m_bResetDerivatives;
126
127 T m_tLearningRate;
128
129 //
130 bool m_bDifferent;
131 int64_t m_i64PrevInputSize;
132 int64_t m_i64PrevPlaceholderSize;
133 std::vector<std::tuple<ENodeType, EDataType, EValueAttribute, ENodeOperator, std::vector<int64_t>>>& m_vctPrevInputInfo;
134 std::vector<std::vector<int64_t>>& m_vctPrevPlaceholderShape;
135
136 bool m_bActualUpdate;
137
138 int32_t m_i32DeviceCount;
139 int32_t m_i32DeviceIndex;
140 int32_t m_i32OptimizerIndex;
141 CValidatorBase<T>* m_pValidator;
142
143 std::vector<CTensor<T>*>& m_vctDelayedLoader;
144
145 int64_t m_i64MiniBatchSize;
146 int64_t m_i64Subdivision;
147 int64_t m_i64SubMiniBatchSize;
148 int64_t m_i64ActualMiniBatchSize;
149 int64_t m_i64TotalIteration;
150 int64_t m_i64CurrentIteration;
151
152 double m_f64MemoryLimitRatio;
153
154 volatile int64_t m_i64LoadCount;
155 volatile bool m_bLoadNextIteration;
156 volatile bool m_bEndSubdivision;
157 volatile bool m_bEndIteration;
158 COptimizer<T>* m_pParentOptimizer;
159 private:
160 template<typename T>
161 friend class CValidatorBase;
162
163 template<typename T>
164 friend class CValidatorForClassifier;
165
166 template<typename T>
167 friend class CValidatorForSemanticSegmentation;
168
169 template<typename T>
170 friend class CValidatorForObjectDetection;
171
172 template<typename T>
173 friend class CValidatorForAnomalyDetection;
174
175 template<typename T>
176 friend class CValidatorForSuperResolution;
177
178 template<typename T>
179 friend class CValidatorForInstanceSegmentation;
180
181 template<typename T>
182 friend class CValidatorForStringBasedOCR;
183
184 template<typename T>
185 friend class CValidatorForGAN;
186
187 template<typename T>
188 friend class COptimizer;
189 };
190
191 typedef CInternalOptimizerBase<float> CInternalOptimizerBaseF;
192 typedef CInternalOptimizerBase<double> CInternalOptimizerBaseD;
193 }
194}
195
196#endif
Processing unit AI class required by algorithm.
Definition AlgorithmAIBase.h:27
Definition AlgorithmAIBase.h:18
@ Assign
Set the value of CGUIPropertyItemView3DFigure to the specified figure.
Definition DefinitionsGUIView3D.h:2930