add banner

This commit is contained in:
Zack Scholl 2019-04-30 17:05:19 -06:00
parent 23c9a9cff8
commit 63ec16f7fb
3 changed files with 24 additions and 18 deletions

View File

@ -261,7 +261,9 @@ func (c *Client) Send(options TransferOptions) (err error) {
go func() { go func() {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
log.Debug("establishing connection") log.Debug("establishing connection")
conn, err := tcp.ConnectToTCPServer("localhost:"+c.Options.RelayPorts[0], c.Options.SharedSecret) var banner string
conn, banner, err := tcp.ConnectToTCPServer("localhost:"+c.Options.RelayPorts[0], c.Options.SharedSecret)
log.Debugf("banner: %s", banner)
if err != nil { if err != nil {
err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress)) err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress))
return return
@ -282,7 +284,9 @@ func (c *Client) Send(options TransferOptions) (err error) {
go func() { go func() {
log.Debug("establishing connection") log.Debug("establishing connection")
conn, err := tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret) var banner string
conn, banner, err := tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret)
log.Debugf("banner: %s", banner)
if err != nil { if err != nil {
err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress)) err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress))
return return
@ -318,7 +322,9 @@ func (c *Client) Receive() (err error) {
log.Debugf("discoveries: %+v", discoveries) log.Debugf("discoveries: %+v", discoveries)
log.Debug("establishing connection") log.Debug("establishing connection")
} }
c.conn[0], err = tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret) var banner string
c.conn[0], banner, err = tcp.ConnectToTCPServer(c.Options.RelayAddress+":"+c.Options.RelayPorts[0], c.Options.SharedSecret)
log.Debugf("banner: %s", banner)
if err != nil { if err != nil {
err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress)) err = errors.Wrap(err, fmt.Sprintf("could not connect to %s", c.Options.RelayAddress))
return return
@ -412,7 +418,7 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) {
for i := 1; i < len(c.Options.RelayPorts); i++ { for i := 1; i < len(c.Options.RelayPorts); i++ {
go func(j int) { go func(j int) {
defer wg.Done() defer wg.Done()
c.conn[j], err = tcp.ConnectToTCPServer( c.conn[j], _, err = tcp.ConnectToTCPServer(
fmt.Sprintf("%s:%s", c.Options.RelayAddress, c.Options.RelayPorts[j]), fmt.Sprintf("%s:%s", c.Options.RelayAddress, c.Options.RelayPorts[j]),
fmt.Sprintf("%s-%d", utils.SHA256(c.Options.SharedSecret)[:7], j), fmt.Sprintf("%s-%d", utils.SHA256(c.Options.SharedSecret)[:7], j),
) )
@ -640,7 +646,7 @@ func (c *Client) updateState() (err error) {
} }
func (c *Client) setBar() { func (c *Client) setBar() {
description := fmt.Sprintf("%28s", c.FilesToTransfer[c.FilesToTransferCurrentNum].Name) description := fmt.Sprintf("%-28s", c.FilesToTransfer[c.FilesToTransferCurrentNum].Name)
if len(c.FilesToTransfer) == 1 { if len(c.FilesToTransfer) == 1 {
description = c.FilesToTransfer[c.FilesToTransferCurrentNum].Name description = c.FilesToTransfer[c.FilesToTransferCurrentNum].Name
} }

View File

@ -14,10 +14,10 @@ import (
"github.com/schollz/croc/v6/src/models" "github.com/schollz/croc/v6/src/models"
) )
type server struct { type server struct {
port string port string
debugLevel string debugLevel string
banner string
rooms roomMap rooms roomMap
} }
@ -34,10 +34,13 @@ type roomMap struct {
} }
// Run starts a tcp listener, run async // Run starts a tcp listener, run async
func Run(debugLevel, port string) (err error) { func Run(debugLevel, port string, banner ...string) (err error) {
s := new(server) s := new(server)
s.port = port s.port = port
s.debugLevel = debugLevel s.debugLevel = debugLevel
if len(banner) > 0 {
s.banner = banner[0]
}
return s.start() return s.start()
} }
@ -64,8 +67,8 @@ func (s *server) start() (err error) {
err = s.run() err = s.run()
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
return return
} }
func (s *server) run() (err error) { func (s *server) run() (err error) {
@ -94,7 +97,7 @@ func (s *server) run() (err error) {
func (s *server) clientCommuncation(port string, c *comm.Comm) (err error) { func (s *server) clientCommuncation(port string, c *comm.Comm) (err error) {
// send ok to tell client they are connected // send ok to tell client they are connected
log.Debug("sending ok") log.Debug("sending ok")
err = c.Send([]byte("ok")) err = c.Send([]byte(s.banner))
if err != nil { if err != nil {
return return
} }
@ -224,7 +227,7 @@ func pipe(conn1 net.Conn, conn2 net.Conn) {
} }
} }
func ConnectToTCPServer(address, room string) (c *comm.Comm, err error) { func ConnectToTCPServer(address, room string) (c *comm.Comm, banner string, err error) {
c, err = comm.NewConnection(address) c, err = comm.NewConnection(address)
if err != nil { if err != nil {
return return
@ -233,10 +236,7 @@ func ConnectToTCPServer(address, room string) (c *comm.Comm, err error) {
if err != nil { if err != nil {
return return
} }
if !bytes.Equal(data, []byte("ok")) { banner = string(data)
err = fmt.Errorf("got bad response: %s", data)
return
}
err = c.Send([]byte(room)) err = c.Send([]byte(room))
if err != nil { if err != nil {
return return

View File

@ -10,11 +10,11 @@ import (
func TestTCP(t *testing.T) { func TestTCP(t *testing.T) {
go Run("debug", "8081") go Run("debug", "8081")
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
c1, err := ConnectToTCPServer("localhost:8081", "testRoom") c1, _, err := ConnectToTCPServer("localhost:8081", "testRoom")
assert.Nil(t, err) assert.Nil(t, err)
c2, err := ConnectToTCPServer("localhost:8081", "testRoom") c2, _, err := ConnectToTCPServer("localhost:8081", "testRoom")
assert.Nil(t, err) assert.Nil(t, err)
_, err = ConnectToTCPServer("localhost:8081", "testRoom") _, _, err = ConnectToTCPServer("localhost:8081", "testRoom")
assert.NotNil(t, err) assert.NotNil(t, err)
// try sending data // try sending data