mirror of https://github.com/schollz/croc.git
tcp communication better/simpler
This commit is contained in:
parent
7c731a90dc
commit
14dd892377
|
@ -2,10 +2,9 @@ package comm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -16,6 +15,19 @@ type Comm struct {
|
||||||
connection net.Conn
|
connection net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewConnection gets a new comm to a tcp address
|
||||||
|
func NewConnection(address string) (c Comm, err error) {
|
||||||
|
connection, err := net.DialTimeout("tcp", address, 3*time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
|
||||||
|
connection.SetDeadline(time.Now().Add(3 * time.Hour))
|
||||||
|
connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
|
||||||
|
c = New(connection)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// New returns a new comm
|
// New returns a new comm
|
||||||
func New(c net.Conn) Comm {
|
func New(c net.Conn) Comm {
|
||||||
c.SetReadDeadline(time.Now().Add(3 * time.Hour))
|
c.SetReadDeadline(time.Now().Add(3 * time.Hour))
|
||||||
|
@ -35,10 +47,12 @@ func (c Comm) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Comm) Write(b []byte) (int, error) {
|
func (c Comm) Write(b []byte) (int, error) {
|
||||||
tmpCopy := make([]byte, len(b)+5)
|
header := new(bytes.Buffer)
|
||||||
// Copy the buffer so it doesn't get changed while read by the recipient.
|
err := binary.Write(header, binary.LittleEndian, uint32(len(b)))
|
||||||
copy(tmpCopy[:5], []byte(fmt.Sprintf("%0.5d", len(b))))
|
if err != nil {
|
||||||
copy(tmpCopy[5:], b)
|
fmt.Println("binary.Write failed:", err)
|
||||||
|
}
|
||||||
|
tmpCopy := append(header.Bytes(), b...)
|
||||||
n, err := c.connection.Write(tmpCopy)
|
n, err := c.connection.Write(tmpCopy)
|
||||||
if n != len(tmpCopy) {
|
if n != len(tmpCopy) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -53,68 +67,48 @@ func (c Comm) Write(b []byte) (int, error) {
|
||||||
|
|
||||||
func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
|
func (c Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
|
||||||
// read until we get 5 bytes
|
// read until we get 5 bytes
|
||||||
tmp := make([]byte, 5)
|
header := make([]byte, 4)
|
||||||
n, err := c.connection.Read(tmp)
|
n, err := c.connection.Read(header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tmpCopy := make([]byte, n)
|
if n < 4 {
|
||||||
// Copy the buffer so it doesn't get changed while read by the recipient.
|
err = fmt.Errorf("not enough bytes: %d", n)
|
||||||
copy(tmpCopy, tmp[:n])
|
return
|
||||||
bs = tmpCopy
|
|
||||||
|
|
||||||
tmp = make([]byte, 1)
|
|
||||||
for {
|
|
||||||
// see if we have enough bytes
|
|
||||||
bs = bytes.Trim(bs, "\x00")
|
|
||||||
if len(bs) == 5 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
n, err := c.connection.Read(tmp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, nil, err
|
|
||||||
}
|
|
||||||
tmpCopy = make([]byte, n)
|
|
||||||
// Copy the buffer so it doesn't get changed while read by the recipient.
|
|
||||||
copy(tmpCopy, tmp[:n])
|
|
||||||
bs = append(bs, tmpCopy...)
|
|
||||||
}
|
}
|
||||||
|
// make it so it won't change
|
||||||
|
header = append([]byte(nil), header...)
|
||||||
|
|
||||||
numBytes, err = strconv.Atoi(strings.TrimLeft(string(bs), "0"))
|
var numBytesUint32 uint32
|
||||||
|
rbuf := bytes.NewReader(header)
|
||||||
|
err = binary.Read(rbuf, binary.LittleEndian, &numBytesUint32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, nil, err
|
fmt.Println("binary.Read failed:", err)
|
||||||
}
|
}
|
||||||
buf = []byte{}
|
numBytes = int(numBytesUint32)
|
||||||
tmp = make([]byte, numBytes)
|
|
||||||
for {
|
for {
|
||||||
n, err := c.connection.Read(tmp)
|
tmp := make([]byte, numBytes)
|
||||||
if err != nil {
|
n, errRead := c.connection.Read(tmp)
|
||||||
return nil, 0, nil, err
|
if errRead != nil {
|
||||||
|
err = errRead
|
||||||
|
return
|
||||||
}
|
}
|
||||||
tmpCopy = make([]byte, n)
|
buf = append(buf, tmp[:n]...)
|
||||||
// Copy the buffer so it doesn't get changed while read by the recipient.
|
if numBytes == len(buf) {
|
||||||
copy(tmpCopy, tmp[:n])
|
|
||||||
buf = append(buf, bytes.TrimRight(tmpCopy, "\x00")...)
|
|
||||||
if len(buf) < numBytes {
|
|
||||||
// shrink the amount we need to read
|
|
||||||
tmp = tmp[:numBytes-len(buf)]
|
|
||||||
} else {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// log.Printf("wanted %d and got %d", numBytes, len(buf))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a message
|
// Send a message
|
||||||
func (c Comm) Send(message string) (err error) {
|
func (c Comm) Send(message []byte) (err error) {
|
||||||
_, err = c.Write([]byte(message))
|
_, err = c.Write(message)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive a message
|
// Receive a message
|
||||||
func (c Comm) Receive() (s string, err error) {
|
func (c Comm) Receive() (b []byte, err error) {
|
||||||
b, _, _, err := c.Read()
|
b, _, _, err = c.Read()
|
||||||
s = string(b)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
package comm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/cihub/seelog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestComm(t *testing.T) {
|
||||||
|
defer log.Flush()
|
||||||
|
|
||||||
|
port := "8001"
|
||||||
|
go func() {
|
||||||
|
log.Debugf("starting TCP server on " + port)
|
||||||
|
server, err := net.Listen("tcp", "0.0.0.0:"+port)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
defer server.Close()
|
||||||
|
// spawn a new goroutine whenever a client connects
|
||||||
|
for {
|
||||||
|
connection, err := server.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
log.Debugf("client %s connected", connection.RemoteAddr().String())
|
||||||
|
go func(port string, connection net.Conn) {
|
||||||
|
c := New(connection)
|
||||||
|
err = c.Send([]byte("hello, world"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
data, err := c.Receive()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, []byte("hello, computer"), data)
|
||||||
|
data, err = c.Receive()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, []byte{'\x00'}, data)
|
||||||
|
}(port, connection)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
a, err := NewConnection("localhost:" + port)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
data, err := a.Receive()
|
||||||
|
assert.Equal(t, []byte("hello, world"), data)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Nil(t, a.Send([]byte("hello, computer")))
|
||||||
|
assert.Nil(t, a.Send([]byte{'\x00'}))
|
||||||
|
}
|
|
@ -13,8 +13,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type roomInfo struct {
|
type roomInfo struct {
|
||||||
receiver comm.Comm
|
first comm.Comm
|
||||||
opened time.Time
|
second comm.Comm
|
||||||
|
opened time.Time
|
||||||
|
full bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type roomMap struct {
|
type roomMap struct {
|
||||||
|
@ -77,37 +79,53 @@ func run(port string) (err error) {
|
||||||
func clientCommuncation(port string, c comm.Comm) (err error) {
|
func clientCommuncation(port string, c comm.Comm) (err error) {
|
||||||
// send ok to tell client they are connected
|
// send ok to tell client they are connected
|
||||||
log.Debug("sending ok")
|
log.Debug("sending ok")
|
||||||
err = c.Send("ok")
|
err = c.Send([]byte("ok"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for client to tell me which room they want
|
// wait for client to tell me which room they want
|
||||||
log.Debug("waiting for answer")
|
log.Debug("waiting for answer")
|
||||||
room, err := c.Receive()
|
roomBytes, err := c.Receive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
room := string(roomBytes)
|
||||||
|
|
||||||
rooms.Lock()
|
rooms.Lock()
|
||||||
// first connection is always the receiver
|
// create the room if it is new
|
||||||
if _, ok := rooms.rooms[room]; !ok {
|
if _, ok := rooms.rooms[room]; !ok {
|
||||||
rooms.rooms[room] = roomInfo{
|
rooms.rooms[room] = roomInfo{
|
||||||
receiver: c,
|
first: c,
|
||||||
opened: time.Now(),
|
opened: time.Now(),
|
||||||
}
|
}
|
||||||
rooms.Unlock()
|
rooms.Unlock()
|
||||||
// tell the client that they got the room
|
// tell the client that they got the room
|
||||||
err = c.Send("recipient")
|
err = c.Send([]byte("ok"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debug("recipient connected")
|
log.Debugf("room %s has 1", room)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.Debug("sender connected")
|
if rooms.rooms[room].full {
|
||||||
receiver := rooms.rooms[room].receiver
|
rooms.Unlock()
|
||||||
|
err = c.Send([]byte("room full"))
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debugf("room %s has 2", room)
|
||||||
|
rooms.rooms[room] = roomInfo{
|
||||||
|
first: rooms.rooms[room].first,
|
||||||
|
second: c,
|
||||||
|
opened: rooms.rooms[room].opened,
|
||||||
|
full: true,
|
||||||
|
}
|
||||||
|
otherConnection := rooms.rooms[room].first
|
||||||
rooms.Unlock()
|
rooms.Unlock()
|
||||||
|
|
||||||
// second connection is the sender, time to staple connections
|
// second connection is the sender, time to staple connections
|
||||||
|
@ -120,10 +138,10 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
|
||||||
pipe(com1.Connection(), com2.Connection())
|
pipe(com1.Connection(), com2.Connection())
|
||||||
wg.Done()
|
wg.Done()
|
||||||
log.Debug("done piping")
|
log.Debug("done piping")
|
||||||
}(c, receiver, &wg)
|
}(otherConnection, c, &wg)
|
||||||
|
|
||||||
// tell the sender everything is ready
|
// tell the sender everything is ready
|
||||||
err = c.Send("sender")
|
err = c.Send([]byte("ok"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -132,6 +150,8 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
|
||||||
// delete room
|
// delete room
|
||||||
rooms.Lock()
|
rooms.Lock()
|
||||||
log.Debugf("deleting room: %s", room)
|
log.Debugf("deleting room: %s", room)
|
||||||
|
rooms.rooms[room].first.Close()
|
||||||
|
rooms.rooms[room].second.Close()
|
||||||
delete(rooms.rooms, room)
|
delete(rooms.rooms, room)
|
||||||
rooms.Unlock()
|
rooms.Unlock()
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/schollz/croc/src/comm"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTCP(t *testing.T) {
|
||||||
|
go Run("debug", "8081")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
c1, err := ConnectToTCPServer("localhost:8081", "testRoom")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
c2, err := ConnectToTCPServer("localhost:8081", "testRoom")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
_, err = ConnectToTCPServer("localhost:8081", "testRoom")
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
// try sending data
|
||||||
|
assert.Nil(t, c1.Send([]byte("hello, c2")))
|
||||||
|
data, err := c2.Receive()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, []byte("hello, c2"), data)
|
||||||
|
|
||||||
|
assert.Nil(t, c2.Send([]byte("hello, c1")))
|
||||||
|
data, err = c1.Receive()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, []byte("hello, c1"), data)
|
||||||
|
|
||||||
|
c1.Close()
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
err = c2.Send([]byte("test"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
_, err = c2.Receive()
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectToTCPServer(address, room string) (c comm.Comm, err error) {
|
||||||
|
c, err = comm.NewConnection("localhost:8081")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := c.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("ok")) {
|
||||||
|
err = fmt.Errorf("got bad response: %s", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = c.Send([]byte(room))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err = c.Receive()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, []byte("ok")) {
|
||||||
|
err = fmt.Errorf("got bad response: %s", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue