3#if _MSC_VER >= 1900 && defined(_M_X64)
5#include "DefinitionsAI.h"
14 class CComputationalGraph;
17 class CComputationalGraphPlaceholder;
20 class CLearningRateSchedulerBase;
23 class CClassEqualizerBase;
26 class CComputationalGraphAugmentationBase;
39 CInternalOptimizerBase();
40 virtual ~CInternalOptimizerBase();
42 virtual const CResult
Assign(
const CInternalOptimizerBase<T>& ob);
43 virtual const CResult
Assign(
const CInternalOptimizerBase<T>* pOb);
45 virtual CInternalOptimizerBase<T>* Clone()
const = 0;
47 virtual const CResult Initialize();
48 virtual const CResult InitializeBatch();
49 virtual bool IsInitialized()
const;
51 virtual const CResult SetFunction(CComputationalGraph<T>& cgFunction);
52 virtual CComputationalGraph<T>& GetFunction();
54 virtual const CResult SetAugmentation(CComputationalGraph<T>* pCgAugmentation);
55 virtual CComputationalGraph<T>* GetAugmentation();
57 virtual const CResult Fit(T* pClippingThreshold=
nullptr, int32_t* pI32IterationFeedbackFeedback =
nullptr) = 0;
58 virtual const CResult BackPropagation();
59 virtual T Validate(int32_t* pI32IterationFeedback =
nullptr);
61 virtual const CTensor<T>& GetResult()
const;
63 virtual int64_t GetRequiredDedicatedMemory(int64_t i64BatchSize)
const;
65 virtual const CResult SetValidator(CValidatorBase<T>& vali);
66 virtual const CResult SetValidator(CValidatorBase<T>* pVali);
67 virtual const CValidatorBase<T>* GetValidator()
const;
69 virtual const std::vector<int64_t>& GetValidationIndices()
const;
71 virtual const CResult ResetDerivatives();
73 virtual const CResult ClipGradient(T* pClippingThreshold =
nullptr);
74 virtual const CResult UpdateWeights() = 0;
76 virtual const CResult EnableResetDerivatives(
bool bResetDerivatives);
77 virtual bool IsResetDerivativesEnabled()
const;
79 virtual int64_t GetTotalIteration()
const;
81 virtual const CResult InitializeMiniBatch();
83 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex);
84 virtual int32_t GetDeviceIndex()
const;
86 virtual const CResult SetOptimizerIndex(int32_t i32OptimizerIndex = 0);
87 virtual int32_t GetOptimizerIndex()
const;
89 virtual const CResult SetMemoryLimitRatio(
double f64MemoryLimitRatio = 1.);
90 virtual double GetMemoryLimitRatio()
const;
93 virtual void Throw(
const CResult& res,
const wchar_t* pWcsExtraMessage =
nullptr)
const override;
94 virtual const CResult InitializeFit();
95 virtual const CResult TerminateFit();
96 virtual const CResult TerminateMiniBatch();
98 virtual const CResult MakeBatchTensor();
100 void ClearVariablesAndDerivatives();
102 virtual const CResult InitializeCurrentIteration(int64_t i64SubdivisionIndex = 0);
103 virtual const CResult LoadTensor(int64_t i64SubdivisionIndex = 0);
105 std::vector<CComputationalBase<T>*>& m_vctInputList;
106 std::vector<CComputationalGraphPlaceholder<T>*>& m_vctPlaceholders;
107 std::vector<CTensor<T>*>& m_vctOrgTensors;
108 std::vector<int64_t>& m_vctLearnOrders;
109 std::vector<int64_t>& m_vctValidationOrders;
110 std::vector<CTensor<T>*>& m_vctBatchTensors;
112 CComputationalGraph<T>* m_pCgFunction;
113 CComputationalGraph<T>* m_pCgAugmentation;
114 std::vector<std::pair<CTensor<T>*, CTensor<T>*>>* m_pVctVariablesAndDerivatives;
116 bool m_bObjectAugmentationPrefeched;
117 bool m_bFitTerminated;
118 CResult m_resLoadTensorResult;
120 CTensor<T> m_tsrResult;
121 int64_t m_i64ResultCount;
123 bool m_bResetDerivatives;
129 int64_t m_i64PrevInputSize;
130 int64_t m_i64PrevPlaceholderSize;
131 std::vector<std::tuple<ENodeType, EDataType, EValueAttribute, ENodeOperator, std::vector<int64_t>>>& m_vctPrevInputInfo;
132 std::vector<std::vector<int64_t>>& m_vctPrevPlaceholderShape;
134 bool m_bActualUpdate;
136 int32_t m_i32DeviceCount;
137 int32_t m_i32DeviceIndex;
138 int32_t m_i32OptimizerIndex;
139 CValidatorBase<T>* m_pValidator;
141 std::vector<CTensor<T>*>& m_vctDelayedLoader;
143 int64_t m_i64MiniBatchSize;
144 int64_t m_i64Subdivision;
145 int64_t m_i64SubMiniBatchSize;
146 int64_t m_i64ActualMiniBatchSize;
147 int64_t m_i64TotalIteration;
148 int64_t m_i64CurrentIteration;
150 double m_f64MemoryLimitRatio;
152 volatile int64_t m_i64LoadCount;
153 volatile bool m_bLoadNextIteration;
154 volatile bool m_bEndSubdivision;
155 volatile bool m_bEndIteration;
156 COptimizer<T>* m_pParentOptimizer;
159 friend class CValidatorBase;
162 friend class CValidatorForClassifier;
165 friend class CValidatorForSemanticSegmentation;
168 friend class CValidatorForObjectDetection;
171 friend class CValidatorForAnomalyDetection;
174 friend class CValidatorForSuperResolution;
177 friend class CValidatorForInstanceSegmentation;
180 friend class CValidatorForStringBasedOCR;
183 friend class CValidatorForGAN;
186 friend class COptimizer;
189 typedef CInternalOptimizerBase<float> CInternalOptimizerBaseF;
190 typedef CInternalOptimizerBase<double> CInternalOptimizerBaseD;
Processing unit AI class required by algorithm.
Definition AlgorithmAIBase.h:27
Definition AlgorithmAIBase.h:18
@ Assign
Set the value of CGUIPropertyItemView3DFigure to the specified figure.
Definition DefinitionsGUIView3D.h:2930