Merge pull request #90 from schollz/resume

Resume
This commit is contained in:
Zack 2018-10-09 06:59:06 -07:00 committed by GitHub
commit 5b5c05d694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 200 additions and 96 deletions

View File

@ -1,6 +1,7 @@
package recipient
import (
"bufio"
"bytes"
"encoding/json"
"errors"
@ -50,8 +51,11 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var transferTime time.Duration
var hash256 []byte
var otherIP string
var progressFile string
var resumeFile bool
var tcpConnections []comm.Comm
dataChan := make(chan []byte, 1024*1024)
blocks := []string{}
useWebsockets := true
switch forceSend {
@ -129,6 +133,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
return err
}
log.Debugf("%x\n", sessionKey)
c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
case 3:
spin.Stop()
@ -151,9 +156,14 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
log.Debugf("got file stats: %+v", fstats)
// prompt user if its okay to receive file
progressFile = fmt.Sprintf("%s.progress", fstats.SentName)
overwritingOrReceiving := "Receiving"
if utils.Exists(fstats.Name) {
if utils.Exists(fstats.Name) || utils.Exists(fstats.SentName) {
overwritingOrReceiving = "Overwriting"
if utils.Exists(progressFile) {
overwritingOrReceiving = "Resume receiving"
resumeFile = true
}
}
fileOrFolder := "file"
if fstats.IsDir {
@ -189,15 +199,50 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
}
// await file
f, err := os.Create(fstats.SentName)
if err != nil {
log.Error(err)
return err
var f *os.File
if utils.Exists(fstats.SentName) && resumeFile {
if !useWebsockets {
f, err = os.OpenFile(fstats.SentName, os.O_WRONLY, 0644)
} else {
f, err = os.OpenFile(fstats.SentName, os.O_APPEND, 0644)
}
if err != nil {
log.Error(err)
return err
}
} else {
f, err = os.Create(fstats.SentName)
if err != nil {
log.Error(err)
return err
}
if !useWebsockets {
if err = f.Truncate(fstats.Size); err != nil {
log.Error(err)
return err
}
}
}
if err = f.Truncate(fstats.Size); err != nil {
log.Error(err)
return err
// append the previous blocks if there was progress previously
if resumeFile {
file, _ := os.Open(progressFile)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
blocks = append(blocks, strings.TrimSpace(scanner.Text()))
}
file.Close()
}
blocksBytes, _ := json.Marshal(blocks)
blockSize := 0
if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
} else {
blockSize = models.TCP_BUFFER_SIZE / 2
}
// start the ui for pgoress
bytesWritten := 0
fmt.Fprintf(os.Stderr, "\nReceiving (<-%s)...\n", otherIP)
bar := progressbar.NewOptions(
@ -206,9 +251,32 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
progressbar.OptionSetBytes(int(fstats.Size)),
progressbar.OptionSetWriter(os.Stderr),
)
bar.Add((len(blocks) * blockSize))
finished := make(chan bool)
go func(finished chan bool, dataChan chan []byte) (err error) {
// remove previous progress
var fProgress *os.File
var progressErr error
if resumeFile {
fProgress, progressErr = os.OpenFile(progressFile, os.O_APPEND, 0644)
bytesWritten = len(blocks) * blockSize
} else {
os.Remove(progressFile)
fProgress, progressErr = os.Create(progressFile)
}
if progressErr != nil {
panic(progressErr)
}
defer fProgress.Close()
blocksWritten := 0.0
blocksToWrite := float64(fstats.Size)
if useWebsockets {
blocksToWrite = blocksToWrite/float64(models.WEBSOCKET_BUFFER_SIZE/8) - float64(len(blocks))
} else {
blocksToWrite = blocksToWrite/float64(models.TCP_BUFFER_SIZE/2) - float64(len(blocks))
}
for {
message := <-dataChan
// do decryption
@ -245,19 +313,25 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
return err
}
n, err = f.WriteAt(decrypted, int64(locationToWrite))
fProgress.WriteString(fmt.Sprintf("%d\n", locationToWrite))
log.Debugf("writing to location %d (%2.0f/%2.0f)", locationToWrite, blocksWritten, blocksToWrite)
} else {
// write to file
n, err = f.Write(decrypted)
log.Debugf("writing to location %d (%2.0f/%2.0f)", bytesWritten, blocksWritten, blocksToWrite)
fProgress.WriteString(fmt.Sprintf("%d\n", bytesWritten))
}
if err != nil {
log.Error(err)
return err
}
// update the bytes written
bytesWritten += n
blocksWritten += 1.0
// update the progress bar
bar.Add(n)
if int64(bytesWritten) == fstats.Size {
if int64(bytesWritten) == fstats.Size || blocksWritten >= blocksToWrite {
log.Debug("finished")
break
}
@ -267,7 +341,8 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
}(finished, dataChan)
log.Debug("telling sender i'm ready")
c.WriteMessage(websocket.BinaryMessage, []byte("ready"))
c.WriteMessage(websocket.BinaryMessage, append([]byte("ready"), blocksBytes...))
startTime := time.Now()
if useWebsockets {
for {
@ -388,6 +463,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
fstats.Name = "stdout"
}
fmt.Fprintf(os.Stderr, "\nReceived %s written to %s (%2.1f %s)\n", folderOrFile, fstats.Name, transferRate, transferType)
os.Remove(progressFile)
}
return err
} else {
@ -397,7 +473,6 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
}
return errors.New("file corrupted")
}
default:
return fmt.Errorf("unknown step")
}

View File

@ -8,6 +8,7 @@ import (
"net"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
@ -52,6 +53,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var otherIP string
var startTransfer time.Time
var tcpConnections []comm.Comm
blocksToSkip := make(map[int64]struct{})
type DataChan struct {
b []byte
@ -169,87 +171,6 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
}
fileReady <- nil
// start streaming encryption/compression
go func(dataChan chan DataChan) {
var buffer []byte
if useWebsockets {
buffer = make([]byte, models.WEBSOCKET_BUFFER_SIZE/8)
} else {
buffer = make([]byte, models.TCP_BUFFER_SIZE/2)
}
currentPostition := int64(0)
for {
bytesread, err := f.Read(buffer)
if bytesread > 0 {
// do compression
var compressedBytes []byte
if useCompression && !fstats.IsDir {
compressedBytes = compress.Compress(buffer[:bytesread])
} else {
compressedBytes = buffer[:bytesread]
}
// if using TCP, prepend the location to write the data to in the resulting file
if !useWebsockets {
compressedBytes = append([]byte(fmt.Sprintf("%d-", currentPostition)), compressedBytes...)
}
// do encryption
enc := crypt.Encrypt(compressedBytes, sessionKey, !useEncryption)
encBytes, err := json.Marshal(enc)
if err != nil {
dataChan <- DataChan{
b: nil,
bytesRead: 0,
err: err,
}
return
}
select {
case dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}
}
currentPostition += int64(bytesread)
}
if err != nil {
if err != io.EOF {
log.Error(err)
}
break
}
}
// finish
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
if !useWebsockets {
log.Debug("sending extra magic to %d others", len(tcpPorts)-1)
for i := 0; i < len(tcpPorts)-1; i++ {
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
}
}
}(dataChan)
}()
// send pake data
@ -275,9 +196,10 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
spin.Start()
case 3:
log.Debugf("[%d] recipient declares readiness for file info", step)
if !bytes.Equal(message, []byte("ready")) {
if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("recipient refused file")
}
err = <-fileReady // block until file is ready
if err != nil {
return err
@ -295,10 +217,111 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
spin.Stop()
log.Debugf("[%d] recipient declares readiness for file data", step)
if !bytes.Equal(message, []byte("ready")) {
if !bytes.HasPrefix(message, []byte("ready")) {
return errors.New("recipient refused file")
}
// determine if any blocks were sent to skip
var blocks []string
errBlocks := json.Unmarshal(message[5:], &blocks)
if errBlocks == nil {
for _, block := range blocks {
blockInt64, errBlock := strconv.Atoi(block)
if errBlock == nil {
blocksToSkip[int64(blockInt64)] = struct{}{}
}
}
}
log.Debugf("found blocks: %+v", blocksToSkip)
// start streaming encryption/compression
go func(dataChan chan DataChan) {
var buffer []byte
if useWebsockets {
buffer = make([]byte, models.WEBSOCKET_BUFFER_SIZE/8)
} else {
buffer = make([]byte, models.TCP_BUFFER_SIZE/2)
}
currentPostition := int64(0)
for {
bytesread, err := f.Read(buffer)
if bytesread > 0 {
if _, ok := blocksToSkip[currentPostition]; ok {
log.Debugf("skipping the sending of block %d", currentPostition)
currentPostition += int64(bytesread)
continue
}
// do compression
var compressedBytes []byte
if useCompression && !fstats.IsDir {
compressedBytes = compress.Compress(buffer[:bytesread])
} else {
compressedBytes = buffer[:bytesread]
}
// if using TCP, prepend the location to write the data to in the resulting file
if !useWebsockets {
compressedBytes = append([]byte(fmt.Sprintf("%d-", currentPostition)), compressedBytes...)
}
// do encryption
enc := crypt.Encrypt(compressedBytes, sessionKey, !useEncryption)
encBytes, err := json.Marshal(enc)
if err != nil {
dataChan <- DataChan{
b: nil,
bytesRead: 0,
err: err,
}
return
}
select {
case dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}:
default:
log.Debug("blocked")
// no message sent
// block
dataChan <- DataChan{
b: encBytes,
bytesRead: bytesread,
err: nil,
}
}
currentPostition += int64(bytesread)
}
if err != nil {
if err != io.EOF {
log.Error(err)
}
break
}
}
// finish
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
if !useWebsockets {
log.Debug("sending extra magic to %d others", len(tcpPorts)-1)
for i := 0; i < len(tcpPorts)-1; i++ {
log.Debug("sending magic")
dataChan <- DataChan{
b: []byte("magic"),
bytesRead: 0,
err: nil,
}
}
}
}(dataChan)
// connect to TCP to receive file
if !useWebsockets {
log.Debugf("connecting to server")
@ -318,12 +341,19 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
// send file, compure hash simultaneously
startTransfer = time.Now()
blockSize := 0
if useWebsockets {
blockSize = models.WEBSOCKET_BUFFER_SIZE / 8
} else {
blockSize = models.TCP_BUFFER_SIZE / 2
}
bar := progressbar.NewOptions(
int(fstats.Size),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetBytes(int(fstats.Size)),
progressbar.OptionSetWriter(os.Stderr),
)
bar.Add(blockSize * len(blocksToSkip))
if useWebsockets {
for {

View File

@ -4,7 +4,6 @@ import (
"archive/zip"
"compress/flate"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
@ -97,7 +96,7 @@ func ZipFile(fname string, compress bool) (writtenFilename string, err error) {
return
}
log.Debugf("current directory: %s", curdir)
newfile, err := ioutil.TempFile(curdir, filename+".")
newfile, err := os.Create(fname + ".croc.zip")
if err != nil {
log.Error(err)
return