FLImaging 6.5.16.1
ComputationalGraphLSTM.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalGraph.h"
6#include "ComputationalGraphRNN.h"
7#include "Tensor.h"
8#include "BackendRNN.h"
9
10#include <vector>
11
12namespace FLImaging
13{
14 namespace AI
15 {
16 template <typename T>
17 class FL_EXPORT CComputationalGraphLSTM : public CComputationalGraphRNN<T>
18 {
19 protected:
20 CComputationalGraphLSTM();
21 CComputationalGraphLSTM(const CComputationalGraphLSTM<T>& cg);
22
23 public:
24
25 CComputationalGraphLSTM(const CComputationalBase<T>& cbOperand, int64_t i64HiddenSize, bool bBatchFirst = false, bool bBias = true, bool bBidirectional = false);
26 virtual ~CComputationalGraphLSTM();
27
28 virtual CTensor<T>& Forward() override;
29 virtual CTensor<T>* Backward() override;
30 virtual CComputationalBase<T>* Clone() const override;
31
32 virtual int64_t GetRequiredDedicatedMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1) const override;
33 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
34
35 DeclareGetClassType();
36 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphLSTM, *this);
37
38 protected:
39
40 CTensor<T>* m_pTsrAllGateResult;
41 CTensor<T>* m_pTsrReverseAllGateResult;
42 CTensor<T>* m_pTsrCell;
43 CTensor<T>* m_pTsrReverseCell;
44
45 public:
46 DeclareGetSignletonObject(CComputationalGraphLSTM);
47 };
48
49 #define CCGFLSTM(...) (*(new CComputationalGraphLSTM<float>(__VA_ARGS__)))
50 #define CCGDLSTM(...) (*(new CComputationalGraphLSTM<double>(__VA_ARGS__)))
51
52 #define CCGTLSTM(T, ...) (*(new CComputationalGraphLSTM<T>(__VA_ARGS__)))
53 }
54}
55
56#endif