allow both sender and receiver to reserve a room

This commit is contained in:
RCL98 2023-11-11 23:31:06 +02:00
parent 20fb6cee36
commit cb5da0c7c2
3 changed files with 128 additions and 122 deletions

View File

@ -614,7 +614,12 @@ func (c *Client) transferOverLocalRelay(errchan chan error) {
log.Debugf("[%+v] had error: %s", conn, errConn.Error()) log.Debugf("[%+v] had error: %s", conn, errConn.Error())
break break
} }
if bytes.Equal(data, handshakeRequest) { if bytes.Equal(data, ipRequest) {
log.Debug("Got ip request, sending nil since we are local")
if err = conn.Send(nil); err != nil {
log.Errorf("error sending: %v", err)
}
} else if bytes.Equal(data, handshakeRequest) {
wgTransfer.Add(1) wgTransfer.Add(1)
go c.makeLocalTransfer(conn, ipaddr, banner, errchan) go c.makeLocalTransfer(conn, ipaddr, banner, errchan)
wgTransfer.Wait() wgTransfer.Wait()
@ -989,7 +994,7 @@ func (c *Client) Receive() (err error) {
} }
serverTry := net.JoinHostPort(ip, port) serverTry := net.JoinHostPort(ip, port)
conn, banner2, externalIP, errConn := tcp.ConnectToTCPServer(serverTry, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, true, 1, 500*time.Millisecond) conn, banner2, externalIP, errConn := tcp.ConnectToTCPServer(serverTry, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, false, 1, 500*time.Millisecond)
if errConn != nil { if errConn != nil {
log.Debug(errConn) log.Debug(errConn)
log.Debugf("could not connect to " + serverTry) log.Debugf("could not connect to " + serverTry)
@ -1009,6 +1014,7 @@ func (c *Client) Receive() (err error) {
} }
} }
log.Debug("sending handshake message")
if err = c.conn[0].Send(handshakeRequest); err != nil { if err = c.conn[0].Send(handshakeRequest); err != nil {
log.Errorf("handshake send error: %v", err) log.Errorf("handshake send error: %v", err)
} }
@ -1037,6 +1043,7 @@ func (c *Client) transfer() (err error) {
// if recipient, initialize with sending pake information // if recipient, initialize with sending pake information
log.Debug("ready") log.Debug("ready")
if !c.Options.IsSender && !c.Step1ChannelSecured { if !c.Options.IsSender && !c.Step1ChannelSecured {
log.Debug("sending pake information")
err = message.Send(c.conn[0], c.Key, message.Message{ err = message.Send(c.conn[0], c.Key, message.Message{
Type: message.TypePAKE, Type: message.TypePAKE,
Bytes: c.Pake.Bytes(), Bytes: c.Pake.Bytes(),

View File

@ -35,7 +35,6 @@ type roomInfo struct {
maxTransfers int maxTransfers int
doneTransfers int doneTransfers int
opened time.Time opened time.Time
transfering bool
} }
type roomMap struct { type roomMap struct {
@ -293,39 +292,36 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er
} }
room = string(roomBytes) room = string(roomBytes)
s.rooms.Lock() log.Debug("Check if this is a main room")
if isSender { enc, err = c.Receive()
if _, ok := s.rooms.rooms[room]; !ok { if err != nil {
// create the room if it is new return
err = s.createRoom(c, room, strongKeyForEncryption)
if err != nil {
log.Error(err)
}
// sender is done
return
} else {
// if the room already exists, then tell the client that the room is full
err = s.sendRoomIsFull(c, strongKeyForEncryption)
return
}
} }
data, err = crypt.Decrypt(enc, strongKeyForEncryption)
if err != nil {
return
}
if !bytes.Equal(data, []byte("main")) && !bytes.Equal(data, []byte("secondary")) {
err = fmt.Errorf("got bad response: %s", data)
return
}
isMainRoom := bytes.Equal(data, []byte("main"))
log.Debugf("isMainRoom: %v", isMainRoom)
if _, ok := s.rooms.rooms[room]; !ok { s.rooms.Lock()
// if the room does not exist and the client is a receiver, then tell them _, roomExists := s.rooms.rooms[room]
// that the room does not exist // create the room if it is new
s.rooms.Unlock() if !roomExists || isSender {
bSend, err = crypt.Encrypt([]byte(noRoom), strongKeyForEncryption) err = s.createOrUpdateRoom(c, room, strongKeyForEncryption, isMainRoom, isSender, roomExists)
if err != nil {
return
}
err = c.Send(bSend)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return
} }
return "", fmt.Errorf("reciever tried to connect to room that does not exist") // if the room is new then return
} else if s.rooms.rooms[room].transfering { if !roomExists {
return
}
} else if s.rooms.rooms[room].receiver != nil {
// if the room has a transfer going on // if the room has a transfer going on
if s.rooms.rooms[room].maxTransfers > 1 { if s.rooms.rooms[room].maxTransfers > 1 {
// if the room is a multi-transfer room then add to queue // if the room is a multi-transfer room then add to queue
@ -349,7 +345,6 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er
maxTransfers: s.rooms.rooms[room].maxTransfers, maxTransfers: s.rooms.rooms[room].maxTransfers,
doneTransfers: s.rooms.rooms[room].doneTransfers, doneTransfers: s.rooms.rooms[room].doneTransfers,
opened: s.rooms.rooms[room].opened, opened: s.rooms.rooms[room].opened,
transfering: true,
} }
s.rooms.roomLocks[room].Lock() s.rooms.roomLocks[room].Lock()
} }
@ -376,26 +371,11 @@ func (s *server) sendRoomIsFull(c *comm.Comm, strongKeyForEncryption []byte) (er
return return
} }
func (s *server) createRoom(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) { func (s *server) createOrUpdateRoom(c *comm.Comm, room string, strongKeyForEncryption []byte, isMainRoom, isSender, updateRoom bool) (err error) {
var enc, data, bSend []byte var enc, data, bSend []byte
log.Debug("Check if this is a main room")
enc, err = c.Receive()
if err != nil {
return
}
data, err = crypt.Decrypt(enc, strongKeyForEncryption)
if err != nil {
return
}
if !bytes.Equal(data, []byte("main")) && !bytes.Equal(data, []byte("secondary")) {
err = fmt.Errorf("got bad response: %s", data)
return
}
isMainRoom := bytes.Equal(data, []byte("main"))
log.Debugf("isMainRoom: %v", isMainRoom)
maxTransfers := 1 maxTransfers := 1
if isMainRoom { if isMainRoom && isSender {
log.Debug("Wait for maxTransfers") log.Debug("Wait for maxTransfers")
enc, err = c.Receive() enc, err = c.Receive()
if err != nil { if err != nil {
@ -413,29 +393,53 @@ func (s *server) createRoom(c *comm.Comm, room string, strongKeyForEncryption []
log.Debugf("maxTransfers: %v", maxTransfers) log.Debugf("maxTransfers: %v", maxTransfers)
} }
var sender, receiver *comm.Comm
var queue *list.List
opened := time.Now()
if isSender {
sender = c
if updateRoom {
receiver = s.rooms.rooms[room].receiver
queue = s.rooms.rooms[room].queue
opened = s.rooms.rooms[room].opened
}
} else {
receiver = c
if updateRoom {
sender = s.rooms.rooms[room].sender
queue = s.rooms.rooms[room].queue
opened = s.rooms.rooms[room].opened
}
}
s.rooms.rooms[room] = roomInfo{ s.rooms.rooms[room] = roomInfo{
sender: c, sender: sender,
receiver: nil, receiver: receiver,
queue: queue,
isMainRoom: isMainRoom, isMainRoom: isMainRoom,
maxTransfers: maxTransfers, maxTransfers: maxTransfers,
doneTransfers: 0, doneTransfers: 0,
opened: time.Now(), opened: opened,
} }
s.rooms.roomLocks[room] = &sync.Mutex{}
s.rooms.Unlock()
// tell the client that they got the room if !updateRoom {
bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption) log.Debugf("Client crated main room %s, %v", room, isSender)
if err != nil { s.rooms.roomLocks[room] = &sync.Mutex{}
return // tell the client that they got the room
bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption)
if err != nil {
return
}
err = c.Send(bSend)
if err != nil {
log.Error(err)
s.deleteRoom(room)
return
}
log.Debugf("room %s has 1", room)
s.rooms.Unlock()
} }
err = c.Send(bSend)
if err != nil {
log.Error(err)
s.deleteRoom(room)
return
}
log.Debugf("room %s has 1", room)
return return
} }
@ -455,7 +459,6 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
maxTransfers: s.rooms.rooms[room].maxTransfers, maxTransfers: s.rooms.rooms[room].maxTransfers,
doneTransfers: s.rooms.rooms[room].doneTransfers, doneTransfers: s.rooms.rooms[room].doneTransfers,
queue: queue, queue: queue,
transfering: true,
} }
s.rooms.Unlock() s.rooms.Unlock()
@ -500,7 +503,6 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
maxTransfers: s.rooms.rooms[room].maxTransfers, maxTransfers: s.rooms.rooms[room].maxTransfers,
doneTransfers: s.rooms.rooms[room].doneTransfers, doneTransfers: s.rooms.rooms[room].doneTransfers,
opened: s.rooms.rooms[room].opened, opened: s.rooms.rooms[room].opened,
transfering: true,
} }
break break
} }
@ -509,7 +511,6 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
} }
func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) { func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) {
otherConnection := s.rooms.rooms[room].sender
s.rooms.Unlock() s.rooms.Unlock()
// second connection is the sender, time to staple connections // second connection is the sender, time to staple connections
@ -522,9 +523,10 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption
pipe(com1.Connection(), com2.Connection()) pipe(com1.Connection(), com2.Connection())
wg.Done() wg.Done()
log.Debug("done piping") log.Debug("done piping")
}(otherConnection, c, &wg) }(s.rooms.rooms[room].sender, s.rooms.rooms[room].receiver, &wg)
// tell the receiver everything is ready // tell the client everything is ready
log.Debug("sending ok to client")
bSend, err := crypt.Encrypt([]byte("ok"), strongKeyForEncryption) bSend, err := crypt.Encrypt([]byte("ok"), strongKeyForEncryption)
if err != nil { if err != nil {
return return
@ -544,6 +546,7 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption
if s.rooms.rooms[room].queue != nil { if s.rooms.rooms[room].queue != nil {
lengthOfQueue = s.rooms.rooms[room].queue.Len() lengthOfQueue = s.rooms.rooms[room].queue.Len()
} }
log.Debugf("room %s has %d left in queue", room, lengthOfQueue)
s.rooms.rooms[room] = roomInfo{ s.rooms.rooms[room] = roomInfo{
sender: s.rooms.rooms[room].sender, sender: s.rooms.rooms[room].sender,
receiver: nil, receiver: nil,
@ -552,7 +555,6 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption
maxTransfers: s.rooms.rooms[room].maxTransfers, maxTransfers: s.rooms.rooms[room].maxTransfers,
doneTransfers: newDoneTransfers, doneTransfers: newDoneTransfers,
opened: s.rooms.rooms[room].opened, opened: s.rooms.rooms[room].opened,
transfering: lengthOfQueue > 0 && newDoneTransfers < s.rooms.rooms[room].maxTransfers,
} }
s.rooms.Unlock() s.rooms.Unlock()
@ -799,13 +801,25 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo
return return
} }
if isSender { log.Debug("tell server if this is a main room")
log.Debug("tell server if this is a main room") roomType := "secondary"
roomType := "secondary" if isMainRoom {
if isMainRoom { roomType = "main"
roomType = "main" }
} bSend, err = crypt.Encrypt([]byte(roomType), strongKeyForEncryption)
bSend, err = crypt.Encrypt([]byte(roomType), strongKeyForEncryption) if err != nil {
log.Debug(err)
return
}
err = c.Send(bSend)
if err != nil {
log.Debug(err)
return
}
if isMainRoom && isSender {
log.Debug("tell server maxTransfers")
bSend, err = crypt.Encrypt([]byte(strconv.Itoa(maxTransfers)), strongKeyForEncryption)
if err != nil { if err != nil {
log.Debug(err) log.Debug(err)
return return
@ -815,20 +829,6 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo
log.Debug(err) log.Debug(err)
return return
} }
if isMainRoom {
log.Debug("tell server maxTransfers")
bSend, err = crypt.Encrypt([]byte(strconv.Itoa(maxTransfers)), strongKeyForEncryption)
if err != nil {
log.Debug(err)
return
}
err = c.Send(bSend)
if err != nil {
log.Debug(err)
return
}
}
} }
log.Debug("waiting for room confirmation") log.Debug("waiting for room confirmation")

View File

@ -33,18 +33,6 @@ func TestTCPServerPing(t *testing.T) {
assert.NotNil(t, err) assert.NotNil(t, err)
} }
// Test that a reeciver cannot connect to a non-existent room
func TestTCPServerNonExistentRoom(t *testing.T) {
log.SetLevel("error")
go Run("debug", "127.0.0.1", "8381", "pass123", "8382")
time.Sleep(100 * time.Millisecond)
c1, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1)
assert.NotNil(t, err)
assert.True(t, strings.Contains(err.Error(), "room does not exist"))
assert.Nil(t, c1)
}
// This is helper function to test that a mocks a transfer // This is helper function to test that a mocks a transfer
// between two clients connected to the server, // between two clients connected to the server,
// and checks that the data is transferred correctly // and checks that the data is transferred correctly
@ -97,6 +85,28 @@ func TestTCPServerSingleConnectionTransfer(t *testing.T) {
time.Sleep(300 * time.Millisecond) time.Sleep(300 * time.Millisecond)
} }
// Test that a receiver can connect before a sender
func TestTCPRecieverFirst(t *testing.T) {
log.SetLevel("error")
go Run("debug", "127.0.0.1", "8381", "pass123", "8382")
time.Sleep(100 * time.Millisecond)
receiver, banner, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1, 1*time.Minute)
assert.Nil(t, err)
assert.NotNil(t, receiver)
assert.Equal(t, banner, "8382")
sender, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", true, true, 1, 1*time.Minute)
assert.Nil(t, err)
assert.NotNil(t, sender)
mockTransfer(receiver, sender, t)
receiver.Close()
sender.Close()
time.Sleep(300 * time.Millisecond)
}
// Test that a third client cannot connect // Test that a third client cannot connect
// to a room that already has two clients // to a room that already has two clients
// connected to it with maxTransfers=1 // connected to it with maxTransfers=1
@ -237,26 +247,15 @@ func TestTCPMultipleConnectionWaitingRoomCloses(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, c2) assert.NotNil(t, c2)
// we need to run this transfer in a goroutine because
// otherwise connections will be idle and the server will
// close them when we try to connect a third client
go func() { go func() {
counter := 1 c3, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1, 5*time.Minute)
time.Sleep(100 * time.Millisecond) assert.NotNil(t, err)
for { assert.True(t, strings.Contains(err.Error(), "sender is gone"))
mockTransfer(c1, c2, t) assert.Nil(t, c3)
if counter == 5 {
c2.Close()
// tell c1 to close pipe listener
c1.Send([]byte("finished"))
break
}
counter++
}
}() }()
c3, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1, 5*time.Minute) time.Sleep(100 * time.Millisecond)
assert.NotNil(t, err) c2.Close()
assert.True(t, strings.Contains(err.Error(), "sender is gone")) // tell c1 to close pipe listener
assert.Nil(t, c3) c1.Send([]byte("finished"))
} }