From 673252f7eadc3aa0cfae3d826e1c7bbe2400df17 Mon Sep 17 00:00:00 2001
From: Mike Buland <mike@xagasoft.com>
Date: Mon, 9 Jul 2012 13:57:37 -0600
Subject: It generates pngs just like the java version.

Maybe even prettier.
---
 src/column.h      |   3 +-
 src/container.h   |  24 ++++++++++++
 src/neuron.h      |  20 +++++++++-
 src/node.h        |   2 +
 src/slopestd.h    |   2 +-
 src/tests/pic.cpp | 113 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 161 insertions(+), 3 deletions(-)

(limited to 'src')

diff --git a/src/column.h b/src/column.h
index 8c07f62..d1b670d 100644
--- a/src/column.h
+++ b/src/column.h
@@ -48,7 +48,8 @@ namespace Neural
 			{
 				(*i)->process( pBuffer, pNextBuffer );
 				pBuffer = pNextBuffer;
-				iBuf++;
+				if( iBuf )
+					iBuf++;
 				if( iBuf )
 					pNextBuffer = *iBuf;
 				else
diff --git a/src/container.h b/src/container.h
index d9eeffd..f341bf8 100644
--- a/src/container.h
+++ b/src/container.h
@@ -22,6 +22,30 @@ namespace Neural
 				delete *i;
 		}
 
+		virtual int setWeights( const sigtype *pWeights )
+		{
+			int iOffset = 0;
+			for( typename Container<sigtype>::NodeList::iterator i =
+				Container<sigtype>::getNodeList().begin(); i; i++ )
+			{
+				iOffset += (*i)->setWeights( &pWeights[iOffset] );
+			}
+
+			return iOffset;
+		}
+
+		virtual int setBiases( const sigtype *pBiases )
+		{
+			int iOffset = 0;
+			for( typename Container<sigtype>::NodeList::iterator i =
+				Container<sigtype>::getNodeList().begin(); i; i++ )
+			{
+				iOffset += (*i)->setBiases( &pBiases[iOffset] );
+			}
+
+			return iOffset;
+		}
+
 		virtual void addNode( Node<sigtype> *pNode )
 		{
 			lNodes.append( pNode );
diff --git a/src/neuron.h b/src/neuron.h
index dc30471..2ad5cfb 100644
--- a/src/neuron.h
+++ b/src/neuron.h
@@ -3,6 +3,9 @@
 
 #include "neural/node.h"
 #include "neural/slope.h"
+#include "neural/slopestd.h"
+
+#include <bu/sio.h>
 
 namespace Neural
 {
@@ -14,7 +17,7 @@ namespace Neural
 			iInputs( 0 ),
 			aWeights( 0 ),
 			sBias( 0.0 ),
-			pSlope( 0 )
+			pSlope( new Neural::SlopeStd<sigtype>() )
 		{
 		}
 
@@ -30,6 +33,21 @@ namespace Neural
 			aWeights = new sigtype[iInputs];
 		}
 
+		virtual int setWeights( const sigtype *pWeights )
+		{
+			for( int j = 0; j < iInputs; j++ )
+				aWeights[j] = pWeights[j];
+
+			return iInputs;
+		}
+
+		virtual int setBiases( const sigtype *pBiases )
+		{
+			sBias = *pBiases;
+
+			return 1;
+		}
+
 		virtual void process( sigtype *aInput, sigtype *aOutput )
 		{
 			sigtype sOutput = sBias;
diff --git a/src/node.h b/src/node.h
index fe2b720..1b82327 100644
--- a/src/node.h
+++ b/src/node.h
@@ -16,6 +16,8 @@ namespace Neural
 		}
 
 		virtual void finalize( int iNumInputs )=0;
+		virtual int setWeights( const sigtype *pWeights )=0;
+		virtual int setBiases( const sigtype *pBiases )=0;
 		virtual void process( sigtype *aInput, sigtype *aOutput )=0;
 
 		virtual int getNumInputs() const=0;
diff --git a/src/slopestd.h b/src/slopestd.h
index 40ceef2..116ce1c 100644
--- a/src/slopestd.h
+++ b/src/slopestd.h
@@ -43,7 +43,7 @@ namespace Neural
 
 		virtual sigtype operator()( sigtype sInput )
 		{
-			return (tpltanh<sigtype>(2.0*sSlope*sInput) + 1.0)/2.0;
+			return tpltanh<sigtype>(2.0*sSlope*sInput);
 		}
 
 	private:
diff --git a/src/tests/pic.cpp b/src/tests/pic.cpp
index 95bb523..4f4139f 100644
--- a/src/tests/pic.cpp
+++ b/src/tests/pic.cpp
@@ -3,8 +3,20 @@
 #include "neural/row.h"
 #include "neural/neuron.h"
 
+#include <time.h>
+
+#include <bu/random.h>
+#include <bu/sio.h>
+using namespace Bu;
+
+#include <stdio.h>
+#include <png.h>
+#include <zlib.h>
+
 int main( int argc, char *argv[] )
 {
+	Bu::Random::seed( time( NULL ) );
+
 	Neural::Column<float> *c = new Neural::Column<float>();
 	Neural::Row<float> *r1 = new Neural::Row<float>();
 	r1->addNode( new Neural::Neuron<float>() );
@@ -27,6 +39,107 @@ int main( int argc, char *argv[] )
 	r3->addNode( new Neural::Neuron<float>() );
 	r3->addNode( new Neural::Neuron<float>() );
 	c->addNode( r3 );
+
+	c->finalize( 2 );
+	sio << "Total weights:   " << c->getNumWeights() << sio.nl;
+	sio << "Total biases:    " << c->getNumBiases() << sio.nl;
+	sio << "Network inputs:  " << c->getNumInputs() << sio.nl;
+	sio << "Network outputs: " << c->getNumOutputs() << sio.nl;
+
+	float *pWeights = new float[c->getNumWeights()];
+	float *pBiases = new float[c->getNumBiases()];
+
+	for( int j = 0; j < c->getNumWeights(); j++ )
+		pWeights[j] = (Bu::Random::randNorm()*2.0)-1.0;
+	for( int j = 0; j < c->getNumBiases(); j++ )
+		pBiases[j] = (Bu::Random::randNorm()*2.0)-1.0;
+
+	c->setWeights( pWeights );
+	c->setBiases( pBiases );
+	delete pWeights;
+	delete pBiases;
+
+	float *pIn = new float[c->getNumInputs()];
+	float *pOut = new float[c->getNumOutputs()];
+
+	FILE *fp = fopen("test.png", "wb");
+
+    if (!fp)
+       return 1;
+
+    png_structp png_ptr = png_create_write_struct
+       (PNG_LIBPNG_VER_STRING, NULL, NULL, NULL );
+
+    if (!png_ptr)
+       return 1;
+
+    png_infop info_ptr = png_create_info_struct(png_ptr);
+    if (!info_ptr)
+    {
+       png_destroy_write_struct(&png_ptr,
+           (png_infopp)NULL);
+       return 1;
+    }
+
+    png_set_IHDR(png_ptr, info_ptr, 500, 500, 8, PNG_COLOR_TYPE_RGB,
+		PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
+		PNG_FILTER_TYPE_DEFAULT
+		);
+   
+    /* Set the zlib compression level */
+    png_set_compression_level(png_ptr,
+        Z_BEST_COMPRESSION);
+
+    /* Set other zlib parameters for compressing IDAT */
+    png_set_compression_mem_level(png_ptr, 8);
+    png_set_compression_strategy(png_ptr,
+        Z_DEFAULT_STRATEGY);
+    png_set_compression_window_bits(png_ptr, 15);
+    png_set_compression_method(png_ptr, 8);
+    png_set_compression_buffer_size(png_ptr, 8192);
+
+    /* Set zlib parameters for text compression
+     * If you don't call these, the parameters
+     * fall back on those defined for IDAT chunks
+     */
+    png_set_text_compression_mem_level(png_ptr, 8);
+    png_set_text_compression_strategy(png_ptr,
+        Z_DEFAULT_STRATEGY);
+    png_set_text_compression_window_bits(png_ptr, 15);
+    png_set_text_compression_method(png_ptr, 8);
+
+	png_bytep *row_pointers;
+	row_pointers = new png_bytep[500];
+	for( int y = 0; y < 500; y++ )
+	{
+		row_pointers[y] = new png_byte[500*3];
+		for( int x = 0; x < 500; x++ )
+		{
+			pIn[0] = (x/499.0)*2.0-1.0;
+			pIn[1] = (y/499.0)*2.0-1.0;
+			c->process( pIn, pOut );
+			row_pointers[y][x*3+0] = pOut[0]*127+127;
+			row_pointers[y][x*3+1] = pOut[1]*127+127;
+			row_pointers[y][x*3+2] = pOut[2]*127+127;
+		}
+	}
+
+	png_init_io( png_ptr, fp );
+
+	png_set_rows( png_ptr, info_ptr, row_pointers );
+
+	png_write_png(png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, NULL);
+
+	if (setjmp(png_jmpbuf(png_ptr)))
+    {
+    	png_destroy_write_struct(&png_ptr, &info_ptr);
+	       fclose(fp);
+    	   return 1;
+    }
+	png_destroy_write_struct( &png_ptr, &info_ptr );
+	fclose( fp );
+
+	delete c;
 	
 	return 0;
 }
-- 
cgit v1.2.3