mirror of https://github.com/schollz/croc.git
Issue #25
fixed spelling connectionMap.receiver Added method to detect if sender and receivers are already connected. Added client code to correctly action "no" returned by the code being in use.
This commit is contained in:
parent
17a1f097c3
commit
e2faa87b59
865
connect.go
865
connect.go
|
@ -1,425 +1,440 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gosuri/uiprogress"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
Server string
|
||||
File FileMetaData
|
||||
NumberOfConnections int
|
||||
Code string
|
||||
HashedCode string
|
||||
IsSender bool
|
||||
Debug bool
|
||||
DontEncrypt bool
|
||||
bars []*uiprogress.Bar
|
||||
rate int
|
||||
}
|
||||
|
||||
type FileMetaData struct {
|
||||
Name string
|
||||
Size int
|
||||
Hash string
|
||||
}
|
||||
|
||||
func NewConnection(flags *Flags) *Connection {
|
||||
c := new(Connection)
|
||||
c.Debug = flags.Debug
|
||||
c.DontEncrypt = flags.DontEncrypt
|
||||
c.Server = flags.Server
|
||||
c.Code = flags.Code
|
||||
c.NumberOfConnections = flags.NumberOfConnections
|
||||
c.rate = flags.Rate
|
||||
if len(flags.File) > 0 {
|
||||
c.File.Name = flags.File
|
||||
c.IsSender = true
|
||||
} else {
|
||||
c.IsSender = false
|
||||
}
|
||||
|
||||
log.SetFormatter(&log.TextFormatter{})
|
||||
if c.Debug {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
} else {
|
||||
log.SetLevel(log.WarnLevel)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Connection) Run() error {
|
||||
forceSingleThreaded := false
|
||||
if c.IsSender {
|
||||
fsize, err := FileSize(c.File.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fsize < MAX_NUMBER_THREADS*BUFFERSIZE {
|
||||
forceSingleThreaded = true
|
||||
log.Debug("forcing single thread")
|
||||
}
|
||||
}
|
||||
log.Debug("checking code validity")
|
||||
for {
|
||||
// check code
|
||||
goodCode := true
|
||||
m := strings.Split(c.Code, "-")
|
||||
numThreads, errParse := strconv.Atoi(m[0])
|
||||
if len(m) < 2 {
|
||||
goodCode = false
|
||||
} else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) {
|
||||
c.NumberOfConnections = MAX_NUMBER_THREADS
|
||||
goodCode = false
|
||||
} else if errParse != nil {
|
||||
goodCode = false
|
||||
}
|
||||
log.Debug(m)
|
||||
if !goodCode {
|
||||
if c.IsSender {
|
||||
if forceSingleThreaded {
|
||||
c.NumberOfConnections = 1
|
||||
}
|
||||
c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName()
|
||||
} else {
|
||||
if len(c.Code) != 0 {
|
||||
fmt.Println("Code must begin with number of threads (e.g. 3-some-code)")
|
||||
}
|
||||
c.Code = getInput("Enter receive code: ")
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
// assign number of connections
|
||||
c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0])
|
||||
|
||||
if c.IsSender {
|
||||
if c.DontEncrypt {
|
||||
// don't encrypt
|
||||
CopyFile(c.File.Name, c.File.Name+".enc")
|
||||
} else {
|
||||
// encrypt
|
||||
log.Debug("encrypting...")
|
||||
if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// get file hash
|
||||
var err error
|
||||
c.File.Hash, err = HashFile(c.File.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// get file size
|
||||
c.File.Size, err = FileSize(c.File.Name + ".enc")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return c.runClient()
|
||||
}
|
||||
|
||||
// runClient spawns threads for parallel uplink/downlink via TCP
|
||||
func (c *Connection) runClient() error {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"code": c.Code,
|
||||
"sender?": c.IsSender,
|
||||
})
|
||||
|
||||
c.HashedCode = Hash(c.Code)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(c.NumberOfConnections)
|
||||
|
||||
uiprogress.Start()
|
||||
if !c.Debug {
|
||||
c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
|
||||
}
|
||||
gotOK := false
|
||||
gotResponse := false
|
||||
for id := 0; id < c.NumberOfConnections; id++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
port := strconv.Itoa(27001 + id)
|
||||
connection, err := net.Dial("tcp", c.Server+":"+port)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer connection.Close()
|
||||
|
||||
message := receiveMessage(connection)
|
||||
logger.Debugf("relay says: %s", message)
|
||||
if c.IsSender {
|
||||
logger.Debugf("telling relay: %s", "s."+c.Code)
|
||||
metaData, err := json.Marshal(c.File)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
|
||||
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
|
||||
} else {
|
||||
logger.Debugf("telling relay: %s", "r."+c.Code)
|
||||
sendMessage("r."+c.HashedCode+".0.0.0", connection)
|
||||
}
|
||||
if c.IsSender { // this is a sender
|
||||
logger.Debug("waiting for ok from relay")
|
||||
message = receiveMessage(connection)
|
||||
logger.Debug("got ok from relay")
|
||||
if id == 0 {
|
||||
fmt.Printf("\nSending (->%s)..\n", message)
|
||||
}
|
||||
// wait for pipe to be made
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Write data from file
|
||||
logger.Debug("send file")
|
||||
c.sendFile(id, connection)
|
||||
} else { // this is a receiver
|
||||
logger.Debug("waiting for meta data from sender")
|
||||
message = receiveMessage(connection)
|
||||
m := strings.Split(message, "-")
|
||||
encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3]
|
||||
encryptedBytes, err := hex.DecodeString(encryptedData)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
|
||||
err = json.Unmarshal(decryptedBytes, &c.File)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
log.Debugf("meta data received: %v", c.File)
|
||||
// have the main thread ask for the okay
|
||||
if id == 0 {
|
||||
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
|
||||
var sentFileNames []string
|
||||
|
||||
if fileAlreadyExists(sentFileNames, c.File.Name) {
|
||||
fmt.Printf("Will not overwrite file!")
|
||||
os.Exit(1)
|
||||
}
|
||||
getOK := getInput("ok? (y/n): ")
|
||||
if getOK == "y" {
|
||||
gotOK = true
|
||||
sentFileNames = append(sentFileNames, c.File.Name)
|
||||
}
|
||||
gotResponse = true
|
||||
}
|
||||
// wait for the main thread to get the okay
|
||||
for limit := 0; limit < 1000; limit++ {
|
||||
if gotResponse {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if !gotOK {
|
||||
sendMessage("not ok", connection)
|
||||
} else {
|
||||
sendMessage("ok", connection)
|
||||
logger.Debug("receive file")
|
||||
fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress)
|
||||
c.receiveFile(id, connection)
|
||||
}
|
||||
}
|
||||
}(id)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if !c.IsSender {
|
||||
if !gotOK {
|
||||
return errors.New("Transfer interrupted")
|
||||
}
|
||||
c.catFile(c.File.Name)
|
||||
log.Debugf("Code: [%s]", c.Code)
|
||||
if c.DontEncrypt {
|
||||
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
|
||||
return errors.Wrap(err, "Problem decrypting file")
|
||||
}
|
||||
}
|
||||
if !c.Debug {
|
||||
os.Remove(c.File.Name + ".enc")
|
||||
}
|
||||
|
||||
fileHash, err := HashFile(c.File.Name)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash)
|
||||
log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
|
||||
|
||||
if c.File.Hash != fileHash {
|
||||
return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
|
||||
} else {
|
||||
fmt.Printf("\nReceived file written to %s", c.File.Name)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("File sent.")
|
||||
// TODO: Add confirmation
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileAlreadyExists(s []string, f string) bool {
|
||||
for _, a := range s {
|
||||
if a == f {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Connection) catFile(fname string) {
|
||||
// cat the file
|
||||
os.Remove(fname)
|
||||
finished, err := os.Create(fname + ".enc")
|
||||
defer finished.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
for id := 0; id < c.NumberOfConnections; id++ {
|
||||
fh, err := os.Open(fname + "." + strconv.Itoa(id))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(finished, fh)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fh.Close()
|
||||
os.Remove(fname + "." + strconv.Itoa(id))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) receiveFile(id int, connection net.Conn) error {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "receiveFile #" + strconv.Itoa(id),
|
||||
})
|
||||
|
||||
logger.Debug("waiting for chunk size from sender")
|
||||
fileSizeBuffer := make([]byte, 10)
|
||||
connection.Read(fileSizeBuffer)
|
||||
fileDataString := strings.Trim(string(fileSizeBuffer), ":")
|
||||
fileSizeInt, _ := strconv.Atoi(fileDataString)
|
||||
chunkSize := int64(fileSizeInt)
|
||||
logger.Debugf("chunk size: %d", chunkSize)
|
||||
|
||||
os.Remove(c.File.Name + "." + strconv.Itoa(id))
|
||||
newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer newFile.Close()
|
||||
|
||||
if !c.Debug {
|
||||
c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
|
||||
}
|
||||
|
||||
logger.Debug("waiting for file")
|
||||
var receivedBytes int64
|
||||
for {
|
||||
if !c.Debug {
|
||||
c.bars[id].Incr()
|
||||
}
|
||||
if (chunkSize - receivedBytes) < BUFFERSIZE {
|
||||
logger.Debug("at the end")
|
||||
io.CopyN(newFile, connection, (chunkSize - receivedBytes))
|
||||
// Empty the remaining bytes that we don't need from the network buffer
|
||||
if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
|
||||
logger.Debug("empty remaining bytes from network buffer")
|
||||
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
|
||||
}
|
||||
break
|
||||
}
|
||||
io.CopyN(newFile, connection, BUFFERSIZE)
|
||||
receivedBytes += BUFFERSIZE
|
||||
}
|
||||
logger.Debug("received file")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) sendFile(id int, connection net.Conn) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "sendFile #" + strconv.Itoa(id),
|
||||
})
|
||||
defer connection.Close()
|
||||
|
||||
var err error
|
||||
|
||||
numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
|
||||
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
|
||||
|
||||
chunkSize := int64(chunksPerWorker * BUFFERSIZE)
|
||||
if id+1 == c.NumberOfConnections {
|
||||
chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
|
||||
}
|
||||
|
||||
if id == 0 || id == c.NumberOfConnections-1 {
|
||||
logger.Debugf("numChunks: %v", numChunks)
|
||||
logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
|
||||
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
|
||||
}
|
||||
|
||||
logger.Debugf("sending chunk size: %d", chunkSize)
|
||||
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
|
||||
|
||||
sendBuffer := make([]byte, BUFFERSIZE)
|
||||
|
||||
// open encrypted file
|
||||
file, err := os.Open(c.File.Name + ".enc")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
chunkI := 0
|
||||
if !c.Debug {
|
||||
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
|
||||
}
|
||||
|
||||
bufferSizeInKilobytes := BUFFERSIZE / 1024
|
||||
rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
|
||||
throttle := time.NewTicker(time.Second / time.Duration(rate))
|
||||
defer throttle.Stop()
|
||||
|
||||
for range throttle.C {
|
||||
_, err = file.Read(sendBuffer)
|
||||
if err == io.EOF {
|
||||
//End of file reached, break out of for loop
|
||||
logger.Debug("EOF")
|
||||
break
|
||||
}
|
||||
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
|
||||
connection.Write(sendBuffer)
|
||||
if !c.Debug {
|
||||
c.bars[id].Incr()
|
||||
}
|
||||
}
|
||||
chunkI++
|
||||
}
|
||||
logger.Debug("file is sent")
|
||||
return
|
||||
}
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gosuri/uiprogress"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
Server string
|
||||
File FileMetaData
|
||||
NumberOfConnections int
|
||||
Code string
|
||||
HashedCode string
|
||||
IsSender bool
|
||||
Debug bool
|
||||
DontEncrypt bool
|
||||
bars []*uiprogress.Bar
|
||||
rate int
|
||||
}
|
||||
|
||||
type FileMetaData struct {
|
||||
Name string
|
||||
Size int
|
||||
Hash string
|
||||
}
|
||||
|
||||
func NewConnection(flags *Flags) *Connection {
|
||||
c := new(Connection)
|
||||
c.Debug = flags.Debug
|
||||
c.DontEncrypt = flags.DontEncrypt
|
||||
c.Server = flags.Server
|
||||
c.Code = flags.Code
|
||||
c.NumberOfConnections = flags.NumberOfConnections
|
||||
c.rate = flags.Rate
|
||||
if len(flags.File) > 0 {
|
||||
c.File.Name = flags.File
|
||||
c.IsSender = true
|
||||
} else {
|
||||
c.IsSender = false
|
||||
}
|
||||
|
||||
log.SetFormatter(&log.TextFormatter{})
|
||||
if c.Debug {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
} else {
|
||||
log.SetLevel(log.WarnLevel)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Connection) Run() error {
|
||||
forceSingleThreaded := false
|
||||
if c.IsSender {
|
||||
fsize, err := FileSize(c.File.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fsize < MAX_NUMBER_THREADS*BUFFERSIZE {
|
||||
forceSingleThreaded = true
|
||||
log.Debug("forcing single thread")
|
||||
}
|
||||
}
|
||||
log.Debug("checking code validity")
|
||||
for {
|
||||
// check code
|
||||
goodCode := true
|
||||
m := strings.Split(c.Code, "-")
|
||||
numThreads, errParse := strconv.Atoi(m[0])
|
||||
if len(m) < 2 {
|
||||
goodCode = false
|
||||
} else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 || (forceSingleThreaded && numThreads != 1) {
|
||||
c.NumberOfConnections = MAX_NUMBER_THREADS
|
||||
goodCode = false
|
||||
} else if errParse != nil {
|
||||
goodCode = false
|
||||
}
|
||||
log.Debug(m)
|
||||
if !goodCode {
|
||||
if c.IsSender {
|
||||
if forceSingleThreaded {
|
||||
c.NumberOfConnections = 1
|
||||
}
|
||||
c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName()
|
||||
} else {
|
||||
if len(c.Code) != 0 {
|
||||
fmt.Println("Code must begin with number of threads (e.g. 3-some-code)")
|
||||
}
|
||||
c.Code = getInput("Enter receive code: ")
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
// assign number of connections
|
||||
c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0])
|
||||
|
||||
if c.IsSender {
|
||||
if c.DontEncrypt {
|
||||
// don't encrypt
|
||||
CopyFile(c.File.Name, c.File.Name+".enc")
|
||||
} else {
|
||||
// encrypt
|
||||
log.Debug("encrypting...")
|
||||
if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// get file hash
|
||||
var err error
|
||||
c.File.Hash, err = HashFile(c.File.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// get file size
|
||||
c.File.Size, err = FileSize(c.File.Name + ".enc")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return c.runClient()
|
||||
}
|
||||
|
||||
// runClient spawns threads for parallel uplink/downlink via TCP
|
||||
func (c *Connection) runClient() error {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"code": c.Code,
|
||||
"sender?": c.IsSender,
|
||||
})
|
||||
|
||||
c.HashedCode = Hash(c.Code)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(c.NumberOfConnections)
|
||||
|
||||
uiprogress.Start()
|
||||
if !c.Debug {
|
||||
c.bars = make([]*uiprogress.Bar, c.NumberOfConnections)
|
||||
}
|
||||
gotOK := false
|
||||
gotResponse := false
|
||||
gotConnectionInUse := false
|
||||
for id := 0; id < c.NumberOfConnections; id++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
port := strconv.Itoa(27001 + id)
|
||||
connection, err := net.Dial("tcp", c.Server+":"+port)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer connection.Close()
|
||||
|
||||
message := receiveMessage(connection)
|
||||
logger.Debugf("relay says: %s", message)
|
||||
if c.IsSender {
|
||||
logger.Debugf("telling relay: %s", "s."+c.Code)
|
||||
metaData, err := json.Marshal(c.File)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
encryptedMetaData, salt, iv := Encrypt(metaData, c.Code)
|
||||
sendMessage("s."+c.HashedCode+"."+hex.EncodeToString(encryptedMetaData)+"-"+salt+"-"+iv, connection)
|
||||
} else {
|
||||
logger.Debugf("telling relay: %s", "r."+c.Code)
|
||||
sendMessage("r."+c.HashedCode+".0.0.0", connection)
|
||||
}
|
||||
if c.IsSender { // this is a sender
|
||||
logger.Debug("waiting for ok from relay")
|
||||
message = receiveMessage(connection)
|
||||
if message == "no" {
|
||||
fmt.Println("The specifed code is already in use by a sender.")
|
||||
gotConnectionInUse = true
|
||||
} else {
|
||||
logger.Debug("got ok from relay")
|
||||
if id == 0 {
|
||||
fmt.Printf("\nSending (->%s)..\n", message)
|
||||
}
|
||||
// wait for pipe to be made
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Write data from file
|
||||
logger.Debug("send file")
|
||||
c.sendFile(id, connection)
|
||||
fmt.Println("File sent.")
|
||||
}
|
||||
} else { // this is a receiver
|
||||
logger.Debug("waiting for meta data from sender")
|
||||
message = receiveMessage(connection)
|
||||
if message == "no" {
|
||||
fmt.Println("The specifed code is already in use by a receiver.")
|
||||
gotConnectionInUse = true
|
||||
} else {
|
||||
m := strings.Split(message, "-")
|
||||
encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3]
|
||||
encryptedBytes, err := hex.DecodeString(encryptedData)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
decryptedBytes, _ := Decrypt(encryptedBytes, c.Code, salt, iv, c.DontEncrypt)
|
||||
err = json.Unmarshal(decryptedBytes, &c.File)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
log.Debugf("meta data received: %v", c.File)
|
||||
// have the main thread ask for the okay
|
||||
if id == 0 {
|
||||
fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name)
|
||||
var sentFileNames []string
|
||||
|
||||
if fileAlreadyExists(sentFileNames, c.File.Name) {
|
||||
fmt.Printf("Will not overwrite file!")
|
||||
os.Exit(1)
|
||||
}
|
||||
getOK := getInput("ok? (y/n): ")
|
||||
if getOK == "y" {
|
||||
gotOK = true
|
||||
sentFileNames = append(sentFileNames, c.File.Name)
|
||||
}
|
||||
gotResponse = true
|
||||
}
|
||||
// wait for the main thread to get the okay
|
||||
for limit := 0; limit < 1000; limit++ {
|
||||
if gotResponse {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if !gotOK {
|
||||
sendMessage("not ok", connection)
|
||||
} else {
|
||||
sendMessage("ok", connection)
|
||||
logger.Debug("receive file")
|
||||
fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress)
|
||||
c.receiveFile(id, connection)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(id)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if gotConnectionInUse {
|
||||
return nil // connection was in use, just quit cleanly
|
||||
}
|
||||
|
||||
if c.IsSender {
|
||||
// TODO: Add confirmation
|
||||
} else { // Is a Receiver
|
||||
if !gotOK {
|
||||
return errors.New("Transfer interrupted")
|
||||
}
|
||||
c.catFile(c.File.Name)
|
||||
log.Debugf("Code: [%s]", c.Code)
|
||||
if c.DontEncrypt {
|
||||
if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil {
|
||||
return errors.Wrap(err, "Problem decrypting file")
|
||||
}
|
||||
}
|
||||
if !c.Debug {
|
||||
os.Remove(c.File.Name + ".enc")
|
||||
}
|
||||
|
||||
fileHash, err := HashFile(c.File.Name)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash)
|
||||
log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash)
|
||||
|
||||
if c.File.Hash != fileHash {
|
||||
return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name)
|
||||
} else {
|
||||
fmt.Printf("\nReceived file written to %s", c.File.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fileAlreadyExists(s []string, f string) bool {
|
||||
for _, a := range s {
|
||||
if a == f {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Connection) catFile(fname string) {
|
||||
// cat the file
|
||||
os.Remove(fname)
|
||||
finished, err := os.Create(fname + ".enc")
|
||||
defer finished.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
for id := 0; id < c.NumberOfConnections; id++ {
|
||||
fh, err := os.Open(fname + "." + strconv.Itoa(id))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(finished, fh)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fh.Close()
|
||||
os.Remove(fname + "." + strconv.Itoa(id))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) receiveFile(id int, connection net.Conn) error {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "receiveFile #" + strconv.Itoa(id),
|
||||
})
|
||||
|
||||
logger.Debug("waiting for chunk size from sender")
|
||||
fileSizeBuffer := make([]byte, 10)
|
||||
connection.Read(fileSizeBuffer)
|
||||
fileDataString := strings.Trim(string(fileSizeBuffer), ":")
|
||||
fileSizeInt, _ := strconv.Atoi(fileDataString)
|
||||
chunkSize := int64(fileSizeInt)
|
||||
logger.Debugf("chunk size: %d", chunkSize)
|
||||
|
||||
os.Remove(c.File.Name + "." + strconv.Itoa(id))
|
||||
newFile, err := os.Create(c.File.Name + "." + strconv.Itoa(id))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer newFile.Close()
|
||||
|
||||
if !c.Debug {
|
||||
c.bars[id] = uiprogress.AddBar(int(chunkSize)/1024 + 1).AppendCompleted().PrependElapsed()
|
||||
}
|
||||
|
||||
logger.Debug("waiting for file")
|
||||
var receivedBytes int64
|
||||
for {
|
||||
if !c.Debug {
|
||||
c.bars[id].Incr()
|
||||
}
|
||||
if (chunkSize - receivedBytes) < BUFFERSIZE {
|
||||
logger.Debug("at the end")
|
||||
io.CopyN(newFile, connection, (chunkSize - receivedBytes))
|
||||
// Empty the remaining bytes that we don't need from the network buffer
|
||||
if (receivedBytes+BUFFERSIZE)-chunkSize < BUFFERSIZE {
|
||||
logger.Debug("empty remaining bytes from network buffer")
|
||||
connection.Read(make([]byte, (receivedBytes+BUFFERSIZE)-chunkSize))
|
||||
}
|
||||
break
|
||||
}
|
||||
io.CopyN(newFile, connection, BUFFERSIZE)
|
||||
receivedBytes += BUFFERSIZE
|
||||
}
|
||||
logger.Debug("received file")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) sendFile(id int, connection net.Conn) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "sendFile #" + strconv.Itoa(id),
|
||||
})
|
||||
defer connection.Close()
|
||||
|
||||
var err error
|
||||
|
||||
numChunks := math.Ceil(float64(c.File.Size) / float64(BUFFERSIZE))
|
||||
chunksPerWorker := int(math.Ceil(numChunks / float64(c.NumberOfConnections)))
|
||||
|
||||
chunkSize := int64(chunksPerWorker * BUFFERSIZE)
|
||||
if id+1 == c.NumberOfConnections {
|
||||
chunkSize = int64(c.File.Size) - int64(c.NumberOfConnections-1)*chunkSize
|
||||
}
|
||||
|
||||
if id == 0 || id == c.NumberOfConnections-1 {
|
||||
logger.Debugf("numChunks: %v", numChunks)
|
||||
logger.Debugf("chunksPerWorker: %v", chunksPerWorker)
|
||||
logger.Debugf("bytesPerchunkSizeConnection: %v", chunkSize)
|
||||
}
|
||||
|
||||
logger.Debugf("sending chunk size: %d", chunkSize)
|
||||
connection.Write([]byte(fillString(strconv.FormatInt(int64(chunkSize), 10), 10)))
|
||||
|
||||
sendBuffer := make([]byte, BUFFERSIZE)
|
||||
|
||||
// open encrypted file
|
||||
file, err := os.Open(c.File.Name + ".enc")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
chunkI := 0
|
||||
if !c.Debug {
|
||||
c.bars[id] = uiprogress.AddBar(chunksPerWorker).AppendCompleted().PrependElapsed()
|
||||
}
|
||||
|
||||
bufferSizeInKilobytes := BUFFERSIZE / 1024
|
||||
rate := float64(c.rate) / float64(c.NumberOfConnections*bufferSizeInKilobytes)
|
||||
throttle := time.NewTicker(time.Second / time.Duration(rate))
|
||||
defer throttle.Stop()
|
||||
|
||||
for range throttle.C {
|
||||
_, err = file.Read(sendBuffer)
|
||||
if err == io.EOF {
|
||||
//End of file reached, break out of for loop
|
||||
logger.Debug("EOF")
|
||||
break
|
||||
}
|
||||
if (chunkI >= chunksPerWorker*id && chunkI < chunksPerWorker*id+chunksPerWorker) || (id == c.NumberOfConnections-1 && chunkI >= chunksPerWorker*id) {
|
||||
connection.Write(sendBuffer)
|
||||
if !c.Debug {
|
||||
c.bars[id].Incr()
|
||||
}
|
||||
}
|
||||
chunkI++
|
||||
}
|
||||
logger.Debug("file is sent")
|
||||
return
|
||||
}
|
||||
|
|
528
relay.go
528
relay.go
|
@ -1,248 +1,280 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const MAX_NUMBER_THREADS = 8
|
||||
|
||||
type connectionMap struct {
|
||||
reciever map[string]net.Conn
|
||||
sender map[string]net.Conn
|
||||
metadata map[string]string
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
type Relay struct {
|
||||
connections connectionMap
|
||||
Debug bool
|
||||
NumberOfConnections int
|
||||
}
|
||||
|
||||
func NewRelay(flags *Flags) *Relay {
|
||||
r := new(Relay)
|
||||
r.Debug = flags.Debug
|
||||
r.NumberOfConnections = MAX_NUMBER_THREADS
|
||||
log.SetFormatter(&log.TextFormatter{})
|
||||
if r.Debug {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
} else {
|
||||
log.SetLevel(log.WarnLevel)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Relay) Run() {
|
||||
r.connections = connectionMap{}
|
||||
r.connections.Lock()
|
||||
r.connections.reciever = make(map[string]net.Conn)
|
||||
r.connections.sender = make(map[string]net.Conn)
|
||||
r.connections.metadata = make(map[string]string)
|
||||
r.connections.Unlock()
|
||||
r.runServer()
|
||||
}
|
||||
|
||||
func (r *Relay) runServer() {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "main",
|
||||
})
|
||||
logger.Debug("Initializing")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(r.NumberOfConnections)
|
||||
for id := 0; id < r.NumberOfConnections; id++ {
|
||||
go r.listenerThread(id, &wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "listenerThread:" + strconv.Itoa(27000+id),
|
||||
})
|
||||
|
||||
defer wg.Done()
|
||||
err := r.listener(id)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) listener(id int) (err error) {
|
||||
port := strconv.Itoa(27001 + id)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "listener" + ":" + port,
|
||||
})
|
||||
server, err := net.Listen("tcp", "0.0.0.0:"+port)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error listening on "+":"+port)
|
||||
}
|
||||
defer server.Close()
|
||||
logger.Debug("waiting for connections")
|
||||
//Spawn a new goroutine whenever a client connects
|
||||
for {
|
||||
connection, err := server.Accept()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "problem accepting connection")
|
||||
}
|
||||
logger.Debugf("Client %s connected", connection.RemoteAddr().String())
|
||||
go r.clientCommuncation(id, connection)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) clientCommuncation(id int, connection net.Conn) {
|
||||
sendMessage("who?", connection)
|
||||
|
||||
m := strings.Split(receiveMessage(connection), ".")
|
||||
connectionType, codePhrase, metaData := m[0], m[1], m[2]
|
||||
key := codePhrase + "-" + strconv.Itoa(id)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"id": id,
|
||||
"codePhrase": codePhrase,
|
||||
})
|
||||
|
||||
if connectionType == "s" {
|
||||
logger.Debug("got sender")
|
||||
r.connections.Lock()
|
||||
r.connections.metadata[key] = metaData
|
||||
r.connections.sender[key] = connection
|
||||
r.connections.Unlock()
|
||||
// wait for receiver
|
||||
receiversAddress := ""
|
||||
for {
|
||||
r.connections.RLock()
|
||||
if _, ok := r.connections.reciever[key]; ok {
|
||||
receiversAddress = r.connections.reciever[key].RemoteAddr().String()
|
||||
logger.Debug("got reciever")
|
||||
r.connections.RUnlock()
|
||||
break
|
||||
}
|
||||
r.connections.RUnlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
logger.Debug("telling sender ok")
|
||||
sendMessage(receiversAddress, connection)
|
||||
logger.Debug("preparing pipe")
|
||||
r.connections.Lock()
|
||||
con1 := r.connections.sender[key]
|
||||
con2 := r.connections.reciever[key]
|
||||
r.connections.Unlock()
|
||||
logger.Debug("piping connections")
|
||||
Pipe(con1, con2)
|
||||
logger.Debug("done piping")
|
||||
r.connections.Lock()
|
||||
delete(r.connections.sender, key)
|
||||
delete(r.connections.reciever, key)
|
||||
delete(r.connections.metadata, key)
|
||||
r.connections.Unlock()
|
||||
logger.Debug("deleted sender and receiver")
|
||||
} else {
|
||||
// wait for sender's metadata
|
||||
sendersAddress := ""
|
||||
for {
|
||||
r.connections.RLock()
|
||||
if _, ok := r.connections.metadata[key]; ok {
|
||||
if _, ok2 := r.connections.sender[key]; ok2 {
|
||||
sendersAddress = r.connections.sender[key].RemoteAddr().String()
|
||||
logger.Debug("got sender meta data")
|
||||
r.connections.RUnlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
r.connections.RUnlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
// send meta data
|
||||
r.connections.RLock()
|
||||
sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection)
|
||||
r.connections.RUnlock()
|
||||
// check for receiver's consent
|
||||
consent := receiveMessage(connection)
|
||||
logger.Debugf("consent: %s", consent)
|
||||
if consent == "ok" {
|
||||
logger.Debug("got consent")
|
||||
r.connections.Lock()
|
||||
r.connections.reciever[key] = connection
|
||||
r.connections.Unlock()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func sendMessage(message string, connection net.Conn) {
|
||||
message = fillString(message, BUFFERSIZE)
|
||||
connection.Write([]byte(message))
|
||||
}
|
||||
|
||||
func receiveMessage(connection net.Conn) string {
|
||||
messageByte := make([]byte, BUFFERSIZE)
|
||||
connection.Read(messageByte)
|
||||
return strings.Replace(string(messageByte), ":", "", -1)
|
||||
}
|
||||
|
||||
func fillString(retunString string, toLength int) string {
|
||||
for {
|
||||
lengthString := len(retunString)
|
||||
if lengthString < toLength {
|
||||
retunString = retunString + ":"
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return retunString
|
||||
}
|
||||
|
||||
// chanFromConn creates a channel from a Conn object, and sends everything it
|
||||
// Read()s from the socket to the channel.
|
||||
func chanFromConn(conn net.Conn) chan []byte {
|
||||
c := make(chan []byte)
|
||||
|
||||
go func() {
|
||||
b := make([]byte, BUFFERSIZE)
|
||||
|
||||
for {
|
||||
n, err := conn.Read(b)
|
||||
if n > 0 {
|
||||
res := make([]byte, n)
|
||||
// Copy the buffer so it doesn't get changed while read by the recipient.
|
||||
copy(res, b[:n])
|
||||
c <- res
|
||||
}
|
||||
if err != nil {
|
||||
c <- nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other.
|
||||
func Pipe(conn1 net.Conn, conn2 net.Conn) {
|
||||
chan1 := chanFromConn(conn1)
|
||||
chan2 := chanFromConn(conn2)
|
||||
|
||||
for {
|
||||
select {
|
||||
case b1 := <-chan1:
|
||||
if b1 == nil {
|
||||
return
|
||||
} else {
|
||||
conn2.Write(b1)
|
||||
}
|
||||
case b2 := <-chan2:
|
||||
if b2 == nil {
|
||||
return
|
||||
} else {
|
||||
conn1.Write(b2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const MAX_NUMBER_THREADS = 8
|
||||
|
||||
type connectionMap struct {
|
||||
receiver map[string]net.Conn
|
||||
sender map[string]net.Conn
|
||||
metadata map[string]string
|
||||
potentialReceivers map[string]struct{}
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *connectionMap) IsSenderConnected(key string) (found bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
_, found = c.sender[key]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *connectionMap) IsPotentialReceiverConnected(key string) (found bool) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
_, found = c.potentialReceivers[key]
|
||||
return
|
||||
}
|
||||
|
||||
type Relay struct {
|
||||
connections connectionMap
|
||||
Debug bool
|
||||
NumberOfConnections int
|
||||
}
|
||||
|
||||
func NewRelay(flags *Flags) *Relay {
|
||||
r := new(Relay)
|
||||
r.Debug = flags.Debug
|
||||
r.NumberOfConnections = MAX_NUMBER_THREADS
|
||||
log.SetFormatter(&log.TextFormatter{})
|
||||
if r.Debug {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
} else {
|
||||
log.SetLevel(log.WarnLevel)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Relay) Run() {
|
||||
r.connections = connectionMap{}
|
||||
r.connections.Lock()
|
||||
r.connections.receiver = make(map[string]net.Conn)
|
||||
r.connections.sender = make(map[string]net.Conn)
|
||||
r.connections.metadata = make(map[string]string)
|
||||
r.connections.potentialReceivers = make(map[string]struct{})
|
||||
r.connections.Unlock()
|
||||
r.runServer()
|
||||
}
|
||||
|
||||
func (r *Relay) runServer() {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "main",
|
||||
})
|
||||
logger.Debug("Initializing")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(r.NumberOfConnections)
|
||||
for id := 0; id < r.NumberOfConnections; id++ {
|
||||
go r.listenerThread(id, &wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *Relay) listenerThread(id int, wg *sync.WaitGroup) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "listenerThread:" + strconv.Itoa(27000+id),
|
||||
})
|
||||
|
||||
defer wg.Done()
|
||||
err := r.listener(id)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) listener(id int) (err error) {
|
||||
port := strconv.Itoa(27001 + id)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"function": "listener" + ":" + port,
|
||||
})
|
||||
server, err := net.Listen("tcp", "0.0.0.0:"+port)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error listening on "+":"+port)
|
||||
}
|
||||
defer server.Close()
|
||||
logger.Debug("waiting for connections")
|
||||
//Spawn a new goroutine whenever a client connects
|
||||
for {
|
||||
connection, err := server.Accept()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "problem accepting connection")
|
||||
}
|
||||
logger.Debugf("Client %s connected", connection.RemoteAddr().String())
|
||||
go r.clientCommuncation(id, connection)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) clientCommuncation(id int, connection net.Conn) {
|
||||
sendMessage("who?", connection)
|
||||
|
||||
m := strings.Split(receiveMessage(connection), ".")
|
||||
connectionType, codePhrase, metaData := m[0], m[1], m[2]
|
||||
key := codePhrase + "-" + strconv.Itoa(id)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"id": id,
|
||||
"codePhrase": codePhrase,
|
||||
})
|
||||
|
||||
if connectionType == "s" { // sender connection
|
||||
if r.connections.IsSenderConnected(key) {
|
||||
sendMessage("no", connection)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("got sender")
|
||||
r.connections.Lock()
|
||||
r.connections.metadata[key] = metaData
|
||||
r.connections.sender[key] = connection
|
||||
r.connections.Unlock()
|
||||
// wait for receiver
|
||||
receiversAddress := ""
|
||||
for {
|
||||
r.connections.RLock()
|
||||
if _, ok := r.connections.receiver[key]; ok {
|
||||
receiversAddress = r.connections.receiver[key].RemoteAddr().String()
|
||||
logger.Debug("got receiver")
|
||||
r.connections.RUnlock()
|
||||
break
|
||||
}
|
||||
r.connections.RUnlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
logger.Debug("telling sender ok")
|
||||
sendMessage(receiversAddress, connection)
|
||||
logger.Debug("preparing pipe")
|
||||
r.connections.Lock()
|
||||
con1 := r.connections.sender[key]
|
||||
con2 := r.connections.receiver[key]
|
||||
r.connections.Unlock()
|
||||
logger.Debug("piping connections")
|
||||
Pipe(con1, con2)
|
||||
logger.Debug("done piping")
|
||||
r.connections.Lock()
|
||||
delete(r.connections.sender, key)
|
||||
delete(r.connections.receiver, key)
|
||||
delete(r.connections.metadata, key)
|
||||
delete(r.connections.potentialReceivers, key)
|
||||
r.connections.Unlock()
|
||||
logger.Debug("deleted sender and receiver")
|
||||
} else { //receiver connection "r"
|
||||
if r.connections.IsPotentialReceiverConnected(key) {
|
||||
sendMessage("no", connection)
|
||||
return
|
||||
}
|
||||
|
||||
// add as a potential receiver
|
||||
r.connections.Lock()
|
||||
r.connections.potentialReceivers[key] = struct{}{}
|
||||
r.connections.Unlock()
|
||||
|
||||
// wait for sender's metadata
|
||||
sendersAddress := ""
|
||||
for {
|
||||
r.connections.RLock()
|
||||
if _, ok := r.connections.metadata[key]; ok {
|
||||
if _, ok2 := r.connections.sender[key]; ok2 {
|
||||
sendersAddress = r.connections.sender[key].RemoteAddr().String()
|
||||
logger.Debug("got sender meta data")
|
||||
r.connections.RUnlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
r.connections.RUnlock()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
// send meta data
|
||||
r.connections.RLock()
|
||||
sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection)
|
||||
r.connections.RUnlock()
|
||||
// check for receiver's consent
|
||||
consent := receiveMessage(connection)
|
||||
logger.Debugf("consent: %s", consent)
|
||||
if consent == "ok" {
|
||||
logger.Debug("got consent")
|
||||
r.connections.Lock()
|
||||
r.connections.receiver[key] = connection
|
||||
r.connections.Unlock()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func sendMessage(message string, connection net.Conn) {
|
||||
message = fillString(message, BUFFERSIZE)
|
||||
connection.Write([]byte(message))
|
||||
}
|
||||
|
||||
func receiveMessage(connection net.Conn) string {
|
||||
messageByte := make([]byte, BUFFERSIZE)
|
||||
connection.Read(messageByte)
|
||||
return strings.Replace(string(messageByte), ":", "", -1)
|
||||
}
|
||||
|
||||
func fillString(retunString string, toLength int) string {
|
||||
for {
|
||||
lengthString := len(retunString)
|
||||
if lengthString < toLength {
|
||||
retunString = retunString + ":"
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return retunString
|
||||
}
|
||||
|
||||
// chanFromConn creates a channel from a Conn object, and sends everything it
|
||||
// Read()s from the socket to the channel.
|
||||
func chanFromConn(conn net.Conn) chan []byte {
|
||||
c := make(chan []byte)
|
||||
|
||||
go func() {
|
||||
b := make([]byte, BUFFERSIZE)
|
||||
|
||||
for {
|
||||
n, err := conn.Read(b)
|
||||
if n > 0 {
|
||||
res := make([]byte, n)
|
||||
// Copy the buffer so it doesn't get changed while read by the recipient.
|
||||
copy(res, b[:n])
|
||||
c <- res
|
||||
}
|
||||
if err != nil {
|
||||
c <- nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Pipe creates a full-duplex pipe between the two sockets and transfers data from one to the other.
|
||||
func Pipe(conn1 net.Conn, conn2 net.Conn) {
|
||||
chan1 := chanFromConn(conn1)
|
||||
chan2 := chanFromConn(conn2)
|
||||
|
||||
for {
|
||||
select {
|
||||
case b1 := <-chan1:
|
||||
if b1 == nil {
|
||||
return
|
||||
} else {
|
||||
conn2.Write(b1)
|
||||
}
|
||||
case b2 := <-chan2:
|
||||
if b2 == nil {
|
||||
return
|
||||
} else {
|
||||
conn1.Write(b2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue