diff --git a/src/cli/cli.go b/src/cli/cli.go index f4f544b..49ca294 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -135,28 +135,6 @@ func Run() (err error) { return app.Run(os.Args) } -func getConfigDir() (homedir string, err error) { - homedir, err = os.UserHomeDir() - if err != nil { - log.Error(err) - return - } - - if envHomedir, isSet := os.LookupEnv("CROC_CONFIG_DIR"); isSet { - homedir = envHomedir - } else if xdgConfigHome, isSet := os.LookupEnv("XDG_CONFIG_HOME"); isSet { - homedir = path.Join(xdgConfigHome, "croc") - } else { - homedir = path.Join(homedir, ".config", "croc") - } - - if _, err = os.Stat(homedir); os.IsNotExist(err) { - log.Debugf("creating home directory %s", homedir) - err = os.MkdirAll(homedir, 0700) - } - return -} - func setDebugLevel(c *cli.Context) { if c.Bool("debug") { log.SetLevel("debug") @@ -167,7 +145,7 @@ func setDebugLevel(c *cli.Context) { } func getConfigFile() string { - configFile, err := getConfigDir() + configFile, err := utils.GetConfigDir() if err != nil { log.Error(err) return "" @@ -446,7 +424,7 @@ func receive(c *cli.Context) (err error) { // load options here setDebugLevel(c) - configFile, err := getConfigDir() + configFile, err := utils.GetConfigDir() if err != nil { log.Error(err) return diff --git a/src/models/constants.go b/src/models/constants.go index ca2543c..a6a0066 100644 --- a/src/models/constants.go +++ b/src/models/constants.go @@ -5,6 +5,9 @@ import ( "fmt" "net" "os" + "path" + + "github.com/schollz/croc/v9/src/utils" ) // TCP_BUFFER_SIZE is the maximum packet size @@ -41,12 +44,39 @@ var publicDns = []string{ "[2620:119:53::53]", // Cisco OpenDNS } +func getConfigFile() (fname string, err error) { + configFile, err := utils.GetConfigDir() + if err != nil { + return + } + fname = path.Join(configFile, "internal-dns") + return +} + func init() { + doRemember := false for _, flag := range os.Args { if flag == "--internal-dns" { INTERNAL_DNS = true break } + if flag == "--remember" { + doRemember = true + } + } + if doRemember { + // save in config file + fname, err := getConfigFile() + if err == nil { + f, _ := os.Create(fname) + f.Close() + } + } + if !INTERNAL_DNS { + fname, err := getConfigFile() + if err == nil { + INTERNAL_DNS = utils.Exists(fname) + } } var err error DEFAULT_RELAY, err = lookup(DEFAULT_RELAY) diff --git a/src/utils/utils.go b/src/utils/utils.go index 42a1520..dad783f 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "os" + "path" "strings" "time" @@ -23,6 +24,27 @@ import ( "github.com/schollz/mnemonicode" ) +// Get or create home directory +func GetConfigDir() (homedir string, err error) { + homedir, err = os.UserHomeDir() + if err != nil { + return + } + + if envHomedir, isSet := os.LookupEnv("CROC_CONFIG_DIR"); isSet { + homedir = envHomedir + } else if xdgConfigHome, isSet := os.LookupEnv("XDG_CONFIG_HOME"); isSet { + homedir = path.Join(xdgConfigHome, "croc") + } else { + homedir = path.Join(homedir, ".config", "croc") + } + + if _, err = os.Stat(homedir); os.IsNotExist(err) { + err = os.MkdirAll(homedir, 0700) + } + return +} + // Exists reports whether the named file or directory exists. func Exists(name string) bool { if _, err := os.Stat(name); err != nil {