3#if _MSC_VER >= 1900 && defined(_M_X64)
5#include "ComputationalGraph.h"
6#include "ComputationalGraphRNN.h"
17 class FL_EXPORT CComputationalGraphLSTM :
public CComputationalGraphRNN<T>
20 CComputationalGraphLSTM();
21 CComputationalGraphLSTM(
const CComputationalGraphLSTM<T>& cg);
25 CComputationalGraphLSTM(
const CComputationalBase<T>& cbOperand, int64_t i64HiddenSize,
bool bBatchFirst =
false,
bool bBias =
true,
bool bBidirectional =
false);
26 virtual ~CComputationalGraphLSTM();
28 virtual CTensor<T>& Forward()
override;
29 virtual CTensor<T>* Backward()
override;
30 virtual CComputationalBase<T>* Clone()
const override;
32 virtual int64_t GetRequiredDedicatedMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1)
const override;
33 virtual int64_t GetRequiredTemporaryMemory(
bool bTraining =
false,
bool bRecursively =
true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0)
const override;
35 DeclareGetClassType();
36 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphLSTM, *
this);
40 CTensor<T>* m_pTsrAllGateResult;
41 CTensor<T>* m_pTsrReverseAllGateResult;
42 CTensor<T>* m_pTsrCell;
43 CTensor<T>* m_pTsrReverseCell;
46 DeclareGetSignletonObject(CComputationalGraphLSTM);
49 #define CCGFLSTM(...) (*(new CComputationalGraphLSTM<float>(__VA_ARGS__)))
50 #define CCGDLSTM(...) (*(new CComputationalGraphLSTM<double>(__VA_ARGS__)))
52 #define CCGTLSTM(T, ...) (*(new CComputationalGraphLSTM<T>(__VA_ARGS__)))