diff options
-rw-r--r-- | src/column.cpp | 1 | ||||
-rw-r--r-- | src/column.h | 83 |
2 files changed, 84 insertions, 0 deletions
diff --git a/src/column.cpp b/src/column.cpp index e69de29..10b010b 100644 --- a/src/column.cpp +++ b/src/column.cpp | |||
@@ -0,0 +1 @@ | |||
#include "neural/column.h" | |||
diff --git a/src/column.h b/src/column.h index e69de29..17ba899 100644 --- a/src/column.h +++ b/src/column.h | |||
@@ -0,0 +1,83 @@ | |||
1 | #ifndef NEURAL_COLUMN_H | ||
2 | #define NEURAL_COLUMN_H | ||
3 | |||
4 | #include "neural/container.h" | ||
5 | |||
6 | namespace Neural | ||
7 | { | ||
8 | template<typename sigtype> | ||
9 | class Column : public Container<sigtype> | ||
10 | { | ||
11 | public: | ||
12 | Column() | ||
13 | { | ||
14 | } | ||
15 | |||
16 | virtual ~Column() | ||
17 | { | ||
18 | } | ||
19 | |||
20 | virtual void finalize( int iNumInputs ) | ||
21 | { | ||
22 | iInputs = iNumInputs; | ||
23 | iWeights = 0; | ||
24 | iBiases = 0; | ||
25 | |||
26 | int iNextInputs = iInputs; | ||
27 | for( typename Container<sigtype>::NodeList::iterator i = | ||
28 | Container<sigtype>::getNodeList().begin(); i; i++ ) | ||
29 | { | ||
30 | (*i)->finalize( iNextInputs ); | ||
31 | iNextInputs = (*i)->getNumOutputs(); | ||
32 | if( (i+1) ) | ||
33 | { | ||
34 | lBuffer.append( new sigtype[iNextInputs] ); | ||
35 | } | ||
36 | iWeights += (*i)->getNumWeights(); | ||
37 | iBiases += (*i)->getNumBiases(); | ||
38 | } | ||
39 | } | ||
40 | |||
41 | virtual void process( sigtype *aInput, sigtype *aOutput ) | ||
42 | { | ||
43 | |||
44 | sigtype *pInput, *pOutput; | ||
45 | int iOutputOffset = 0; | ||
46 | for( typename Container<sigtype>::NodeList::iterator i = | ||
47 | Container<sigtype>::getNodeList().begin(); i; i++ ) | ||
48 | { | ||
49 | (*i)->process( aInput, aOutput+iOutputOffset ); | ||
50 | iOutputOffset += (*i)->getNumOutputs(); | ||
51 | } | ||
52 | } | ||
53 | |||
54 | virtual int getNumInputs() const | ||
55 | { | ||
56 | return iInputs; | ||
57 | } | ||
58 | |||
59 | virtual int getNumOutputs() const | ||
60 | { | ||
61 | return Container<sigtype>::getNodeList().last()->getNumOutputs(); | ||
62 | } | ||
63 | |||
64 | virtual int getNumWeights() const | ||
65 | { | ||
66 | return iWeights; | ||
67 | } | ||
68 | |||
69 | virtual int getNumBiases() const | ||
70 | { | ||
71 | return iBiases; | ||
72 | } | ||
73 | |||
74 | private: | ||
75 | int iInputs; | ||
76 | int iWeights; | ||
77 | int iBiases; | ||
78 | typedef Bu::List<sigtype *> BufferList; | ||
79 | BufferList lBuffer; | ||
80 | }; | ||
81 | }; | ||
82 | |||
83 | #endif | ||