add configuration file

This commit is contained in:
Zack Scholl 2018-10-21 08:21:58 -07:00
parent 1ae117166a
commit 81bc06eabb
3 changed files with 154 additions and 10 deletions

3
go.mod
View File

@ -1,7 +1,7 @@
module github.com/schollz/croc
require (
github.com/BurntSushi/toml v0.3.1 // indirect
github.com/BurntSushi/toml v0.3.1
github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575
github.com/dustin/go-humanize v1.0.0
github.com/fatih/color v1.7.0 // indirect
@ -9,6 +9,7 @@ require (
github.com/gorilla/websocket v1.4.0
github.com/mattn/go-colorable v0.0.9 // indirect
github.com/mattn/go-isatty v0.0.4 // indirect
github.com/mitchellh/go-homedir v1.0.0
github.com/pkg/errors v0.8.0
github.com/schollz/mnemonicode v1.0.1
github.com/schollz/pake v1.1.0

View File

@ -60,6 +60,16 @@ func Run() {
return relay(c)
},
},
{
Name: "config",
Usage: "generates a config file",
Description: "the croc config can be used to set static parameters",
Flags: []cli.Flag{},
HelpName: "croc config",
Action: func(c *cli.Context) error {
return saveDefaultConfig(c)
},
},
}
app.Flags = []cli.Flag{
cli.StringFlag{Name: "addr", Value: "croc4.schollz.com", Usage: "address of the public relay"},
@ -115,6 +125,10 @@ func Run() {
}
}
func saveDefaultConfig(c *cli.Context) error {
return croc.SaveDefaultConfig()
}
func send(c *cli.Context) error {
stat, _ := os.Stdin.Stat()
var fname string
@ -147,11 +161,12 @@ func send(c *cli.Context) error {
cr.UseCompression = !c.Bool("no-compress")
cr.UseEncryption = !c.Bool("no-encrypt")
if c.String("code") != "" {
codePhrase = c.String("code")
cr.Codephrase = c.String("code")
}
if len(codePhrase) == 0 {
cr.LoadConfig()
if len(cr.Codephrase) == 0 {
// generate code phrase
codePhrase = utils.GetRandomName()
cr.Codephrase = utils.GetRandomName()
}
// print the text
@ -176,10 +191,10 @@ func send(c *cli.Context) error {
humanize.Bytes(uint64(fsize)),
fileOrFolder,
filename,
codePhrase,
codePhrase,
cr.Codephrase,
cr.Codephrase,
)
return cr.Send(fname, codePhrase)
return cr.Send(fname, cr.Codephrase)
}
func receive(c *cli.Context) error {

View File

@ -2,9 +2,16 @@ package croc
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"time"
"github.com/BurntSushi/toml"
homedir "github.com/mitchellh/go-homedir"
"github.com/schollz/croc/src/utils"
)
type Config struct {
@ -33,8 +40,7 @@ type Config struct {
Codephrase string
}
// DefaultConfig returns the default config
func DefaultConfig() string {
func defaultConfig() Config {
c := Config{}
cr := Init(false)
c.RelayWebsocketPort = cr.RelayWebsocketPort
@ -53,7 +59,129 @@ func DefaultConfig() string {
c.ForceTCP = false
c.ForceWebsockets = false
c.Codephrase = ""
return c
}
func SaveDefaultConfig() error {
homedir, err := homedir.Dir()
if err != nil {
return err
}
os.MkdirAll(path.Join(homedir, ".config", "croc"), 0644)
c := defaultConfig()
buf := new(bytes.Buffer)
toml.NewEncoder(buf).Encode(c)
return buf.String()
confTOML := buf.String()
err = ioutil.WriteFile(path.Join(homedir, ".config", "croc", "config.toml"), []byte(confTOML), 0644)
if err == nil {
fmt.Printf("Default config file written at '%s'", filepath.Clean(path.Join(homedir, ".config", "croc", "config.toml")))
}
return err
}
// LoadConfig will override parameters
func (cr *Croc) LoadConfig() (err error) {
homedir, err := homedir.Dir()
if err != nil {
return err
}
pathToConfig := path.Join(homedir, ".config", "croc", "config.toml")
if !utils.Exists(pathToConfig) {
// ignore if doesn't exist
return nil
}
var c Config
_, err = toml.DecodeFile(pathToConfig, &c)
if err != nil {
return
}
cDefault := defaultConfig()
// only load if things are different than defaults
// just in case the CLI parameters are used
if c.RelayWebsocketPort != cDefault.RelayWebsocketPort && cr.RelayWebsocketPort == cDefault.RelayWebsocketPort {
cr.RelayWebsocketPort = c.RelayWebsocketPort
fmt.Printf("loaded RelayWebsocketPort from config\n")
}
if !slicesEqual(c.RelayTCPPorts, cDefault.RelayTCPPorts) && slicesEqual(cr.RelayTCPPorts, cDefault.RelayTCPPorts) {
cr.RelayTCPPorts = c.RelayTCPPorts
fmt.Printf("loaded RelayTCPPorts from config\n")
}
if c.CurveType != cDefault.CurveType && cr.CurveType == cDefault.CurveType {
cr.CurveType = c.CurveType
fmt.Printf("loaded CurveType from config\n")
}
if c.PublicServerIP != cDefault.PublicServerIP && cr.Address == cDefault.PublicServerIP {
cr.Address = c.PublicServerIP
fmt.Printf("loaded Address from config\n")
}
if !slicesEqual(c.AddressTCPPorts, cDefault.AddressTCPPorts) {
cr.AddressTCPPorts = c.AddressTCPPorts
fmt.Printf("loaded AddressTCPPorts from config\n")
}
if c.AddressWebsocketPort != cDefault.AddressWebsocketPort && cr.AddressWebsocketPort == cDefault.AddressWebsocketPort {
cr.AddressWebsocketPort = c.AddressWebsocketPort
fmt.Printf("loaded AddressWebsocketPort from config\n")
}
if c.Timeout != cDefault.Timeout && cr.Timeout == cDefault.Timeout {
cr.Timeout = c.Timeout
fmt.Printf("loaded Timeout from config\n")
}
if c.LocalOnly != cDefault.LocalOnly && cr.LocalOnly == cDefault.LocalOnly {
cr.LocalOnly = c.LocalOnly
fmt.Printf("loaded LocalOnly from config\n")
}
if c.NoLocal != cDefault.NoLocal && cr.NoLocal == cDefault.NoLocal {
cr.NoLocal = c.NoLocal
fmt.Printf("loaded NoLocal from config\n")
}
if c.UseEncryption != cDefault.UseEncryption && cr.UseEncryption == cDefault.UseEncryption {
cr.UseEncryption = c.UseEncryption
fmt.Printf("loaded UseEncryption from config\n")
}
if c.UseCompression != cDefault.UseCompression && cr.UseCompression == cDefault.UseCompression {
cr.UseCompression = c.UseCompression
fmt.Printf("loaded UseCompression from config\n")
}
if c.AllowLocalDiscovery != cDefault.AllowLocalDiscovery && cr.AllowLocalDiscovery == cDefault.AllowLocalDiscovery {
cr.AllowLocalDiscovery = c.AllowLocalDiscovery
fmt.Printf("loaded AllowLocalDiscovery from config\n")
}
if c.NoRecipientPrompt != cDefault.NoRecipientPrompt && cr.NoRecipientPrompt == cDefault.NoRecipientPrompt {
cr.NoRecipientPrompt = c.NoRecipientPrompt
fmt.Printf("loaded NoRecipientPrompt from config\n")
}
if c.ForceWebsockets {
cr.ForceSend = 1
}
if c.ForceTCP {
cr.ForceSend = 2
}
if c.Codephrase != cDefault.Codephrase && cr.Codephrase == cDefault.Codephrase {
cr.Codephrase = c.Codephrase
fmt.Printf("loaded Codephrase from config\n")
}
return
}
// slicesEqual checcks if two slices are equal
// from https://stackoverflow.com/a/15312097
func slicesEqual(a, b []string) bool {
// If one is nil, the other must also be nil.
if (a == nil) != (b == nil) {
return false
}
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}