diff --git a/src/croc/croc.go b/src/croc/croc.go index af8bb2d..8c130e7 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -614,7 +614,12 @@ func (c *Client) transferOverLocalRelay(errchan chan error) { log.Debugf("[%+v] had error: %s", conn, errConn.Error()) break } - if bytes.Equal(data, handshakeRequest) { + if bytes.Equal(data, ipRequest) { + log.Debug("Got ip request, sending nil since we are local") + if err = conn.Send(nil); err != nil { + log.Errorf("error sending: %v", err) + } + } else if bytes.Equal(data, handshakeRequest) { wgTransfer.Add(1) go c.makeLocalTransfer(conn, ipaddr, banner, errchan) wgTransfer.Wait() @@ -989,7 +994,7 @@ func (c *Client) Receive() (err error) { } serverTry := net.JoinHostPort(ip, port) - conn, banner2, externalIP, errConn := tcp.ConnectToTCPServer(serverTry, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, true, 1, 500*time.Millisecond) + conn, banner2, externalIP, errConn := tcp.ConnectToTCPServer(serverTry, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, false, 1, 500*time.Millisecond) if errConn != nil { log.Debug(errConn) log.Debugf("could not connect to " + serverTry) @@ -1009,6 +1014,7 @@ func (c *Client) Receive() (err error) { } } + log.Debug("sending handshake message") if err = c.conn[0].Send(handshakeRequest); err != nil { log.Errorf("handshake send error: %v", err) } @@ -1037,6 +1043,7 @@ func (c *Client) transfer() (err error) { // if recipient, initialize with sending pake information log.Debug("ready") if !c.Options.IsSender && !c.Step1ChannelSecured { + log.Debug("sending pake information") err = message.Send(c.conn[0], c.Key, message.Message{ Type: message.TypePAKE, Bytes: c.Pake.Bytes(), diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index f0eeb00..ad3d5cc 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -35,7 +35,6 @@ type roomInfo struct { maxTransfers int doneTransfers int opened time.Time - transfering bool } type roomMap struct { @@ -293,39 +292,36 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er } room = string(roomBytes) - s.rooms.Lock() - if isSender { - if _, ok := s.rooms.rooms[room]; !ok { - // create the room if it is new - err = s.createRoom(c, room, strongKeyForEncryption) - if err != nil { - log.Error(err) - } - // sender is done - return - } else { - // if the room already exists, then tell the client that the room is full - err = s.sendRoomIsFull(c, strongKeyForEncryption) - return - } + log.Debug("Check if this is a main room") + enc, err = c.Receive() + if err != nil { + return } + data, err = crypt.Decrypt(enc, strongKeyForEncryption) + if err != nil { + return + } + if !bytes.Equal(data, []byte("main")) && !bytes.Equal(data, []byte("secondary")) { + err = fmt.Errorf("got bad response: %s", data) + return + } + isMainRoom := bytes.Equal(data, []byte("main")) + log.Debugf("isMainRoom: %v", isMainRoom) - if _, ok := s.rooms.rooms[room]; !ok { - // if the room does not exist and the client is a receiver, then tell them - // that the room does not exist - s.rooms.Unlock() - bSend, err = crypt.Encrypt([]byte(noRoom), strongKeyForEncryption) - if err != nil { - return - } - err = c.Send(bSend) + s.rooms.Lock() + _, roomExists := s.rooms.rooms[room] + // create the room if it is new + if !roomExists || isSender { + err = s.createOrUpdateRoom(c, room, strongKeyForEncryption, isMainRoom, isSender, roomExists) if err != nil { log.Error(err) - return } - return "", fmt.Errorf("reciever tried to connect to room that does not exist") - } else if s.rooms.rooms[room].transfering { + // if the room is new then return + if !roomExists { + return + } + } else if s.rooms.rooms[room].receiver != nil { // if the room has a transfer going on if s.rooms.rooms[room].maxTransfers > 1 { // if the room is a multi-transfer room then add to queue @@ -349,7 +345,6 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er maxTransfers: s.rooms.rooms[room].maxTransfers, doneTransfers: s.rooms.rooms[room].doneTransfers, opened: s.rooms.rooms[room].opened, - transfering: true, } s.rooms.roomLocks[room].Lock() } @@ -376,26 +371,11 @@ func (s *server) sendRoomIsFull(c *comm.Comm, strongKeyForEncryption []byte) (er return } -func (s *server) createRoom(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) { +func (s *server) createOrUpdateRoom(c *comm.Comm, room string, strongKeyForEncryption []byte, isMainRoom, isSender, updateRoom bool) (err error) { var enc, data, bSend []byte - log.Debug("Check if this is a main room") - enc, err = c.Receive() - if err != nil { - return - } - data, err = crypt.Decrypt(enc, strongKeyForEncryption) - if err != nil { - return - } - if !bytes.Equal(data, []byte("main")) && !bytes.Equal(data, []byte("secondary")) { - err = fmt.Errorf("got bad response: %s", data) - return - } - isMainRoom := bytes.Equal(data, []byte("main")) - log.Debugf("isMainRoom: %v", isMainRoom) maxTransfers := 1 - if isMainRoom { + if isMainRoom && isSender { log.Debug("Wait for maxTransfers") enc, err = c.Receive() if err != nil { @@ -413,29 +393,53 @@ func (s *server) createRoom(c *comm.Comm, room string, strongKeyForEncryption [] log.Debugf("maxTransfers: %v", maxTransfers) } + var sender, receiver *comm.Comm + var queue *list.List + opened := time.Now() + if isSender { + sender = c + if updateRoom { + receiver = s.rooms.rooms[room].receiver + queue = s.rooms.rooms[room].queue + opened = s.rooms.rooms[room].opened + } + } else { + receiver = c + if updateRoom { + sender = s.rooms.rooms[room].sender + queue = s.rooms.rooms[room].queue + opened = s.rooms.rooms[room].opened + } + } + s.rooms.rooms[room] = roomInfo{ - sender: c, - receiver: nil, + sender: sender, + receiver: receiver, + queue: queue, isMainRoom: isMainRoom, maxTransfers: maxTransfers, doneTransfers: 0, - opened: time.Now(), + opened: opened, } - s.rooms.roomLocks[room] = &sync.Mutex{} - s.rooms.Unlock() - // tell the client that they got the room - bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption) - if err != nil { - return + if !updateRoom { + log.Debugf("Client crated main room %s, %v", room, isSender) + s.rooms.roomLocks[room] = &sync.Mutex{} + // tell the client that they got the room + bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption) + if err != nil { + return + } + err = c.Send(bSend) + if err != nil { + log.Error(err) + s.deleteRoom(room) + return + } + log.Debugf("room %s has 1", room) + s.rooms.Unlock() } - err = c.Send(bSend) - if err != nil { - log.Error(err) - s.deleteRoom(room) - return - } - log.Debugf("room %s has 1", room) + return } @@ -455,7 +459,6 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong maxTransfers: s.rooms.rooms[room].maxTransfers, doneTransfers: s.rooms.rooms[room].doneTransfers, queue: queue, - transfering: true, } s.rooms.Unlock() @@ -500,7 +503,6 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong maxTransfers: s.rooms.rooms[room].maxTransfers, doneTransfers: s.rooms.rooms[room].doneTransfers, opened: s.rooms.rooms[room].opened, - transfering: true, } break } @@ -509,7 +511,6 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong } func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) { - otherConnection := s.rooms.rooms[room].sender s.rooms.Unlock() // second connection is the sender, time to staple connections @@ -522,9 +523,10 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption pipe(com1.Connection(), com2.Connection()) wg.Done() log.Debug("done piping") - }(otherConnection, c, &wg) + }(s.rooms.rooms[room].sender, s.rooms.rooms[room].receiver, &wg) - // tell the receiver everything is ready + // tell the client everything is ready + log.Debug("sending ok to client") bSend, err := crypt.Encrypt([]byte("ok"), strongKeyForEncryption) if err != nil { return @@ -544,6 +546,7 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption if s.rooms.rooms[room].queue != nil { lengthOfQueue = s.rooms.rooms[room].queue.Len() } + log.Debugf("room %s has %d left in queue", room, lengthOfQueue) s.rooms.rooms[room] = roomInfo{ sender: s.rooms.rooms[room].sender, receiver: nil, @@ -552,7 +555,6 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption maxTransfers: s.rooms.rooms[room].maxTransfers, doneTransfers: newDoneTransfers, opened: s.rooms.rooms[room].opened, - transfering: lengthOfQueue > 0 && newDoneTransfers < s.rooms.rooms[room].maxTransfers, } s.rooms.Unlock() @@ -799,13 +801,25 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo return } - if isSender { - log.Debug("tell server if this is a main room") - roomType := "secondary" - if isMainRoom { - roomType = "main" - } - bSend, err = crypt.Encrypt([]byte(roomType), strongKeyForEncryption) + log.Debug("tell server if this is a main room") + roomType := "secondary" + if isMainRoom { + roomType = "main" + } + bSend, err = crypt.Encrypt([]byte(roomType), strongKeyForEncryption) + if err != nil { + log.Debug(err) + return + } + err = c.Send(bSend) + if err != nil { + log.Debug(err) + return + } + + if isMainRoom && isSender { + log.Debug("tell server maxTransfers") + bSend, err = crypt.Encrypt([]byte(strconv.Itoa(maxTransfers)), strongKeyForEncryption) if err != nil { log.Debug(err) return @@ -815,20 +829,6 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo log.Debug(err) return } - - if isMainRoom { - log.Debug("tell server maxTransfers") - bSend, err = crypt.Encrypt([]byte(strconv.Itoa(maxTransfers)), strongKeyForEncryption) - if err != nil { - log.Debug(err) - return - } - err = c.Send(bSend) - if err != nil { - log.Debug(err) - return - } - } } log.Debug("waiting for room confirmation") diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go index 6fd4813..456216f 100644 --- a/src/tcp/tcp_test.go +++ b/src/tcp/tcp_test.go @@ -33,18 +33,6 @@ func TestTCPServerPing(t *testing.T) { assert.NotNil(t, err) } -// Test that a reeciver cannot connect to a non-existent room -func TestTCPServerNonExistentRoom(t *testing.T) { - log.SetLevel("error") - go Run("debug", "127.0.0.1", "8381", "pass123", "8382") - time.Sleep(100 * time.Millisecond) - - c1, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1) - assert.NotNil(t, err) - assert.True(t, strings.Contains(err.Error(), "room does not exist")) - assert.Nil(t, c1) -} - // This is helper function to test that a mocks a transfer // between two clients connected to the server, // and checks that the data is transferred correctly @@ -97,6 +85,28 @@ func TestTCPServerSingleConnectionTransfer(t *testing.T) { time.Sleep(300 * time.Millisecond) } +// Test that a receiver can connect before a sender +func TestTCPRecieverFirst(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8381", "pass123", "8382") + time.Sleep(100 * time.Millisecond) + + receiver, banner, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1, 1*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, receiver) + assert.Equal(t, banner, "8382") + + sender, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", true, true, 1, 1*time.Minute) + assert.Nil(t, err) + assert.NotNil(t, sender) + + mockTransfer(receiver, sender, t) + + receiver.Close() + sender.Close() + time.Sleep(300 * time.Millisecond) +} + // Test that a third client cannot connect // to a room that already has two clients // connected to it with maxTransfers=1 @@ -237,26 +247,15 @@ func TestTCPMultipleConnectionWaitingRoomCloses(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, c2) - // we need to run this transfer in a goroutine because - // otherwise connections will be idle and the server will - // close them when we try to connect a third client go func() { - counter := 1 - time.Sleep(100 * time.Millisecond) - for { - mockTransfer(c1, c2, t) - if counter == 5 { - c2.Close() - // tell c1 to close pipe listener - c1.Send([]byte("finished")) - break - } - counter++ - } + c3, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1, 5*time.Minute) + assert.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "sender is gone")) + assert.Nil(t, c3) }() - c3, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", false, true, 1, 5*time.Minute) - assert.NotNil(t, err) - assert.True(t, strings.Contains(err.Error(), "sender is gone")) - assert.Nil(t, c3) + time.Sleep(100 * time.Millisecond) + c2.Close() + // tell c1 to close pipe listener + c1.Send([]byte("finished")) }