FLImaging 6.5.16.1
BackendCTCLoss.h
1#pragma once
2
3#if _MSC_VER >= 1900 && defined(_M_X64)
4
5#include "BackendBase.h"
6
7namespace FLImaging
8{
9 namespace AI
10 {
11 template <typename T>
12 class FL_EXPORT CTensor;
13
14 template <typename T>
15 class FL_EXPORT CBackendCTCLoss : public CBackendBase<T>
16 {
17 public:
18 CBackendCTCLoss();
19 CBackendCTCLoss(const CBackendCTCLoss<T>& bl);
20 virtual ~CBackendCTCLoss();
21
22 virtual const CResult SetReductionType(EReductionType eReductionType);
23 virtual EReductionType GetReductionType() const;
24
25 virtual const CResult SetBlankLabel(int64_t i64BlankLabel);
26 virtual int64_t GetBlankLabel() const;
27
28 virtual const CResult EnableZeroInfinity(bool bZeroInfinity);
29 virtual bool IsZeroInfinityEnable() const;
30
31 virtual const CResult Forward(const CTensor<T>* pTsrX, const CTensor<T>* pTsrTarget, CTensor<T>* pTsrProbablity, CTensor<T>* pTsrBatchLoss, CTensor<T>* pTsrResult);
32 virtual const CResult Backward(const CTensor<T>* pTsrDy, const CTensor<T>* pTsrX, const CTensor<T>* pTsrTarget, const CTensor<T>* pTsrForwardProbablity, const CTensor<T>* pTsrBatchLoss, CTensor<T>* pTsrDx, bool bAddGradient);
33
34 DeclareGetClassType();
35 SupportToDuplicateObjectWithoutCreateNewObject(CBackendCTCLoss<T>, *this);
36
37 protected:
38
39 inline T GetLogInnerAdd(T tValue1, T tValue2) const;
40 inline T GetLogInnerAdd_DynamicType(T tValue1, T tValue2) const;
41
42 EReductionType m_eReductionType;
43 int64_t m_i64BlankLabel;
44 bool m_bZeroInfinity;
45 };
46 }
47}
48
49#endif