3#if _MSC_VER >= 1900 && defined(_M_X64)
5#include "ComputationalBase.h"
22 class FL_EXPORT CComputationalGraph :
public CComputationalBase<T>
26 CComputationalGraph();
28 CComputationalGraph(
const CComputationalGraph<T>& cg);
31 virtual ~CComputationalGraph();
33 virtual const CResult GetBinaryData(Base::CFLData& fldBinary,
bool bSuperClass =
false, int32_t i32Version = -1,
bool bDumpMode =
false)
const override;
34 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary,
bool bSuperClass =
false, int32_t i32Version = -1,
bool bDumpMode =
false)
const override;
36 virtual const CResult SetBinaryData(
const Base::CFLData& fldBinary, int64_t* pI64Offset =
nullptr)
override;
37 virtual const CResult SetBinaryData(
const Base::CFLData* pFldBinary, int64_t* pI64Offset =
nullptr)
override;
39 const CResult
Assign(
const CComputationalGraph<T>& cg);
40 virtual const CResult
Clear()
override;
42 virtual CTensor<T>& Evaluate()
override;
43 virtual CTensor<T>& Forward() = 0;
44 virtual CTensor<T>* Backward() = 0;
46 virtual void ClearDerivativesRecursive()
override;
47 virtual void ResetDerivativesRecursive()
override;
49 virtual CComputationalBase<T>* Clone()
const = 0;
51 virtual const CResult Swap(CComputationalBase<T>& cbSwap)
override;
53 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex)
override;
55 virtual const CComputationalBase<T>* GetAt(
const wchar_t* pWcsName)
const override;
56 virtual const CResult FindByValueAttribute(EValueAttribute eValueAttribute, std::vector<
const CComputationalBase<T>*>& vctResult)
const override;
57 virtual const CResult FindByNodeOperator(ENodeOperator eNodeOperator, std::vector<
const CComputationalBase<T>*>& vctResult)
const;
59 virtual const CResult SetOperand(int64_t i64Index,
const CComputationalBase<T>* pCbOperand,
bool bSetSourceReferenceType =
false);
60 virtual const CResult SetOperand(
const std::vector<
const CComputationalBase<T>*>& cbOperands,
bool bSetSourceReferenceType =
false);
61 virtual CComputationalBase<T>* GetOperand(int64_t i64Index)
const;
62 virtual bool IsIntrinsicOperand(int64_t i64Index)
const;
63 virtual int64_t GetOperandCount()
const override;
65 virtual CTensor<T>* GetMemberTensor(int64_t i64Index)
const;
66 virtual int64_t GetMemberTensorCount()
const;
68 virtual CBackendBase<T>* GetMemberBackend(int64_t i64Index)
const;
69 virtual int64_t GetMemberBackendCount()
const;
71 virtual int64_t GetNextBatchSize(int64_t i64BatchSize)
const override;
72 virtual const CResult EnableTrainingMode(
bool bMode,
bool bRecursively =
true)
override;
74 virtual const CResult PrintGraphInfo(
bool bRecursively =
true,
bool bShapeOrderAsc =
false,
bool bIncludeTensors =
false)
const override;
76 virtual const CResult SetParentOptimizer(COptimizer<T>* pParentOptimizer,
bool bRecursive =
true);
77 virtual const COptimizer<T>* GetParentOptimizer()
const;
79 virtual int64_t GetClassType()
const = 0;
81 virtual int64_t GetOperandIndexForNextBatchSize()
const;
82 virtual int64_t GetRequiredDedicatedMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1)
const override;
83 virtual int64_t GetRequiredTemporaryMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0)
const override;
85 virtual const CTensor<T>& GetDerivativeTemp()
const;
87 virtual int64_t GetSharedMemoryCount()
const;
89 SupportToDuplicateAbstractObject(CComputationalGraph);
91 virtual const CResult InternalAssign(
const CComputationalGraph<T>& cg);
93 virtual const CResult InternalConnectNode(
const CComputationalBase<T>& cbOperand1);
94 virtual const CResult InternalConnectNode(
const CComputationalBase<T>& cbOperand1,
const CComputationalBase<T>& cbOperand2);
95 virtual const CResult InternalConnectNode(
const CComputationalBase<T>& cbOperand1,
const CComputationalBase<T>& cbOperand2,
const CComputationalBase<T>& cbOperand3);
96 virtual const CResult InternalConnectNode(
const std::initializer_list<
const CComputationalBase<T>*>& ilOperands);
97 virtual const CResult InternalConnectNode(
const std::vector<
const CComputationalBase<T>*>& vctOperands);
99 virtual const CResult InternalCreateOperandsNode(int64_t i64OperandsCount);
101 virtual const CResult ClearParentNode();
102 virtual const CResult RemoveCurrentParentNode();
103 virtual const CResult UpdateParentNode();
105 virtual const CResult ClearOperand();
106 virtual const CResult UpdateGeneration(
bool bRecursive =
false);
110 EMaxCount_Operands = 3,
111 EMaxCount_Derivatives = 3,
114 virtual bool AreOperatorsValid()
const;
115 virtual bool AreOperatorsAndTheirValuesValid();
117 virtual bool InitializeDerivativeVariable();
119 CBackendBase<T>** m_ppMemberBackend;
120 CTensor<T>** m_ppTsrMemberTensor;
121 int64_t m_i64MemberTensorCount;
122 int64_t m_i64MemberBackendCount;
124 int64_t m_i64SharedMemoryCount;
126 CComputationalBase<T>** m_pCbOperands;
127 bool* m_pArrBIntrinsicOperand;
128 int64_t m_i64OperandsCount;
129 CTensor<T> m_tsrDerivativeTemp;
130 COptimizer<T>* m_pParentOptimizer;
132 friend class CComputationalGraphUtilities<T>;
133 friend class COptimizer<T>;
136 typedef CComputationalGraph<float> CComputationalGraphF;
137 typedef CComputationalGraph<double> CComputationalGraphD;
139 typedef CComputationalGraph<float> CCGF;
140 typedef CComputationalGraph<double> CCGD;
142 template <
typename T>
143 using CCG = CComputationalGraph<T>;
@ Assign
Set the value of CGUIPropertyItemView3DFigure to the specified figure.
@ Clear
Clear all the figure objects.