tcp communication better/simpler

This commit is contained in:
Zack Scholl 2019-04-27 09:20:03 -07:00
parent 7c731a90dc
commit 14dd892377
4 changed files with 196 additions and 62 deletions

View File

@ -2,10 +2,9 @@ package comm
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"strconv"
"strings"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -16,6 +15,19 @@ type Comm struct {
connection net.Conn connection net.Conn
} }
// NewConnection gets a new comm to a tcp address
func NewConnection(address string) (c Comm, err error) {
connection, err := net.DialTimeout("tcp", address, 3*time.Hour)
if err != nil {
return
}
connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
connection.SetDeadline(time.Now().Add(3 * time.Hour))
connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
c = New(connection)
return
}
// New returns a new comm // New returns a new comm
func New(c net.Conn) Comm { func New(c net.Conn) Comm {
c.SetReadDeadline(time.Now().Add(3 * time.Hour)) c.SetReadDeadline(time.Now().Add(3 * time.Hour))
@ -35,10 +47,12 @@ func (c Comm) Close() {
} }
func (c Comm) Write(b []byte) (int, error) { func (c Comm) Write(b []byte) (int, error) {
tmpCopy := make([]byte, len(b)+5) header := new(bytes.Buffer)
// Copy the buffer so it doesn't get changed while read by the recipient. err := binary.Write(header, binary.LittleEndian, uint32(len(b)))
copy(tmpCopy[:5], []byte(fmt.Sprintf("%0.5d", len(b)))) if err != nil {
copy(tmpCopy[5:], b) fmt.Println("binary.Write failed:", err)
}
tmpCopy := append(header.Bytes(), b...)
n, err := c.connection.Write(tmpCopy) n, err := c.connection.Write(tmpCopy)
if n != len(tmpCopy) { if n != len(tmpCopy) {
if err != nil { if err != nil {
@ -53,68 +67,48 @@ func (c Comm) Write(b []byte) (int, error) {
func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
// read until we get 5 bytes // read until we get 5 bytes
tmp := make([]byte, 5) header := make([]byte, 4)
n, err := c.connection.Read(tmp) n, err := c.connection.Read(header)
if err != nil { if err != nil {
return return
} }
tmpCopy := make([]byte, n) if n < 4 {
// Copy the buffer so it doesn't get changed while read by the recipient. err = fmt.Errorf("not enough bytes: %d", n)
copy(tmpCopy, tmp[:n]) return
bs = tmpCopy }
// make it so it won't change
header = append([]byte(nil), header...)
tmp = make([]byte, 1) var numBytesUint32 uint32
rbuf := bytes.NewReader(header)
err = binary.Read(rbuf, binary.LittleEndian, &numBytesUint32)
if err != nil {
fmt.Println("binary.Read failed:", err)
}
numBytes = int(numBytesUint32)
for { for {
// see if we have enough bytes tmp := make([]byte, numBytes)
bs = bytes.Trim(bs, "\x00") n, errRead := c.connection.Read(tmp)
if len(bs) == 5 { if errRead != nil {
break err = errRead
return
} }
n, err := c.connection.Read(tmp) buf = append(buf, tmp[:n]...)
if err != nil { if numBytes == len(buf) {
return nil, 0, nil, err
}
tmpCopy = make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient.
copy(tmpCopy, tmp[:n])
bs = append(bs, tmpCopy...)
}
numBytes, err = strconv.Atoi(strings.TrimLeft(string(bs), "0"))
if err != nil {
return nil, 0, nil, err
}
buf = []byte{}
tmp = make([]byte, numBytes)
for {
n, err := c.connection.Read(tmp)
if err != nil {
return nil, 0, nil, err
}
tmpCopy = make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient.
copy(tmpCopy, tmp[:n])
buf = append(buf, bytes.TrimRight(tmpCopy, "\x00")...)
if len(buf) < numBytes {
// shrink the amount we need to read
tmp = tmp[:numBytes-len(buf)]
} else {
break break
} }
} }
// log.Printf("wanted %d and got %d", numBytes, len(buf))
return return
} }
// Send a message // Send a message
func (c Comm) Send(message string) (err error) { func (c Comm) Send(message []byte) (err error) {
_, err = c.Write([]byte(message)) _, err = c.Write(message)
return return
} }
// Receive a message // Receive a message
func (c Comm) Receive() (s string, err error) { func (c Comm) Receive() (b []byte, err error) {
b, _, _, err := c.Read() b, _, _, err = c.Read()
s = string(b)
return return
} }

52
src/comm/comm_test.go Normal file
View File

@ -0,0 +1,52 @@
package comm
import (
"net"
"testing"
"time"
log "github.com/cihub/seelog"
"github.com/stretchr/testify/assert"
)
func TestComm(t *testing.T) {
defer log.Flush()
port := "8001"
go func() {
log.Debugf("starting TCP server on " + port)
server, err := net.Listen("tcp", "0.0.0.0:"+port)
if err != nil {
log.Error(err)
}
defer server.Close()
// spawn a new goroutine whenever a client connects
for {
connection, err := server.Accept()
if err != nil {
log.Error(err)
}
log.Debugf("client %s connected", connection.RemoteAddr().String())
go func(port string, connection net.Conn) {
c := New(connection)
err = c.Send([]byte("hello, world"))
assert.Nil(t, err)
data, err := c.Receive()
assert.Nil(t, err)
assert.Equal(t, []byte("hello, computer"), data)
data, err = c.Receive()
assert.Nil(t, err)
assert.Equal(t, []byte{'\x00'}, data)
}(port, connection)
}
}()
time.Sleep(100 * time.Millisecond)
a, err := NewConnection("localhost:" + port)
assert.Nil(t, err)
data, err := a.Receive()
assert.Equal(t, []byte("hello, world"), data)
assert.Nil(t, err)
assert.Nil(t, a.Send([]byte("hello, computer")))
assert.Nil(t, a.Send([]byte{'\x00'}))
}

View File

@ -13,8 +13,10 @@ import (
) )
type roomInfo struct { type roomInfo struct {
receiver comm.Comm first comm.Comm
second comm.Comm
opened time.Time opened time.Time
full bool
} }
type roomMap struct { type roomMap struct {
@ -77,37 +79,53 @@ func run(port string) (err error) {
func clientCommuncation(port string, c comm.Comm) (err error) { func clientCommuncation(port string, c comm.Comm) (err error) {
// send ok to tell client they are connected // send ok to tell client they are connected
log.Debug("sending ok") log.Debug("sending ok")
err = c.Send("ok") err = c.Send([]byte("ok"))
if err != nil { if err != nil {
return return
} }
// wait for client to tell me which room they want // wait for client to tell me which room they want
log.Debug("waiting for answer") log.Debug("waiting for answer")
room, err := c.Receive() roomBytes, err := c.Receive()
if err != nil { if err != nil {
return return
} }
room := string(roomBytes)
rooms.Lock() rooms.Lock()
// first connection is always the receiver // create the room if it is new
if _, ok := rooms.rooms[room]; !ok { if _, ok := rooms.rooms[room]; !ok {
rooms.rooms[room] = roomInfo{ rooms.rooms[room] = roomInfo{
receiver: c, first: c,
opened: time.Now(), opened: time.Now(),
} }
rooms.Unlock() rooms.Unlock()
// tell the client that they got the room // tell the client that they got the room
err = c.Send("recipient") err = c.Send([]byte("ok"))
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return
} }
log.Debug("recipient connected") log.Debugf("room %s has 1", room)
return nil return nil
} }
log.Debug("sender connected") if rooms.rooms[room].full {
receiver := rooms.rooms[room].receiver rooms.Unlock()
err = c.Send([]byte("room full"))
if err != nil {
log.Error(err)
return
}
return nil
}
log.Debugf("room %s has 2", room)
rooms.rooms[room] = roomInfo{
first: rooms.rooms[room].first,
second: c,
opened: rooms.rooms[room].opened,
full: true,
}
otherConnection := rooms.rooms[room].first
rooms.Unlock() rooms.Unlock()
// second connection is the sender, time to staple connections // second connection is the sender, time to staple connections
@ -120,10 +138,10 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
pipe(com1.Connection(), com2.Connection()) pipe(com1.Connection(), com2.Connection())
wg.Done() wg.Done()
log.Debug("done piping") log.Debug("done piping")
}(c, receiver, &wg) }(otherConnection, c, &wg)
// tell the sender everything is ready // tell the sender everything is ready
err = c.Send("sender") err = c.Send([]byte("ok"))
if err != nil { if err != nil {
return return
} }
@ -132,6 +150,8 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
// delete room // delete room
rooms.Lock() rooms.Lock()
log.Debugf("deleting room: %s", room) log.Debugf("deleting room: %s", room)
rooms.rooms[room].first.Close()
rooms.rooms[room].second.Close()
delete(rooms.rooms, room) delete(rooms.rooms, room)
rooms.Unlock() rooms.Unlock()
return nil return nil

68
src/tcp/tcp_test.go Normal file
View File

@ -0,0 +1,68 @@
package tcp
import (
"bytes"
"fmt"
"testing"
"time"
"github.com/schollz/croc/src/comm"
"github.com/stretchr/testify/assert"
)
func TestTCP(t *testing.T) {
go Run("debug", "8081")
time.Sleep(100 * time.Millisecond)
c1, err := ConnectToTCPServer("localhost:8081", "testRoom")
assert.Nil(t, err)
c2, err := ConnectToTCPServer("localhost:8081", "testRoom")
assert.Nil(t, err)
_, err = ConnectToTCPServer("localhost:8081", "testRoom")
assert.NotNil(t, err)
// try sending data
assert.Nil(t, c1.Send([]byte("hello, c2")))
data, err := c2.Receive()
assert.Nil(t, err)
assert.Equal(t, []byte("hello, c2"), data)
assert.Nil(t, c2.Send([]byte("hello, c1")))
data, err = c1.Receive()
assert.Nil(t, err)
assert.Equal(t, []byte("hello, c1"), data)
c1.Close()
time.Sleep(200 * time.Millisecond)
err = c2.Send([]byte("test"))
assert.Nil(t, err)
_, err = c2.Receive()
assert.NotNil(t, err)
}
func ConnectToTCPServer(address, room string) (c comm.Comm, err error) {
c, err = comm.NewConnection("localhost:8081")
if err != nil {
return
}
data, err := c.Receive()
if err != nil {
return
}
if !bytes.Equal(data, []byte("ok")) {
err = fmt.Errorf("got bad response: %s", data)
return
}
err = c.Send([]byte(room))
if err != nil {
return
}
data, err = c.Receive()
if err != nil {
return
}
if !bytes.Equal(data, []byte("ok")) {
err = fmt.Errorf("got bad response: %s", data)
return
}
return
}