fork-sever-process-and-introduce-broadcast-api.patch

Signed-off-by: Andy Green <andy@warmcat.com>
diff --git a/lib/libwebsockets.c b/lib/libwebsockets.c
index fe6408d..bfc5615 100644
--- a/lib/libwebsockets.c
+++ b/lib/libwebsockets.c
@@ -54,7 +54,7 @@
  * 
  * 	LWS_CALLBACK_CLOSED: when the websocket session ends
  *
- * 	LWS_CALLBACK_SEND: opportunity to send to client (you would use
+ * 	LWS_CALLBACK_BROADCAST: signal to send to client (you would use
  * 				libwebsocket_write() taking care about the
  * 				special buffer requirements
  * 	LWS_CALLBACK_RECEIVE: data has appeared for the server, it can be
@@ -77,7 +77,12 @@
 void 
 libwebsocket_close_and_free_session(struct libwebsocket *wsi)
 {
-	int n = wsi->state;
+	int n;
+
+	if ((unsigned long)wsi < LWS_MAX_PROTOCOLS)
+		return;
+
+	n = wsi->state;
 
 	wsi->state = WSI_STATE_DEAD_SOCKET;
 
@@ -110,12 +115,136 @@
 	free(wsi);
 }
 
+static int
+libwebsocket_poll_connections(struct libwebsocket_context *this)
+{
+	unsigned char buf[LWS_SEND_BUFFER_PRE_PADDING + MAX_BROADCAST_PAYLOAD +
+						  LWS_SEND_BUFFER_POST_PADDING];
+	int client;
+	int n;
+	size_t len;
+
+	/* check for activity on client sockets */
+	
+	for (client = this->count_protocols + 1; client < this->fds_count;
+								     client++) {
+		
+		/* handle session socket closed */
+		
+		if (this->fds[client].revents & (POLLERR | POLLHUP)) {
+			
+			debug("Session Socket %d %p (fd=%d) dead\n",
+				  client, this->wsi[client], this->fds[client]);
+
+			libwebsocket_close_and_free_session(this->wsi[client]);
+			goto nuke_this;
+		}
+		
+		/* any incoming data ready? */
+
+		if (!(this->fds[client].revents & POLLIN))
+			continue;
+
+		/* broadcast? */
+
+		if ((unsigned long)this->wsi[client] < LWS_MAX_PROTOCOLS) {
+
+			len = read(this->fds[client].fd,
+				   buf + LWS_SEND_BUFFER_PRE_PADDING,
+				   MAX_BROADCAST_PAYLOAD);
+			if (len < 0) {
+				fprintf(stderr, "Error receiving broadcast payload\n");
+				continue;
+			}
+
+			/* broadcast it to all guys with this protocol index */
+
+			for (n = this->count_protocols + 1;
+						     n < this->fds_count; n++) {
+
+				if ((unsigned long)this->wsi[n] <
+							      LWS_MAX_PROTOCOLS)
+					continue;
+
+				/*
+				 * never broadcast to non-established
+				 * connection
+				 */
+
+				if (this->wsi[n]->state != WSI_STATE_ESTABLISHED)
+					continue;
+
+				/*
+				 * only broadcast to connections using
+				 * the requested protocol
+				 */
+
+				if (this->wsi[n]->protocol->protocol_index !=
+					       (unsigned long)this->wsi[client])
+					continue;
+
+				this->wsi[n]->protocol-> callback(this->wsi[n],
+						 LWS_CALLBACK_BROADCAST, 
+						 this->wsi[n]->user_space,
+						 buf + LWS_SEND_BUFFER_PRE_PADDING, len);
+			}
+
+			continue;
+		}
+
+#ifdef LWS_OPENSSL_SUPPORT
+		if (this->use_ssl)
+			n = SSL_read(this->wsi[client]->ssl, buf, sizeof buf);
+		else
+#endif
+			n = recv(this->fds[client].fd, buf, sizeof buf, 0);
+
+		if (n < 0) {
+			fprintf(stderr, "Socket read returned %d\n", n);
+			continue;
+		}
+		if (!n) {
+//			fprintf(stderr, "POLLIN with 0 len waiting\n");
+				libwebsocket_close_and_free_session(
+							     this->wsi[client]);
+			goto nuke_this;
+		}
+
+
+		/* service incoming data */
+
+		if (libwebsocket_read(this->wsi[client], buf, n) >= 0)
+			continue;
+		
+		/*
+		 * it closed and nuked wsi[client], so remove the
+		 * socket handle and wsi from our service list
+		 */
+nuke_this:
+
+		debug("nuking wsi %p, fsd_count = %d\n",
+					this->wsi[client], this->fds_count - 1);
+
+		this->fds_count--;
+		for (n = client; n < this->fds_count; n++) {
+			this->fds[n] = this->fds[n + 1];
+			this->wsi[n] = this->wsi[n + 1];
+		}
+		break;
+	}
+
+	return 0;
+}
+
+
+
 /**
  * libwebsocket_create_server() - Create the listening websockets server
  * @port:	Port to listen on
  * @protocols:	Array of structures listing supported protocols and a protocol-
  * 		specific callback for each one.  The list is ended with an
  * 		entry that has a NULL callback pointer.
+ * 	        It's not const because we write the owning_server member
  * @ssl_cert_filepath:	If libwebsockets was compiled to use ssl, and you want
  * 			to listen using SSL, set to the filepath to fetch the
  * 			server cert from, otherwise NULL for unencrypted
@@ -127,10 +256,10 @@
  * 	This function creates the listening socket and takes care
  * 	of all initialization in one step.
  *
- * 	It does not return since it sits in a service loop and operates via the
- * 	callbacks given in @protocol.  User code should fork before calling
- * 	libwebsocket_create_server() if it wants to do other things in
- * 	parallel other than serve websockets.
+ * 	After initialization, it forks a thread that will sits in a service loop
+ *	and returns to the caller.  The actual service actions are performed by
+ * 	user code in a per-protocol callback from the appropriate one selected
+ *	by the client from the list in @protocols.
  * 
  * 	The protocol callback functions are called for a handful of events
  * 	including http requests coming in, websocket connections becoming
@@ -150,7 +279,7 @@
  */
 
 int libwebsocket_create_server(int port,
-			       const struct libwebsocket_protocols *protocols,
+			       struct libwebsocket_protocols *protocols,
 			       const char * ssl_cert_filepath,
 			       const char * ssl_private_key_filepath,
 			       int gid, int uid)
@@ -161,11 +290,9 @@
 	int fd;
 	unsigned int clilen;
 	struct sockaddr_in serv_addr, cli_addr;
-	struct libwebsocket *wsi[MAX_CLIENTS + 1];
-	struct pollfd fds[MAX_CLIENTS + 1];
-	int fds_count = 0;
-	unsigned char buf[1024];
 	int opt = 1;
+	struct libwebsocket_context * this = NULL;
+	unsigned int slen;
 
 #ifdef LWS_OPENSSL_SUPPORT
 	SSL_METHOD *method;
@@ -195,7 +322,7 @@
 			// Firefox insists on SSLv23 not SSLv3
 			// Konq disables SSLv2 by default now, SSLv23 works
 
-		method = SSLv23_server_method();   // create server instance
+		method = (SSL_METHOD *)SSLv23_server_method();
 		if (!method) {
 			fprintf(stderr, "problem creating ssl method: %s\n",
 				ERR_error_string(ERR_get_error(), ssl_err_buf));
@@ -234,6 +361,10 @@
 		/* SSL is happy and has a cert it's content with */
 	}
 #endif
+
+	this = malloc(sizeof (struct libwebsocket_context));
+
+	/* set up our external listening socket we serve on */
   
 	sockfd = socket(AF_INET, SOCK_STREAM, 0);
 	if (sockfd < 0) {
@@ -255,7 +386,7 @@
 									 errno);
               return -1;
         }
- 
+
 	/* drop any root privs for this process */
 
 	if (gid != -1)
@@ -266,65 +397,198 @@
 			fprintf(stderr, "setuid: %s\n", strerror(errno));
 
  	/*
-	 * sit there listening for connects, accept and service connections
-	 * in a poll loop, without any forking
+	 * prepare the poll() fd array... it's like this
+	 *
+	 * [0] = external listening socket
+	 * [1 .. this->count_protocols] = per-protocol broadcast sockets
+	 * [this->count_protocols + 1 ... this->fds_count-1] = connection skts
 	 */
 
+	this->fds_count = 1;
+	this->fds[0].fd = sockfd;
+	this->fds[0].events = POLLIN;
+	this->count_protocols = 0;
+#ifdef LWS_OPENSSL_SUPPORT
+	this->use_ssl = use_ssl;
+#endif
+
 	listen(sockfd, 5);
 	fprintf(stderr, " Listening on port %d\n", port);
- 	
-	fds[0].fd = sockfd;
-	fds_count = 1;
-	fds[0].events = POLLIN;
+
+	/* set up our internal broadcast trigger sockets per-protocol */
+
+	for (; protocols[this->count_protocols].callback;
+						      this->count_protocols++) {
+		protocols[this->count_protocols].owning_server = this;
+		protocols[this->count_protocols].protocol_index =
+							  this->count_protocols;
+
+		fd = socket(AF_INET, SOCK_STREAM, 0);
+		if (fd < 0) {
+			fprintf(stderr, "ERROR opening socket");
+			return -1;
+		}
+		
+		/* allow us to restart even if old sockets in TIME_WAIT */
+		setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
+
+		bzero((char *) &serv_addr, sizeof(serv_addr));
+		serv_addr.sin_family = AF_INET;
+		serv_addr.sin_addr.s_addr = inet_addr("127.0.0.1");
+		serv_addr.sin_port = 0; /* pick the port for us */
+
+		n = bind(fd, (struct sockaddr *) &serv_addr, sizeof(serv_addr));
+		if (n < 0) {
+		      fprintf(stderr, "ERROR on binding to port %d (%d %d)\n",
+								port, n, errno);
+		      return -1;
+		}
+
+		slen = sizeof cli_addr;
+		n = getsockname(fd, (struct sockaddr *)&cli_addr, &slen);
+		if (n < 0) {
+			fprintf(stderr, "getsockname failed\n");
+			return -1;
+		}
+		protocols[this->count_protocols].broadcast_socket_port =
+						       ntohs(cli_addr.sin_port);
+		listen(fd, 5);
+
+		debug("  Protocol %s broadcast socket %d\n",
+				protocols[this->count_protocols].name,
+						      ntohs(cli_addr.sin_port));
+
+		this->fds[this->fds_count].fd = fd;
+		this->fds[this->fds_count].events = POLLIN;
+		/* wsi only exists for connections, not broadcast listener */
+		this->wsi[this->fds_count] = NULL;
+		this->fds_count++;
+	}
+
+
+	/*
+	 * We will enter out poll and service loop now, just before that
+	 * fork and return to caller for the main thread of execution
+	 */
+
+	n = fork();
+	if (n < 0) {
+		fprintf(stderr, "Failed to fork websocket poll loop\n");
+		return -1;
+	}
+	if (n) {
+		/* original process context */
+
+		/*
+		 * before we return to caller, we set up per-protocol
+		 * broadcast sockets connected to the server ready to use
+		 */
+
+		/* give server fork a chance to start up */
+		sleep(1);
+
+		for (client = 1; client < this->count_protocols + 1; client++) {
+			fd = socket(AF_INET, SOCK_STREAM, 0);
+			if (fd < 0) {
+				fprintf(stderr,"Unable to create socket\n");
+				return -1;
+			}
+			cli_addr.sin_family = AF_INET;
+			cli_addr.sin_port = htons(
+				   protocols[client - 1].broadcast_socket_port);
+			cli_addr.sin_addr.s_addr = inet_addr("127.0.0.1");
+			n = connect(fd, (struct sockaddr *)&cli_addr,
+							       sizeof cli_addr);
+			if (n < 0) {
+				fprintf(stderr, "Unable to connect to "
+						"broadcast socket %d, %s\n",
+						client, strerror(errno));
+				return -1;
+			}
+
+			protocols[client - 1].broadcast_socket_user_fd = fd;
+		}
+
+		fprintf(stderr, "libwebsocket poll process forked\n");
+		
+		return 0;
+	}
+
+	/* we want a SIGHUP when our parent goes down */
+	prctl(PR_SET_PDEATHSIG, SIGHUP);
+
+	/* in this forked process, sit and service websocket connections */
     
 	while (1) {
 
- 		n = poll(fds, fds_count, 50);
-		if (n < 0 || fds[0].revents & (POLLERR | POLLHUP)) {
+		n = poll(this->fds, this->fds_count, 1000);
+
+		if (n < 0 || this->fds[0].revents & (POLLERR | POLLHUP)) {
 			fprintf(stderr, "Listen Socket dead\n");
 			goto fatal;
 		}
 		if (n == 0) /* poll timeout */
-			goto poll_out;
+			continue;
 
-		if (fds[0].revents & POLLIN) {
+		/* handle accept on listening socket? */
+
+		for (client = 0; client < this->count_protocols + 1; client++) {
+
+			if (!this->fds[client].revents & POLLIN)
+				continue;
 
 			/* listen socket got an unencrypted connection... */
 
 			clilen = sizeof(cli_addr);
-			fd  = accept(sockfd,
-				     (struct sockaddr *)&cli_addr,
-							       &clilen);
+			fd  = accept(this->fds[client].fd,
+				     (struct sockaddr *)&cli_addr, &clilen);
 			if (fd < 0) {
 				fprintf(stderr, "ERROR on accept");
 				continue;
 			}
 
-			if (fds_count >= MAX_CLIENTS) {
+			if (this->fds_count >= MAX_CLIENTS) {
 				fprintf(stderr, "too busy");
 				close(fd);
 				continue;
 			}
 
-			wsi[fds_count] = malloc(sizeof(struct libwebsocket));
-			if (!wsi[fds_count])
+			if (client) {
+				/*
+				 * accepting a connection to broadcast socket
+				 * set wsi to be protocol index not pointer
+				 */
+
+				this->wsi[this->fds_count] =
+				      (struct libwebsocket *)(long)(client - 1);
+
+				goto fill_in_fds;
+			}
+
+			/* accepting connection to main listener */
+
+			this->wsi[this->fds_count] =
+					    malloc(sizeof(struct libwebsocket));
+			if (!this->wsi[this->fds_count])
 				return -1;
 
-#ifdef LWS_OPENSSL_SUPPORT
-			if (use_ssl) {
+	#ifdef LWS_OPENSSL_SUPPORT
+			if (this->use_ssl) {
 
-				wsi[fds_count]->ssl = SSL_new(ssl_ctx);
-				if (wsi[fds_count]->ssl == NULL) {
+				this->wsi[this->fds_count]->ssl =
+							       SSL_new(ssl_ctx);
+				if (this->wsi[this->fds_count]->ssl == NULL) {
 					fprintf(stderr, "SSL_new failed: %s\n",
 					    ERR_error_string(SSL_get_error(
-					        wsi[fds_count]->ssl, 0), NULL));
-					free(wsi[fds_count]);
+					    this->wsi[this->fds_count]->ssl, 0),
+									 NULL));
+					free(this->wsi[this->fds_count]);
 					continue;
 				}
 
-				SSL_set_fd(wsi[fds_count]->ssl, fd);
+				SSL_set_fd(this->wsi[this->fds_count]->ssl, fd);
 
-				n = SSL_accept(wsi[fds_count]->ssl);
+				n = SSL_accept(this->wsi[this->fds_count]->ssl);
 				if (n != 1) {
 					/*
 					 * browsers seem to probe with various
@@ -332,32 +596,37 @@
 					 * and succeed
 					 */
 					debug("SSL_accept failed skt %u: %s\n",
-						fd,
-						ERR_error_string(SSL_get_error(
-						wsi[fds_count]->ssl, n), NULL));
-					SSL_free(wsi[fds_count]->ssl);
-					free(wsi[fds_count]);
+					      fd,
+					      ERR_error_string(SSL_get_error(
+					      this->wsi[this->fds_count]->ssl,
+								     n), NULL));
+					SSL_free(
+					       this->wsi[this->fds_count]->ssl);
+					free(this->wsi[this->fds_count]);
 					continue;
 				}
 				debug("accepted new SSL conn  "
 				      "port %u on fd=%d SSL ver %s\n",
 					ntohs(cli_addr.sin_port), fd,
-					  SSL_get_version(wsi[fds_count]->ssl));
+					  SSL_get_version(this->wsi[
+							this->fds_count]->ssl));
 				
 			} else
-#endif
+	#endif
 				debug("accepted new conn  port %u on fd=%d\n",
 						  ntohs(cli_addr.sin_port), fd);
-			
+				
 			/* intialize the instance struct */
 
-			wsi[fds_count]->sock = fd;
-			wsi[fds_count]->state = WSI_STATE_HTTP;
-			wsi[fds_count]->name_buffer_pos = 0;
+			this->wsi[this->fds_count]->sock = fd;
+			this->wsi[this->fds_count]->state = WSI_STATE_HTTP;
+			this->wsi[this->fds_count]->name_buffer_pos = 0;
 
 			for (n = 0; n < WSI_TOKEN_COUNT; n++) {
-				wsi[fds_count]->utf8_token[n].token = NULL;
-				wsi[fds_count]->utf8_token[n].token_len = 0;
+				this->wsi[this->fds_count]->
+						     utf8_token[n].token = NULL;
+				this->wsi[this->fds_count]->
+						    utf8_token[n].token_len = 0;
 			}
 
 			/*
@@ -366,8 +635,8 @@
 			 * to the start of the supported list, so it can look
 			 * for matching ones during the handshake
 			 */
-			wsi[fds_count]->protocol = protocols;
-			wsi[fds_count]->user_space = NULL;
+			this->wsi[this->fds_count]->protocol = protocols;
+			this->wsi[this->fds_count]->user_space = NULL;
 
 			/*
 			 * Default protocol is 76
@@ -375,107 +644,127 @@
 			 * draft the client wants, when that's seen we modify
 			 * the individual connection's spec revision accordingly
 			 */
-			wsi[fds_count]->ietf_spec_revision = 76;
+			this->wsi[this->fds_count]->ietf_spec_revision = 76;
 
-			fds[fds_count].events = POLLIN;
-			fds[fds_count++].fd = fd;
+fill_in_fds:
 
 			/*
 			 * make sure NO events are seen yet on this new socket
 			 * (otherwise we inherit old fds[client].revents from
 			 * previous socket there and die mysteriously! )
 			 */
-			fds[client].revents = 0;
-		}
-		
-		/* check for activity on client sockets */
-		
-		for (client = 1; client < fds_count; client++) {
-			
-			/* handle session socket closed */
-			
-			if (fds[client].revents & (POLLERR | POLLHUP)) {
-				
-				debug("Session Socket %d %p (fd=%d) dead\n",
-					      client, wsi[client], fds[client]);
+			this->fds[this->fds_count].revents = 0;
 
-				libwebsocket_close_and_free_session(
-								   wsi[client]);
-				goto nuke_this;
-			}
-			
-			/* any incoming data ready? */
+			this->fds[this->fds_count].events = POLLIN;
+			this->fds[this->fds_count++].fd = fd;
 
-			if (!(fds[client].revents & POLLIN))
-				continue;
-
-#ifdef LWS_OPENSSL_SUPPORT
-			if (use_ssl)
-				n = SSL_read(wsi[client]->ssl, buf, sizeof buf);
-			else
-#endif
-				n = recv(fds[client].fd, buf, sizeof(buf), 0);
-
-			if (n < 0) {
-				fprintf(stderr, "Socket read returned %d\n", n);
-				continue;
-			}
-			if (!n) {
-//				fprintf(stderr, "POLLIN with 0 len waiting\n");
-				libwebsocket_close_and_free_session(
-								   wsi[client]);
-				goto nuke_this;
-			}
-			
-			/* service incoming data */
-
-			if (libwebsocket_read(wsi[client], buf, n) >= 0)
-				continue;
-			
-			/*
-			 * it closed and nuked wsi[client], so remove the
-			 * socket handle and wsi from our service list
-			 */
-nuke_this:
-
-			debug("nuking wsi %p, fsd_count = %d\n",
-						   wsi[client], fds_count - 1);
-
-			fds_count--;
-			for (n = client; n < fds_count; n++) {
-				fds[n] = fds[n + 1];
-				wsi[n] = wsi[n + 1];
-			}
-			break;
 		}
 
-poll_out:		
-		for (client = 1; client < fds_count; client++) {
 
-			if (wsi[client]->state != WSI_STATE_ESTABLISHED)
-				continue;
+		/* service anything incoming on websocket connection */
 
-			wsi[client]->protocol->callback(wsi[client],
-							LWS_CALLBACK_SEND, 
-							wsi[client]->user_space,
-								       NULL, 0);
-		}
-		
-		continue;		
+		libwebsocket_poll_connections(this);
 	}
 	
 fatal:
-	/* listening socket */
-	close(fds[0].fd);
-	for (client = 1; client < fds_count; client++)
-		libwebsocket_close_and_free_session(wsi[client]);
+
+	/* close listening skt and per-protocol broadcast sockets */
+	for (client = 0; client < this->fds_count; client++)
+		close(this->fds[0].fd);
 
 #ifdef LWS_OPENSSL_SUPPORT
 	SSL_CTX_free(ssl_ctx);
 #endif
 	kill(0, SIGTERM);
+
+	if (this)
+		free(this);
 	
 	return 0;
 }
 
+/**
+ * libwebsockets_get_protocol() - Returns a protocol pointer from a websocket
+ * 				  connection.
+ * @wsi:	pointer to struct websocket you want to know the protocol of
+ *
+ * 
+ * 	This is useful to get the protocol to broadcast back to from inside
+ * the callback.
+ */
 
+const struct libwebsocket_protocols *
+libwebsockets_get_protocol(struct libwebsocket *wsi)
+{
+	return wsi->protocol;
+}
+
+/**
+ * libwebsockets_broadcast() - Sends a buffer to rthe callback for all active
+ * 				  connections of the given protocol.
+ * @protocol:	pointer to the protocol you will broadcast to all members of
+ * @buf:  buffer containing the data to be broadcase.  NOTE: this has to be
+ * 		allocated with LWS_SEND_BUFFER_PRE_PADDING valid bytes before
+ * 		the pointer and LWS_SEND_BUFFER_POST_PADDING afterwards in the
+ * 		case you are calling this function from callback context.
+ * @len:	length of payload data in buf, starting from buf.
+ * 
+ * 	This function allows bulk sending of a packet to every connection using
+ * the given protocol.  It does not send the data directly; instead it calls
+ * the callback with a reason type of LWS_CALLBACK_BROADCAST.  If the callback
+ * wants to actually send the data for that connection, the callback itself
+ * should call libwebsocket_write().
+ *
+ * libwebsockets_broadcast() can be called from another fork context without
+ * having to take any care about data visibility between the processes, it'll
+ * "just work".
+ */
+
+
+int
+libwebsockets_broadcast(const struct libwebsocket_protocols * protocol,
+						 unsigned char *buf, size_t len)
+{
+	struct libwebsocket_context * this = protocol->owning_server;
+	int n;
+
+	if (!protocol->broadcast_socket_user_fd) {
+		/*
+		 * we are being called from poll thread context
+		 * eg, from a callback.  In that case don't use sockets for
+		 * broadcast IPC (since we can't open a socket connection to
+		 * a socket listening on our own thread) but directly do the
+		 * send action.
+		 *
+		 * Locking is not needed because we are by definition being
+		 * called in the poll thread context and are serialized.
+		 */
+
+		for (n = this->count_protocols + 1; n < this->fds_count; n++) {
+
+			if ((unsigned long)this->wsi[n] < LWS_MAX_PROTOCOLS)
+				continue;
+
+			/* never broadcast to non-established connection */
+
+			if (this->wsi[n]->state != WSI_STATE_ESTABLISHED)
+				continue;
+
+			/* only broadcast to guys using requested protocol */
+
+			if (this->wsi[n]->protocol != protocol)
+				continue;
+
+			this->wsi[n]->protocol-> callback(this->wsi[n],
+					 LWS_CALLBACK_BROADCAST, 
+					 this->wsi[n]->user_space,
+					 buf, len);
+		}
+
+		return 0;
+	}
+
+	n = send(protocol->broadcast_socket_user_fd, buf, len, 0);
+
+	return n;
+}