fixed spelling connectionMap.receiver
Added method to detect if sender and receivers are already connected.
Added client code to correctly action "no" returned by the code being in use.
This commit is contained in:
lummie 2017-10-20 21:51:03 +01:00
parent 17a1f097c3
commit e2faa87b59
2 changed files with 720 additions and 673 deletions

View File

@ -1,425 +1,440 @@
package main
import (
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gosuri/uiprogress"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
type Connection struct {
Server string
File FileMetaData
NumberOfConnections int
Code string
HashedCode string
IsSender bool
Debug bool
DontEncrypt bool
bars []*uiprogress.Bar
rate int
}
type FileMetaData struct {
Name string
Size int
Hash string
}
func NewConnection(flags *Flags) *Connection {
c := new(Connection)
c.Debug = flags.Debug
c.DontEncrypt = flags.DontEncrypt
c.Server = flags.Server
c.Code = flags.Code
c.NumberOfConnections = flags.NumberOfConnections
c.rate = flags.Rate
if len(flags.File) > 0 {
c.File.Name = flags.File
c.IsSender = true
} else {
c.IsSender = false
}
log.SetFormatter(&log.TextFormatter{})
if c.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
return c
}
func (c *Connection) Run() error {
forceSingleThreaded := false
if c.IsSender {
fsize, err := FileSize(c.File.Name)
if err != nil {
return err
}
if fsize < MAX_NUMBER_THREADS*BUFFERSIZE {
forceSingleThreaded = true
log.Debug("forcing single thread")
}
}
log.Debug("checking code validity")
for {
// check code
goodCode := true
m := strings.Split(c.Code, "-")
numThreads, errParse := strconv.Atoi(m[0])
if len(m) < 2 {
goodCode = false
} else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) {
c.NumberOfConnections = MAX_NUMBER_THREADS
goodCode = false
} else if errParse != nil {
goodCode = false
}
log.Debug(m)
if !goodCode {
if c.IsSender {
if forceSingleThreaded {
c.NumberOfConnections = 1
}
c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName()
} else {
if len(c.Code) != 0 {
fmt.Println("Code must begin with number of threads (e.g. 3-some-code)")
}
c.Code = getInput("Enter receive code: ")
}
} else {
break
}
}
// assign number of connections
c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0])
if c.IsSender {
if c.DontEncrypt {
// don't encrypt
CopyFile(c.File.Name, c.File.Name+".enc")
} else {
// encrypt
log.Debug("encrypting...")
if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil {
return err
}
}
// get file hash
var err error
c.File.Hash, err = HashFile(c.File.Name)
if err != nil {
return err
}
// get file size
c.File.Size, err = FileSize(c.File.Name + ".enc")
if err != nil {
return err
}
}
return c.runClient()
}
// runClient spawns threads for parallel uplink/downlink via TCP
func (c *Connection) runClient() error {
logger := log.WithFields(log.Fields{
"code": c.Code,
"sender?": c.IsSender,
})
c.HashedCode = Hash(c.Code)
var wg sync.WaitGroup
wg.Add(c.NumberOfConnections)
uiprogress.Start()
if !c.Debug {
c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
}
gotOK := false
gotResponse := false
for id := 0; id < c.NumberOfConnections; id++ {
go func(id int) {
defer wg.Done()
port := strconv.Itoa(27001 + id)
connection, err := net.Dial("tcp", c.Server+":"+port)
if err != nil {
panic(err)
}
defer connection.Close()
message := receiveMessage(connection)
logger.Debugf("relay says: %s", message)
if c.IsSender {
logger.Debugf("telling relay: %s", "s."+c.Code)
metaData, err := json.Marshal(c.File)
if err != nil {
log.Error(err)
}
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
} else {
logger.Debugf("telling relay: %s", "r."+c.Code)
sendMessage("r."+c.HashedCode+".0.0.0", connection)
}
if c.IsSender { // this is a sender
logger.Debug("waiting for ok from relay")
message = receiveMessage(connection)
logger.Debug("got ok from relay")
if id == 0 {
fmt.Printf("\nSending (->%s)..\n", message)
}
// wait for pipe to be made
time.Sleep(100 * time.Millisecond)
// Write data from file
logger.Debug("send file")
c.sendFile(id, connection)
} else { // this is a receiver
logger.Debug("waiting for meta data from sender")
message = receiveMessage(connection)
m := strings.Split(message, "-")
encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3]
encryptedBytes, err := hex.DecodeString(encryptedData)
if err != nil {
log.Error(err)
return
}
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
err = json.Unmarshal(decryptedBytes, &c.File)
if err != nil {
log.Error(err)
return
}
log.Debugf("meta data received: %v", c.File)
// have the main thread ask for the okay
if id == 0 {
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
var sentFileNames []string
if fileAlreadyExists(sentFileNames, c.File.Name) {
fmt.Printf("Will not overwrite file!")
os.Exit(1)
}
getOK := getInput("ok? (y/n): ")
if getOK == "y" {
gotOK = true
sentFileNames = append(sentFileNames, c.File.Name)
}
gotResponse = true
}
// wait for the main thread to get the okay
for limit := 0; limit < 1000; limit++ {
if gotResponse {
break
}
time.Sleep(10 * time.Millisecond)
}
if !gotOK {
sendMessage("not ok", connection)
} else {
sendMessage("ok", connection)
logger.Debug("receive file")
fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress)
c.receiveFile(id, connection)
}
}
}(id)
}
wg.Wait()
if !c.IsSender {
if !gotOK {
return errors.New("Transfer interrupted")
}
c.catFile(c.File.Name)
log.Debugf("Code: [%s]", c.Code)
if c.DontEncrypt {
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil {
return err
}
} else {
if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
return errors.Wrap(err, "Problem decrypting file")
}
}
if !c.Debug {
os.Remove(c.File.Name + ".enc")
}
fileHash, err := HashFile(c.File.Name)
if err != nil {
log.Error(err)
}
log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash)
log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
if c.File.Hash != fileHash {
return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
} else {
fmt.Printf("\nReceived file written to %s", c.File.Name)
}
} else {
fmt.Println("File sent.")
// TODO: Add confirmation
}
return nil
}
func fileAlreadyExists(s []string, f string) bool {
for _, a := range s {
if a == f {
return true
}
}
return false
}
func (c *Connection) catFile(fname string) {
// cat the file
os.Remove(fname)
finished, err := os.Create(fname + ".enc")
defer finished.Close()
if err != nil {
log.Fatal(err)
}
for id := 0; id < c.NumberOfConnections; id++ {
fh, err := os.Open(fname + "." + strconv.Itoa(id))
if err != nil {
log.Fatal(err)
}
_, err = io.Copy(finished, fh)
if err != nil {
log.Fatal(err)
}
fh.Close()
os.Remove(fname + "." + strconv.Itoa(id))
}
}
func (c *Connection) receiveFile(id int, connection net.Conn) error {
logger := log.WithFields(log.Fields{
"function": "receiveFile #" + strconv.Itoa(id),
})
logger.Debug("waiting for chunk size from sender")
fileSizeBuffer := make([]byte, 10)
connection.Read(fileSizeBuffer)
fileDataString := strings.Trim(string(fileSizeBuffer), ":")
fileSizeInt, _ := strconv.Atoi(fileDataString)
chunkSize := int64(fileSizeInt)
logger.Debugf("chunk size: %d", chunkSize)
os.Remove(c.File.Name + "." + strconv.Itoa(id))
newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
if err != nil {
panic(err)
}
defer newFile.Close()
if !c.Debug {
c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
}
logger.Debug("waiting for file")
var receivedBytes int64
for {
if !c.Debug {
c.bars[id].Incr()
}
if (chunkSize - receivedBytes) < BUFFERSIZE {
logger.Debug("at the end")
io.CopyN(newFile, connection, (chunkSize - receivedBytes))
// Empty the remaining bytes that we don't need from the network buffer
if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
logger.Debug("empty remaining bytes from network buffer")
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
}
break
}
io.CopyN(newFile, connection, BUFFERSIZE)
receivedBytes += BUFFERSIZE
}
logger.Debug("received file")
return nil
}
func (c *Connection) sendFile(id int, connection net.Conn) {
logger := log.WithFields(log.Fields{
"function": "sendFile #" + strconv.Itoa(id),
})
defer connection.Close()
var err error
numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
chunkSize := int64(chunksPerWorker * BUFFERSIZE)
if id+1 == c.NumberOfConnections {
chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
}
if id == 0 || id == c.NumberOfConnections-1 {
logger.Debugf("numChunks: %v", numChunks)
logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
}
logger.Debugf("sending chunk size: %d", chunkSize)
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
sendBuffer := make([]byte, BUFFERSIZE)
// open encrypted file
file, err := os.Open(c.File.Name + ".enc")
if err != nil {
log.Error(err)
return
}
defer file.Close()
chunkI := 0
if !c.Debug {
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
}
bufferSizeInKilobytes := BUFFERSIZE / 1024
rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
throttle := time.NewTicker(time.Second / time.Duration(rate))
defer throttle.Stop()
for range throttle.C {
_, err = file.Read(sendBuffer)
if err == io.EOF {
//End of file reached, break out of for loop
logger.Debug("EOF")
break
}
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
connection.Write(sendBuffer)
if !c.Debug {
c.bars[id].Incr()
}
}
chunkI++
}
logger.Debug("file is sent")
return
}
package main
import (
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gosuri/uiprogress"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
type Connection struct {
Server string
File FileMetaData
NumberOfConnections int
Code string
HashedCode string
IsSender bool
Debug bool
DontEncrypt bool
bars []*uiprogress.Bar
rate int
}
type FileMetaData struct {
Name string
Size int
Hash string
}
func NewConnection(flags *Flags) *Connection {
c := new(Connection)
c.Debug = flags.Debug
c.DontEncrypt = flags.DontEncrypt
c.Server = flags.Server
c.Code = flags.Code
c.NumberOfConnections = flags.NumberOfConnections
c.rate = flags.Rate
if len(flags.File) > 0 {
c.File.Name = flags.File
c.IsSender = true
} else {
c.IsSender = false
}
log.SetFormatter(&log.TextFormatter{})
if c.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
return c
}
func (c *Connection) Run() error {
forceSingleThreaded := false
if c.IsSender {
fsize, err := FileSize(c.File.Name)
if err != nil {
return err
}
if fsize < MAX_NUMBER_THREADS*BUFFERSIZE {
forceSingleThreaded = true
log.Debug("forcing single thread")
}
}
log.Debug("checking code validity")
for {
// check code
goodCode := true
m := strings.Split(c.Code, "-")
numThreads, errParse := strconv.Atoi(m[0])
if len(m) < 2 {
goodCode = false
} else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) {
c.NumberOfConnections = MAX_NUMBER_THREADS
goodCode = false
} else if errParse != nil {
goodCode = false
}
log.Debug(m)
if !goodCode {
if c.IsSender {
if forceSingleThreaded {
c.NumberOfConnections = 1
}
c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName()
} else {
if len(c.Code) != 0 {
fmt.Println("Code must begin with number of threads (e.g. 3-some-code)")
}
c.Code = getInput("Enter receive code: ")
}
} else {
break
}
}
// assign number of connections
c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0])
if c.IsSender {
if c.DontEncrypt {
// don't encrypt
CopyFile(c.File.Name, c.File.Name+".enc")
} else {
// encrypt
log.Debug("encrypting...")
if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil {
return err
}
}
// get file hash
var err error
c.File.Hash, err = HashFile(c.File.Name)
if err != nil {
return err
}
// get file size
c.File.Size, err = FileSize(c.File.Name + ".enc")
if err != nil {
return err
}
}
return c.runClient()
}
// runClient spawns threads for parallel uplink/downlink via TCP
func (c *Connection) runClient() error {
logger := log.WithFields(log.Fields{
"code": c.Code,
"sender?": c.IsSender,
})
c.HashedCode = Hash(c.Code)
var wg sync.WaitGroup
wg.Add(c.NumberOfConnections)
uiprogress.Start()
if !c.Debug {
c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
}
gotOK := false
gotResponse := false
gotConnectionInUse := false
for id := 0; id < c.NumberOfConnections; id++ {
go func(id int) {
defer wg.Done()
port := strconv.Itoa(27001 + id)
connection, err := net.Dial("tcp", c.Server+":"+port)
if err != nil {
panic(err)
}
defer connection.Close()
message := receiveMessage(connection)
logger.Debugf("relay says: %s", message)
if c.IsSender {
logger.Debugf("telling relay: %s", "s."+c.Code)
metaData, err := json.Marshal(c.File)
if err != nil {
log.Error(err)
}
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
} else {
logger.Debugf("telling relay: %s", "r."+c.Code)
sendMessage("r."+c.HashedCode+".0.0.0", connection)
}
if c.IsSender { // this is a sender
logger.Debug("waiting for ok from relay")
message = receiveMessage(connection)
if message == "no" {
fmt.Println("The specifed code is already in use by a sender.")
gotConnectionInUse = true
} else {
logger.Debug("got ok from relay")
if id == 0 {
fmt.Printf("\nSending (->%s)..\n", message)
}
// wait for pipe to be made
time.Sleep(100 * time.Millisecond)
// Write data from file
logger.Debug("send file")
c.sendFile(id, connection)
fmt.Println("File sent.")
}
} else { // this is a receiver
logger.Debug("waiting for meta data from sender")
message = receiveMessage(connection)
if message == "no" {
fmt.Println("The specifed code is already in use by a receiver.")
gotConnectionInUse = true
} else {
m := strings.Split(message, "-")
encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3]
encryptedBytes, err := hex.DecodeString(encryptedData)
if err != nil {
log.Error(err)
return
}
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
err = json.Unmarshal(decryptedBytes, &c.File)
if err != nil {
log.Error(err)
return
}
log.Debugf("meta data received: %v", c.File)
// have the main thread ask for the okay
if id == 0 {
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
var sentFileNames []string
if fileAlreadyExists(sentFileNames, c.File.Name) {
fmt.Printf("Will not overwrite file!")
os.Exit(1)
}
getOK := getInput("ok? (y/n): ")
if getOK == "y" {
gotOK = true
sentFileNames = append(sentFileNames, c.File.Name)
}
gotResponse = true
}
// wait for the main thread to get the okay
for limit := 0; limit < 1000; limit++ {
if gotResponse {
break
}
time.Sleep(10 * time.Millisecond)
}
if !gotOK {
sendMessage("not ok", connection)
} else {
sendMessage("ok", connection)
logger.Debug("receive file")
fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress)
c.receiveFile(id, connection)
}
}
}
}(id)
}
wg.Wait()
if gotConnectionInUse {
return nil // connection was in use, just quit cleanly
}
if c.IsSender {
// TODO: Add confirmation
} else { // Is a Receiver
if !gotOK {
return errors.New("Transfer interrupted")
}
c.catFile(c.File.Name)
log.Debugf("Code: [%s]", c.Code)
if c.DontEncrypt {
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil {
return err
}
} else {
if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
return errors.Wrap(err, "Problem decrypting file")
}
}
if !c.Debug {
os.Remove(c.File.Name + ".enc")
}
fileHash, err := HashFile(c.File.Name)
if err != nil {
log.Error(err)
}
log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash)
log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
if c.File.Hash != fileHash {
return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
} else {
fmt.Printf("\nReceived file written to %s", c.File.Name)
}
}
return nil
}
func fileAlreadyExists(s []string, f string) bool {
for _, a := range s {
if a == f {
return true
}
}
return false
}
func (c *Connection) catFile(fname string) {
// cat the file
os.Remove(fname)
finished, err := os.Create(fname + ".enc")
defer finished.Close()
if err != nil {
log.Fatal(err)
}
for id := 0; id < c.NumberOfConnections; id++ {
fh, err := os.Open(fname + "." + strconv.Itoa(id))
if err != nil {
log.Fatal(err)
}
_, err = io.Copy(finished, fh)
if err != nil {
log.Fatal(err)
}
fh.Close()
os.Remove(fname + "." + strconv.Itoa(id))
}
}
func (c *Connection) receiveFile(id int, connection net.Conn) error {
logger := log.WithFields(log.Fields{
"function": "receiveFile #" + strconv.Itoa(id),
})
logger.Debug("waiting for chunk size from sender")
fileSizeBuffer := make([]byte, 10)
connection.Read(fileSizeBuffer)
fileDataString := strings.Trim(string(fileSizeBuffer), ":")
fileSizeInt, _ := strconv.Atoi(fileDataString)
chunkSize := int64(fileSizeInt)
logger.Debugf("chunk size: %d", chunkSize)
os.Remove(c.File.Name + "." + strconv.Itoa(id))
newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
if err != nil {
panic(err)
}
defer newFile.Close()
if !c.Debug {
c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
}
logger.Debug("waiting for file")
var receivedBytes int64
for {
if !c.Debug {
c.bars[id].Incr()
}
if (chunkSize - receivedBytes) < BUFFERSIZE {
logger.Debug("at the end")
io.CopyN(newFile, connection, (chunkSize - receivedBytes))
// Empty the remaining bytes that we don't need from the network buffer
if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
logger.Debug("empty remaining bytes from network buffer")
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
}
break
}
io.CopyN(newFile, connection, BUFFERSIZE)
receivedBytes += BUFFERSIZE
}
logger.Debug("received file")
return nil
}
func (c *Connection) sendFile(id int, connection net.Conn) {
logger := log.WithFields(log.Fields{
"function": "sendFile #" + strconv.Itoa(id),
})
defer connection.Close()
var err error
numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
chunkSize := int64(chunksPerWorker * BUFFERSIZE)
if id+1 == c.NumberOfConnections {
chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
}
if id == 0 || id == c.NumberOfConnections-1 {
logger.Debugf("numChunks: %v", numChunks)
logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
}
logger.Debugf("sending chunk size: %d", chunkSize)
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
sendBuffer := make([]byte, BUFFERSIZE)
// open encrypted file
file, err := os.Open(c.File.Name + ".enc")
if err != nil {
log.Error(err)
return
}
defer file.Close()
chunkI := 0
if !c.Debug {
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
}
bufferSizeInKilobytes := BUFFERSIZE / 1024
rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
throttle := time.NewTicker(time.Second / time.Duration(rate))
defer throttle.Stop()
for range throttle.C {
_, err = file.Read(sendBuffer)
if err == io.EOF {
//End of file reached, break out of for loop
logger.Debug("EOF")
break
}
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
connection.Write(sendBuffer)
if !c.Debug {
c.bars[id].Incr()
}
}
chunkI++
}
logger.Debug("file is sent")
return
}

528
relay.go
View File

@ -1,248 +1,280 @@
package main
import (
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
const MAX_NUMBER_THREADS = 8
type connectionMap struct {
reciever map[string]net.Conn
sender map[string]net.Conn
metadata map[string]string
sync.RWMutex
}
type Relay struct {
connections connectionMap
Debug bool
NumberOfConnections int
}
func NewRelay(flags *Flags) *Relay {
r := new(Relay)
r.Debug = flags.Debug
r.NumberOfConnections = MAX_NUMBER_THREADS
log.SetFormatter(&log.TextFormatter{})
if r.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
return r
}
func (r *Relay) Run() {
r.connections = connectionMap{}
r.connections.Lock()
r.connections.reciever = make(map[string]net.Conn)
r.connections.sender = make(map[string]net.Conn)
r.connections.metadata = make(map[string]string)
r.connections.Unlock()
r.runServer()
}
func (r *Relay) runServer() {
logger := log.WithFields(log.Fields{
"function": "main",
})
logger.Debug("Initializing")
var wg sync.WaitGroup
wg.Add(r.NumberOfConnections)
for id := 0; id < r.NumberOfConnections; id++ {
go r.listenerThread(id, &wg)
}
wg.Wait()
}
func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
logger := log.WithFields(log.Fields{
"function": "listenerThread:" + strconv.Itoa(27000+id),
})
defer wg.Done()
err := r.listener(id)
if err != nil {
logger.Error(err)
}
}
func (r *Relay) listener(id int) (err error) {
port := strconv.Itoa(27001 + id)
logger := log.WithFields(log.Fields{
"function": "listener" + ":" + port,
})
server, err := net.Listen("tcp", "0.0.0.0:"+port)
if err != nil {
return errors.Wrap(err, "Error listening on "+":"+port)
}
defer server.Close()
logger.Debug("waiting for connections")
//Spawn a new goroutine whenever a client connects
for {
connection, err := server.Accept()
if err != nil {
return errors.Wrap(err, "problem accepting connection")
}
logger.Debugf("Client %s connected", connection.RemoteAddr().String())
go r.clientCommuncation(id, connection)
}
}
func (r *Relay) clientCommuncation(id int, connection net.Conn) {
sendMessage("who?", connection)
m := strings.Split(receiveMessage(connection), ".")
connectionType, codePhrase, metaData := m[0], m[1], m[2]
key := codePhrase + "-" + strconv.Itoa(id)
logger := log.WithFields(log.Fields{
"id": id,
"codePhrase": codePhrase,
})
if connectionType == "s" {
logger.Debug("got sender")
r.connections.Lock()
r.connections.metadata[key] = metaData
r.connections.sender[key] = connection
r.connections.Unlock()
// wait for receiver
receiversAddress := ""
for {
r.connections.RLock()
if _, ok := r.connections.reciever[key]; ok {
receiversAddress = r.connections.reciever[key].RemoteAddr().String()
logger.Debug("got reciever")
r.connections.RUnlock()
break
}
r.connections.RUnlock()
time.Sleep(100 * time.Millisecond)
}
logger.Debug("telling sender ok")
sendMessage(receiversAddress, connection)
logger.Debug("preparing pipe")
r.connections.Lock()
con1 := r.connections.sender[key]
con2 := r.connections.reciever[key]
r.connections.Unlock()
logger.Debug("piping connections")
Pipe(con1, con2)
logger.Debug("done piping")
r.connections.Lock()
delete(r.connections.sender, key)
delete(r.connections.reciever, key)
delete(r.connections.metadata, key)
r.connections.Unlock()
logger.Debug("deleted sender and receiver")
} else {
// wait for sender's metadata
sendersAddress := ""
for {
r.connections.RLock()
if _, ok := r.connections.metadata[key]; ok {
if _, ok2 := r.connections.sender[key]; ok2 {
sendersAddress = r.connections.sender[key].RemoteAddr().String()
logger.Debug("got sender meta data")
r.connections.RUnlock()
break
}
}
r.connections.RUnlock()
time.Sleep(100 * time.Millisecond)
}
// send meta data
r.connections.RLock()
sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection)
r.connections.RUnlock()
// check for receiver's consent
consent := receiveMessage(connection)
logger.Debugf("consent: %s", consent)
if consent == "ok" {
logger.Debug("got consent")
r.connections.Lock()
r.connections.reciever[key] = connection
r.connections.Unlock()
}
}
return
}
func sendMessage(message string, connection net.Conn) {
message = fillString(message, BUFFERSIZE)
connection.Write([]byte(message))
}
func receiveMessage(connection net.Conn) string {
messageByte := make([]byte, BUFFERSIZE)
connection.Read(messageByte)
return strings.Replace(string(messageByte), ":", "", -1)
}
func fillString(retunString string, toLength int) string {
for {
lengthString := len(retunString)
if lengthString < toLength {
retunString = retunString + ":"
continue
}
break
}
return retunString
}
// chanFromConn creates a channel from a Conn object, and sends everything it
// Read()s from the socket to the channel.
func chanFromConn(conn net.Conn) chan []byte {
c := make(chan []byte)
go func() {
b := make([]byte, BUFFERSIZE)
for {
n, err := conn.Read(b)
if n > 0 {
res := make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient.
copy(res, b[:n])
c <- res
}
if err != nil {
c <- nil
break
}
}
}()
return c
}
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other.
func Pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1)
chan2 := chanFromConn(conn2)
for {
select {
case b1 := <-chan1:
if b1 == nil {
return
} else {
conn2.Write(b1)
}
case b2 := <-chan2:
if b2 == nil {
return
} else {
conn1.Write(b2)
}
}
}
}
package main
import (
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
const MAX_NUMBER_THREADS = 8
type connectionMap struct {
receiver map[string]net.Conn
sender map[string]net.Conn
metadata map[string]string
potentialReceivers map[string]struct{}
sync.RWMutex
}
func (c *connectionMap) IsSenderConnected(key string) (found bool) {
c.RLock()
defer c.RUnlock()
_, found = c.sender[key]
return
}
func (c *connectionMap) IsPotentialReceiverConnected(key string) (found bool) {
c.RLock()
defer c.RUnlock()
_, found = c.potentialReceivers[key]
return
}
type Relay struct {
connections connectionMap
Debug bool
NumberOfConnections int
}
func NewRelay(flags *Flags) *Relay {
r := new(Relay)
r.Debug = flags.Debug
r.NumberOfConnections = MAX_NUMBER_THREADS
log.SetFormatter(&log.TextFormatter{})
if r.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
return r
}
func (r *Relay) Run() {
r.connections = connectionMap{}
r.connections.Lock()
r.connections.receiver = make(map[string]net.Conn)
r.connections.sender = make(map[string]net.Conn)
r.connections.metadata = make(map[string]string)
r.connections.potentialReceivers = make(map[string]struct{})
r.connections.Unlock()
r.runServer()
}
func (r *Relay) runServer() {
logger := log.WithFields(log.Fields{
"function": "main",
})
logger.Debug("Initializing")
var wg sync.WaitGroup
wg.Add(r.NumberOfConnections)
for id := 0; id < r.NumberOfConnections; id++ {
go r.listenerThread(id, &wg)
}
wg.Wait()
}
func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
logger := log.WithFields(log.Fields{
"function": "listenerThread:" + strconv.Itoa(27000+id),
})
defer wg.Done()
err := r.listener(id)
if err != nil {
logger.Error(err)
}
}
func (r *Relay) listener(id int) (err error) {
port := strconv.Itoa(27001 + id)
logger := log.WithFields(log.Fields{
"function": "listener" + ":" + port,
})
server, err := net.Listen("tcp", "0.0.0.0:"+port)
if err != nil {
return errors.Wrap(err, "Error listening on "+":"+port)
}
defer server.Close()
logger.Debug("waiting for connections")
//Spawn a new goroutine whenever a client connects
for {
connection, err := server.Accept()
if err != nil {
return errors.Wrap(err, "problem accepting connection")
}
logger.Debugf("Client %s connected", connection.RemoteAddr().String())
go r.clientCommuncation(id, connection)
}
}
func (r *Relay) clientCommuncation(id int, connection net.Conn) {
sendMessage("who?", connection)
m := strings.Split(receiveMessage(connection), ".")
connectionType, codePhrase, metaData := m[0], m[1], m[2]
key := codePhrase + "-" + strconv.Itoa(id)
logger := log.WithFields(log.Fields{
"id": id,
"codePhrase": codePhrase,
})
if connectionType == "s" { // sender connection
if r.connections.IsSenderConnected(key) {
sendMessage("no", connection)
return
}
logger.Debug("got sender")
r.connections.Lock()
r.connections.metadata[key] = metaData
r.connections.sender[key] = connection
r.connections.Unlock()
// wait for receiver
receiversAddress := ""
for {
r.connections.RLock()
if _, ok := r.connections.receiver[key]; ok {
receiversAddress = r.connections.receiver[key].RemoteAddr().String()
logger.Debug("got receiver")
r.connections.RUnlock()
break
}
r.connections.RUnlock()
time.Sleep(100 * time.Millisecond)
}
logger.Debug("telling sender ok")
sendMessage(receiversAddress, connection)
logger.Debug("preparing pipe")
r.connections.Lock()
con1 := r.connections.sender[key]
con2 := r.connections.receiver[key]
r.connections.Unlock()
logger.Debug("piping connections")
Pipe(con1, con2)
logger.Debug("done piping")
r.connections.Lock()
delete(r.connections.sender, key)
delete(r.connections.receiver, key)
delete(r.connections.metadata, key)
delete(r.connections.potentialReceivers, key)
r.connections.Unlock()
logger.Debug("deleted sender and receiver")
} else { //receiver connection "r"
if r.connections.IsPotentialReceiverConnected(key) {
sendMessage("no", connection)
return
}
// add as a potential receiver
r.connections.Lock()
r.connections.potentialReceivers[key] = struct{}{}
r.connections.Unlock()
// wait for sender's metadata
sendersAddress := ""
for {
r.connections.RLock()
if _, ok := r.connections.metadata[key]; ok {
if _, ok2 := r.connections.sender[key]; ok2 {
sendersAddress = r.connections.sender[key].RemoteAddr().String()
logger.Debug("got sender meta data")
r.connections.RUnlock()
break
}
}
r.connections.RUnlock()
time.Sleep(100 * time.Millisecond)
}
// send meta data
r.connections.RLock()
sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection)
r.connections.RUnlock()
// check for receiver's consent
consent := receiveMessage(connection)
logger.Debugf("consent: %s", consent)
if consent == "ok" {
logger.Debug("got consent")
r.connections.Lock()
r.connections.receiver[key] = connection
r.connections.Unlock()
}
}
return
}
func sendMessage(message string, connection net.Conn) {
message = fillString(message, BUFFERSIZE)
connection.Write([]byte(message))
}
func receiveMessage(connection net.Conn) string {
messageByte := make([]byte, BUFFERSIZE)
connection.Read(messageByte)
return strings.Replace(string(messageByte), ":", "", -1)
}
func fillString(retunString string, toLength int) string {
for {
lengthString := len(retunString)
if lengthString < toLength {
retunString = retunString + ":"
continue
}
break
}
return retunString
}
// chanFromConn creates a channel from a Conn object, and sends everything it
// Read()s from the socket to the channel.
func chanFromConn(conn net.Conn) chan []byte {
c := make(chan []byte)
go func() {
b := make([]byte, BUFFERSIZE)
for {
n, err := conn.Read(b)
if n > 0 {
res := make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient.
copy(res, b[:n])
c <- res
}
if err != nil {
c <- nil
break
}
}
}()
return c
}
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other.
func Pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1)
chan2 := chanFromConn(conn2)
for {
select {
case b1 := <-chan1:
if b1 == nil {
return
} else {
conn2.Write(b1)
}
case b2 := <-chan2:
if b2 == nil {
return
} else {
conn1.Write(b2)
}
}
}
}