Refactor works

This commit is contained in:
Zack Scholl 2017-10-18 07:05:48 -06:00
parent 86b12a3770
commit eae21303f2
4 changed files with 314 additions and 217 deletions

View File

@ -2,13 +2,14 @@ package main
import ( import (
"bytes" "bytes"
"encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math" "math"
"net" "net"
"os" "os"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -18,26 +19,104 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var bars []*uiprogress.Bar type Connection struct {
Server string
File FileMetaData
NumberOfConnections int
Code string
HashedCode string
IsSender bool
Debug bool
DontEncrypt bool
bars []*uiprogress.Bar
}
type FileMetaData struct {
Name string
Size int
Hash string
IV string
Salt string
bytes []byte
}
func NewConnection(flags *Flags) *Connection {
c := new(Connection)
c.Debug = flags.Debug
c.DontEncrypt = flags.DontEncrypt
c.Server = flags.Server
c.Code = flags.Code
c.NumberOfConnections = flags.NumberOfConnections
if len(flags.File) > 0 {
c.File.Name = flags.File
c.IsSender = true
} else {
c.IsSender = false
}
return c
}
func (c *Connection) Run() {
if len(c.Code) == 0 {
if !c.IsSender {
c.Code = getInput("Enter receive code: ")
}
if len(c.Code) < 5 {
c.Code = GetRandomName()
}
}
log.SetFormatter(&log.TextFormatter{})
if c.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
if c.IsSender {
// encrypt the file
log.Debug("encrypting...")
fdata, err := ioutil.ReadFile(c.File.Name)
if err != nil {
log.Fatal(err)
return
}
c.File.bytes, c.File.Salt, c.File.IV = Encrypt(fdata, c.Code, c.DontEncrypt)
log.Debug("...finished encryption")
c.File.Hash = HashBytes(fdata)
c.File.Size = len(c.File.bytes)
if c.Debug {
ioutil.WriteFile(c.File.Name+".encrypted", c.File.bytes, 0644)
}
fmt.Printf("Sending %d byte file named '%s'\n", c.File.Size, c.File.Name)
fmt.Printf("Code is: %s\n", c.Code)
}
c.runClient()
}
// runClient spawns threads for parallel uplink/downlink via TCP // runClient spawns threads for parallel uplink/downlink via TCP
func runClient(connectionType string, codePhrase string) { func (c *Connection) runClient() {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"codePhrase": codePhrase, "code": c.Code,
"connection": connectionType, "sender?": c.IsSender,
}) })
c.HashedCode = Hash(c.Code)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numberConnections) wg.Add(c.NumberOfConnections)
uiprogress.Start() uiprogress.Start()
if !debugFlag { if !c.Debug {
bars = make([]*uiprogress.Bar, numberConnections) c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
} }
for id := 0; id < numberConnections; id++ { gotOK := false
for id := 0; id < c.NumberOfConnections; id++ {
go func(id int) { go func(id int) {
defer wg.Done() defer wg.Done()
port := strconv.Itoa(27001 + id) port := strconv.Itoa(27001 + id)
connection, err := net.Dial("tcp", serverAddress+":"+port) connection, err := net.Dial("tcp", c.Server+":"+port)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -45,12 +124,21 @@ func runClient(connectionType string, codePhrase string) {
message := receiveMessage(connection) message := receiveMessage(connection)
logger.Debugf("relay says: %s", message) logger.Debugf("relay says: %s", message)
logger.Debugf("telling relay: %s", connectionType+"."+codePhrase) if c.IsSender {
logger.Debugf("telling relay: %s", "s."+c.Code)
sendMessage(connectionType+"."+Hash(codePhrase), connection) metaData, err := json.Marshal(c.File)
if connectionType == "s" { // this is a sender if err != nil {
log.Error(err)
}
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
} else {
logger.Debugf("telling relay: %s", "r."+c.Code)
sendMessage("r."+c.HashedCode+".0.0.0", connection)
}
if c.IsSender { // this is a sender
if id == 0 { if id == 0 {
fmt.Println("waiting for other to connect") fmt.Printf("\nSending (<-%s)..\n", connection.RemoteAddr().String())
} }
logger.Debug("waiting for ok from relay") logger.Debug("waiting for ok from relay")
message = receiveMessage(connection) message = receiveMessage(connection)
@ -59,62 +147,97 @@ func runClient(connectionType string, codePhrase string) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
// Write data from file // Write data from file
logger.Debug("send file") logger.Debug("send file")
sendFile(id, connection, codePhrase) c.sendFile(id, connection)
} else { // this is a receiver } else { // this is a receiver
// receive file logger.Debug("waiting for meta data from sender")
message = receiveMessage(connection)
m := strings.Split(message, "-")
encryptedData, salt, iv := m[0], m[1], m[2]
encryptedBytes, err := hex.DecodeString(encryptedData)
if err != nil {
log.Error(err)
return
}
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
err = json.Unmarshal(decryptedBytes, &c.File)
if err != nil {
log.Error(err)
return
}
log.Debugf("meta data received: %v", c.File)
// have the main thread ask for the okay
if id == 0 {
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
getOk := getInput("ok? (y/n): ")
if getOk == "y" {
gotOK = true
} else {
return
}
}
// wait for the main thread to get the okay
for limit := 0; limit < 1000; limit++ {
if gotOK {
break
}
time.Sleep(10 * time.Millisecond)
}
if !gotOK {
return
}
sendMessage("ok", connection)
logger.Debug("receive file") logger.Debug("receive file")
fileName, fileIV, fileSalt, fileHash = receiveFile(id, connection, codePhrase) c.receiveFile(id, connection)
} }
}(id) }(id)
} }
wg.Wait() wg.Wait()
if connectionType == "r" { if !c.IsSender {
catFile(fileName) c.catFile(c.File.Name)
encrypted, err := ioutil.ReadFile(fileName + ".encrypted") encrypted, err := ioutil.ReadFile(c.File.Name + ".encrypted")
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return
} }
fmt.Println("\n\ndecrypting...") fmt.Println("\n\ndecrypting...")
log.Debugf("codePhrase: [%s]", codePhrase) log.Debugf("Code: [%s]", c.Code)
log.Debugf("fileSalt: [%s]", fileSalt) log.Debugf("Salt: [%s]", c.File.Salt)
log.Debugf("fileIV: [%s]", fileIV) log.Debugf("IV: [%s]", c.File.IV)
decrypted, err := Decrypt(encrypted, codePhrase, fileSalt, fileIV) decrypted, err := Decrypt(encrypted, c.Code, c.File.Salt, c.File.IV, c.DontEncrypt)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return return
} }
log.Debugf("writing %d bytes to %s", len(decrypted), fileName) log.Debugf("writing %d bytes to %s", len(decrypted), c.File.Name)
err = ioutil.WriteFile(fileName, decrypted, 0644) err = ioutil.WriteFile(c.File.Name, decrypted, 0644)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }
if !debugFlag { if !c.Debug {
os.Remove(fileName + ".encrypted") os.Remove(c.File.Name + ".encrypted")
} }
log.Debugf("\n\n\ndownloaded hash: [%s]", HashBytes(decrypted)) log.Debugf("\n\n\ndownloaded hash: [%s]", HashBytes(decrypted))
log.Debugf("\n\n\nrelayed hash: [%s]", fileHash) log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
if fileHash != HashBytes(decrypted) { if c.File.Hash != HashBytes(decrypted) {
fmt.Printf("\nUh oh! %s is corrupted! Sorry, try again.\n", fileName) fmt.Printf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
} else { } else {
fmt.Printf("\nDownloaded %s!", fileName) fmt.Printf("\nDownloaded %s!", c.File.Name)
} }
} }
} }
func catFile(fileNameToReceive string) { func (c *Connection) catFile(fname string) {
// cat the file // cat the file
os.Remove(fileNameToReceive) os.Remove(fname)
finished, err := os.Create(fileNameToReceive + ".encrypted") finished, err := os.Create(fname + ".encrypted")
defer finished.Close() defer finished.Close()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
for id := 0; id < numberConnections; id++ { for id := 0; id < c.NumberOfConnections; id++ {
fh, err := os.Open(fileNameToReceive + "." + strconv.Itoa(id)) fh, err := os.Open(fname + "." + strconv.Itoa(id))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -124,74 +247,59 @@ func catFile(fileNameToReceive string) {
log.Fatal(err) log.Fatal(err)
} }
fh.Close() fh.Close()
os.Remove(fileNameToReceive + "." + strconv.Itoa(id)) os.Remove(fname + "." + strconv.Itoa(id))
} }
} }
func receiveFile(id int, connection net.Conn, codePhrase string) (fileNameToReceive string, iv string, salt string, hashOfFile string) { func (c *Connection) receiveFile(id int, connection net.Conn) error {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "receiveFile #" + strconv.Itoa(id), "function": "receiveFile #" + strconv.Itoa(id),
}) })
logger.Debug("waiting for file data") logger.Debug("waiting for chunk size from sender")
fileSizeBuffer := make([]byte, 10)
connection.Read(fileSizeBuffer)
fileDataString := strings.Trim(string(fileSizeBuffer), ":")
fileSizeInt, _ := strconv.Atoi(fileDataString)
chunkSize := int64(fileSizeInt)
logger.Debugf("chunk size: %d", chunkSize)
fileDataBuffer := make([]byte, BUFFERSIZE) os.Remove(c.File.Name + "." + strconv.Itoa(id))
connection.Read(fileDataBuffer) newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
fileDataString := strings.Trim(string(fileDataBuffer), ":")
pieces := strings.Split(fileDataString, "-")
fileSizeInt, _ := strconv.Atoi(pieces[0])
fileSize := int64(fileSizeInt)
logger.Debugf("filesize: %d", fileSize)
fileNameToReceive = pieces[1]
logger.Debugf("fileName: [%s]", fileNameToReceive)
iv = pieces[2]
logger.Debugf("iv: [%s]", iv)
salt = pieces[3]
logger.Debugf("salt: [%s]", salt)
hashOfFile = pieces[4]
logger.Debugf("hashOfFile: [%s]", hashOfFile)
os.Remove(fileNameToReceive + "." + strconv.Itoa(id))
newFile, err := os.Create(fileNameToReceive + "." + strconv.Itoa(id))
if err != nil { if err != nil {
panic(err) panic(err)
} }
defer newFile.Close() defer newFile.Close()
if !debugFlag { if !c.Debug {
bars[id] = uiprogress.AddBar(int(fileSize)/1024 + 1).AppendCompleted().PrependElapsed() c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
} }
logger.Debug("waiting for file") logger.Debug("waiting for file")
var receivedBytes int64 var receivedBytes int64
for { for {
if !debugFlag { if !c.Debug {
bars[id].Incr() c.bars[id].Incr()
} }
if (fileSize - receivedBytes) < BUFFERSIZE { if (chunkSize - receivedBytes) < BUFFERSIZE {
logger.Debug("at the end") logger.Debug("at the end")
io.CopyN(newFile, connection, (fileSize - receivedBytes)) io.CopyN(newFile, connection, (chunkSize - receivedBytes))
// Empty the remaining bytes that we don't need from the network buffer // Empty the remaining bytes that we don't need from the network buffer
if (receivedBytes+BUFFERSIZE)-fileSize < BUFFERSIZE { if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
logger.Debug("empty remaining bytes from network buffer") logger.Debug("empty remaining bytes from network buffer")
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-fileSize)) connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
} }
break break
} }
io.CopyN(newFile, connection, BUFFERSIZE) io.CopyN(newFile, connection, BUFFERSIZE)
//Increment the counter
receivedBytes += BUFFERSIZE receivedBytes += BUFFERSIZE
} }
logger.Debug("received file") logger.Debug("received file")
return return nil
} }
func sendFile(id int, connection net.Conn, codePhrase string) { func (c *Connection) sendFile(id int, connection net.Conn) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "sendFile #" + strconv.Itoa(id), "function": "sendFile #" + strconv.Itoa(id),
}) })
@ -199,41 +307,29 @@ func sendFile(id int, connection net.Conn, codePhrase string) {
var err error var err error
numChunks := math.Ceil(float64(len(fileBytes)) / float64(BUFFERSIZE)) numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
chunksPerWorker := int(math.Ceil(numChunks / float64(numberConnections))) chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
bytesPerConnection := int64(chunksPerWorker * BUFFERSIZE) chunkSize := int64(chunksPerWorker * BUFFERSIZE)
if id+1 == numberConnections { if id+1 == c.NumberOfConnections {
bytesPerConnection = int64(len(fileBytes)) - (numberConnections-1)*bytesPerConnection chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
} }
if id == 0 || id == numberConnections-1 { if id == 0 || id == c.NumberOfConnections-1 {
logger.Debugf("numChunks: %v", numChunks) logger.Debugf("numChunks: %v", numChunks)
logger.Debugf("chunksPerWorker: %v", chunksPerWorker) logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
logger.Debugf("bytesPerConnection: %v", bytesPerConnection) logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
logger.Debugf("fileNameToSend: %v", path.Base(fileName))
} }
payload := strings.Join([]string{ logger.Debugf("sending chunk size: %d", chunkSize)
strconv.FormatInt(int64(bytesPerConnection), 10), // filesize connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
path.Base(fileName),
fileIV,
fileSalt,
fileHash,
}, "-")
logger.Debugf("sending fileSize: %d", bytesPerConnection)
logger.Debugf("sending fileName: %s", path.Base(fileName))
logger.Debugf("sending iv: %s", fileIV)
logger.Debugf("sending salt: %s", fileSalt)
logger.Debugf("sending sha256sum: %s", fileHash)
logger.Debugf("payload is %d bytes", len(payload))
connection.Write([]byte(fillString(payload, BUFFERSIZE)))
sendBuffer := make([]byte, BUFFERSIZE) sendBuffer := make([]byte, BUFFERSIZE)
file := bytes.NewBuffer(fileBytes) file := bytes.NewBuffer(c.File.bytes)
chunkI := 0 chunkI := 0
if !c.Debug {
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
}
for { for {
_, err = file.Read(sendBuffer) _, err = file.Read(sendBuffer)
if err == io.EOF { if err == io.EOF {
@ -241,8 +337,11 @@ func sendFile(id int, connection net.Conn, codePhrase string) {
logger.Debug("EOF") logger.Debug("EOF")
break break
} }
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == numberConnections-1 && chunkI >= chunksPerWorker*id) { if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
connection.Write(sendBuffer) connection.Write(sendBuffer)
if !c.Debug {
c.bars[id].Incr()
}
} }
chunkI++ chunkI++
} }

View File

@ -28,23 +28,25 @@ func GetRandomName() string {
return strings.Join(result, "-") return strings.Join(result, "-")
} }
func Encrypt(plaintext []byte, passphrase string) ([]byte, string, string) { func Encrypt(plaintext []byte, passphrase string, dontencrypt ...bool) (encrypted []byte, salt string, iv string) {
if dontEncrypt { if len(dontencrypt) > 0 && dontencrypt[0] {
return plaintext, "salt", "iv" return plaintext, "salt", "iv"
} }
key, salt := deriveKey(passphrase, nil) key, saltBytes := deriveKey(passphrase, nil)
iv := make([]byte, 12) ivBytes := make([]byte, 12)
// http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf // http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf
// Section 8.2 // Section 8.2
rand.Read(iv) rand.Read(ivBytes)
b, _ := aes.NewCipher(key) b, _ := aes.NewCipher(key)
aesgcm, _ := cipher.NewGCM(b) aesgcm, _ := cipher.NewGCM(b)
data := aesgcm.Seal(nil, iv, plaintext, nil) encrypted = aesgcm.Seal(nil, ivBytes, plaintext, nil)
return data, hex.EncodeToString(salt), hex.EncodeToString(iv) salt = hex.EncodeToString(saltBytes)
iv = hex.EncodeToString(ivBytes)
return
} }
func Decrypt(data []byte, passphrase string, salt string, iv string) (plaintext []byte, err error) { func Decrypt(data []byte, passphrase string, salt string, iv string, dontencrypt ...bool) (plaintext []byte, err error) {
if dontEncrypt { if len(dontencrypt) > 0 && dontencrypt[0] {
return data, nil return data, nil
} }
saltBytes, _ := hex.DecodeString(salt) saltBytes, _ := hex.DecodeString(salt)

92
main.go
View File

@ -4,87 +4,39 @@ import (
"bufio" "bufio"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strings" "strings"
log "github.com/sirupsen/logrus"
) )
const BUFFERSIZE = 1024 const BUFFERSIZE = 1024
const numberConnections = 4
// Build flags type Flags struct {
var server, file string Relay bool
Debug bool
// Global varaibles DontEncrypt bool
var serverAddress, fileName, codePhraseFlag, connectionTypeFlag string Server string
var runAsRelay, debugFlag, dontEncrypt bool File string
var fileSalt, fileIV, fileHash string Code string
var fileBytes []byte NumberOfConnections int
}
func main() { func main() {
flag.BoolVar(&runAsRelay, "relay", false, "run as relay") flags := new(Flags)
flag.BoolVar(&debugFlag, "debug", false, "debug mode") flag.BoolVar(&flags.Relay, "relay", false, "run as relay")
flag.StringVar(&serverAddress, "server", "cowyo.com", "address of relay server") flag.BoolVar(&flags.Debug, "debug", false, "debug mode")
flag.StringVar(&fileName, "send", "", "file to send") flag.StringVar(&flags.Server, "server", "cowyo.com", "address of relay server")
flag.StringVar(&codePhraseFlag, "code", "", "use your own code phrase") flag.StringVar(&flags.File, "send", "", "file to send")
flag.BoolVar(&dontEncrypt, "no-encrypt", false, "turn off encryption") flag.StringVar(&flags.Code, "code", "", "use your own code phrase")
flag.BoolVar(&flags.DontEncrypt, "no-encrypt", false, "turn off encryption")
flag.IntVar(&flags.NumberOfConnections, "threads", 4, "number of threads to use")
flag.Parse() flag.Parse()
// Check build flags too, which take precedent
if server != "" {
serverAddress = server
}
if file != "" {
fileName = file
}
if len(fileName) > 0 { if flags.Relay {
connectionTypeFlag = "s" // sender r := NewRelay(flags)
r.Run()
} else { } else {
connectionTypeFlag = "r" //receiver c := NewConnection(flags)
} c.Run()
if !runAsRelay {
if len(codePhraseFlag) == 0 {
if connectionTypeFlag == "r" {
codePhraseFlag = getInput("What is your code phrase? ")
}
if len(codePhraseFlag) < 5 {
codePhraseFlag = GetRandomName()
fmt.Println("Your code phrase is now " + codePhraseFlag)
}
}
}
if connectionTypeFlag == "s" {
// encrypt the file
fmt.Println("encrypting...")
fdata, err := ioutil.ReadFile(fileName)
if err != nil {
log.Fatal(err)
return
}
fileBytes, fileSalt, fileIV = Encrypt(fdata, codePhraseFlag)
fileHash = HashBytes(fdata)
if debugFlag {
ioutil.WriteFile(fileName+".encrypted", fileBytes, 0644)
}
}
log.SetFormatter(&log.TextFormatter{})
if debugFlag {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
if runAsRelay {
runServer()
} else if len(serverAddress) != 0 {
runClient(connectionTypeFlag, codePhraseFlag)
} else {
fmt.Println("You must specify either -file (for running as a server) or -server (for running as a client)")
} }
} }

118
relay.go
View File

@ -14,44 +14,65 @@ import (
type connectionMap struct { type connectionMap struct {
reciever map[string]net.Conn reciever map[string]net.Conn
sender map[string]net.Conn sender map[string]net.Conn
metadata map[string]string
sync.RWMutex sync.RWMutex
} }
var connections connectionMap type Relay struct {
connections connectionMap
func init() { Debug bool
connections.Lock() NumberOfConnections int
connections.reciever = make(map[string]net.Conn)
connections.sender = make(map[string]net.Conn)
connections.Unlock()
} }
func runServer() { func NewRelay(flags *Flags) *Relay {
r := new(Relay)
r.Debug = flags.Debug
r.NumberOfConnections = flags.NumberOfConnections
log.SetFormatter(&log.TextFormatter{})
if r.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.WarnLevel)
}
return r
}
func (r *Relay) Run() {
r.connections = connectionMap{}
r.connections.Lock()
r.connections.reciever = make(map[string]net.Conn)
r.connections.sender = make(map[string]net.Conn)
r.connections.metadata = make(map[string]string)
r.connections.Unlock()
r.runServer()
}
func (r *Relay) runServer() {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "main", "function": "main",
}) })
logger.Debug("Initializing") logger.Debug("Initializing")
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(numberConnections) wg.Add(r.NumberOfConnections)
for id := 0; id < numberConnections; id++ { for id := 0; id < r.NumberOfConnections; id++ {
go listenerThread(id, &wg) go r.listenerThread(id, &wg)
} }
wg.Wait() wg.Wait()
} }
func listenerThread(id int, wg *sync.WaitGroup) { func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "listenerThread:" + strconv.Itoa(27000+id), "function": "listenerThread:" + strconv.Itoa(27000+id),
}) })
defer wg.Done() defer wg.Done()
err := listener(id) err := r.listener(id)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
} }
func listener(id int) (err error) { func (r *Relay) listener(id int) (err error) {
port := strconv.Itoa(27001 + id) port := strconv.Itoa(27001 + id)
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"function": "listener" + ":" + port, "function": "listener" + ":" + port,
@ -69,15 +90,16 @@ func listener(id int) (err error) {
return errors.Wrap(err, "problem accepting connection") return errors.Wrap(err, "problem accepting connection")
} }
logger.Debugf("Client %s connected", connection.RemoteAddr().String()) logger.Debugf("Client %s connected", connection.RemoteAddr().String())
go clientCommuncation(id, connection) go r.clientCommuncation(id, connection)
} }
} }
func clientCommuncation(id int, connection net.Conn) { func (r *Relay) clientCommuncation(id int, connection net.Conn) {
sendMessage("who?", connection) sendMessage("who?", connection)
message := receiveMessage(connection)
connectionType := strings.Split(message, ".")[0] m := strings.Split(receiveMessage(connection), ".")
codePhrase := strings.Split(message, ".")[1] + "-" + strconv.Itoa(id) connectionType, codePhrase, metaData := m[0], m[1], m[2]
key := codePhrase + "-" + strconv.Itoa(id)
logger := log.WithFields(log.Fields{ logger := log.WithFields(log.Fields{
"id": id, "id": id,
"codePhrase": codePhrase, "codePhrase": codePhrase,
@ -85,39 +107,61 @@ func clientCommuncation(id int, connection net.Conn) {
if connectionType == "s" { if connectionType == "s" {
logger.Debug("got sender") logger.Debug("got sender")
connections.Lock() r.connections.Lock()
connections.sender[codePhrase] = connection r.connections.metadata[key] = metaData
connections.Unlock() r.connections.sender[key] = connection
r.connections.Unlock()
// wait for receiver
for { for {
connections.RLock() r.connections.RLock()
if _, ok := connections.reciever[codePhrase]; ok { if _, ok := r.connections.reciever[key]; ok {
logger.Debug("got reciever") logger.Debug("got reciever")
connections.RUnlock() r.connections.RUnlock()
break break
} }
connections.RUnlock() r.connections.RUnlock()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
logger.Debug("telling sender ok") logger.Debug("telling sender ok")
sendMessage("ok", connection) sendMessage("ok", connection)
logger.Debug("preparing pipe") logger.Debug("preparing pipe")
connections.Lock() r.connections.Lock()
con1 := connections.sender[codePhrase] con1 := r.connections.sender[key]
con2 := connections.reciever[codePhrase] con2 := r.connections.reciever[key]
connections.Unlock() r.connections.Unlock()
logger.Debug("piping connections") logger.Debug("piping connections")
Pipe(con1, con2) Pipe(con1, con2)
logger.Debug("done piping") logger.Debug("done piping")
connections.Lock() r.connections.Lock()
delete(connections.sender, codePhrase) delete(r.connections.sender, key)
delete(connections.reciever, codePhrase) delete(r.connections.reciever, key)
connections.Unlock() delete(r.connections.metadata, key)
r.connections.Unlock()
logger.Debug("deleted sender and receiver") logger.Debug("deleted sender and receiver")
} else { } else {
// wait for sender's metadata
for {
r.connections.RLock()
if _, ok := r.connections.metadata[key]; ok {
logger.Debug("got sender meta data")
r.connections.RUnlock()
break
}
r.connections.RUnlock()
time.Sleep(100 * time.Millisecond)
logger.Debug("waiting for metadata")
}
// send meta data
r.connections.RLock()
sendMessage(r.connections.metadata[key], connection)
r.connections.RUnlock()
// check for receiver's consent
consent := receiveMessage(connection)
logger.Debug("consent: %s", consent)
logger.Debug("got reciever") logger.Debug("got reciever")
connections.Lock() r.connections.Lock()
connections.reciever[codePhrase] = connection r.connections.reciever[key] = connection
connections.Unlock() r.connections.Unlock()
} }
return return
} }