diff --git a/src/croc/croc.go b/src/croc/croc.go index 2c983ee..61add9f 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -261,7 +261,9 @@ func (c *Client) Send(options TransferOptions) (err error) { go func() { time.Sleep(500 * time.Millisecond) log.Debug("establishing connection") - conn, err := tcp.ConnectToTCPServer("localhost:"+c.Options.RelayPorts[0], c.Options.SharedSecret) + var banner string + conn, banner, err := tcp.ConnectToTCPServer("localhost:"+c.Options.RelayPorts[0], c.Options.SharedSecret) + log.Debugf("banner: %s", banner) if err != nil { err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress)) return @@ -282,7 +284,9 @@ func (c *Client) Send(options TransferOptions) (err error) { go func() { log.Debug("establishing connection") - conn, err := tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret) + var banner string + conn, banner, err := tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret) + log.Debugf("banner: %s", banner) if err != nil { err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress)) return @@ -318,7 +322,9 @@ func (c *Client) Receive() (err error) { log.Debugf("discoveries: %+v", discoveries) log.Debug("establishing connection") } - c.conn[0], err = tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret) + var banner string + c.conn[0], banner, err = tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret) + log.Debugf("banner: %s", banner) if err != nil { err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress)) return @@ -412,7 +418,7 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) { for i := 1; i < len(c.Options.RelayPorts); i++ { go func(j int) { defer wg.Done() - c.conn[j], err = tcp.ConnectToTCPServer( + c.conn[j], _, err = tcp.ConnectToTCPServer( fmt.Sprintf("%s:%s", c.Options.RelayAddress, c.Options.RelayPorts[j]), fmt.Sprintf("%s-%d", utils.SHA256(c.Options.SharedSecret)[:7], j), ) @@ -640,7 +646,7 @@ func (c *Client) updateState() (err error) { } func (c *Client) setBar() { - description := fmt.Sprintf("%28s", c.FilesToTransfer[c.FilesToTransferCurrentNum].Name) + description := fmt.Sprintf("%-28s", c.FilesToTransfer[c.FilesToTransferCurrentNum].Name) if len(c.FilesToTransfer) == 1 { description = c.FilesToTransfer[c.FilesToTransferCurrentNum].Name } diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 02be442..6294820 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -14,10 +14,10 @@ import ( "github.com/schollz/croc/v6/src/models" ) - type server struct { port string debugLevel string + banner string rooms roomMap } @@ -34,10 +34,13 @@ type roomMap struct { } // Run starts a tcp listener, run async -func Run(debugLevel, port string) (err error) { +func Run(debugLevel, port string, banner ...string) (err error) { s := new(server) s.port = port s.debugLevel = debugLevel + if len(banner) > 0 { + s.banner = banner[0] + } return s.start() } @@ -64,8 +67,8 @@ func (s *server) start() (err error) { err = s.run() if err != nil { log.Error(err) - } - return + } + return } func (s *server) run() (err error) { @@ -94,7 +97,7 @@ func (s *server) run() (err error) { func (s *server) clientCommuncation(port string, c *comm.Comm) (err error) { // send ok to tell client they are connected log.Debug("sending ok") - err = c.Send([]byte("ok")) + err = c.Send([]byte(s.banner)) if err != nil { return } @@ -224,7 +227,7 @@ func pipe(conn1 net.Conn, conn2 net.Conn) { } } -func ConnectToTCPServer(address, room string) (c *comm.Comm, err error) { +func ConnectToTCPServer(address, room string) (c *comm.Comm, banner string, err error) { c, err = comm.NewConnection(address) if err != nil { return @@ -233,10 +236,7 @@ func ConnectToTCPServer(address, room string) (c *comm.Comm, err error) { if err != nil { return } - if !bytes.Equal(data, []byte("ok")) { - err = fmt.Errorf("got bad response: %s", data) - return - } + banner = string(data) err = c.Send([]byte(room)) if err != nil { return diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go index 14482fd..f048b46 100644 --- a/src/tcp/tcp_test.go +++ b/src/tcp/tcp_test.go @@ -10,11 +10,11 @@ import ( func TestTCP(t *testing.T) { go Run("debug", "8081") time.Sleep(100 * time.Millisecond) - c1, err := ConnectToTCPServer("localhost:8081", "testRoom") + c1, _, err := ConnectToTCPServer("localhost:8081", "testRoom") assert.Nil(t, err) - c2, err := ConnectToTCPServer("localhost:8081", "testRoom") + c2, _, err := ConnectToTCPServer("localhost:8081", "testRoom") assert.Nil(t, err) - _, err = ConnectToTCPServer("localhost:8081", "testRoom") + _, _, err = ConnectToTCPServer("localhost:8081", "testRoom") assert.NotNil(t, err) // try sending data