summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Buland <mike@xagasoft.com>2012-07-09 13:57:37 -0600
committerMike Buland <mike@xagasoft.com>2012-07-09 13:57:37 -0600
commit673252f7eadc3aa0cfae3d826e1c7bbe2400df17 (patch)
treefdb3a7f602ca9d91794f35e880017756d8487fa9
parent87dc10690035b02485067f2b0b77bcb0459da42b (diff)
downloadlibneural-673252f7eadc3aa0cfae3d826e1c7bbe2400df17.tar.gz
libneural-673252f7eadc3aa0cfae3d826e1c7bbe2400df17.tar.bz2
libneural-673252f7eadc3aa0cfae3d826e1c7bbe2400df17.tar.xz
libneural-673252f7eadc3aa0cfae3d826e1c7bbe2400df17.zip
It generates pngs just like the java version.
Maybe even prettier.
-rw-r--r--default.bld4
-rw-r--r--src/column.h3
-rw-r--r--src/container.h24
-rw-r--r--src/neuron.h20
-rw-r--r--src/node.h2
-rw-r--r--src/slopestd.h2
-rw-r--r--src/tests/pic.cpp113
7 files changed, 165 insertions, 3 deletions
diff --git a/default.bld b/default.bld
index f2fb0cc..8b2947a 100644
--- a/default.bld
+++ b/default.bld
@@ -32,3 +32,7 @@ for dir in dirs("src/tests/*") do
32 } 32 }
33} 33}
34 34
35target "tests/pic"
36{
37 LDFLAGS += "-lpng -lz";
38}
diff --git a/src/column.h b/src/column.h
index 8c07f62..d1b670d 100644
--- a/src/column.h
+++ b/src/column.h
@@ -48,7 +48,8 @@ namespace Neural
48 { 48 {
49 (*i)->process( pBuffer, pNextBuffer ); 49 (*i)->process( pBuffer, pNextBuffer );
50 pBuffer = pNextBuffer; 50 pBuffer = pNextBuffer;
51 iBuf++; 51 if( iBuf )
52 iBuf++;
52 if( iBuf ) 53 if( iBuf )
53 pNextBuffer = *iBuf; 54 pNextBuffer = *iBuf;
54 else 55 else
diff --git a/src/container.h b/src/container.h
index d9eeffd..f341bf8 100644
--- a/src/container.h
+++ b/src/container.h
@@ -22,6 +22,30 @@ namespace Neural
22 delete *i; 22 delete *i;
23 } 23 }
24 24
25 virtual int setWeights( const sigtype *pWeights )
26 {
27 int iOffset = 0;
28 for( typename Container<sigtype>::NodeList::iterator i =
29 Container<sigtype>::getNodeList().begin(); i; i++ )
30 {
31 iOffset += (*i)->setWeights( &pWeights[iOffset] );
32 }
33
34 return iOffset;
35 }
36
37 virtual int setBiases( const sigtype *pBiases )
38 {
39 int iOffset = 0;
40 for( typename Container<sigtype>::NodeList::iterator i =
41 Container<sigtype>::getNodeList().begin(); i; i++ )
42 {
43 iOffset += (*i)->setBiases( &pBiases[iOffset] );
44 }
45
46 return iOffset;
47 }
48
25 virtual void addNode( Node<sigtype> *pNode ) 49 virtual void addNode( Node<sigtype> *pNode )
26 { 50 {
27 lNodes.append( pNode ); 51 lNodes.append( pNode );
diff --git a/src/neuron.h b/src/neuron.h
index dc30471..2ad5cfb 100644
--- a/src/neuron.h
+++ b/src/neuron.h
@@ -3,6 +3,9 @@
3 3
4#include "neural/node.h" 4#include "neural/node.h"
5#include "neural/slope.h" 5#include "neural/slope.h"
6#include "neural/slopestd.h"
7
8#include <bu/sio.h>
6 9
7namespace Neural 10namespace Neural
8{ 11{
@@ -14,7 +17,7 @@ namespace Neural
14 iInputs( 0 ), 17 iInputs( 0 ),
15 aWeights( 0 ), 18 aWeights( 0 ),
16 sBias( 0.0 ), 19 sBias( 0.0 ),
17 pSlope( 0 ) 20 pSlope( new Neural::SlopeStd<sigtype>() )
18 { 21 {
19 } 22 }
20 23
@@ -30,6 +33,21 @@ namespace Neural
30 aWeights = new sigtype[iInputs]; 33 aWeights = new sigtype[iInputs];
31 } 34 }
32 35
36 virtual int setWeights( const sigtype *pWeights )
37 {
38 for( int j = 0; j < iInputs; j++ )
39 aWeights[j] = pWeights[j];
40
41 return iInputs;
42 }
43
44 virtual int setBiases( const sigtype *pBiases )
45 {
46 sBias = *pBiases;
47
48 return 1;
49 }
50
33 virtual void process( sigtype *aInput, sigtype *aOutput ) 51 virtual void process( sigtype *aInput, sigtype *aOutput )
34 { 52 {
35 sigtype sOutput = sBias; 53 sigtype sOutput = sBias;
diff --git a/src/node.h b/src/node.h
index fe2b720..1b82327 100644
--- a/src/node.h
+++ b/src/node.h
@@ -16,6 +16,8 @@ namespace Neural
16 } 16 }
17 17
18 virtual void finalize( int iNumInputs )=0; 18 virtual void finalize( int iNumInputs )=0;
19 virtual int setWeights( const sigtype *pWeights )=0;
20 virtual int setBiases( const sigtype *pBiases )=0;
19 virtual void process( sigtype *aInput, sigtype *aOutput )=0; 21 virtual void process( sigtype *aInput, sigtype *aOutput )=0;
20 22
21 virtual int getNumInputs() const=0; 23 virtual int getNumInputs() const=0;
diff --git a/src/slopestd.h b/src/slopestd.h
index 40ceef2..116ce1c 100644
--- a/src/slopestd.h
+++ b/src/slopestd.h
@@ -43,7 +43,7 @@ namespace Neural
43 43
44 virtual sigtype operator()( sigtype sInput ) 44 virtual sigtype operator()( sigtype sInput )
45 { 45 {
46 return (tpltanh<sigtype>(2.0*sSlope*sInput) + 1.0)/2.0; 46 return tpltanh<sigtype>(2.0*sSlope*sInput);
47 } 47 }
48 48
49 private: 49 private:
diff --git a/src/tests/pic.cpp b/src/tests/pic.cpp
index 95bb523..4f4139f 100644
--- a/src/tests/pic.cpp
+++ b/src/tests/pic.cpp
@@ -3,8 +3,20 @@
3#include "neural/row.h" 3#include "neural/row.h"
4#include "neural/neuron.h" 4#include "neural/neuron.h"
5 5
6#include <time.h>
7
8#include <bu/random.h>
9#include <bu/sio.h>
10using namespace Bu;
11
12#include <stdio.h>
13#include <png.h>
14#include <zlib.h>
15
6int main( int argc, char *argv[] ) 16int main( int argc, char *argv[] )
7{ 17{
18 Bu::Random::seed( time( NULL ) );
19
8 Neural::Column<float> *c = new Neural::Column<float>(); 20 Neural::Column<float> *c = new Neural::Column<float>();
9 Neural::Row<float> *r1 = new Neural::Row<float>(); 21 Neural::Row<float> *r1 = new Neural::Row<float>();
10 r1->addNode( new Neural::Neuron<float>() ); 22 r1->addNode( new Neural::Neuron<float>() );
@@ -27,6 +39,107 @@ int main( int argc, char *argv[] )
27 r3->addNode( new Neural::Neuron<float>() ); 39 r3->addNode( new Neural::Neuron<float>() );
28 r3->addNode( new Neural::Neuron<float>() ); 40 r3->addNode( new Neural::Neuron<float>() );
29 c->addNode( r3 ); 41 c->addNode( r3 );
42
43 c->finalize( 2 );
44 sio << "Total weights: " << c->getNumWeights() << sio.nl;
45 sio << "Total biases: " << c->getNumBiases() << sio.nl;
46 sio << "Network inputs: " << c->getNumInputs() << sio.nl;
47 sio << "Network outputs: " << c->getNumOutputs() << sio.nl;
48
49 float *pWeights = new float[c->getNumWeights()];
50 float *pBiases = new float[c->getNumBiases()];
51
52 for( int j = 0; j < c->getNumWeights(); j++ )
53 pWeights[j] = (Bu::Random::randNorm()*2.0)-1.0;
54 for( int j = 0; j < c->getNumBiases(); j++ )
55 pBiases[j] = (Bu::Random::randNorm()*2.0)-1.0;
56
57 c->setWeights( pWeights );
58 c->setBiases( pBiases );
59 delete pWeights;
60 delete pBiases;
61
62 float *pIn = new float[c->getNumInputs()];
63 float *pOut = new float[c->getNumOutputs()];
64
65 FILE *fp = fopen("test.png", "wb");
66
67 if (!fp)
68 return 1;
69
70 png_structp png_ptr = png_create_write_struct
71 (PNG_LIBPNG_VER_STRING, NULL, NULL, NULL );
72
73 if (!png_ptr)
74 return 1;
75
76 png_infop info_ptr = png_create_info_struct(png_ptr);
77 if (!info_ptr)
78 {
79 png_destroy_write_struct(&png_ptr,
80 (png_infopp)NULL);
81 return 1;
82 }
83
84 png_set_IHDR(png_ptr, info_ptr, 500, 500, 8, PNG_COLOR_TYPE_RGB,
85 PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
86 PNG_FILTER_TYPE_DEFAULT
87 );
88
89 /* Set the zlib compression level */
90 png_set_compression_level(png_ptr,
91 Z_BEST_COMPRESSION);
92
93 /* Set other zlib parameters for compressing IDAT */
94 png_set_compression_mem_level(png_ptr, 8);
95 png_set_compression_strategy(png_ptr,
96 Z_DEFAULT_STRATEGY);
97 png_set_compression_window_bits(png_ptr, 15);
98 png_set_compression_method(png_ptr, 8);
99 png_set_compression_buffer_size(png_ptr, 8192);
100
101 /* Set zlib parameters for text compression
102 * If you don't call these, the parameters
103 * fall back on those defined for IDAT chunks
104 */
105 png_set_text_compression_mem_level(png_ptr, 8);
106 png_set_text_compression_strategy(png_ptr,
107 Z_DEFAULT_STRATEGY);
108 png_set_text_compression_window_bits(png_ptr, 15);
109 png_set_text_compression_method(png_ptr, 8);
110
111 png_bytep *row_pointers;
112 row_pointers = new png_bytep[500];
113 for( int y = 0; y < 500; y++ )
114 {
115 row_pointers[y] = new png_byte[500*3];
116 for( int x = 0; x < 500; x++ )
117 {
118 pIn[0] = (x/499.0)*2.0-1.0;
119 pIn[1] = (y/499.0)*2.0-1.0;
120 c->process( pIn, pOut );
121 row_pointers[y][x*3+0] = pOut[0]*127+127;
122 row_pointers[y][x*3+1] = pOut[1]*127+127;
123 row_pointers[y][x*3+2] = pOut[2]*127+127;
124 }
125 }
126
127 png_init_io( png_ptr, fp );
128
129 png_set_rows( png_ptr, info_ptr, row_pointers );
130
131 png_write_png(png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, NULL);
132
133 if (setjmp(png_jmpbuf(png_ptr)))
134 {
135 png_destroy_write_struct(&png_ptr, &info_ptr);
136 fclose(fp);
137 return 1;
138 }
139 png_destroy_write_struct( &png_ptr, &info_ptr );
140 fclose( fp );
141
142 delete c;
30 143
31 return 0; 144 return 0;
32} 145}