From 673252f7eadc3aa0cfae3d826e1c7bbe2400df17 Mon Sep 17 00:00:00 2001 From: Mike Buland Date: Mon, 9 Jul 2012 13:57:37 -0600 Subject: It generates pngs just like the java version. Maybe even prettier. --- src/column.h | 3 +- src/container.h | 24 ++++++++++++ src/neuron.h | 20 +++++++++- src/node.h | 2 + src/slopestd.h | 2 +- src/tests/pic.cpp | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 161 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/column.h b/src/column.h index 8c07f62..d1b670d 100644 --- a/src/column.h +++ b/src/column.h @@ -48,7 +48,8 @@ namespace Neural { (*i)->process( pBuffer, pNextBuffer ); pBuffer = pNextBuffer; - iBuf++; + if( iBuf ) + iBuf++; if( iBuf ) pNextBuffer = *iBuf; else diff --git a/src/container.h b/src/container.h index d9eeffd..f341bf8 100644 --- a/src/container.h +++ b/src/container.h @@ -22,6 +22,30 @@ namespace Neural delete *i; } + virtual int setWeights( const sigtype *pWeights ) + { + int iOffset = 0; + for( typename Container::NodeList::iterator i = + Container::getNodeList().begin(); i; i++ ) + { + iOffset += (*i)->setWeights( &pWeights[iOffset] ); + } + + return iOffset; + } + + virtual int setBiases( const sigtype *pBiases ) + { + int iOffset = 0; + for( typename Container::NodeList::iterator i = + Container::getNodeList().begin(); i; i++ ) + { + iOffset += (*i)->setBiases( &pBiases[iOffset] ); + } + + return iOffset; + } + virtual void addNode( Node *pNode ) { lNodes.append( pNode ); diff --git a/src/neuron.h b/src/neuron.h index dc30471..2ad5cfb 100644 --- a/src/neuron.h +++ b/src/neuron.h @@ -3,6 +3,9 @@ #include "neural/node.h" #include "neural/slope.h" +#include "neural/slopestd.h" + +#include namespace Neural { @@ -14,7 +17,7 @@ namespace Neural iInputs( 0 ), aWeights( 0 ), sBias( 0.0 ), - pSlope( 0 ) + pSlope( new Neural::SlopeStd() ) { } @@ -30,6 +33,21 @@ namespace Neural aWeights = new sigtype[iInputs]; } + virtual int setWeights( const sigtype *pWeights ) + { + for( int j = 0; j < iInputs; j++ ) + aWeights[j] = pWeights[j]; + + return iInputs; + } + + virtual int setBiases( const sigtype *pBiases ) + { + sBias = *pBiases; + + return 1; + } + virtual void process( sigtype *aInput, sigtype *aOutput ) { sigtype sOutput = sBias; diff --git a/src/node.h b/src/node.h index fe2b720..1b82327 100644 --- a/src/node.h +++ b/src/node.h @@ -16,6 +16,8 @@ namespace Neural } virtual void finalize( int iNumInputs )=0; + virtual int setWeights( const sigtype *pWeights )=0; + virtual int setBiases( const sigtype *pBiases )=0; virtual void process( sigtype *aInput, sigtype *aOutput )=0; virtual int getNumInputs() const=0; diff --git a/src/slopestd.h b/src/slopestd.h index 40ceef2..116ce1c 100644 --- a/src/slopestd.h +++ b/src/slopestd.h @@ -43,7 +43,7 @@ namespace Neural virtual sigtype operator()( sigtype sInput ) { - return (tpltanh(2.0*sSlope*sInput) + 1.0)/2.0; + return tpltanh(2.0*sSlope*sInput); } private: diff --git a/src/tests/pic.cpp b/src/tests/pic.cpp index 95bb523..4f4139f 100644 --- a/src/tests/pic.cpp +++ b/src/tests/pic.cpp @@ -3,8 +3,20 @@ #include "neural/row.h" #include "neural/neuron.h" +#include + +#include +#include +using namespace Bu; + +#include +#include +#include + int main( int argc, char *argv[] ) { + Bu::Random::seed( time( NULL ) ); + Neural::Column *c = new Neural::Column(); Neural::Row *r1 = new Neural::Row(); r1->addNode( new Neural::Neuron() ); @@ -27,6 +39,107 @@ int main( int argc, char *argv[] ) r3->addNode( new Neural::Neuron() ); r3->addNode( new Neural::Neuron() ); c->addNode( r3 ); + + c->finalize( 2 ); + sio << "Total weights: " << c->getNumWeights() << sio.nl; + sio << "Total biases: " << c->getNumBiases() << sio.nl; + sio << "Network inputs: " << c->getNumInputs() << sio.nl; + sio << "Network outputs: " << c->getNumOutputs() << sio.nl; + + float *pWeights = new float[c->getNumWeights()]; + float *pBiases = new float[c->getNumBiases()]; + + for( int j = 0; j < c->getNumWeights(); j++ ) + pWeights[j] = (Bu::Random::randNorm()*2.0)-1.0; + for( int j = 0; j < c->getNumBiases(); j++ ) + pBiases[j] = (Bu::Random::randNorm()*2.0)-1.0; + + c->setWeights( pWeights ); + c->setBiases( pBiases ); + delete pWeights; + delete pBiases; + + float *pIn = new float[c->getNumInputs()]; + float *pOut = new float[c->getNumOutputs()]; + + FILE *fp = fopen("test.png", "wb"); + + if (!fp) + return 1; + + png_structp png_ptr = png_create_write_struct + (PNG_LIBPNG_VER_STRING, NULL, NULL, NULL ); + + if (!png_ptr) + return 1; + + png_infop info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) + { + png_destroy_write_struct(&png_ptr, + (png_infopp)NULL); + return 1; + } + + png_set_IHDR(png_ptr, info_ptr, 500, 500, 8, PNG_COLOR_TYPE_RGB, + PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT, + PNG_FILTER_TYPE_DEFAULT + ); + + /* Set the zlib compression level */ + png_set_compression_level(png_ptr, + Z_BEST_COMPRESSION); + + /* Set other zlib parameters for compressing IDAT */ + png_set_compression_mem_level(png_ptr, 8); + png_set_compression_strategy(png_ptr, + Z_DEFAULT_STRATEGY); + png_set_compression_window_bits(png_ptr, 15); + png_set_compression_method(png_ptr, 8); + png_set_compression_buffer_size(png_ptr, 8192); + + /* Set zlib parameters for text compression + * If you don't call these, the parameters + * fall back on those defined for IDAT chunks + */ + png_set_text_compression_mem_level(png_ptr, 8); + png_set_text_compression_strategy(png_ptr, + Z_DEFAULT_STRATEGY); + png_set_text_compression_window_bits(png_ptr, 15); + png_set_text_compression_method(png_ptr, 8); + + png_bytep *row_pointers; + row_pointers = new png_bytep[500]; + for( int y = 0; y < 500; y++ ) + { + row_pointers[y] = new png_byte[500*3]; + for( int x = 0; x < 500; x++ ) + { + pIn[0] = (x/499.0)*2.0-1.0; + pIn[1] = (y/499.0)*2.0-1.0; + c->process( pIn, pOut ); + row_pointers[y][x*3+0] = pOut[0]*127+127; + row_pointers[y][x*3+1] = pOut[1]*127+127; + row_pointers[y][x*3+2] = pOut[2]*127+127; + } + } + + png_init_io( png_ptr, fp ); + + png_set_rows( png_ptr, info_ptr, row_pointers ); + + png_write_png(png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, NULL); + + if (setjmp(png_jmpbuf(png_ptr))) + { + png_destroy_write_struct(&png_ptr, &info_ptr); + fclose(fp); + return 1; + } + png_destroy_write_struct( &png_ptr, &info_ptr ); + fclose( fp ); + + delete c; return 0; } -- cgit v1.2.3