mirror of https://github.com/schollz/croc.git
allow only one sender to connect
This commit is contained in:
parent
e5ba8c5580
commit
ee927bdc55
|
@ -606,7 +606,7 @@ func (c *Client) transferOverLocalRelay(errchan chan error) {
|
|||
// not really an error because it will try to connect over the actual relay
|
||||
return
|
||||
}
|
||||
log.Debugf("local connection established: %+v", conn)
|
||||
log.Debugf("local sender connection established: %+v", conn)
|
||||
err = nil
|
||||
for {
|
||||
data, errConn := conn.Receive()
|
||||
|
@ -800,7 +800,7 @@ func (c *Client) Send(filesInfo []FileInfo, emptyFoldersToTransfer []FileInfo, t
|
|||
errchan <- err
|
||||
} else {
|
||||
log.Debugf("banner: %s", banner)
|
||||
log.Debugf("connection established: %+v", conn)
|
||||
log.Debugf("sender connection established: %+v", conn)
|
||||
|
||||
c.listenToMainConn(conn, ipaddr, banner, errchan)
|
||||
}
|
||||
|
@ -939,6 +939,7 @@ func (c *Client) Receive() (err error) {
|
|||
c.conn[0], banner, c.ExternalIP, err = tcp.ConnectToTCPServer(address, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, true, 1, time.Duration(c.Options.TimeLimit)*time.Second)
|
||||
if err == nil {
|
||||
c.Options.RelayAddress = address
|
||||
log.Debug("receiver connection established")
|
||||
break
|
||||
}
|
||||
log.Debugf("could not establish '%s'", address)
|
||||
|
|
|
@ -277,7 +277,7 @@ func TestCrocLocal(t *testing.T) {
|
|||
IsSender: true,
|
||||
TimeLimit: 30,
|
||||
MaxTransfers: 1,
|
||||
SharedSecret: "8123-testingthecroc",
|
||||
SharedSecret: "2813-testingthecroc",
|
||||
Debug: true,
|
||||
RelayAddress: "127.0.0.1:8181",
|
||||
RelayPorts: []string{"8181", "8182"},
|
||||
|
@ -297,7 +297,7 @@ func TestCrocLocal(t *testing.T) {
|
|||
log.Debug("setting up receiver")
|
||||
receiver, err := New(Options{
|
||||
IsSender: false,
|
||||
SharedSecret: "8123-testingthecroc",
|
||||
SharedSecret: "2813-testingthecroc",
|
||||
Debug: true,
|
||||
RelayAddress: "127.0.0.1:8181",
|
||||
RelayPassword: "pass123",
|
||||
|
@ -325,7 +325,7 @@ func TestCrocLocal(t *testing.T) {
|
|||
}
|
||||
wg.Done()
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
go func() {
|
||||
err := receiver.Receive()
|
||||
if err != nil {
|
||||
|
|
104
src/tcp/tcp.go
104
src/tcp/tcp.go
|
@ -156,8 +156,8 @@ func (s *server) run() (err error) {
|
|||
log.Debugf("checking connection of room %s for %+v", room, c)
|
||||
deleteIt := false
|
||||
s.rooms.Lock()
|
||||
if _, ok := s.rooms.rooms[room]; !ok {
|
||||
log.Debug("room is gone")
|
||||
if _, ok := s.rooms.rooms[room]; !ok || (s.rooms.rooms[room].sender == nil && s.rooms.rooms[room].receiver == nil) {
|
||||
log.Debugf("room %s is gone", room)
|
||||
s.rooms.Unlock()
|
||||
return
|
||||
}
|
||||
|
@ -312,6 +312,12 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er
|
|||
_, roomExists := s.rooms.rooms[room]
|
||||
// create the room if it is new
|
||||
if !roomExists || isSender {
|
||||
if roomExists && isSender && s.rooms.rooms[room].sender != nil {
|
||||
// if the room exists and the sender is already connected
|
||||
// then signal to the client that the room is full
|
||||
err = s.sendRoomIsFull(c, strongKeyForEncryption)
|
||||
return
|
||||
}
|
||||
err = s.createOrUpdateRoom(c, room, strongKeyForEncryption, isMainRoom, isSender, roomExists)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
|
@ -378,6 +384,12 @@ func (s *server) sendRoomIsFull(c *comm.Comm, strongKeyForEncryption []byte) (er
|
|||
func (s *server) createOrUpdateRoom(c *comm.Comm, room string, strongKeyForEncryption []byte, isMainRoom, isSender, updateRoom bool) (err error) {
|
||||
var enc, data, bSend []byte
|
||||
|
||||
if !updateRoom {
|
||||
log.Debugf("Creating room %s", room)
|
||||
} else {
|
||||
log.Debugf("Updating room %s", room)
|
||||
}
|
||||
|
||||
maxTransfers := 1
|
||||
if isMainRoom && isSender {
|
||||
log.Debug("Wait for maxTransfers")
|
||||
|
@ -447,13 +459,6 @@ func (s *server) createOrUpdateRoom(c *comm.Comm, room string, strongKeyForEncry
|
|||
return
|
||||
}
|
||||
|
||||
func removeReceiverFromQueue(queue *list.List, c *comm.Comm) {
|
||||
var rmElem *list.Element
|
||||
for rmElem = queue.Front(); rmElem.Value != c.ID(); rmElem = rmElem.Next() {
|
||||
}
|
||||
queue.Remove(rmElem)
|
||||
}
|
||||
|
||||
func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error, keepGoing bool) {
|
||||
var bSend []byte
|
||||
log.Debugf("room %s is full, adding to queue", room)
|
||||
|
@ -477,34 +482,7 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
|
|||
for {
|
||||
s.rooms.roomLocks[room].Lock()
|
||||
|
||||
if s.rooms.rooms[room].doneTransfers >= s.rooms.rooms[room].maxTransfers {
|
||||
// remove the client from the queue
|
||||
newQueue := s.rooms.rooms[room].queue
|
||||
removeReceiverFromQueue(newQueue, c)
|
||||
s.rooms.Lock()
|
||||
s.rooms.rooms[room] = roomInfo{
|
||||
sender: s.rooms.rooms[room].sender,
|
||||
receiver: s.rooms.rooms[room].receiver,
|
||||
isMainRoom: s.rooms.rooms[room].isMainRoom,
|
||||
opened: s.rooms.rooms[room].opened,
|
||||
maxTransfers: s.rooms.rooms[room].maxTransfers,
|
||||
doneTransfers: s.rooms.rooms[room].doneTransfers,
|
||||
queue: newQueue,
|
||||
}
|
||||
s.rooms.Unlock()
|
||||
|
||||
// tell the client that the sender is no longer available
|
||||
bSend, err = crypt.Encrypt([]byte(senderGone), strongKeyForEncryption)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = c.Send(bSend)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
break
|
||||
} else if s.rooms.rooms[room].receiver != nil || s.rooms.rooms[room].queue.Front().Value.(string) != c.ID() {
|
||||
if s.rooms.rooms[room].receiver != nil || s.rooms.rooms[room].queue.Front().Value.(string) != c.ID() {
|
||||
time.Sleep(1 * time.Second)
|
||||
// tell the client that they need to wait
|
||||
bSend, err = crypt.Encrypt([]byte("wait"), strongKeyForEncryption)
|
||||
|
@ -517,11 +495,23 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
|
|||
return
|
||||
}
|
||||
s.rooms.roomLocks[room].Unlock()
|
||||
} else if s.rooms.rooms[room].doneTransfers >= s.rooms.rooms[room].maxTransfers {
|
||||
// tell the client that the sender is no longer available
|
||||
bSend, err = crypt.Encrypt([]byte(senderGone), strongKeyForEncryption)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = c.Send(bSend)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
break
|
||||
} else {
|
||||
s.rooms.Lock()
|
||||
// remove the client from the queue
|
||||
newQueue := s.rooms.rooms[room].queue
|
||||
removeReceiverFromQueue(newQueue, c)
|
||||
newQueue.Remove(newQueue.Front())
|
||||
s.rooms.rooms[room] = roomInfo{
|
||||
sender: s.rooms.rooms[room].sender,
|
||||
receiver: c,
|
||||
|
@ -541,6 +531,12 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
|
|||
func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) {
|
||||
s.rooms.Unlock()
|
||||
|
||||
// safety check (it should never happen)
|
||||
if s.rooms.rooms[room].sender == nil || s.rooms.rooms[room].receiver == nil {
|
||||
err = fmt.Errorf("sender or receiver is nil")
|
||||
return
|
||||
}
|
||||
|
||||
// second connection is the sender, time to staple connections
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
@ -587,8 +583,8 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption
|
|||
s.rooms.Unlock()
|
||||
|
||||
// delete the room if it is done or unlock it if it is not
|
||||
if newDoneTransfers == s.rooms.rooms[room].maxTransfers {
|
||||
log.Debugf("room %s is done", room)
|
||||
if newDoneTransfers >= s.rooms.rooms[room].maxTransfers {
|
||||
log.Debugf("room %s is done, deleting it", room)
|
||||
s.deleteRoom(room)
|
||||
} else {
|
||||
log.Debugf("room %s has %d done", room, newDoneTransfers)
|
||||
|
@ -598,6 +594,8 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption
|
|||
}
|
||||
|
||||
func (s *server) deleteRoom(room string) {
|
||||
s.rooms.Lock()
|
||||
defer s.rooms.Unlock()
|
||||
if _, ok := s.rooms.rooms[room]; !ok {
|
||||
return
|
||||
}
|
||||
|
@ -609,11 +607,21 @@ func (s *server) deleteRoom(room string) {
|
|||
break
|
||||
}
|
||||
s.rooms.roomLocks[room].Lock()
|
||||
// remove the client from the queue
|
||||
newQueue := s.rooms.rooms[room].queue
|
||||
newQueue.Remove(newQueue.Front())
|
||||
s.rooms.rooms[room] = roomInfo{
|
||||
sender: s.rooms.rooms[room].sender,
|
||||
receiver: s.rooms.rooms[room].receiver,
|
||||
isMainRoom: s.rooms.rooms[room].isMainRoom,
|
||||
opened: s.rooms.rooms[room].opened,
|
||||
maxTransfers: s.rooms.rooms[room].maxTransfers,
|
||||
doneTransfers: s.rooms.rooms[room].doneTransfers,
|
||||
queue: newQueue,
|
||||
}
|
||||
}
|
||||
delete(s.rooms.roomLocks, room)
|
||||
}
|
||||
s.rooms.Lock()
|
||||
defer s.rooms.Unlock()
|
||||
log.Debugf("deleting room: %s", room)
|
||||
if s.rooms.rooms[room].sender != nil {
|
||||
s.rooms.rooms[room].sender.Close()
|
||||
|
@ -871,6 +879,13 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo
|
|||
log.Debug(err)
|
||||
return
|
||||
}
|
||||
|
||||
if bytes.Equal(data, []byte(fullRoom)) {
|
||||
err = fmt.Errorf("room is full")
|
||||
c = nil
|
||||
return
|
||||
}
|
||||
|
||||
if !isSender {
|
||||
if bytes.Equal(data, []byte("wait")) {
|
||||
log.Debug("waiting for sender to be free")
|
||||
|
@ -880,16 +895,13 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo
|
|||
err = fmt.Errorf("sender is gone")
|
||||
c = nil
|
||||
return
|
||||
} else if bytes.Equal(data, []byte(fullRoom)) {
|
||||
err = fmt.Errorf("room is full")
|
||||
c = nil
|
||||
return
|
||||
} else if bytes.Equal(data, []byte(noRoom)) {
|
||||
err = fmt.Errorf("room does not exist")
|
||||
c = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !bytes.Equal(data, []byte("ok")) {
|
||||
err = fmt.Errorf("got bad response: %s", data)
|
||||
log.Debug(err)
|
||||
|
|
|
@ -33,6 +33,22 @@ func TestTCPServerPing(t *testing.T) {
|
|||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestOnlyOneSenderPerRoom(t *testing.T) {
|
||||
log.SetLevel("error")
|
||||
go Run("debug", "127.0.0.1", "8381", "pass123", "8382")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", true, true, 1, 1*time.Minute)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, c1)
|
||||
assert.Equal(t, banner, "8382")
|
||||
|
||||
c2, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", true, true, 1, 1*time.Minute)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "room is full"))
|
||||
assert.Nil(t, c2)
|
||||
}
|
||||
|
||||
// This is helper function to test that a mocks a transfer
|
||||
// between two clients connected to the server,
|
||||
// and checks that the data is transferred correctly
|
||||
|
|
Loading…
Reference in New Issue