/* * Copyright (C) 2007-2019 Xagasoft, All rights reserved. * * This file is part of the libbu++ library and is released under the * terms of the license contained in the file LICENSE. */ #include "bu/config.h" #include "bu/protocolwebsocket.h" #include "bu/sio.h" #include "bu/fmt.h" #include "bu/client.h" #include "bu/membuf.h" #include "bu/base64.h" #include "bu/sha1.h" #include "bu/json.h" #include "bu/mutexlocker.h" #include #define DEBUG( X ) { } (void)0 Bu::ProtocolWebSocket::ProtocolWebSocket() : eStatus( stProtoId ) { } Bu::ProtocolWebSocket::~ProtocolWebSocket() { mClient.lock(); this->pClient = NULL; mClient.unlock(); } void Bu::ProtocolWebSocket::onNewConnection( Bu::Client *pClient ) { mClient.lock(); this->pClient = pClient; mClient.unlock(); } void Bu::ProtocolWebSocket::onNewData( Bu::Client * /*pClient*/ ) { for(;;) { switch( eStatus ) { case stProtoId: if( !stateProtoId() ) return; break; case stHandshake: if( !stateHandshake() ) return; break; case stReady: if( !parseMessage() ) return; break; } } } bool Bu::ProtocolWebSocket::onProcessHeaders( Bu::StringList & /*lHeadersOut*/ ) { return true; } void Bu::ProtocolWebSocket::writeMessage( const Bu::String &sData, Bu::ProtocolWebSocket::Operation eOp ) { DEBUG( Bu::println("websocket: Writing message, %1 bytes").arg( sData.getSize() ) ); uint8_t cHeader[32]; //uint8_t *cMask; memset( cHeader, 0, 32 ); int idx = 2; cHeader[0] = (((uint8_t)(eOp&0x0f)))|0x80; uint64_t iLen = sData.getSize(); if( iLen < 126 ) { DEBUG( Bu::println("websocket: --> Tiny header") ); cHeader[1] = ((uint8_t)iLen); } else if( iLen < 65536 ) { DEBUG( Bu::println("websocket: --> Mid header") ); cHeader[1] = ((uint8_t)126); uint16_t uLen = iLen; uLen = htobe16( uLen ); memcpy( cHeader+idx, &uLen, 2 ); idx += 2; } else { DEBUG( Bu::println("websocket: --> Big header?") ); cHeader[1] = ((uint8_t)127); uint64_t iTmp = htobe64( iLen ); memcpy( cHeader+idx, &iTmp, 8 ); idx += 8; } DEBUG( Bu::println("Message size: %1 (%2)").arg( iLen ).arg( iLen, Bu::Fmt::bin(4) ) ); for( int j = 0; j < idx; j++ ) { DEBUG( Bu::print(" %1").arg( cHeader[j], Bu::Fmt::bin(8) ) ); } DEBUG( Bu::println("") ); Bu::MutexLocker l( mClient ); if( pClient == NULL ) return; pClient->write( cHeader, idx ); pClient->write( sData ); } bool Bu::ProtocolWebSocket::stateProtoId() { Bu::String sLine; if( !readHttpHdrLine( sLine ) ) return false; Bu::StringList lChunks = sLine.split(' '); if( lChunks.getSize() != 3 ) { Bu::MutexLocker l( mClient ); pClient->disconnect(); return false; } Bu::StringList::iterator i = lChunks.begin(); if( *i != "GET" ) { Bu::MutexLocker l( mClient ); pClient->disconnect(); return false; } sPath = *(++i); if( *(++i) != "HTTP/1.1" ) { Bu::MutexLocker l( mClient ); pClient->disconnect(); return false; } eStatus = stHandshake; return true; } bool Bu::ProtocolWebSocket::stateHandshake() { DEBUG( Bu::println("websocket: Begining handshake.") ); Bu::String sLine; if( !readHttpHdrLine( sLine ) ) return false; if( sLine.getSize() == 0 ) { if( !processHeaders() ) return false; onHandshakeComplete(); eStatus = stReady; return true; } int iPos = sLine.findIdx(':'); if( iPos < 0 ) { Bu::MutexLocker l( mClient ); pClient->disconnect(); return false; } Bu::String sKey( sLine, iPos ); Bu::String sValue( sLine.getSubStrIdx( iPos+2 ) ); sKey = sKey.toLower(); if( !hHeader.has( sKey ) ) { hHeader.insert( sKey, Bu::StringList() ); } hHeader.get( sKey ).append( sValue ); DEBUG( Bu::println("Hdr: >>%1<<").arg( sLine ) ); DEBUG( Bu::println("%1 = %2").arg( sKey ).arg( sValue ) ); return true; } bool Bu::ProtocolWebSocket::readHttpHdrLine( Bu::String &sLine ) { char buf[1024]; int iSize = pClient->peek( buf, 1024 ); for( int j = 0; j < iSize-1; j++ ) { if( buf[j] == '\r' && buf[j+1] == '\n' ) { pClient->seek(j+2); sLine.set( buf, j ); return true; } } return false; } bool Bu::ProtocolWebSocket::processHeaders() { Bu::MutexLocker l( mClient ); if( !headerMatch("Connection", "Upgrade") || !headerMatch("Upgrade", "websocket") || !headerMatch("Sec-WebSocket-Version", "13") ) { pClient->disconnect(); return false; } Bu::String sNonce; if( !hHeader.has("sec-websocket-key") ) // "Sec-WebSocket-Key" { pClient->disconnect(); return false; } sNonce = hHeader.get("sec-websocket-key").first(); // "Sec-WebSocket-Key" Bu::String sGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; Bu::Sha1 sum; sum.addData( sNonce+sGuid ); Bu::MemBuf mbOut; Bu::Base64 bOut( mbOut ); sum.writeResult( bOut ); bOut.stop(); DEBUG( Bu::println("accept: %1").arg( mbOut.getString() ) ); Bu::StringList lHeadersOut; lHeadersOut.append("Upgrade: websocket"); lHeadersOut.append("Connection: Upgrade"); lHeadersOut.append("Sec-WebSocket-Accept: " + mbOut.getString()); if( !onProcessHeaders( lHeadersOut ) ) { pClient->disconnect(); return false; } Bu::String sHeaderBlock("HTTP/1.1 101 Switching Protocols\r\n"); for( Bu::StringList::iterator i = lHeadersOut.begin(); i; i++ ) { sHeaderBlock += (*i) + "\r\n"; } sHeaderBlock += "\r\n"; pClient->write( sHeaderBlock ); DEBUG( Bu::println("websocket: Switching protocols.") ); return true; } bool Bu::ProtocolWebSocket::headerMatch( const Bu::String &sKey, const Bu::String &sValue ) { Bu::String sKeyLow = sKey.toLower(); if( !hHeader.has( sKeyLow ) ) return false; for( Bu::StringList::iterator i = hHeader.get( sKeyLow ).begin(); i; i++ ) { if( !strcasecmp((*i).getStr(), sValue.getStr()) ) return true; } return false; } bool Bu::ProtocolWebSocket::parseMessage() { DEBUG( Bu::println("websocket: Recieved message, input available: %1").arg( pClient->getInputSize() ) ); if( pClient->getInputSize() < 2 ) return false; uint8_t buf[32]; int64_t iTgtLength = 2; pClient->peek( buf, 2, 0 ); Operation eOp = (Operation)(buf[0]&0x0f); int64_t iLen = buf[1]&(~0x80); bool bMasked = (buf[1]&0x80) == 0x80; if( iLen == 126 ) { uint16_t iLenBuf; if( pClient->getInputSize() < iTgtLength+2 ) return false; pClient->peek( &iLenBuf, 2, iTgtLength ); iTgtLength += 2; iLen = be16toh( iLenBuf ); } else if( iLen == 127 ) { if( pClient->getInputSize() < iTgtLength+8 ) return false; pClient->peek( &iLen, 8, iTgtLength ); iTgtLength += 8; iLen = be64toh( iLen ); } char cMask[4]; if( bMasked ) { if( pClient->getInputSize() < iTgtLength+4 ) return false; pClient->peek( cMask, 4, iTgtLength ); iTgtLength += 4; } if( pClient->getInputSize() < iTgtLength+iLen ) return false; pClient->seek( iTgtLength ); Bu::String sData( iLen ); int iRead = 0; do { iRead += pClient->read( sData.getStr()+iRead, iLen-iRead ); } while( iRead < iLen ); if( bMasked ) { for( int j = 0; j < iLen; j++ ) { sData[j] = sData[j]^cMask[j%4]; } } DEBUG( Bu::println("") ); DEBUG( Bu::println("Data: >>%1<<").arg( sData ) ); onNewMessage( sData, eOp ); return true; }