diff --git a/src/cli/cli.go b/src/cli/cli.go index 7c1d782..0a39dcf 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -94,6 +94,7 @@ func Run() (err error) { &cli.BoolFlag{Name: "local", Usage: "force to use only local connections"}, &cli.BoolFlag{Name: "ignore-stdin", Usage: "ignore piped stdin"}, &cli.BoolFlag{Name: "overwrite", Usage: "do not prompt to overwrite"}, + &cli.BoolFlag{Name: "zip", Usage: "zip folder if specified"}, &cli.StringFlag{Name: "curve", Value: "p256", Usage: "choose an encryption curve (" + strings.Join(pake.AvailableCurves(), ", ") + ")"}, &cli.StringFlag{Name: "ip", Value: "", Usage: "set sender ip if known e.g. 10.0.0.1:9009, [::1]:9009"}, &cli.StringFlag{Name: "relay", Value: models.DEFAULT_RELAY, Usage: "address of the relay", EnvVars: []string{"CROC_RELAY"}}, @@ -186,6 +187,7 @@ func send(c *cli.Context) (err error) { Curve: c.String("curve"), HashAlgorithm: c.String("hash"), ThrottleUpload: c.String("throttleUpload"), + ZipFolder: c.Bool("zip"), } if crocOptions.RelayAddress != models.DEFAULT_RELAY { crocOptions.RelayAddress6 = "" @@ -266,8 +268,7 @@ func send(c *cli.Context) (err error) { // generate code phrase crocOptions.SharedSecret = utils.GetRandomName() } - - minimalFileInfos, emptyFoldersToTransfer, totalNumberFolders, err := croc.GetFilesInfo(fnames) + minimalFileInfos, emptyFoldersToTransfer, totalNumberFolders, err := croc.GetFilesInfo(fnames, crocOptions.ZipFolder) if err != nil { return } diff --git a/src/croc/croc.go b/src/croc/croc.go index 0028753..2f18c74 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -75,6 +75,7 @@ type Options struct { Curve string HashAlgorithm string ThrottleUpload string + ZipFolder bool } // Client holds the state of the croc transfer @@ -146,6 +147,7 @@ type FileInfo struct { IsEncrypted bool `json:"e,omitempty"` Symlink string `json:"sy,omitempty"` Mode os.FileMode `json:"md,omitempty"` + TempFile bool `json:"tf,omitempty"` } // RemoteFileRequest requests specific bytes @@ -249,7 +251,7 @@ func isEmptyFolder(folderPath string) (bool, error) { // This function retrives the important file informations // for every file that will be transfered -func GetFilesInfo(fnames []string) (filesInfo []FileInfo, emptyFolders []FileInfo, totalNumberFolders int, err error) { +func GetFilesInfo(fnames []string, zipfolder bool) (filesInfo []FileInfo, emptyFolders []FileInfo, totalNumberFolders int, err error) { // fnames: the relativ/absolute paths of files/folders that will be transfered totalNumberFolders = 0 var paths []string @@ -283,6 +285,35 @@ func GetFilesInfo(fnames []string) (filesInfo []FileInfo, emptyFolders []FileInf return } + if stat.IsDir() && zipfolder { + if path[len(path)-1:] != "/" { + path += "/" + } + path := filepath.Dir(path) + dest := filepath.Base(path) + ".zip" + utils.ZipDirectory(dest, path) + stat, errStat = os.Lstat(dest) + if errStat != nil { + err = errStat + return + } + absPath, errAbs = filepath.Abs(dest) + if errAbs != nil { + err = errAbs + return + } + filesInfo = append(filesInfo, FileInfo{ + Name: stat.Name(), + FolderRemote: "./", + FolderSource: filepath.Dir(absPath), + Size: stat.Size(), + ModTime: stat.ModTime(), + Mode: stat.Mode(), + TempFile: true, + }) + continue + } + if stat.IsDir() { err = filepath.Walk(absPath, func(pathName string, info os.FileInfo, err error) error { @@ -299,6 +330,7 @@ func GetFilesInfo(fnames []string) (filesInfo []FileInfo, emptyFolders []FileInf Size: info.Size(), ModTime: info.ModTime(), Mode: info.Mode(), + TempFile: false, }) } else { totalNumberFolders++ @@ -316,6 +348,7 @@ func GetFilesInfo(fnames []string) (filesInfo []FileInfo, emptyFolders []FileInf if err != nil { return } + } else { filesInfo = append(filesInfo, FileInfo{ Name: stat.Name(), @@ -324,6 +357,7 @@ func GetFilesInfo(fnames []string) (filesInfo []FileInfo, emptyFolders []FileInf Size: stat.Size(), ModTime: stat.ModTime(), Mode: stat.Mode(), + TempFile: false, }) } @@ -869,6 +903,24 @@ func (c *Client) transfer() (err error) { } err = nil } + if c.Options.IsSender && c.SuccessfulTransfer { + for _, file := range c.FilesToTransfer { + if file.TempFile { + fmt.Println("Removing " + file.Name) + os.Remove(file.Name) + } + } + } + + if c.SuccessfulTransfer && !c.Options.IsSender { + for _, file := range c.FilesToTransfer { + if file.TempFile { + utils.UnzipDirectory(".", file.Name) + os.Remove(file.Name) + log.Debugf("Removing %s\n", file.Name) + } + } + } if c.Options.Stdout && !c.Options.IsSender { pathToFile := path.Join( @@ -1579,7 +1631,6 @@ func (c *Client) updateState() (err error) { c.FilesToTransfer[c.FilesToTransferCurrentNum].FolderSource, c.FilesToTransfer[c.FilesToTransferCurrentNum].Name, ) - c.fread, err = os.Open(pathToFile) c.numfinished = 0 if err != nil { diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go index c825b9d..6b9847a 100644 --- a/src/croc/croc_test.go +++ b/src/croc/croc_test.go @@ -66,7 +66,7 @@ func TestCrocReadme(t *testing.T) { var wg sync.WaitGroup wg.Add(2) go func() { - filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{"../../README.md"}) + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{"../../README.md"}, false) if errGet != nil { t.Errorf("failed to get minimal info: %v", errGet) } @@ -132,7 +132,7 @@ func TestCrocEmptyFolder(t *testing.T) { var wg sync.WaitGroup wg.Add(2) go func() { - filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{pathName}) + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{pathName}, false) if errGet != nil { t.Errorf("failed to get minimal info: %v", errGet) } @@ -199,7 +199,7 @@ func TestCrocSymlink(t *testing.T) { var wg sync.WaitGroup wg.Add(2) go func() { - filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{pathName}) + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{pathName}, false) if errGet != nil { t.Errorf("failed to get minimal info: %v", errGet) } @@ -276,7 +276,7 @@ func TestCrocLocal(t *testing.T) { os.Create("touched") wg.Add(2) go func() { - filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{"../../LICENSE", "touched"}) + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{"../../LICENSE", "touched"}, false) if errGet != nil { t.Errorf("failed to get minimal info: %v", errGet) } @@ -329,7 +329,7 @@ func TestCrocError(t *testing.T) { Curve: "siec", Overwrite: true, }) - filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tmpfile.Name()}) + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tmpfile.Name()}, false) if errGet != nil { t.Errorf("failed to get minimal info: %v", errGet) } diff --git a/src/utils/utils.go b/src/utils/utils.go index dad783f..9dce321 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "archive/zip" "bufio" "bytes" "crypto/md5" @@ -16,6 +17,7 @@ import ( "net/http" "os" "path" + "path/filepath" "strings" "time" @@ -369,3 +371,87 @@ func IsLocalIP(ipaddress string) bool { } return false } + +func ZipDirectory(destination string, source string) (err error) { + if _, err := os.Stat(destination); err == nil { + log.Fatalf("%s file already exists!\n", destination) + } + fmt.Fprintf(os.Stderr, "Zipping %s to %s\n", source, destination) + file, err := os.Create(destination) + if err != nil { + log.Fatalln(err) + } + defer file.Close() + writer := zip.NewWriter(file) + defer writer.Close() + err = filepath.Walk(source, func(path string, info os.FileInfo, err error) error { + if err != nil { + log.Fatalln(err) + } + if info.Mode().IsRegular() { + f1, err := os.Open(path) + if err != nil { + log.Fatalln(err) + } + defer f1.Close() + zip_path := strings.ReplaceAll(path, source, strings.TrimSuffix(destination, ".zip")) + w1, err := writer.Create(zip_path) + if err != nil { + log.Fatalln(err) + } + if _, err := io.Copy(w1, f1); err != nil { + log.Fatalln(err) + } + fmt.Fprintf(os.Stderr, "\r\033[2K") + fmt.Fprintf(os.Stderr, "\rAdding %s", zip_path) + } + return nil + }) + if err != nil { + log.Fatalln(err) + } + fmt.Println() + return nil +} + +func UnzipDirectory(destination string, source string) error { + + archive, err := zip.OpenReader(source) + if err != nil { + log.Fatalln(err) + } + defer archive.Close() + + for _, f := range archive.File { + filePath := filepath.Join(destination, f.Name) + fmt.Fprintf(os.Stderr, "\r\033[2K") + fmt.Fprintf(os.Stderr, "\rUnzipping file %s", filePath) + if f.FileInfo().IsDir() { + os.MkdirAll(filePath, os.ModePerm) + continue + } + + if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { + log.Fatalln(err) + } + + dstFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + if err != nil { + log.Fatalln(err) + } + + fileInArchive, err := f.Open() + if err != nil { + log.Fatalln(err) + } + + if _, err := io.Copy(dstFile, fileInArchive); err != nil { + log.Fatalln(err) + } + + dstFile.Close() + fileInArchive.Close() + } + fmt.Println() + return nil +}