diff --git a/connect.go b/connect.go index c3fedcb..3e3e68b 100644 --- a/connect.go +++ b/connect.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "math" "net" "os" @@ -15,6 +14,7 @@ import ( "time" "github.com/gosuri/uiprogress" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -32,12 +32,9 @@ type Connection struct { } type FileMetaData struct { - Name string - Size int - Hash string - IV string - Salt string - bytes []byte + Name string + Size int + Hash string } func NewConnection(flags *Flags) *Connection { @@ -65,15 +62,14 @@ func NewConnection(flags *Flags) *Connection { return c } -func (c *Connection) Run() { +func (c *Connection) Run() error { forceSingleThreaded := false if c.IsSender { - fdata, err := ioutil.ReadFile(c.File.Name) + fsize, err := FileSize(c.File.Name) if err != nil { - log.Fatal(err) - return + return err } - if len(fdata) < MAX_NUMBER_THREADS*BUFFERSIZE { + if fsize < MAX_NUMBER_THREADS*BUFFERSIZE { forceSingleThreaded = true log.Debug("forcing single thread") } @@ -113,29 +109,34 @@ func (c *Connection) Run() { c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0]) if c.IsSender { - // encrypt the file - log.Debug("encrypting...") - fdata, err := ioutil.ReadFile(c.File.Name) + if c.DontEncrypt { + // don't encrypt + CopyFile(c.File.Name, c.File.Name+".enc") + } else { + // encrypt + log.Debug("encrypting...") + if err := EncryptFile(c.File.Name, c.File.Name+".enc", c.Code); err != nil { + return err + } + } + // get file hash + var err error + c.File.Hash, err = HashFile(c.File.Name) if err != nil { - log.Fatal(err) - return + return err } - c.File.bytes, c.File.Salt, c.File.IV = Encrypt(fdata, c.Code, c.DontEncrypt) - log.Debug("...finished encryption") - c.File.Hash = HashBytes(fdata) - c.File.Size = len(c.File.bytes) - if c.Debug { - ioutil.WriteFile(c.File.Name+".encrypted", c.File.bytes, 0644) + // get file size + c.File.Size, err = FileSize(c.File.Name + ".enc") + if err != nil { + return err } - fmt.Printf("Sending %d byte file named '%s'\n", c.File.Size, c.File.Name) - fmt.Printf("Code is: %s\n", c.Code) } - c.runClient() + return c.runClient() } // runClient spawns threads for parallel uplink/downlink via TCP -func (c *Connection) runClient() { +func (c *Connection) runClient() error { logger := log.WithFields(log.Fields{ "code": c.Code, "sender?": c.IsSender, @@ -177,12 +178,12 @@ func (c *Connection) runClient() { sendMessage("r."+c.HashedCode+".0.0.0", connection) } if c.IsSender { // this is a sender - if id == 0 { - fmt.Printf("\nSending (<-%s)..\n", connection.RemoteAddr().String()) - } logger.Debug("waiting for ok from relay") message = receiveMessage(connection) logger.Debug("got ok from relay") + if id == 0 { + fmt.Printf("\nSending (->%s)..\n", message) + } // wait for pipe to be made time.Sleep(100 * time.Millisecond) // Write data from file @@ -192,7 +193,7 @@ func (c *Connection) runClient() { logger.Debug("waiting for meta data from sender") message = receiveMessage(connection) m := strings.Split(message, "-") - encryptedData, salt, iv := m[0], m[1], m[2] + encryptedData, salt, iv, sendersAddress := m[0], m[1], m[2], m[3] encryptedBytes, err := hex.DecodeString(encryptedData) if err != nil { log.Error(err) @@ -233,6 +234,7 @@ func (c *Connection) runClient() { } else { sendMessage("ok", connection) logger.Debug("receive file") + fmt.Printf("\n\nReceiving (<-%s)..\n", sendersAddress) c.receiveFile(id, connection) } } @@ -242,36 +244,32 @@ func (c *Connection) runClient() { if !c.IsSender { if !gotOK { - return + return errors.New("Transfer interrupted") } c.catFile(c.File.Name) - encrypted, err := ioutil.ReadFile(c.File.Name + ".encrypted") - if err != nil { - log.Error(err) - return - } - fmt.Println("\n\ndecrypting...") log.Debugf("Code: [%s]", c.Code) - log.Debugf("Salt: [%s]", c.File.Salt) - log.Debugf("IV: [%s]", c.File.IV) - decrypted, err := Decrypt(encrypted, c.Code, c.File.Salt, c.File.IV, c.DontEncrypt) - if err != nil { - log.Error(err) - return - } - log.Debugf("writing %d bytes to %s", len(decrypted), c.File.Name) - err = ioutil.WriteFile(c.File.Name, decrypted, 0644) - if err != nil { - log.Error(err) + if c.DontEncrypt { + if err := CopyFile(c.File.Name+".enc", c.File.Name); err != nil { + return err + } + } else { + if err := DecryptFile(c.File.Name+".enc", c.File.Name, c.Code); err != nil { + return errors.Wrap(err, "Problem decrypting file") + } } if !c.Debug { - os.Remove(c.File.Name + ".encrypted") + os.Remove(c.File.Name + ".enc") } - log.Debugf("\n\n\ndownloaded hash: [%s]", HashBytes(decrypted)) + + fileHash, err := HashFile(c.File.Name) + if err != nil { + log.Error(err) + } + log.Debugf("\n\n\ndownloaded hash: [%s]", fileHash) log.Debugf("\n\n\nrelayed hash: [%s]", c.File.Hash) - if c.File.Hash != HashBytes(decrypted) { - fmt.Printf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name) + if c.File.Hash != fileHash { + return fmt.Errorf("\nUh oh! %s is corrupted! Sorry, try again.\n", c.File.Name) } else { fmt.Printf("\nReceived file written to %s", c.File.Name) } @@ -279,6 +277,7 @@ func (c *Connection) runClient() { fmt.Println("File sent.") // TODO: Add confirmation } + return nil } func fileAlreadyExists(s []string, f string) bool { @@ -293,7 +292,7 @@ func fileAlreadyExists(s []string, f string) bool { func (c *Connection) catFile(fname string) { // cat the file os.Remove(fname) - finished, err := os.Create(fname + ".encrypted") + finished, err := os.Create(fname + ".enc") defer finished.Close() if err != nil { log.Fatal(err) diff --git a/main.go b/main.go index 178d1b2..fbc719e 100644 --- a/main.go +++ b/main.go @@ -32,7 +32,7 @@ func main() { /~____ =ΓΈ= / (______)__m_m) -croc version `+version+` +croc version ` + version + ` `) flags := new(Flags) flag.BoolVar(&flags.Relay, "relay", false, "run as relay") @@ -50,7 +50,10 @@ croc version `+version+` r.Run() } else { c := NewConnection(flags) - c.Run() + err := c.Run() + if err != nil { + fmt.Printf("Error! Please submit the following error to https://github.com/schollz/croc/issues:\n\n'%s'\n\n", err.Error()) + } } } diff --git a/relay.go b/relay.go index f558fa6..ef35c20 100644 --- a/relay.go +++ b/relay.go @@ -114,9 +114,11 @@ func (r *Relay) clientCommuncation(id int, connection net.Conn) { r.connections.sender[key] = connection r.connections.Unlock() // wait for receiver + receiversAddress := "" for { r.connections.RLock() if _, ok := r.connections.reciever[key]; ok { + receiversAddress = r.connections.reciever[key].RemoteAddr().String() logger.Debug("got reciever") r.connections.RUnlock() break @@ -125,7 +127,7 @@ func (r *Relay) clientCommuncation(id int, connection net.Conn) { time.Sleep(100 * time.Millisecond) } logger.Debug("telling sender ok") - sendMessage("ok", connection) + sendMessage(receiversAddress, connection) logger.Debug("preparing pipe") r.connections.Lock() con1 := r.connections.sender[key] @@ -142,19 +144,23 @@ func (r *Relay) clientCommuncation(id int, connection net.Conn) { logger.Debug("deleted sender and receiver") } else { // wait for sender's metadata + sendersAddress := "" for { r.connections.RLock() if _, ok := r.connections.metadata[key]; ok { - logger.Debug("got sender meta data") - r.connections.RUnlock() - break + if _, ok2 := r.connections.sender[key]; ok2 { + sendersAddress = r.connections.sender[key].RemoteAddr().String() + logger.Debug("got sender meta data") + r.connections.RUnlock() + break + } } r.connections.RUnlock() time.Sleep(100 * time.Millisecond) } // send meta data r.connections.RLock() - sendMessage(r.connections.metadata[key], connection) + sendMessage(r.connections.metadata[key]+"-"+sendersAddress, connection) r.connections.RUnlock() // check for receiver's consent consent := receiveMessage(connection) diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..f40c4a1 --- /dev/null +++ b/utils.go @@ -0,0 +1,101 @@ +package main + +import ( + "crypto/md5" + "fmt" + "io" + "os" +) + +// CopyFile copies a file from src to dst. If src and dst files exist, and are +// the same, then return success. Otherise, attempt to create a hard link +// between the two files. If that fail, copy the file contents from src to dst. +func CopyFile(src, dst string) (err error) { + sfi, err := os.Stat(src) + if err != nil { + return + } + if !sfi.Mode().IsRegular() { + // cannot copy non-regular files (e.g., directories, + // symlinks, devices, etc.) + return fmt.Errorf("CopyFile: non-regular source file %s (%q)", sfi.Name(), sfi.Mode().String()) + } + dfi, err := os.Stat(dst) + if err != nil { + if !os.IsNotExist(err) { + return + } + } else { + if !(dfi.Mode().IsRegular()) { + return fmt.Errorf("CopyFile: non-regular destination file %s (%q)", dfi.Name(), dfi.Mode().String()) + } + if os.SameFile(sfi, dfi) { + return + } + } + if err = os.Link(src, dst); err == nil { + return + } + err = copyFileContents(src, dst) + return +} + +// copyFileContents copies the contents of the file named src to the file named +// by dst. The file will be created if it does not already exist. If the +// destination file exists, all it's contents will be replaced by the contents +// of the source file. +func copyFileContents(src, dst string) (err error) { + in, err := os.Open(src) + if err != nil { + return + } + defer in.Close() + out, err := os.Create(dst) + if err != nil { + return + } + defer func() { + cerr := out.Close() + if err == nil { + err = cerr + } + }() + if _, err = io.Copy(out, in); err != nil { + return + } + err = out.Sync() + return +} + +// HashFile does a md5 hash on the file +// from https://golang.org/pkg/crypto/md5/#example_New_file +func HashFile(filename string) (hash string, err error) { + f, err := os.Open(filename) + if err != nil { + return + } + defer f.Close() + + h := md5.New() + if _, err = io.Copy(h, f); err != nil { + return + } + hash = fmt.Sprintf("%x", h.Sum(nil)) + return +} + +// FileSize returns the size of a file +func FileSize(filename string) (int, error) { + f, err := os.Open(filename) + if err != nil { + return -1, err + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return -1, err + } + size := int(fi.Size()) + return size, nil +}