FLImaging 7.2.4.2
ComputationalGraphUpsample3D.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "ComputationalGraph.h"
6
7namespace FLImaging
8{
9 namespace AI
10 {
11 enum EUpsample3DMode : int32_t
12 {
13 EUpsample3DMode_Nearest = 0,
14 EUpsample3DMode_Trilinear,
15 EUpsample3DMode_Area,
16 EUpsample3DMode_NearestExact,
17 };
18
19 template <typename T>
20 class FL_EXPORT CComputationalGraphUpsample3D : public CComputationalGraph<T>
21 {
22 private:
23 CComputationalGraphUpsample3D();
24
25 protected:
26 CComputationalGraphUpsample3D(const CComputationalGraphUpsample3D<T>& cg);
27
28 public:
29 CComputationalGraphUpsample3D(const CComputationalBase<T>& cbOperand, T tScaleX, T tScaleY, T tScaleZ, EUpsample3DMode eMode = EUpsample3DMode_Nearest);
30 CComputationalGraphUpsample3D(const CComputationalBase<T>& cbOperand, int64_t i64Width, int64_t i64Height, int64_t i64Depth, EUpsample3DMode eMode = EUpsample3DMode_Nearest);
31
32 virtual ~CComputationalGraphUpsample3D();
33
34 virtual CTensor<T>& Forward() override;
35 virtual CTensor<T>* Backward() override;
36 virtual CComputationalBase<T>* Clone() const override;
37 virtual const std::vector<int64_t>& GetEstimatedShape(bool bRecursive = true) const override;
38
39 virtual const CResult PrintNodeParamInfo() const override;
40
41 virtual const CResult GetBinaryData(Base::CFLData& fldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
42 virtual const CResult GetBinaryData(Base::CFLData* pFldBinary, bool bSuperClass = false, int32_t i32Version = -1, bool bDumpMode = false) const override;
43
44 virtual const CResult SetBinaryData(const Base::CFLData& fldBinary, int64_t* pI64Offset = nullptr) override;
45 virtual const CResult SetBinaryData(const Base::CFLData* pFldBinary, int64_t* pI64Offset = nullptr) override;
46
47 virtual int64_t GetRequiredTemporaryMemory(bool bTraining = false, bool bRecursively = true, int64_t i64BatchSize = 1, int64_t i64MemoryIndex = 0) const override;
48
49 DeclareGetClassType();
50 SupportToDuplicateObjectWithoutCreateNewObject(CComputationalGraphUpsample3D, *this);
51 protected:
52 virtual const CResult Upsample3D(const CTensor<T>* pTsrInput, CTensor<T>* pTsrResult, EUpsample3DMode eMode);
53 virtual const CResult Derivative(bool bAddGradients);
54 private:
55 virtual const CResult Upsample3D_Nearest(const CTensor<T>* pTsrX, CTensor<T>* pTsrY);
56 virtual const CResult Upsample3D_Trilinear(const CTensor<T>* pTsrX, CTensor<T>* pTsrY);
57 virtual const CResult Upsample3D_Area(const CTensor<T>* pTsrX, CTensor<T>* pTsrY);
58 virtual const CResult Upsample3D_NearestExact(const CTensor<T>* pTsrX, CTensor<T>* pTsrY);
59
60 virtual const CResult Derivative_Nearest(const CTensor<T>* pTsrDy, CTensor<T>* pTsrDx);
61 virtual const CResult Derivative_Trilinear(const CTensor<T>* pTsrDy, CTensor<T>* pTsrDx);
62 virtual const CResult Derivative_Area(const CTensor<T>* pTsrDy, CTensor<T>* pTsrDx);
63 virtual const CResult Derivative_NearestExact(const CTensor<T>* pTsrDy, CTensor<T>* pTsrDx);
64
65 protected:
66 T m_tScaleX;
67 T m_tScaleY;
68 T m_tScaleZ;
69 int64_t m_i64Width;
70 int64_t m_i64Height;
71 int64_t m_i64Depth;
72 EUpsample3DMode m_eMode;
73
74 public:
75 DeclareGetSignletonObject(CComputationalGraphUpsample3D);
76 };
77
78 #define CCGFUpsample3D(...) (*(new CComputationalGraphUpsample3D<float>(__VA_ARGS__)))
79 #define CCGDUpsample3D(...) (*(new CComputationalGraphUpsample3D<double>(__VA_ARGS__)))
80
81 #define CCGTUpsample3D(T, ...) (*(new CComputationalGraphUpsample3D<T>(__VA_ARGS__)))
82 }
83}
84
85#endif
Definition AlgorithmAIBase.h:18