3#if _MSC_VER >= 1900 && defined(_M_X64)
5#include "ComputationalBase.h"
22 class FL_EXPORT CComputationalGraph :
public CComputationalBase<T>
26 CComputationalGraph();
28 CComputationalGraph(
const CComputationalGraph<T>& cg);
31 virtual ~CComputationalGraph();
33 virtual const CResult GetBinaryData(Base::CFLData& fldBinary,
bool bSuperClass =
false, int32_t i32Version = -1,
bool bDumpMode =
false)
const override;
34 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary,
bool bSuperClass =
false, int32_t i32Version = -1,
bool bDumpMode =
false)
const override;
36 virtual const CResult SetBinaryData(
const Base::CFLData& fldBinary, int64_t* pI64Offset =
nullptr)
override;
37 virtual const CResult SetBinaryData(
const Base::CFLData* pFldBinary, int64_t* pI64Offset =
nullptr)
override;
39 const CResult
Assign(
const CComputationalGraph<T>& cg);
40 virtual const CResult
Clear()
override;
42 virtual CTensor<T>& Evaluate()
override;
43 virtual CTensor<T>& Forward() = 0;
44 virtual CTensor<T>* Backward() = 0;
46 virtual void ClearDerivativesRecursive()
override;
47 virtual void ResetDerivativesRecursive()
override;
49 virtual CComputationalBase<T>* Clone()
const = 0;
51 virtual const CResult Swap(CComputationalBase<T>& cbSwap)
override;
53 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex)
override;
55 virtual const CComputationalBase<T>* GetAt(
const wchar_t* pWcsName)
const override;
56 virtual const CResult FindByValueAttribute(EValueAttribute eValueAttribute, std::vector<
const CComputationalBase<T>*>& vctResult)
const override;
57 virtual const CResult FindByNodeOperator(ENodeOperator eNodeOperator, std::vector<
const CComputationalBase<T>*>& vctResult)
const;
59 virtual const CResult SetOperand(int64_t i64Index,
const CComputationalBase<T>* pCbOperand,
bool bSetSourceReferenceType =
false);
60 virtual const CResult SetOperand(
const std::vector<
const CComputationalBase<T>*>& cbOperands,
bool bSetSourceReferenceType =
false);
61 virtual CComputationalBase<T>* GetOperand(int64_t i64Index)
const;
62 virtual bool IsIntrinsicOperand(int64_t i64Index)
const;
63 virtual int64_t GetOperandCount()
const override;
65 virtual CTensor<T>* GetMemberTensor(int64_t i64Index)
const;
66 virtual int64_t GetMemberTensorCount()
const;
68 virtual CBackendBase<T>* GetMemberBackend(int64_t i64Index)
const;
69 virtual int64_t GetMemberBackendCount()
const;
71 virtual int64_t GetNextBatchSize(int64_t i64BatchSize)
const override;
72 virtual const CResult EnableTrainingMode(
bool bMode,
bool bRecursively =
true)
override;
74 virtual const CResult PrintGraphInfo(
bool bRecursively =
true,
bool bShapeOrderAsc =
false,
bool bIncludeTensors =
false)
const override;
76 virtual const CResult SetParentOptimizer(COptimizer<T>* pParentOptimizer,
bool bRecursive =
true);
77 virtual const COptimizer<T>* GetParentOptimizer()
const;
79 virtual int64_t GetClassType()
const = 0;
81 virtual int64_t GetOperandIndexForNextBatchSize()
const;
82 virtual int64_t GetRequiredDedicatedMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1)
const override;
83 virtual int64_t GetRequiredTemporaryMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0)
const override;
85 virtual const CTensor<T>& GetDerivativeTemp()
const;
87 virtual int64_t GetSharedMemoryCount()
const;
88 virtual const CResult EnableTensorCore(
bool bTensorCore =
true)
override;
90 SupportToDuplicateAbstractObject(CComputationalGraph);
92 virtual const CResult InternalAssign(
const CComputationalGraph<T>& cg);
94 virtual const CResult InternalConnectNode(
const CComputationalBase<T>& cbOperand1);
95 virtual const CResult InternalConnectNode(
const CComputationalBase<T>& cbOperand1,
const CComputationalBase<T>& cbOperand2);
96 virtual const CResult InternalConnectNode(
const CComputationalBase<T>& cbOperand1,
const CComputationalBase<T>& cbOperand2,
const CComputationalBase<T>& cbOperand3);
97 virtual const CResult InternalConnectNode(
const std::initializer_list<
const CComputationalBase<T>*>& ilOperands);
98 virtual const CResult InternalConnectNode(
const std::vector<
const CComputationalBase<T>*>& vctOperands);
100 virtual const CResult InternalCreateOperandsNode(int64_t i64OperandsCount);
102 virtual const CResult ClearParentNode();
103 virtual const CResult RemoveCurrentParentNode();
104 virtual const CResult UpdateParentNode();
106 virtual const CResult ClearOperand();
107 virtual const CResult UpdateGeneration(
bool bRecursive =
false);
111 EMaxCount_Operands = 3,
112 EMaxCount_Derivatives = 3,
115 virtual bool AreOperatorsValid()
const;
116 virtual bool AreOperatorsAndTheirValuesValid();
118 virtual bool InitializeDerivativeVariable();
120 CBackendBase<T>** m_ppMemberBackend;
121 CTensor<T>** m_ppTsrMemberTensor;
122 int64_t m_i64MemberTensorCount;
123 int64_t m_i64MemberBackendCount;
125 int64_t m_i64SharedMemoryCount;
127 CComputationalBase<T>** m_pCbOperands;
128 bool* m_pArrBIntrinsicOperand;
129 int64_t m_i64OperandsCount;
130 CTensor<T> m_tsrDerivativeTemp;
131 COptimizer<T>* m_pParentOptimizer;
133 friend class CComputationalGraphUtilities<T>;
134 friend class COptimizer<T>;
137 typedef CComputationalGraph<float> CComputationalGraphF;
138 typedef CComputationalGraph<double> CComputationalGraphD;
140 typedef CComputationalGraph<float> CCGF;
141 typedef CComputationalGraph<double> CCGD;
143 template <
typename T>
144 using CCG = CComputationalGraph<T>;
Definition AlgorithmAIBase.h:18
@ Assign
CGUIPropertyItemView3DFigure 의 값을 해당 도형으로 설정하는 함수
Definition DefinitionsGUIView3D.h:2930
@ Clear
도형 정리 메뉴
Definition DefinitionsGUI.h:2110