FLImaging 6.5.16.1
ComputationalGraph.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalBase.h"
6#include "Tensor.h"
7
8namespace FLImaging
9{
10 namespace AI
11 {
12 template <typename T>
13 class CTensor;
14
15 template <typename T>
16 class CBackendBase;
17
18 template <typename T>
19 class COptimizer;
20
21 template <typename T>
22 class FL_EXPORT CComputationalGraph : public CComputationalBase<T>
23 {
24
25 public:
26 CComputationalGraph();
27 protected:
28 CComputationalGraph(const CComputationalGraph<T>& cg);
29
30 public:
31 virtual ~CComputationalGraph();
32
33 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
34 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
35
36 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
37 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
38
39 const CResult Assign(const CComputationalGraph<T>& cg);
40 virtual const CResult Clear() override;
41
42 virtual CTensor<T>& Evaluate() override;
43 virtual CTensor<T>& Forward() = 0;
44 virtual CTensor<T>* Backward() = 0;
45
46 virtual void ClearDerivativesRecursive() override;
47 virtual void ResetDerivativesRecursive() override;
48
49 virtual CComputationalBase<T>* Clone() const = 0;
50
51 virtual const CResult Swap(CComputationalBase<T>& cbSwap) override;
52
53 virtual const CResult SetDeviceIndex(int32_t i32DeviceIndex) override;
54
55 virtual const CComputationalBase<T>* GetAt(const wchar_t* pWcsName) const override;
56 virtual const CResult FindByValueAttribute(EValueAttribute eValueAttribute, std::vector<const CComputationalBase<T>*>& vctResult) const override;
57 virtual const CResult FindByNodeOperator(ENodeOperator eNodeOperator, std::vector<const CComputationalBase<T>*>& vctResult) const;
58
59 virtual const CResult SetOperand(int64_t i64Index, const CComputationalBase<T>* pCbOperand, bool bSetSourceReferenceType = false);
60 virtual const CResult SetOperand(const std::vector<const CComputationalBase<T>*>& cbOperands, bool bSetSourceReferenceType = false);
61 virtual CComputationalBase<T>* GetOperand(int64_t i64Index) const;
62 virtual bool IsIntrinsicOperand(int64_t i64Index) const;
63 virtual int64_t GetOperandCount() const override;
64
65 virtual CTensor<T>* GetMemberTensor(int64_t i64Index) const;
66 virtual int64_t GetMemberTensorCount() const;
67
68 virtual CBackendBase<T>* GetMemberBackend(int64_t i64Index) const;
69 virtual int64_t GetMemberBackendCount() const;
70
71 virtual int64_t GetNextBatchSize(int64_t i64BatchSize) const override;
72 virtual const CResult EnableTrainingMode(bool bMode, bool bRecursively = true) override;
73
74 virtual const CResult PrintGraphInfo(bool bRecursively = true, bool bShapeOrderAsc = false, bool bIncludeTensors = false) const override;
75
76 virtual const CResult SetParentOptimizer(COptimizer<T>* pParentOptimizer, bool bRecursive = true);
77 virtual const COptimizer<T>* GetParentOptimizer() const;
78
79 virtual int64_t GetClassType() const = 0;
80
81 virtual int64_t GetOperandIndexForNextBatchSize() const;
82 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
83 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
84
85 virtual const CTensor<T>& GetDerivativeTemp() const;
86
87 virtual int64_t GetSharedMemoryCount() const;
88
89 SupportToDuplicateAbstractObject(CComputationalGraph);
90 protected:
91 virtual const CResult InternalAssign(const CComputationalGraph<T>& cg);
92
93 virtual const CResult InternalConnectNode(const CComputationalBase<T>& cbOperand1);
94 virtual const CResult InternalConnectNode(const CComputationalBase<T>& cbOperand1, const CComputationalBase<T>& cbOperand2);
95 virtual const CResult InternalConnectNode(const CComputationalBase<T>& cbOperand1, const CComputationalBase<T>& cbOperand2, const CComputationalBase<T>& cbOperand3);
96 virtual const CResult InternalConnectNode(const std::initializer_list<const CComputationalBase<T>*>& ilOperands);
97 virtual const CResult InternalConnectNode(const std::vector<const CComputationalBase<T>*>& vctOperands);
98
99 virtual const CResult InternalCreateOperandsNode(int64_t i64OperandsCount);
100
101 virtual const CResult ClearParentNode();
102 virtual const CResult RemoveCurrentParentNode();
103 virtual const CResult UpdateParentNode();
104
105 virtual const CResult ClearOperand();
106 virtual const CResult UpdateGeneration(bool bRecursive = false);
107
108 enum EMaxCount
109 {
110 EMaxCount_Operands = 3,
111 EMaxCount_Derivatives = 3,
112 };
113
114 virtual bool AreOperatorsValid() const;
115 virtual bool AreOperatorsAndTheirValuesValid();
116
117 virtual bool InitializeDerivativeVariable();
118
119 CBackendBase<T>** m_ppMemberBackend;
120 CTensor<T>** m_ppTsrMemberTensor;
121 int64_t m_i64MemberTensorCount;
122 int64_t m_i64MemberBackendCount;
123
124 int64_t m_i64SharedMemoryCount;
125
126 CComputationalBase<T>** m_pCbOperands;
127 bool* m_pArrBIntrinsicOperand;
128 int64_t m_i64OperandsCount;
129 CTensor<T> m_tsrDerivativeTemp;
130 COptimizer<T>* m_pParentOptimizer;
131 private:
132 friend class CComputationalGraphUtilities<T>;
133 friend class COptimizer<T>;
134 };
135
136 typedef CComputationalGraph<float> CComputationalGraphF;
137 typedef CComputationalGraph<double> CComputationalGraphD;
138
139 typedef CComputationalGraph<float> CCGF;
140 typedef CComputationalGraph<double> CCGD;
141
142 template <typename T>
143 using CCG = CComputationalGraph<T>;
144 }
145}
146
147#endif
@ Assign
Set the value of CGUIPropertyItemView3DFigure to the specified figure.
@ Clear
Clear all the figure objects.