allow TCP connections as alternative

This commit is contained in:
Zack Scholl 2018-09-23 12:34:29 -07:00
parent 177612f112
commit 384de31c5a
10 changed files with 409 additions and 30 deletions

12
main.go
View File

@ -61,13 +61,16 @@ func main() {
}, },
} }
app.Flags = []cli.Flag{ app.Flags = []cli.Flag{
cli.StringFlag{Name: "relay", Value: "ws://198.199.67.130:8153"}, cli.StringFlag{Name: "addr", Value: "198.199.67.130", Usage: "address of the public relay"},
cli.StringFlag{Name: "addr-ws", Value: "8153", Usage: "port of the public relay websocket server to connect"},
cli.StringFlag{Name: "addr-tcp", Value: "8154", Usage: "tcp port of the public relay serer to connect"},
cli.BoolFlag{Name: "no-local", Usage: "disable local mode"}, cli.BoolFlag{Name: "no-local", Usage: "disable local mode"},
cli.BoolFlag{Name: "local", Usage: "use only local mode"}, cli.BoolFlag{Name: "local", Usage: "use only local mode"},
cli.BoolFlag{Name: "debug", Usage: "increase verbosity (a lot)"}, cli.BoolFlag{Name: "debug", Usage: "increase verbosity (a lot)"},
cli.BoolFlag{Name: "yes", Usage: "automatically agree to all prompts"}, cli.BoolFlag{Name: "yes", Usage: "automatically agree to all prompts"},
cli.BoolFlag{Name: "stdout", Usage: "redirect file to stdout"}, cli.BoolFlag{Name: "stdout", Usage: "redirect file to stdout"},
cli.StringFlag{Name: "port", Value: "8153", Usage: "port that the websocket listens on"}, cli.StringFlag{Name: "port", Value: "8153", Usage: "port that the websocket listens on"},
cli.StringFlag{Name: "tcp-port", Value: "8154", Usage: "port that the tcp server listens on"},
cli.StringFlag{Name: "curve", Value: "siec", Usage: "specify elliptic curve to use (p224, p256, p384, p521, siec)"}, cli.StringFlag{Name: "curve", Value: "siec", Usage: "specify elliptic curve to use (p224, p256, p384, p521, siec)"},
} }
app.EnableBashCompletion = true app.EnableBashCompletion = true
@ -82,13 +85,16 @@ func main() {
app.Before = func(c *cli.Context) error { app.Before = func(c *cli.Context) error {
cr = croc.Init(c.GlobalBool("debug")) cr = croc.Init(c.GlobalBool("debug"))
cr.AllowLocalDiscovery = true cr.AllowLocalDiscovery = true
cr.WebsocketAddress = c.GlobalString("relay") cr.Address = c.GlobalString("addr")
cr.AddressTCPPort = c.GlobalString("addr-tcp")
cr.AddressWebsocketPort = c.GlobalString("addr-ws")
cr.NoRecipientPrompt = c.GlobalBool("yes") cr.NoRecipientPrompt = c.GlobalBool("yes")
cr.Stdout = c.GlobalBool("stdout") cr.Stdout = c.GlobalBool("stdout")
cr.LocalOnly = c.GlobalBool("local") cr.LocalOnly = c.GlobalBool("local")
cr.NoLocal = c.GlobalBool("no-local") cr.NoLocal = c.GlobalBool("no-local")
cr.ShowText = true cr.ShowText = true
cr.ServerPort = c.String("port") cr.RelayWebsocketPort = c.String("port")
cr.RelayTCPPort = c.String("tcp-port")
cr.CurveType = c.String("curve") cr.CurveType = c.String("curve")
return nil return nil
} }

77
src/comm/comm.go Normal file
View File

@ -0,0 +1,77 @@
package comm
import (
"net"
"strings"
"time"
"github.com/schollz/croc/src/models"
)
// Comm is some basic TCP communication
type Comm struct {
connection net.Conn
}
// New returns a new comm
func New(c net.Conn) Comm {
return Comm{c}
}
// Connection returns the net.Conn connection
func (c Comm) Connection() net.Conn {
return c.connection
}
func (c Comm) Write(b []byte) (int, error) {
return c.connection.Write(b)
}
func (c Comm) Read() (buf []byte, err error) {
buf = make([]byte, models.WEBSOCKET_BUFFER_SIZE)
n, err := c.connection.Read(buf)
buf = buf[:n]
return
}
// Send a message
func (c Comm) Send(message string) (err error) {
message = fillString(message, models.TCP_BUFFER_SIZE)
_, err = c.connection.Write([]byte(message))
return
}
// Receive a message
func (c Comm) Receive() (s string, err error) {
messageByte := make([]byte, models.TCP_BUFFER_SIZE)
err = c.connection.SetReadDeadline(time.Now().Add(60 * time.Minute))
if err != nil {
return
}
err = c.connection.SetDeadline(time.Now().Add(60 * time.Minute))
if err != nil {
return
}
err = c.connection.SetWriteDeadline(time.Now().Add(60 * time.Minute))
if err != nil {
return
}
_, err = c.connection.Read(messageByte)
if err != nil {
return
}
s = strings.TrimRight(string(messageByte), ":")
return
}
func fillString(returnString string, toLength int) string {
for {
lengthString := len(returnString)
if lengthString < toLength {
returnString = returnString + ":"
continue
}
break
}
return returnString
}

View File

@ -18,14 +18,17 @@ type Croc struct {
ShowText bool ShowText bool
// Options for relay // Options for relay
ServerPort string RelayWebsocketPort string
CurveType string RelayTCPPort string
CurveType string
// Options for connecting to server // Options for connecting to server
WebsocketAddress string Address string
Timeout time.Duration AddressTCPPort string
LocalOnly bool AddressWebsocketPort string
NoLocal bool Timeout time.Duration
LocalOnly bool
NoLocal bool
// Options for file transfering // Options for file transfering
UseEncryption bool UseEncryption bool
@ -48,7 +51,6 @@ type Croc struct {
// Init will initiate with the default parameters // Init will initiate with the default parameters
func Init(debug bool) (c *Croc) { func Init(debug bool) (c *Croc) {
c = new(Croc) c = new(Croc)
c.ServerPort = "8152"
c.CurveType = "siec" c.CurveType = "siec"
c.UseCompression = true c.UseCompression = true
c.UseEncryption = true c.UseEncryption = true

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strings"
"time" "time"
log "github.com/cihub/seelog" log "github.com/cihub/seelog"
@ -28,7 +29,7 @@ func (c *Croc) Send(fname, codephrase string) (err error) {
if !c.LocalOnly { if !c.LocalOnly {
go func() { go func() {
// atttempt to connect to public relay // atttempt to connect to public relay
errChan <- c.sendReceive(c.WebsocketAddress, fname, codephrase, true, false) errChan <- c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPort, fname, codephrase, true, false)
}() }()
} else { } else {
waitingFor = 1 waitingFor = 1
@ -38,7 +39,7 @@ func (c *Croc) Send(fname, codephrase string) (err error) {
if !c.NoLocal { if !c.NoLocal {
go func() { go func() {
// start own relay and connect to it // start own relay and connect to it
go relay.Run(c.ServerPort) go relay.Run(c.RelayWebsocketPort, "")
time.Sleep(250 * time.Millisecond) // race condition here, but this should work most of the time :( time.Sleep(250 * time.Millisecond) // race condition here, but this should work most of the time :(
// broadcast for peer discovery // broadcast for peer discovery
@ -48,12 +49,12 @@ func (c *Croc) Send(fname, codephrase string) (err error) {
Limit: 1, Limit: 1,
TimeLimit: 600 * time.Second, TimeLimit: 600 * time.Second,
Delay: 50 * time.Millisecond, Delay: 50 * time.Millisecond,
Payload: []byte(c.ServerPort), Payload: []byte(c.RelayWebsocketPort + "-" + c.RelayTCPPort),
}) })
}() }()
// connect to own relay // connect to own relay
errChan <- c.sendReceive("ws://localhost:"+c.ServerPort, fname, codephrase, true, true) errChan <- c.sendReceive("localhost", c.RelayWebsocketPort, c.RelayTCPPort, fname, codephrase, true, true)
}() }()
} else { } else {
waitingFor = 1 waitingFor = 1
@ -95,7 +96,11 @@ func (c *Croc) Receive(codephrase string) (err error) {
if err == nil { if err == nil {
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {
// we connected, so use this // we connected, so use this
return c.sendReceive(fmt.Sprintf("ws://%s:%s", discovered[0].Address, discovered[0].Payload), "", codephrase, false, true) ports := strings.Split(string(discovered[0].Payload), "-")
if len(ports) != 2 {
return errors.New("bad payload")
}
return c.sendReceive(discovered[0].Address, ports[0], ports[1], "", codephrase, false, true)
} }
} else { } else {
log.Debugf("could not connect: %s", err.Error()) log.Debugf("could not connect: %s", err.Error())
@ -108,13 +113,13 @@ func (c *Croc) Receive(codephrase string) (err error) {
// use public relay // use public relay
if !c.LocalOnly { if !c.LocalOnly {
log.Debug("using public relay") log.Debug("using public relay")
return c.sendReceive(c.WebsocketAddress, "", codephrase, false, false) return c.sendReceive(c.Address, c.AddressWebsocketPort, c.AddressTCPPort, "", codephrase, false, false)
} }
return errors.New("must use local or public relay") return errors.New("must use local or public relay")
} }
func (c *Croc) sendReceive(websocketAddress, fname, codephrase string, isSender bool, isLocal bool) (err error) { func (c *Croc) sendReceive(address, websocketPort, tcpPort, fname, codephrase string, isSender bool, isLocal bool) (err error) {
defer log.Flush() defer log.Flush()
if len(codephrase) < 4 { if len(codephrase) < 4 {
return fmt.Errorf("codephrase is too short") return fmt.Errorf("codephrase is too short")
@ -126,8 +131,11 @@ func (c *Croc) sendReceive(websocketAddress, fname, codephrase string, isSender
done := make(chan struct{}) done := make(chan struct{})
// connect to server // connect to server
log.Debugf("connecting to %s", websocketAddress+"/ws?room="+codephrase[:3]) log.Debugf("connecting to %s", address+"/ws?room="+codephrase[:3])
sock, _, err := websocket.DefaultDialer.Dial(websocketAddress+"/ws?room="+codephrase[:3], nil) if len(websocketPort) > 0 {
address += ":" + websocketPort
}
sock, _, err := websocket.DefaultDialer.Dial("ws://"+address+"/ws?room="+codephrase[:3], nil)
if err != nil { if err != nil {
return return
} }
@ -176,5 +184,5 @@ func (c *Croc) sendReceive(websocketAddress, fname, codephrase string, isSender
// Relay will start a relay on the specified port // Relay will start a relay on the specified port
func (c *Croc) Relay() (err error) { func (c *Croc) Relay() (err error) {
return relay.Run(c.ServerPort) return relay.Run(c.RelayWebsocketPort, c.RelayTCPPort)
} }

View File

@ -1,3 +1,4 @@
package models package models
const WEBSOCKET_BUFFER_SIZE = 1024 * 1024 * 32 const WEBSOCKET_BUFFER_SIZE = 1024 * 1024 * 32
const TCP_BUFFER_SIZE = 1024

View File

@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"os" "os"
"strings" "strings"
"time" "time"
@ -14,6 +15,7 @@ import (
log "github.com/cihub/seelog" log "github.com/cihub/seelog"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/schollz/croc/src/comm"
"github.com/schollz/croc/src/compress" "github.com/schollz/croc/src/compress"
"github.com/schollz/croc/src/crypt" "github.com/schollz/croc/src/crypt"
"github.com/schollz/croc/src/logger" "github.com/schollz/croc/src/logger"
@ -46,6 +48,7 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool,
var transferTime time.Duration var transferTime time.Duration
var hash256 []byte var hash256 []byte
var otherIP string var otherIP string
var tcpConnection comm.Comm
// start a spinner // start a spinner
spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond) spin := spinner.New(spinner.CharSets[9], 100*time.Millisecond)
@ -156,6 +159,15 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool,
} }
} }
// connect to TCP to receive file
if !isLocal {
tcpConnection, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), "localhost:8154")
if err != nil {
log.Error(err)
return err
}
}
// await file // await file
f, err := os.Create(fstats.SentName) f, err := os.Create(fstats.SentName)
if err != nil { if err != nil {
@ -173,12 +185,23 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool,
c.WriteMessage(websocket.BinaryMessage, []byte("ready")) c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
startTime := time.Now() startTime := time.Now()
for { for {
messageType, message, err := c.ReadMessage() if isLocal {
if err != nil { var messageType int
return err // read from websockets
messageType, message, err = c.ReadMessage()
if messageType != websocket.BinaryMessage {
continue
}
} else {
// read from TCP connection
message, err = tcpConnection.Read()
if bytes.Equal(message, []byte("end")) {
break
}
} }
if messageType != websocket.BinaryMessage { if err != nil {
continue log.Error(err)
return err
} }
// // tell the sender that we recieved this packet // // tell the sender that we recieved this packet
@ -289,3 +312,34 @@ func receive(isLocal bool, c *websocket.Conn, codephrase string, noPrompt bool,
step++ step++
} }
} }
func connectToTCPServer(room string, address string) (com comm.Comm, err error) {
connection, err := net.Dial("tcp", address)
if err != nil {
return
}
connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
connection.SetDeadline(time.Now().Add(3 * time.Hour))
connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
com = comm.New(connection)
ok, err := com.Receive()
if err != nil {
return
}
log.Debugf("server says: %s", ok)
err = com.Send(room)
if err != nil {
return
}
ok, err = com.Receive()
log.Debugf("server says: %s", ok)
if err != nil {
return
}
if ok != "recipient" {
err = errors.New(ok)
}
return
}

View File

@ -6,14 +6,19 @@ import (
log "github.com/cihub/seelog" log "github.com/cihub/seelog"
"github.com/schollz/croc/src/logger" "github.com/schollz/croc/src/logger"
"github.com/schollz/croc/src/tcp"
) )
var DebugLevel string var DebugLevel string
// Run is the async operation for running a server // Run is the async operation for running a server
func Run(port string) (err error) { func Run(port string, tcpPort string) (err error) {
logger.SetLogLevel(DebugLevel) logger.SetLogLevel(DebugLevel)
if tcpPort != "" {
go tcp.Run(DebugLevel, tcpPort)
}
go h.run() go h.run()
log.Debug("running relay on " + port) log.Debug("running relay on " + port)
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -13,6 +14,7 @@ import (
log "github.com/cihub/seelog" log "github.com/cihub/seelog"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/schollz/croc/src/comm"
"github.com/schollz/croc/src/compress" "github.com/schollz/croc/src/compress"
"github.com/schollz/croc/src/crypt" "github.com/schollz/croc/src/crypt"
"github.com/schollz/croc/src/logger" "github.com/schollz/croc/src/logger"
@ -48,6 +50,8 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC
var fileHash []byte var fileHash []byte
var otherIP string var otherIP string
var startTransfer time.Time var startTransfer time.Time
var tcpConnection comm.Comm
fileReady := make(chan error) fileReady := make(chan error)
// normalize the file name // normalize the file name
@ -191,6 +195,15 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC
return errors.New("recipient refused file") return errors.New("recipient refused file")
} }
if !isLocal {
// connection to TCP
tcpConnection, err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%x", sessionKey)), "localhost:8154")
if err != nil {
log.Error(err)
return
}
}
fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP) fmt.Fprintf(os.Stderr, "\rSending (->%s)...\n", otherIP)
// send file, compure hash simultaneously // send file, compure hash simultaneously
startTransfer = time.Now() startTransfer = time.Now()
@ -220,19 +233,25 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC
return err return err
} }
// send message if isLocal {
err = c.WriteMessage(websocket.BinaryMessage, encBytes) // write data to websockets
err = c.WriteMessage(websocket.BinaryMessage, encBytes)
} else {
// write data to tcp connection
_, err = tcpConnection.Write(encBytes)
}
if err != nil { if err != nil {
err = errors.Wrap(err, "problem writing message") err = errors.Wrap(err, "problem writing message")
return err return err
} }
// // wait for ok
// c.ReadMessage()
} }
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Error(err) log.Error(err)
} }
if !isLocal {
tcpConnection.Write([]byte("end"))
}
break break
} }
} }
@ -271,3 +290,34 @@ func send(isLocal bool, c *websocket.Conn, fname string, codephrase string, useC
step++ step++
} }
} }
func connectToTCPServer(room string, address string) (com comm.Comm, err error) {
connection, err := net.Dial("tcp", address)
if err != nil {
return
}
connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
connection.SetDeadline(time.Now().Add(3 * time.Hour))
connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
com = comm.New(connection)
ok, err := com.Receive()
if err != nil {
return
}
log.Debugf("server says: %s", ok)
err = com.Send(room)
if err != nil {
return
}
ok, err = com.Receive()
log.Debugf("server says: %s", ok)
if err != nil {
return
}
if ok != "sender" {
err = errors.New(ok)
}
return
}

167
src/tcp/tcp.go Normal file
View File

@ -0,0 +1,167 @@
package tcp
import (
"net"
"sync"
"time"
log "github.com/cihub/seelog"
"github.com/pkg/errors"
"github.com/schollz/croc/src/comm"
"github.com/schollz/croc/src/logger"
"github.com/schollz/croc/src/models"
)
type roomInfo struct {
receiver comm.Comm
opened time.Time
}
type roomMap struct {
rooms map[string]roomInfo
sync.Mutex
}
var rooms roomMap
// Run starts a tcp listener, run async
func Run(debugLevel, port string) {
logger.SetLogLevel(debugLevel)
rooms.Lock()
rooms.rooms = make(map[string]roomInfo)
rooms.Unlock()
err := run(port)
if err != nil {
log.Error(err)
}
}
func run(port string) (err error) {
log.Debugf("starting TCP server on " + port)
server, err := net.Listen("tcp", "0.0.0.0:"+port)
if err != nil {
return errors.Wrap(err, "Error listening on :"+port)
}
defer server.Close()
// spawn a new goroutine whenever a client connects
for {
connection, err := server.Accept()
if err != nil {
return errors.Wrap(err, "problem accepting connection")
}
log.Debugf("client %s connected", connection.RemoteAddr().String())
go func(port string, connection net.Conn) {
errCommunication := clientCommuncation(port, comm.New(connection))
if errCommunication != nil {
log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error())
}
}(port, connection)
}
}
func clientCommuncation(port string, c comm.Comm) (err error) {
// send ok to tell client they are connected
err = c.Send("ok")
if err != nil {
return
}
// wait for client to tell me which room they want
room, err := c.Receive()
if err != nil {
return
}
rooms.Lock()
// first connection is always the receiver
if _, ok := rooms.rooms[room]; !ok {
rooms.rooms[room] = roomInfo{
receiver: c,
opened: time.Now(),
}
rooms.Unlock()
// tell the client that they got the room
err = c.Send("recipient")
if err != nil {
return
}
return nil
}
receiver := rooms.rooms[room].receiver
rooms.Unlock()
// second connection is the sender, time to staple connections
var wg sync.WaitGroup
wg.Add(1)
// start piping
go func(com1, com2 comm.Comm, wg *sync.WaitGroup) {
log.Debug("starting pipes")
pipe(com1.Connection(), com2.Connection())
wg.Done()
log.Debug("done piping")
}(c, receiver, &wg)
// tell the sender everything is ready
err = c.Send("sender")
if err != nil {
return
}
wg.Wait()
// delete room
rooms.Lock()
log.Debugf("deleting room: %s", room)
delete(rooms.rooms, room)
rooms.Unlock()
return nil
}
// chanFromConn creates a channel from a Conn object, and sends everything it
// Read()s from the socket to the channel.
func chanFromConn(conn net.Conn) chan []byte {
c := make(chan []byte)
go func() {
b := make([]byte, models.WEBSOCKET_BUFFER_SIZE)
for {
n, err := conn.Read(b)
if n > 0 {
res := make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient.
copy(res, b[:n])
c <- res
}
if err != nil {
c <- nil
break
}
}
}()
return c
}
// pipe creates a full-duplex pipe between the two sockets and
// transfers data from one to the other.
func pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1)
chan2 := chanFromConn(conn2)
for {
select {
case b1 := <-chan1:
if b1 == nil {
return
}
conn2.Write(b1)
case b2 := <-chan2:
if b2 == nil {
return
}
conn1.Write(b2)
}
}
}

View File

@ -2,6 +2,8 @@ package utils
import ( import (
"crypto/md5" "crypto/md5"
"crypto/sha256"
"fmt"
"io" "io"
"os" "os"
) )
@ -22,3 +24,10 @@ func HashFile(fname string) (hash256 []byte, err error) {
hash256 = h.Sum(nil) hash256 = h.Sum(nil)
return return
} }
// SHA256 returns sha256 sum
func SHA256(s string) string {
sha := sha256.New()
sha.Write([]byte(s))
return fmt.Sprintf("%x", sha.Sum(nil))
}