move reading to goroutine

This commit is contained in:
Zack Scholl 2018-10-22 19:11:43 -07:00
parent db42e96b7e
commit ab5df93d10
3 changed files with 44 additions and 18 deletions

View File

@ -67,9 +67,9 @@ func TestSendReceiveLocalWebsockets(t *testing.T) {
sendAndReceive(t, 1, true) sendAndReceive(t, 1, true)
} }
// func TestSendReceiveLocalTCP(t *testing.T) { func TestSendReceiveLocalTCP(t *testing.T) {
// sendAndReceive(t, 2, true) sendAndReceive(t, 2, true)
// } }
func generateRandomFile(megabytes int) (fname string) { func generateRandomFile(megabytes int) (fname string) {
// generate a random file // generate a random file

View File

@ -82,9 +82,26 @@ func (cr *Croc) receive(forceSend int, serverAddress string, tcpPorts []string,
// both parties should have a weak key // both parties should have a weak key
pw := []byte(codephrase) pw := []byte(codephrase)
// start the reader
websocketMessages := make(chan WebSocketMessage, 1024)
go func() {
defer func() {
if r := recover(); r != nil {
log.Debugf("recovered from %s", r)
}
}()
for {
messageType, message, err := c.ReadMessage()
websocketMessages <- WebSocketMessage{messageType, message, err}
}
}()
step := 0 step := 0
for { for {
messageType, message, err := c.ReadMessage() websocketMessage := <-websocketMessages
messageType := websocketMessage.messageType
message := websocketMessage.message
err := websocketMessage.err
if err != nil { if err != nil {
return err return err
} }
@ -151,6 +168,7 @@ func (cr *Croc) receive(forceSend int, serverAddress string, tcpPorts []string,
// initialize TCP connections if using (possible, but unlikely, race condition) // initialize TCP connections if using (possible, but unlikely, race condition)
go func() { go func() {
log.Debug("initializing TCP connections")
if !useWebsockets { if !useWebsockets {
log.Debugf("connecting to server") log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts)) tcpConnections = make([]comm.Comm, len(tcpPorts))
@ -406,29 +424,20 @@ func (cr *Croc) receive(forceSend int, serverAddress string, tcpPorts []string,
startTime := time.Now() startTime := time.Now()
if useWebsockets { if useWebsockets {
for { for {
var messageType int
// read from websockets // read from websockets
messageType, message, err = c.ReadMessage() websocketMessageData := <-websocketMessages
if messageType != websocket.BinaryMessage { if websocketMessageData.messageType != websocket.BinaryMessage {
continue continue
} }
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return err
} }
if bytes.Equal(message, []byte("magic")) { if bytes.Equal(websocketMessageData.message, []byte("magic")) {
log.Debug("got magic") log.Debug("got magic")
break break
} }
dataChan <- message dataChan <- websocketMessageData.message
// select {
// case dataChan <- message:
// default:
// log.Debug("blocked")
// // no message sent
// // block
// dataChan <- message
// }
} }
} else { } else {
log.Debugf("starting listening with tcp with %d connections", len(tcpConnections)) log.Debugf("starting listening with tcp with %d connections", len(tcpConnections))

View File

@ -102,9 +102,26 @@ func (cr *Croc) send(forceSend int, serverAddress string, tcpPorts []string, isL
return return
} }
// start the reader
websocketMessages := make(chan WebSocketMessage, 1024)
go func() {
defer func() {
if r := recover(); r != nil {
log.Debugf("recovered from %s", r)
}
}()
for {
messageType, message, err := c.ReadMessage()
websocketMessages <- WebSocketMessage{messageType, message, err}
}
}()
step := 0 step := 0
for { for {
messageType, message, errRead := c.ReadMessage() websocketMessage := <-websocketMessages
messageType := websocketMessage.messageType
message := websocketMessage.message
errRead := websocketMessage.err
if errRead != nil { if errRead != nil {
return errRead return errRead
} }