FLImaging 6.5.16.1
BackendRNN.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "BackendBase.h"
6
7namespace FLImaging
8{
9 namespace AI
10 {
11 template <typename T>
12 class FL_EXPORT CTensor;
13
14 template <typename T>
15 class FL_EXPORT CBackendBase;
16
17 template <typename T>
18 class FL_EXPORT CBackendRNN : public CBackendBase<T>
19 {
20
21 public:
22 CBackendRNN();
23 CBackendRNN(const CBackendRNN<T>& bgn);
24 virtual ~CBackendRNN();
25
26 virtual const CResult EnableBatchFirst(bool bBatchFirst);
27 virtual bool IsBatchFirstEnabled();
28 virtual const CResult EnableBias(bool bBias);
29 virtual bool IsBiasEnabled();
30 virtual const CResult SetNonlinearity(ENonlinearlity eNonlinearity);
31 virtual ENonlinearlity GetNonlinearity();
32
33 virtual const CResult Forward(const CTensor<T>* pTsrX, const CTensor<T>* pTsrWeightX, const CTensor<T>* pTsrBiasX, const CTensor<T>* pTsrWeightH, const CTensor<T>* pTsrBiasH, CTensor<T>* pTsrY, bool bReverse = false);
34 virtual const CResult Backward(const CTensor<T>* pTsrDy, const CTensor<T>* pTsrX, const CTensor<T>* pTsrY, const CTensor<T>* pTsrWeightX, const CTensor<T>* pTsrWeightH, CTensor<T>* pTsrDx, CTensor<T>* pTsrDWeightX, CTensor<T>* pTsrDBiasX, CTensor<T>* pTsrDWeightH, CTensor<T>* pTsrDBiasH, bool bAddGradient, bool bWeightXAddGradient, bool bBiasXAddGradient, bool bWeightHAddGradient, bool bBiasHAddGradient, bool bReverse = false);
35
36 virtual const CResult Forward_LSTM(const CTensor<T>* pTsrX, const CTensor<T>* pTsrWeightX, const CTensor<T>* pTsrBiasX, const CTensor<T>* pTsrWeightH, const CTensor<T>* pTsrBiasH, CTensor<T>* pTsrAllGateResult, CTensor<T>* pTsrCell, CTensor<T>* pTsrY, bool bReverse = false);
37 virtual const CResult Backward_LSTM(const CTensor<T>* pTsrDy, const CTensor<T>* pTsrX, const CTensor<T>* pTsrY, const CTensor<T>* pTsrAllGateResult, const CTensor<T>* pTsrCell, const CTensor<T>* pTsrWeightX, const CTensor<T>* pTsrWeightH, CTensor<T>* pTsrDx, CTensor<T>* pTsrDWeightX, CTensor<T>* pTsrDBiasX, CTensor<T>* pTsrDWeightH, CTensor<T>* pTsrDBiasH, bool bAddGradient, bool bWeightXAddGradient, bool bBiasXAddGradient, bool bWeightHAddGradient, bool bBiasHAddGradient, bool bReverse = false);
38
39 virtual const int64_t GetRequiredDedicatedMemory_Forward(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize) const;
40 virtual const int64_t GetRequiredDedicatedMemory_Backward(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize) const;
41 virtual const int64_t GetRequiredTemporaryMemory_Forward(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize, int64_t i64MemoryIndex) const;
42 virtual const int64_t GetRequiredTemporaryMemory_Backward(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize, int64_t i64MemoryIndex) const;
43
44 virtual const int64_t GetRequiredDedicatedMemory_Forward_LSTM(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize) const;
45 virtual const int64_t GetRequiredDedicatedMemory_Backward_LSTM(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize) const;
46 virtual const int64_t GetRequiredTemporaryMemory_Forward_LSTM(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize, int64_t i64MemoryIndex) const;
47 virtual const int64_t GetRequiredTemporaryMemory_Backward_LSTM(std::vector<int64_t>& vctInputShape, int64_t i64HiddenSize, int64_t i64MemoryIndex) const;
48
49 protected:
50 virtual const CResult RNN_Step(CTensor<T>* pTsrXResult, CTensor<T>* pTsrHBefore, CTensor<T>* pTsrWeightH, CTensor<T>* pTsrBiasH, CTensor<T>* pTsrH, int64_t i64TotalElement, int64_t i64HiddenSize, int64_t i64Step);
51 virtual const CResult RNN_Derivative_Step(CTensor<T>* pTsrDy, CTensor<T>* pTsrDhBefore, CTensor<T>* pTsrXTransposed, CTensor<T>* pTsrHTransposed, CTensor<T>* pTsrWeightXTransposed, CTensor<T>* pTsrWeightHTransposed, CTensor<T>* pTsrOneMinusYSquare, CTensor<T>* pTsrDx, CTensor<T>* pTsrDh, CTensor<T>* pTsrDWeightX, CTensor<T>* pTsrDBiasX, CTensor<T>* pTsrDWeightH, CTensor<T>* pTsrDBiasH, int64_t i64TotalElement);
52
53 virtual const CResult LSTM_Step(T* pTXResult, const T* pTCellBefore, T* pTCell, T* pTH, CTensor<T>* pTsrHBefore, CTensor<T>* pTsrWeightH, CTensor<T>* pTsrBiasH, int64_t i64TotalElement, int64_t i64GateElement, int64_t i64HiddenSize, int64_t i64AllGateElement, int64_t i64Step);
54 virtual const CResult LSTM_Derivative_Step(const T* pTCell, const T* pTCellBefore, const T* pTAllGateResult, const T* pTDy, T* pTDs, CTensor<T>* pTsrDCellBefore, CTensor<T>* pTsrDhBefore, CTensor<T>* pTsrXTransposed, CTensor<T>* pTsrHTransposed, CTensor<T>* pTsrWeightXTransposed, CTensor<T>* pTsrWeightHTransposed, CTensor<T>* pTsrDx, CTensor<T>* pTsrDCell, CTensor<T>* pTsrDh, CTensor<T>* pTsrDWeightX, CTensor<T>* pTsrDBiasX, CTensor<T>* pTsrDWeightH, CTensor<T>* pTsrDBiasH, int64_t i64TotalElement, int64_t i64GateResult);
55
56 virtual const CResult ReluBackward_CPU(const T* pTDy, const T* pTY, T* pTDx, int64_t i64TotalElement);
57 virtual const CResult OneMinusSquareTanhAndMultiplyOutputGateResult_CPU(const T* pTCell, const T* pTAllGateResult, T* pTResult, int64_t i64TotalElement, int64_t i64AllGateTimeStepElement, int64_t i64TimeStepElement);
58 virtual const CResult CaculateDCellAndAllGateDerivative_CPU(T* pTDs, T* pTDCell, T* pTDAllGate, const T* pTDy, const T* pTCell, const T* pTCellBefore, const T* pTDhBefore, const T* pTDCellBefore, const T* pTAllGateResult, int64_t i64TotalElement, int64_t i64GateElement);
59 virtual const CResult ReduceAllGateAndAdd_CPU(T* pTDx, const T* pTDa, int64_t i64TotalElement, bool bAddGradient);
60
61 DeclareGetClassType();
62 SupportToDuplicateObjectWithoutCreateNewObject(CBackendRNN<T>, *this);
63
64 protected:
65
66 bool m_bBatchFirst;
67 bool m_bBias;
68 ENonlinearlity m_eNonlinearity;
69
70 CTensor<T> m_tsrGateMatmulResult;
71 CTensor<T> m_tsrForgetGateMatmulResult;
72 CTensor<T> m_tsrCellGateMatmulResult;
73 CTensor<T> m_tsrOutputGateMatmulResult;
74 CTensor<T> m_tsrHMatmulResult;
75
76 CTensor<T> m_tsrOneMinusYSquare;
77 CTensor<T> m_tsrDyStep;
78 CTensor<T> m_tsrXTransposedStep;
79 CTensor<T> m_tsrHTransposedStep;
80 CTensor<T> m_tsrDsStep;
81 CTensor<T> m_tsrFinalDxStep;
82 CTensor<T> m_tsrDxStep;
83 CTensor<T> m_tsrDhStep;
84 CTensor<T> m_tsrDwXStep;
85 CTensor<T> m_tsrDwHStep;
86 CTensor<T> m_tsrDbStep;
87 };
88 }
89}
90
91#endif