mirror of https://github.com/schollz/croc.git
219 lines
4.9 KiB
Go
219 lines
4.9 KiB
Go
package croc
|
|
|
|
import (
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/cihub/seelog"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func (c *Croc) startRelay() {
|
|
ports := c.TcpPorts
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(ports))
|
|
for _, port := range ports {
|
|
go func(port string, wg *sync.WaitGroup) {
|
|
defer wg.Done()
|
|
log.Debugf("listening on port %s", port)
|
|
if err := c.listener(port); err != nil {
|
|
log.Error(err)
|
|
return
|
|
}
|
|
}(port, &wg)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func (c *Croc) listener(port string) (err error) {
|
|
server, err := net.Listen("tcp", "0.0.0.0:"+port)
|
|
if err != nil {
|
|
return errors.Wrap(err, "Error listening on :"+port)
|
|
}
|
|
defer server.Close()
|
|
// spawn a new goroutine whenever a client connects
|
|
for {
|
|
connection, err := server.Accept()
|
|
if err != nil {
|
|
return errors.Wrap(err, "problem accepting connection")
|
|
}
|
|
log.Debugf("client %s connected", connection.RemoteAddr().String())
|
|
go func(port string, connection net.Conn) {
|
|
errCommunication := c.clientCommuncation(port, connection)
|
|
if errCommunication != nil {
|
|
log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error())
|
|
}
|
|
}(port, connection)
|
|
}
|
|
}
|
|
|
|
func (c *Croc) clientCommuncation(port string, connection net.Conn) (err error) {
|
|
var con1, con2 net.Conn
|
|
|
|
// get the channel and UUID from the client
|
|
err = sendMessage("channel and uuid?", connection)
|
|
if err != nil {
|
|
return
|
|
}
|
|
channel, err := receiveMessage(connection)
|
|
if err != nil {
|
|
return
|
|
}
|
|
uuid, err := receiveMessage(connection)
|
|
if err != nil {
|
|
return
|
|
}
|
|
log.Debugf("%s connected to port %s on channel %s and uuid %s", connection.RemoteAddr().String(), port, channel, uuid)
|
|
|
|
// validate channel and UUID
|
|
c.rs.Lock()
|
|
if _, ok := c.rs.channel[channel]; !ok {
|
|
c.rs.Unlock()
|
|
err = errors.Errorf("channel %s does not exist", channel)
|
|
return
|
|
}
|
|
if uuid != c.rs.channel[channel].uuids[0] &&
|
|
uuid != c.rs.channel[channel].uuids[1] {
|
|
c.rs.Unlock()
|
|
err = errors.Errorf("uuid '%s' is invalid", uuid)
|
|
return
|
|
}
|
|
role := 0
|
|
if uuid == c.rs.channel[channel].uuids[1] {
|
|
role = 1
|
|
}
|
|
|
|
if _, ok := c.rs.channel[channel].connection[port]; !ok {
|
|
c.rs.channel[channel].connection[port] = [2]net.Conn{nil, nil}
|
|
}
|
|
con1 = c.rs.channel[channel].connection[port][0]
|
|
con2 = c.rs.channel[channel].connection[port][1]
|
|
if role == 0 {
|
|
con1 = connection
|
|
} else {
|
|
con2 = connection
|
|
}
|
|
log.Debug(c.rs.channel[channel].connection[port])
|
|
c.rs.channel[channel].connection[port] = [2]net.Conn{con1, con2}
|
|
ports := c.rs.channel[channel].Ports
|
|
c.rs.Unlock()
|
|
|
|
if con1 != nil && con2 != nil {
|
|
log.Debugf("beginning the piping")
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
|
|
// start piping
|
|
go func(con1 net.Conn, con2 net.Conn, wg *sync.WaitGroup) {
|
|
pipe(con1, con2)
|
|
wg.Done()
|
|
log.Debug("done piping")
|
|
}(con1, con2, &wg)
|
|
|
|
if port == ports[0] {
|
|
// then set transfer ready
|
|
c.rs.Lock()
|
|
c.rs.channel[channel].TransferReady = true
|
|
c.rs.channel[channel].websocketConn[0].WriteJSON(c.rs.channel[channel])
|
|
c.rs.channel[channel].websocketConn[1].WriteJSON(c.rs.channel[channel])
|
|
c.rs.Unlock()
|
|
log.Debugf("sent ready signal")
|
|
}
|
|
wg.Wait()
|
|
log.Debugf("finished transfer")
|
|
}
|
|
log.Debug("finished client communication")
|
|
return
|
|
}
|
|
|
|
func sendMessage(message string, connection net.Conn) (err error) {
|
|
message = fillString(message, bufferSize)
|
|
_, err = connection.Write([]byte(message))
|
|
return
|
|
}
|
|
|
|
func receiveMessage(connection net.Conn) (s string, err error) {
|
|
messageByte := make([]byte, bufferSize)
|
|
err = connection.SetReadDeadline(time.Now().Add(60 * time.Minute))
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = connection.SetDeadline(time.Now().Add(60 * time.Minute))
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = connection.SetWriteDeadline(time.Now().Add(60 * time.Minute))
|
|
if err != nil {
|
|
return
|
|
}
|
|
_, err = connection.Read(messageByte)
|
|
if err != nil {
|
|
return
|
|
}
|
|
s = strings.TrimRight(string(messageByte), ":")
|
|
return
|
|
}
|
|
|
|
func fillString(returnString string, toLength int) string {
|
|
for {
|
|
lengthString := len(returnString)
|
|
if lengthString < toLength {
|
|
returnString = returnString + ":"
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
return returnString
|
|
}
|
|
|
|
// chanFromConn creates a channel from a Conn object, and sends everything it
|
|
// Read()s from the socket to the channel.
|
|
func chanFromConn(conn net.Conn) chan []byte {
|
|
c := make(chan []byte)
|
|
|
|
go func() {
|
|
b := make([]byte, bufferSize)
|
|
|
|
for {
|
|
n, err := conn.Read(b)
|
|
if n > 0 {
|
|
res := make([]byte, n)
|
|
// Copy the buffer so it doesn't get changed while read by the recipient.
|
|
copy(res, b[:n])
|
|
c <- res
|
|
}
|
|
if err != nil {
|
|
c <- nil
|
|
break
|
|
}
|
|
}
|
|
}()
|
|
|
|
return c
|
|
}
|
|
|
|
// pipe creates a full-duplex pipe between the two sockets and
|
|
// transfers data from one to the other.
|
|
func pipe(conn1 net.Conn, conn2 net.Conn) {
|
|
chan1 := chanFromConn(conn1)
|
|
chan2 := chanFromConn(conn2)
|
|
|
|
for {
|
|
select {
|
|
case b1 := <-chan1:
|
|
if b1 == nil {
|
|
return
|
|
}
|
|
conn2.Write(b1)
|
|
|
|
case b2 := <-chan2:
|
|
if b2 == nil {
|
|
return
|
|
}
|
|
conn1.Write(b2)
|
|
}
|
|
}
|
|
}
|