3#if _MSC_VER >= 1900 && defined(_M_X64)
5#include "DefinitionsAI.h"
6#include "ComputationalGraphUtilities.h"
16 class CComputationalGraph;
24 virtual ~CComputationalBase();
26 virtual bool IsInitialized()
const;
27 virtual const CResult SetInitialized(
bool bSet);
29 virtual const CResult
Load(
const wchar_t* pWcsFileName);
30 virtual const CResult
Save(
const wchar_t* pWcsFileName);
32 virtual const CResult GetBinaryData(Base::CFLData& fldBinary,
bool bSuperClass =
false, int32_t i32Version = -1,
bool bDumpMode =
false)
const;
33 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary,
bool bSuperClass =
false, int32_t i32Version = -1,
bool bDumpMode =
false)
const;
35 virtual const CResult SetBinaryData(
const Base::CFLData& fldBinary, int64_t* pI64Offset =
nullptr);
36 virtual const CResult SetBinaryData(
const Base::CFLData* pFldBinary, int64_t* pI64Offset =
nullptr);
38 virtual const CResult
Clear();
40 const int64_t GetObjectNumber()
const;
42 CComputationalBase<T>& SetID(
const wchar_t* pWcsName);
43 CComputationalBase<T>& ID(
const wchar_t* pWcsName);
44 const wchar_t* GetID()
const;
46 virtual void Throw(
const CResult& res,
const wchar_t* pWcsExtraMessage =
nullptr)
const override;
48 CComputationalBase<T>& EnableReference(
bool bReference);
49 bool IsReference()
const;
51 const CResult SetValueAttribute(EValueAttribute eType);
52 EValueAttribute GetValueAttribute()
const;
54 ENodeType GetNodeType()
const;
55 EDataType GetDataType()
const;
56 ENodeOperator GetNodeOperator()
const;
57 virtual const std::vector<int64_t>& GetEstimatedShape(
bool bRecursive =
true)
const;
58 virtual const std::vector<int64_t>& GetShape()
const;
59 virtual const std::vector<int64_t>& GetShapeAsc()
const;
60 virtual const CResult ClearShape(
bool bRecursive =
false);
61 virtual int64_t GetNextBatchSize(int64_t i64BatchSize)
const;
63 virtual CTensor<T>& Evaluate() = 0;
64 virtual CTensor<T>& Forward() = 0;
65 virtual CTensor<T>* Backward() = 0;
67 virtual CTensor<T>& GetValue()
const;
68 virtual CTensor<T>* GetDerivative()
const;
69 virtual void ClearDerivativesRecursive();
70 virtual void ResetDerivativesRecursive();
72 virtual CComputationalBase<T>* Clone()
const = 0;
74 virtual const CResult Swap(CComputationalBase<T>& cbSwap);
76 virtual const CComputationalBase<T>* GetAt(
const wchar_t* pWcsName)
const = 0;
77 virtual const CResult FindByValueAttribute(EValueAttribute eValueAttribute, std::vector<
const CComputationalBase<T>*>& vctResult)
const = 0;
79 virtual CComputationalBase<T>* GetOperand(int64_t i64Index)
const;
80 virtual bool IsIntrinsicOperand(int64_t i64Index)
const;
81 virtual int64_t GetOperandCount()
const;
83 virtual int64_t GetGeneration()
const;
85 virtual bool IsTrainingModeEnabled()
const;
86 virtual const CResult EnableTrainingMode(
bool bMode,
bool bRecursively =
true);
88 virtual const CResult PrintGraphInfo(
bool bRecursively =
true,
bool bShapeOrderAsc =
false,
bool bIncludeTensors =
false)
const;
89 virtual const CResult PrintNodeParamInfo()
const;
91 virtual bool IsTensorCoreAvailable()
const;
93 virtual int64_t GetAddGradientCount()
const;
94 virtual const CResult SetAddGradientCount(int64_t i64AddGradientCount);
96 virtual int64_t GetRequiredDedicatedMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1)
const = 0;
97 virtual int64_t GetRequiredTemporaryMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0)
const = 0;
99 virtual const CResult EnableRetainedValue(
bool bRetained);
100 virtual bool IsRetainedEnabled()
const;
102 virtual bool IsInplace()
const;
104 virtual int64_t GetParentNodeCount()
const;
105 virtual CComputationalBase<T>* GetParentNode(int64_t i64Index)
const;
107 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex);
108 virtual int32_t GetDeviceIndex()
const;
110 virtual const CResult CreateValue();
111 virtual const CResult ClearValue();
113 virtual const CResult CreateDerivative();
114 virtual const CResult ClearDerivative();
116 const CResult SetValuePtr(CTensor<T>* pTsrValue);
117 const CResult SetDerivativePtr(CTensor<T>* pTsrDerivative);
119 virtual CTensor<T>* GetValuePtr();
121 static CComputationalBase<T>* GetSingletonObject();
123 virtual const CResult SetProcessingUnit(
const Base::CProcessingUnitBase& puBase)
override;
124 virtual const CResult SetProcessingUnit(
const Base::CProcessingUnitBase* pPuBase)
override;
126 SupportToDuplicateAbstractObject(CComputationalBase);
128 const CResult InternalAssign(
const CComputationalBase<T>& cb);
130 virtual const CResult BeginForwardPerformanceCheck();
131 virtual const CResult EndForwardPerformanceCheck();
133 virtual const CResult BeginBackwardPerformanceCheck();
134 virtual const CResult EndBackwardPerformanceCheck();
138 ENodeType m_eNodeType;
139 EDataType m_eDataType;
140 EValueAttribute m_eValueAttribute;
141 ENodeOperator m_eNodeOperator;
143 bool m_bTrainingMode;
144 int64_t m_i64ObjectNumber;
147 int64_t m_i64AddGradientCount;
148 int64_t m_i64Generation;
149 int32_t m_i32DeviceIndex;
151 CTensor<T>* m_pTsrValue;
152 CTensor<T>* m_pTsrDerivative;
154 Base::CPerformanceCounter m_perfForward;
155 Base::CPerformanceCounter m_perfBackward;
157 std::vector<int64_t>& m_vctShape;
158 std::vector<int64_t>& m_vctShapeAsc;
160 std::vector<CComputationalBase<T>*>& m_vctParentNodes;
163 friend class CComputationalGraph<T>;
164 friend class CComputationalGraphUtilities<T>;
167 typedef CComputationalBase<float> CComputationalBaseF;
168 typedef CComputationalBase<double> CComputationalBaseD;
170 typedef CComputationalBase<float> CCBF;
171 typedef CComputationalBase<double> CCBD;
173 template <
typename T>
174 using CCB = CComputationalBase<T>;
알고리즘에서 필요한 프로세싱 유닛 AI 클래스
Definition AlgorithmAIBase.h:25
Definition AlgorithmAIBase.h:16
@ Clear
도형 정리 메뉴
Definition DefinitionsGUI.h:2058
@ Save
저장 메뉴
Definition DefinitionsGUI.h:303
@ Load
불러오기
Definition DefinitionsGUI.h:50