fix: make sure that only pake messages are unencrypted

This commit is contained in:
Zack Scholl 2021-04-16 17:15:51 -07:00
parent babfd5f35f
commit c02b4f1256
1 changed files with 20 additions and 11 deletions

View File

@ -104,11 +104,11 @@ type Client struct {
longestFilename int
firstSend bool
mutex *sync.Mutex
fread *os.File
numfinished int
quit chan bool
finishedNum int
mutex *sync.Mutex
fread *os.File
numfinished int
quit chan bool
finishedNum int
numberOfTransferedFiles int
}
@ -678,7 +678,7 @@ func (c *Client) Receive() (err error) {
err = c.transfer(TransferOptions{})
if err == nil {
if c.numberOfTransferedFiles == 0 {
fmt.Fprintf(os.Stderr,"\rNo files need transfering.")
fmt.Fprintf(os.Stderr, "\rNo files need transfering.")
}
}
return
@ -931,6 +931,15 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) {
return
}
// only "pake" messages should be unencrypted
// if a non-"pake" message is received unencrypted something
// is weird
if m.Type != "pake" && c.Key == nil {
err = fmt.Errorf("unencrypted communication rejected")
done = true
return
}
switch m.Type {
case "finished":
err = message.Send(c.conn[0], c.Key, message.Message{
@ -1209,7 +1218,7 @@ func (c *Client) updateIfRecipientHasFileInfo() (err error) {
err = c.createEmptyFileAndFinish(fileInfo, i)
if err != nil {
return
} else{
} else {
c.numberOfTransferedFiles++
}
continue
@ -1217,12 +1226,12 @@ func (c *Client) updateIfRecipientHasFileInfo() (err error) {
log.Debugf("%s %+x %+x %+v", fileInfo.Name, fileHash, fileInfo.Hash, errHash)
if !bytes.Equal(fileHash, fileInfo.Hash) {
log.Debugf("hashes are not equal %x != %x", fileHash, fileInfo.Hash)
if errHash== nil && !c.Options.Overwrite {
ans := utils.GetInput(fmt.Sprintf("\rOverwrite '%s'? (y/n) ",path.Join(fileInfo.FolderRemote, fileInfo.Name)))
if errHash == nil && !c.Options.Overwrite {
ans := utils.GetInput(fmt.Sprintf("\rOverwrite '%s'? (y/n) ", path.Join(fileInfo.FolderRemote, fileInfo.Name)))
if strings.TrimSpace(strings.ToLower(ans)) != "y" {
fmt.Fprintf(os.Stderr,"skipping '%s'",path.Join(fileInfo.FolderRemote, fileInfo.Name))
fmt.Fprintf(os.Stderr, "skipping '%s'", path.Join(fileInfo.FolderRemote, fileInfo.Name))
continue
}
}
}
} else {
log.Debugf("hashes are equal %x == %x", fileHash, fileInfo.Hash)