diff --git a/main.go b/main.go index 5d001db..72fcccd 100644 --- a/main.go +++ b/main.go @@ -61,13 +61,16 @@ func main() { }, } app.Flags = []cli.Flag{ - cli.StringFlag{Name: "relay", Value: "ws://198.199.67.130:8153"}, + cli.StringFlag{Name: "addr", Value: "198.199.67.130", Usage: "address of the public relay"}, + cli.StringFlag{Name: "addr-ws", Value: "8153", Usage: "port of the public relay websocket server to connect"}, + cli.StringFlag{Name: "addr-tcp", Value: "8154", Usage: "tcp port of the public relay serer to connect"}, cli.BoolFlag{Name: "no-local", Usage: "disable local mode"}, cli.BoolFlag{Name: "local", Usage: "use only local mode"}, cli.BoolFlag{Name: "debug", Usage: "increase verbosity (a lot)"}, cli.BoolFlag{Name: "yes", Usage: "automatically agree to all prompts"}, cli.BoolFlag{Name: "stdout", Usage: "redirect file to stdout"}, cli.StringFlag{Name: "port", Value: "8153", Usage: "port that the websocket listens on"}, + cli.StringFlag{Name: "tcp-port", Value: "8154", Usage: "port that the tcp server listens on"}, cli.StringFlag{Name: "curve", Value: "siec", Usage: "specify elliptic curve to use (p224, p256, p384, p521, siec)"}, } app.EnableBashCompletion = true @@ -82,13 +85,16 @@ func main() { app.Before = func(c *cli.Context) error { cr = croc.Init(c.GlobalBool("debug")) cr.AllowLocalDiscovery = true - cr.WebsocketAddress = c.GlobalString("relay") + cr.Address = c.GlobalString("addr") + cr.AddressTCPPort = c.GlobalString("addr-tcp") + cr.AddressWebsocketPort = c.GlobalString("addr-ws") cr.NoRecipientPrompt = c.GlobalBool("yes") cr.Stdout = c.GlobalBool("stdout") cr.LocalOnly = c.GlobalBool("local") cr.NoLocal = c.GlobalBool("no-local") cr.ShowText = true - cr.ServerPort = c.String("port") + cr.RelayWebsocketPort = c.String("port") + cr.RelayTCPPort = c.String("tcp-port") cr.CurveType = c.String("curve") return nil } diff --git a/src/comm/comm.go b/src/comm/comm.go new file mode 100644 index 0000000..afb6785 --- /dev/null +++ b/src/comm/comm.go @@ -0,0 +1,77 @@ +package comm + +import ( + "net" + "strings" + "time" + + "github.com/schollz/croc/src/models" +) + +// Comm is some basic TCP communication +type Comm struct { + connection net.Conn +} + +// New returns a new comm +func New(c net.Conn) Comm { + return Comm{c} +} + +// Connection returns the net.Conn connection +func (c Comm) Connection() net.Conn { + return c.connection +} + +func (c Comm) Write(b []byte) (int, error) { + return c.connection.Write(b) +} + +func (c Comm) Read() (buf []byte, err error) { + buf = make([]byte, models.WEBSOCKET_BUFFER_SIZE) + n, err := c.connection.Read(buf) + buf = buf[:n] + return +} + +// Send a message +func (c Comm) Send(message string) (err error) { + message = fillString(message, models.TCP_BUFFER_SIZE) + _, err = c.connection.Write([]byte(message)) + return +} + +// Receive a message +func (c Comm) Receive() (s string, err error) { + messageByte := make([]byte, models.TCP_BUFFER_SIZE) + err = c.connection.SetReadDeadline(time.Now().Add(60 * time.Minute)) + if err != nil { + return + } + err = c.connection.SetDeadline(time.Now().Add(60 * time.Minute)) + if err != nil { + return + } + err = c.connection.SetWriteDeadline(time.Now().Add(60 * time.Minute)) + if err != nil { + return + } + _, err = c.connection.Read(messageByte) + if err != nil { + return + } + s = strings.TrimRight(string(messageByte), ":") + return +} + +func fillString(returnString string, toLength int) string { + for { + lengthString := len(returnString) + if lengthString < toLength { + returnString = returnString + ":" + continue + } + break + } + return returnString +} diff --git a/src/croc/croc.go b/src/croc/croc.go index bbd7027..cda398b 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -18,14 +18,17 @@ type Croc struct { ShowText bool // Options for relay - ServerPort string - CurveType string + RelayWebsocketPort string + RelayTCPPort string + CurveType string // Options for connecting to server - WebsocketAddress string - Timeout time.Duration - LocalOnly bool - NoLocal bool + Address string + AddressTCPPort string + AddressWebsocketPort string + Timeout time.Duration + LocalOnly bool + NoLocal bool // Options for file transfering UseEncryption bool @@ -48,7 +51,6 @@ type Croc struct { // Init will initiate with the default parameters func Init(debug bool) (c *Croc) { c = new(Croc) - c.ServerPort = "8152" c.CurveType = "siec" c.UseCompression = true c.UseEncryption = true diff --git a/src/croc/sending.go b/src/croc/sending.go index 1b8f392..f044d68 100644 --- a/src/croc/sending.go +++ b/src/croc/sending.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "os/signal" + "strings" "time" log "github.com/cihub/seelog" @@ -28,7 +29,7 @@ func (c *Croc) Send(fname, codephrase string) (err error) { if !c.LocalOnly { go func() { // atttempt to connect to public relay - errChan <- c.sendReceive(c.WebsocketAddress, fname, codephrase, true, false) + errChan <- c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPort, fname, codephrase, true, false) }() } else { waitingFor = 1 @@ -38,7 +39,7 @@ func (c *Croc) Send(fname, codephrase string) (err error) { if !c.NoLocal { go func() { // start own relay and connect to it - go relay.Run(c.ServerPort) + go relay.Run(c.RelayWebsocketPort, "") time.Sleep(250 * time.Millisecond) // race condition here, but this should work most of the time :( // broadcast for peer discovery @@ -48,12 +49,12 @@ func (c *Croc) Send(fname, codephrase string) (err error) { Limit: 1, TimeLimit: 600 * time.Second, Delay: 50 * time.Millisecond, - Payload: []byte(c.ServerPort), + Payload: []byte(c.RelayWebsocketPort + "-" + c.RelayTCPPort), }) }() // connect to own relay - errChan <- c.sendReceive("ws://localhost:"+c.ServerPort, fname, codephrase, true, true) + errChan <- c.sendReceive("localhost", c.RelayWebsocketPort, c.RelayTCPPort, fname, codephrase, true, true) }() } else { waitingFor = 1 @@ -95,7 +96,11 @@ func (c *Croc) Receive(codephrase string) (err error) { if err == nil { if resp.StatusCode == http.StatusOK { // we connected, so use this - return c.sendReceive(fmt.Sprintf("ws://%s:%s", discovered[0].Address, discovered[0].Payload), "", codephrase, false, true) + ports := strings.Split(string(discovered[0].Payload), "-") + if len(ports) != 2 { + return errors.New("bad payload") + } + return c.sendReceive(discovered[0].Address, ports[0], ports[1], "", codephrase, false, true) } } else { log.Debugf("could not connect: %s", err.Error()) @@ -108,13 +113,13 @@ func (c *Croc) Receive(codephrase string) (err error) { // use public relay if !c.LocalOnly { log.Debug("using public relay") - return c.sendReceive(c.WebsocketAddress, "", codephrase, false, false) + return c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPort, "", codephrase, false, false) } return errors.New("must use local or public relay") } -func (c *Croc) sendReceive(websocketAddress, fname, codephrase string, isSender bool, isLocal bool) (err error) { +func (c *Croc) sendReceive(address, websocketPort, tcpPort, fname, codephrase string, isSender bool, isLocal bool) (err error) { defer log.Flush() if len(codephrase) < 4 { return fmt.Errorf("codephrase is too short") @@ -126,8 +131,11 @@ func (c *Croc) sendReceive(websocketAddress, fname, codephrase string, isSender done := make(chan struct{}) // connect to server - log.Debugf("connecting to %s", websocketAddress+"/ws?room="+codephrase[:3]) - sock, _, err := websocket.DefaultDialer.Dial(websocketAddress+"/ws?room="+codephrase[:3], nil) + log.Debugf("connecting to %s", address+"/ws?room="+codephrase[:3]) + if len(websocketPort) > 0 { + address += ":" + websocketPort + } + sock, _, err := websocket.DefaultDialer.Dial("ws://"+address+"/ws?room="+codephrase[:3], nil) if err != nil { return } @@ -176,5 +184,5 @@ func (c *Croc) sendReceive(websocketAddress, fname, codephrase string, isSender // Relay will start a relay on the specified port func (c *Croc) Relay() (err error) { - return relay.Run(c.ServerPort) + return relay.Run(c.RelayWebsocketPort, c.RelayTCPPort) } diff --git a/src/models/constants.go b/src/models/constants.go index 04bcd8b..3196ca4 100644 --- a/src/models/constants.go +++ b/src/models/constants.go @@ -1,3 +1,4 @@ package models const WEBSOCKET_BUFFER_SIZE = 1024 * 1024 * 32 +const TCP_BUFFER_SIZE = 1024 diff --git a/src/recipient/recipient.go b/src/recipient/recipient.go index 4f5bbb9..5c97d86 100644 --- a/src/recipient/recipient.go +++ b/src/recipient/recipient.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io/ioutil" + "net" "os" "strings" "time" @@ -14,6 +15,7 @@ import ( log "github.com/cihub/seelog" "github.com/gorilla/websocket" + "github.com/schollz/croc/src/comm" "github.com/schollz/croc/src/compress" "github.com/schollz/croc/src/crypt" "github.com/schollz/croc/src/logger" @@ -46,6 +48,7 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, var transferTime time.Duration var hash256 []byte var otherIP string + var tcpConnection comm.Comm // start a spinner spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond) @@ -156,6 +159,15 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, } } + // connect to TCP to receive file + if !isLocal { + tcpConnection, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), "localhost:8154") + if err != nil { + log.Error(err) + return err + } + } + // await file f, err := os.Create(fstats.SentName) if err != nil { @@ -173,12 +185,23 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, c.WriteMessage(websocket.BinaryMessage, []byte("ready")) startTime := time.Now() for { - messageType, message, err := c.ReadMessage() - if err != nil { - return err + if isLocal { + var messageType int + // read from websockets + messageType, message, err = c.ReadMessage() + if messageType != websocket.BinaryMessage { + continue + } + } else { + // read from TCP connection + message, err = tcpConnection.Read() + if bytes.Equal(message, []byte("end")) { + break + } } - if messageType != websocket.BinaryMessage { - continue + if err != nil { + log.Error(err) + return err } // // tell the sender that we recieved this packet @@ -289,3 +312,34 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool, step++ } } + +func connectToTCPServer(room string, address string) (com comm.Comm, err error) { + connection, err := net.Dial("tcp", address) + if err != nil { + return + } + connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) + connection.SetDeadline(time.Now().Add(3 * time.Hour)) + connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) + + com = comm.New(connection) + ok, err := com.Receive() + if err != nil { + return + } + log.Debugf("server says: %s", ok) + + err = com.Send(room) + if err != nil { + return + } + ok, err = com.Receive() + log.Debugf("server says: %s", ok) + if err != nil { + return + } + if ok != "recipient" { + err = errors.New(ok) + } + return +} diff --git a/src/relay/relay.go b/src/relay/relay.go index 19830c0..42aa596 100644 --- a/src/relay/relay.go +++ b/src/relay/relay.go @@ -6,14 +6,19 @@ import ( log "github.com/cihub/seelog" "github.com/schollz/croc/src/logger" + "github.com/schollz/croc/src/tcp" ) var DebugLevel string // Run is the async operation for running a server -func Run(port string) (err error) { +func Run(port string, tcpPort string) (err error) { logger.SetLogLevel(DebugLevel) + if tcpPort != "" { + go tcp.Run(DebugLevel, tcpPort) + } + go h.run() log.Debug("running relay on " + port) http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { diff --git a/src/sender/sender.go b/src/sender/sender.go index 2a946a7..d54932c 100644 --- a/src/sender/sender.go +++ b/src/sender/sender.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "net" "os" "path/filepath" "strings" @@ -13,6 +14,7 @@ import ( log "github.com/cihub/seelog" "github.com/gorilla/websocket" "github.com/pkg/errors" + "github.com/schollz/croc/src/comm" "github.com/schollz/croc/src/compress" "github.com/schollz/croc/src/crypt" "github.com/schollz/croc/src/logger" @@ -48,6 +50,8 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC var fileHash []byte var otherIP string var startTransfer time.Time + var tcpConnection comm.Comm + fileReady := make(chan error) // normalize the file name @@ -191,6 +195,15 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC return errors.New("recipient refused file") } + if !isLocal { + // connection to TCP + tcpConnection, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), "localhost:8154") + if err != nil { + log.Error(err) + return + } + } + fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP) // send file, compure hash simultaneously startTransfer = time.Now() @@ -220,19 +233,25 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC return err } - // send message - err = c.WriteMessage(websocket.BinaryMessage, encBytes) + if isLocal { + // write data to websockets + err = c.WriteMessage(websocket.BinaryMessage, encBytes) + } else { + // write data to tcp connection + _, err = tcpConnection.Write(encBytes) + } if err != nil { err = errors.Wrap(err, "problem writing message") return err } - // // wait for ok - // c.ReadMessage() } if err != nil { if err != io.EOF { log.Error(err) } + if !isLocal { + tcpConnection.Write([]byte("end")) + } break } } @@ -271,3 +290,34 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC step++ } } + +func connectToTCPServer(room string, address string) (com comm.Comm, err error) { + connection, err := net.Dial("tcp", address) + if err != nil { + return + } + connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) + connection.SetDeadline(time.Now().Add(3 * time.Hour)) + connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) + + com = comm.New(connection) + ok, err := com.Receive() + if err != nil { + return + } + log.Debugf("server says: %s", ok) + + err = com.Send(room) + if err != nil { + return + } + ok, err = com.Receive() + log.Debugf("server says: %s", ok) + if err != nil { + return + } + if ok != "sender" { + err = errors.New(ok) + } + return +} diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go new file mode 100644 index 0000000..34e568a --- /dev/null +++ b/src/tcp/tcp.go @@ -0,0 +1,167 @@ +package tcp + +import ( + "net" + "sync" + "time" + + log "github.com/cihub/seelog" + "github.com/pkg/errors" + "github.com/schollz/croc/src/comm" + "github.com/schollz/croc/src/logger" + "github.com/schollz/croc/src/models" +) + +type roomInfo struct { + receiver comm.Comm + opened time.Time +} + +type roomMap struct { + rooms map[string]roomInfo + sync.Mutex +} + +var rooms roomMap + +// Run starts a tcp listener, run async +func Run(debugLevel, port string) { + logger.SetLogLevel(debugLevel) + rooms.Lock() + rooms.rooms = make(map[string]roomInfo) + rooms.Unlock() + err := run(port) + if err != nil { + log.Error(err) + } +} + +func run(port string) (err error) { + log.Debugf("starting TCP server on " + port) + server, err := net.Listen("tcp", "0.0.0.0:"+port) + if err != nil { + return errors.Wrap(err, "Error listening on :"+port) + } + defer server.Close() + // spawn a new goroutine whenever a client connects + for { + connection, err := server.Accept() + if err != nil { + return errors.Wrap(err, "problem accepting connection") + } + log.Debugf("client %s connected", connection.RemoteAddr().String()) + go func(port string, connection net.Conn) { + errCommunication := clientCommuncation(port, comm.New(connection)) + if errCommunication != nil { + log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error()) + } + }(port, connection) + } +} + +func clientCommuncation(port string, c comm.Comm) (err error) { + // send ok to tell client they are connected + err = c.Send("ok") + if err != nil { + return + } + + // wait for client to tell me which room they want + room, err := c.Receive() + if err != nil { + return + } + + rooms.Lock() + // first connection is always the receiver + if _, ok := rooms.rooms[room]; !ok { + rooms.rooms[room] = roomInfo{ + receiver: c, + opened: time.Now(), + } + rooms.Unlock() + // tell the client that they got the room + err = c.Send("recipient") + if err != nil { + return + } + return nil + } + receiver := rooms.rooms[room].receiver + rooms.Unlock() + + // second connection is the sender, time to staple connections + var wg sync.WaitGroup + wg.Add(1) + + // start piping + go func(com1, com2 comm.Comm, wg *sync.WaitGroup) { + log.Debug("starting pipes") + pipe(com1.Connection(), com2.Connection()) + wg.Done() + log.Debug("done piping") + }(c, receiver, &wg) + + // tell the sender everything is ready + err = c.Send("sender") + if err != nil { + return + } + wg.Wait() + + // delete room + rooms.Lock() + log.Debugf("deleting room: %s", room) + delete(rooms.rooms, room) + rooms.Unlock() + return nil +} + +// chanFromConn creates a channel from a Conn object, and sends everything it +// Read()s from the socket to the channel. +func chanFromConn(conn net.Conn) chan []byte { + c := make(chan []byte) + + go func() { + b := make([]byte, models.WEBSOCKET_BUFFER_SIZE) + + for { + n, err := conn.Read(b) + if n > 0 { + res := make([]byte, n) + // Copy the buffer so it doesn't get changed while read by the recipient. + copy(res, b[:n]) + c <- res + } + if err != nil { + c <- nil + break + } + } + }() + + return c +} + +// pipe creates a full-duplex pipe between the two sockets and +// transfers data from one to the other. +func pipe(conn1 net.Conn, conn2 net.Conn) { + chan1 := chanFromConn(conn1) + chan2 := chanFromConn(conn2) + + for { + select { + case b1 := <-chan1: + if b1 == nil { + return + } + conn2.Write(b1) + + case b2 := <-chan2: + if b2 == nil { + return + } + conn1.Write(b2) + } + } +} diff --git a/src/utils/hash.go b/src/utils/hash.go index 83d512b..34d1c0d 100644 --- a/src/utils/hash.go +++ b/src/utils/hash.go @@ -2,6 +2,8 @@ package utils import ( "crypto/md5" + "crypto/sha256" + "fmt" "io" "os" ) @@ -22,3 +24,10 @@ func HashFile(fname string) (hash256 []byte, err error) { hash256 = h.Sum(nil) return } + +// SHA256 returns sha256 sum +func SHA256(s string) string { + sha := sha256.New() + sha.Write([]byte(s)) + return fmt.Sprintf("%x", sha.Sum(nil)) +}