1#include "All.h"
2#include "GlobalFunctions.h"
3#include "NNFilter.h"
4#include "Assembly.h"
5
6CNNFilter::CNNFilter(int nOrder, int nShift, int nVersion)
7{
8    if ((nOrder <= 0) || ((nOrder % 16) != 0)) throw(1);
9    m_nOrder = nOrder;
10    m_nShift = nShift;
11    m_nVersion = nVersion;
12
13    m_bMMXAvailable = GetMMXAvailable();
14
15    m_rbInput.Create(NN_WINDOW_ELEMENTS, m_nOrder);
16    m_rbDeltaM.Create(NN_WINDOW_ELEMENTS, m_nOrder);
17    m_paryM = new short [m_nOrder];
18
19#ifdef NN_TEST_MMX
20    srand(GetTickCount());
21#endif
22}
23
24CNNFilter::~CNNFilter()
25{
26    SAFE_ARRAY_DELETE(m_paryM)
27}
28
29void CNNFilter::Flush()
30{
31    memset(&m_paryM[0], 0, m_nOrder * sizeof(short));
32    m_rbInput.Flush();
33    m_rbDeltaM.Flush();
34    m_nRunningAverage = 0;
35}
36
37int CNNFilter::Compress(int nInput)
38{
39    // convert the input to a short and store it
40    m_rbInput[0] = GetSaturatedShortFromInt(nInput);
41
42    // figure a dot product
43    int nDotProduct;
44    if (m_bMMXAvailable)
45        nDotProduct = CalculateDotProduct(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
46    else
47        nDotProduct = CalculateDotProductNoMMX(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
48
49    // calculate the output
50    int nOutput = nInput - ((nDotProduct + (1 << (m_nShift - 1))) >> m_nShift);
51
52    // adapt
53    if (m_bMMXAvailable)
54        Adapt(&m_paryM[0], &m_rbDeltaM[-m_nOrder], -nOutput, m_nOrder);
55    else
56        AdaptNoMMX(&m_paryM[0], &m_rbDeltaM[-m_nOrder], nOutput, m_nOrder);
57
58    int nTempABS = abs(nInput);
59
60    if (nTempABS > (m_nRunningAverage * 3))
61        m_rbDeltaM[0] = ((nInput >> 25) & 64) - 32;
62    else if (nTempABS > (m_nRunningAverage * 4) / 3)
63        m_rbDeltaM[0] = ((nInput >> 26) & 32) - 16;
64    else if (nTempABS > 0)
65        m_rbDeltaM[0] = ((nInput >> 27) & 16) - 8;
66    else
67        m_rbDeltaM[0] = 0;
68
69    m_nRunningAverage += (nTempABS - m_nRunningAverage) / 16;
70
71    m_rbDeltaM[-1] >>= 1;
72    m_rbDeltaM[-2] >>= 1;
73    m_rbDeltaM[-8] >>= 1;
74
75    // increment and roll if necessary
76    m_rbInput.IncrementSafe();
77    m_rbDeltaM.IncrementSafe();
78
79    return nOutput;
80}
81
82int CNNFilter::Decompress(int nInput)
83{
84    // figure a dot product
85    int nDotProduct;
86
87    if (m_bMMXAvailable)
88        nDotProduct = CalculateDotProduct(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
89    else
90        nDotProduct = CalculateDotProductNoMMX(&m_rbInput[-m_nOrder], &m_paryM[0], m_nOrder);
91
92    // adapt
93    if (m_bMMXAvailable)
94        Adapt(&m_paryM[0], &m_rbDeltaM[-m_nOrder], -nInput, m_nOrder);
95    else
96        AdaptNoMMX(&m_paryM[0], &m_rbDeltaM[-m_nOrder], nInput, m_nOrder);
97
98    // store the output value
99    int nOutput = nInput + ((nDotProduct + (1 << (m_nShift - 1))) >> m_nShift);
100
101    // update the input buffer
102    m_rbInput[0] = GetSaturatedShortFromInt(nOutput);
103
104    if (m_nVersion >= 3980)
105    {
106        int nTempABS = abs(nOutput);
107
108        if (nTempABS > (m_nRunningAverage * 3))
109            m_rbDeltaM[0] = ((nOutput >> 25) & 64) - 32;
110        else if (nTempABS > (m_nRunningAverage * 4) / 3)
111            m_rbDeltaM[0] = ((nOutput >> 26) & 32) - 16;
112        else if (nTempABS > 0)
113            m_rbDeltaM[0] = ((nOutput >> 27) & 16) - 8;
114        else
115            m_rbDeltaM[0] = 0;
116
117        m_nRunningAverage += (nTempABS - m_nRunningAverage) / 16;
118
119        m_rbDeltaM[-1] >>= 1;
120        m_rbDeltaM[-2] >>= 1;
121        m_rbDeltaM[-8] >>= 1;
122    }
123    else
124    {
125        m_rbDeltaM[0] = (nOutput == 0) ? 0 : ((nOutput >> 28) & 8) - 4;
126        m_rbDeltaM[-4] >>= 1;
127        m_rbDeltaM[-8] >>= 1;
128    }
129
130    // increment and roll if necessary
131    m_rbInput.IncrementSafe();
132    m_rbDeltaM.IncrementSafe();
133
134    return nOutput;
135}
136
137void CNNFilter::AdaptNoMMX(short * pM, short * pAdapt, int nDirection, int nOrder)
138{
139    nOrder >>= 4;
140
141    if (nDirection < 0)
142    {
143        while (nOrder--)
144        {
145            EXPAND_16_TIMES(*pM++ += *pAdapt++;)
146        }
147    }
148    else if (nDirection > 0)
149    {
150        while (nOrder--)
151        {
152            EXPAND_16_TIMES(*pM++ -= *pAdapt++;)
153        }
154    }
155}
156
157int CNNFilter::CalculateDotProductNoMMX(short * pA, short * pB, int nOrder)
158{
159    int nDotProduct = 0;
160    nOrder >>= 4;
161
162    while (nOrder--)
163    {
164        EXPAND_16_TIMES(nDotProduct += *pA++ * *pB++;)
165    }
166
167    return nDotProduct;
168}
169