authelia/internal/commands/util.go

362 lines
8.9 KiB
Go
Raw Normal View History

package commands
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"syscall"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"golang.org/x/term"
"github.com/authelia/authelia/v4/internal/configuration"
"github.com/authelia/authelia/v4/internal/utils"
)
func recoverErr(i any) error {
switch v := i.(type) {
case nil:
return nil
case string:
return fmt.Errorf("recovered panic: %s", v)
case error:
return fmt.Errorf("recovered panic: %w", v)
default:
return fmt.Errorf("recovered panic with unknown type: %v", v)
}
}
func flagsGetUserIdentifiersGenerateOptions(flags *pflag.FlagSet) (users, services, sectors []string, err error) {
if users, err = flags.GetStringSlice(cmdFlagNameUsers); err != nil {
return nil, nil, nil, err
}
if services, err = flags.GetStringSlice(cmdFlagNameServices); err != nil {
return nil, nil, nil, err
}
if sectors, err = flags.GetStringSlice(cmdFlagNameSectors); err != nil {
return nil, nil, nil, err
}
return users, services, sectors, nil
}
func flagsGetTOTPExportOptions(flags *pflag.FlagSet) (format, dir string, err error) {
if format, err = flags.GetString(cmdFlagNameFormat); err != nil {
return "", "", err
}
if dir, err = flags.GetString("dir"); err != nil {
return "", "", err
}
switch format {
case storageTOTPExportFormatCSV, storageTOTPExportFormatURI:
break
case storageTOTPExportFormatPNG:
if dir == "" {
dir = utils.RandomString(8, utils.CharSetAlphaNumeric, false)
}
if _, err = os.Stat(dir); !os.IsNotExist(err) {
return "", "", errors.New("output directory must not exist")
}
if err = os.MkdirAll(dir, 0700); err != nil {
return "", "", err
}
default:
return "", "", errors.New("format must be csv, uri, or png")
}
return format, dir, nil
}
//nolint:gocyclo
func flagsGetRandomCharacters(flags *pflag.FlagSet, flagNameLength, flagNameCharSet, flagNameCharacters string) (r string, err error) {
var (
n int
charset string
)
if n, err = flags.GetInt(flagNameLength); err != nil {
return "", err
}
if n < 1 {
return "", fmt.Errorf("flag --%s with value '%d' is invalid: must be at least 1", flagNameLength, n)
}
useCharSet, useCharacters := flags.Changed(flagNameCharSet), flags.Changed(flagNameCharacters)
if useCharSet && useCharacters {
return "", fmt.Errorf("flag --%s and flag --%s are mutually exclusive, only one may be used", flagNameCharSet, flagNameCharacters)
}
switch {
case useCharSet, !useCharSet && !useCharacters:
var c string
if c, err = flags.GetString(flagNameCharSet); err != nil {
return "", err
}
switch c {
case "ascii":
charset = utils.CharSetASCII
case "alphanumeric":
charset = utils.CharSetAlphaNumeric
case "alphanumeric-lower":
charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric
case "alphanumeric-upper":
charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric
case "alphabetic":
charset = utils.CharSetAlphabetic
case "alphabetic-lower":
charset = utils.CharSetAlphabeticLower
case "alphabetic-upper":
charset = utils.CharSetAlphabeticUpper
case "numeric-hex":
charset = utils.CharSetNumericHex
case "numeric":
charset = utils.CharSetNumeric
case "rfc3986":
charset = utils.CharSetRFC3986Unreserved
case "rfc3986-lower":
charset = utils.CharSetAlphabeticLower + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved
case "rfc3986-upper":
charset = utils.CharSetAlphabeticUpper + utils.CharSetNumeric + utils.CharSetSymbolicRFC3986Unreserved
default:
return "", fmt.Errorf("flag '--%s' with value '%s' is invalid, must be one of 'ascii', 'alphanumeric', 'alphabetic', 'numeric', 'numeric-hex', or 'rfc3986'", flagNameCharSet, c)
}
case useCharacters:
if charset, err = flags.GetString(flagNameCharacters); err != nil {
return "", err
}
}
return utils.RandomString(n, charset, true), nil
}
func termReadConfirmation(flags *pflag.FlagSet, name, prompt, confirmation string) (confirmed bool, err error) {
if confirmed, _ = flags.GetBool(name); confirmed {
return confirmed, nil
}
terminal, fd, state, err := getTerminal(prompt)
if err != nil {
return false, err
}
defer func(fd int, oldState *term.State) {
_ = term.Restore(fd, oldState)
}(fd, state)
var input string
if input, err = terminal.ReadLine(); err != nil {
return false, fmt.Errorf("failed to read from the terminal: %w", err)
}
if input != confirmation {
return false, nil
}
return true, nil
}
func getTerminal(prompt string) (terminal *term.Terminal, fd int, state *term.State, err error) {
fd = int(syscall.Stdin) //nolint:unconvert,nolintlint
if !term.IsTerminal(fd) {
return nil, -1, nil, ErrStdinIsNotTerminal
}
var width, height int
if width, height, err = term.GetSize(int(syscall.Stdout)); err != nil { //nolint:unconvert,nolintlint
return nil, -1, nil, fmt.Errorf("failed to get terminal size: %w", err)
}
state, err = term.MakeRaw(fd)
if err != nil {
return nil, -1, nil, fmt.Errorf("failed to get terminal state: %w", err)
}
c := struct {
io.Reader
io.Writer
}{
os.Stdin,
os.Stdout,
}
terminal = term.NewTerminal(c, prompt)
if err = terminal.SetSize(width, height); err != nil {
return nil, -1, nil, fmt.Errorf("failed to set terminal size: %w", err)
}
return terminal, fd, state, nil
}
func termReadPasswordWithPrompt(prompt, flag string) (password string, err error) {
terminal, fd, state, err := getTerminal("")
if err != nil {
if errors.Is(err, ErrStdinIsNotTerminal) {
switch len(flag) {
case 0:
return "", err
case 1:
return "", fmt.Errorf("you must either use an interactive terminal or use the -%s flag", flag)
default:
return "", fmt.Errorf("you must either use an interactive terminal or use the --%s flag", flag)
}
}
return "", err
}
defer func(fd int, oldState *term.State) {
_ = term.Restore(fd, oldState)
}(fd, state)
if password, err = terminal.ReadPassword(prompt); err != nil {
return "", fmt.Errorf("failed to read the input from the terminal: %w", err)
}
return password, nil
}
type XEnvCLIResult int
const (
XEnvCLIResultCLIExplicit XEnvCLIResult = iota
XEnvCLIResultCLIImplicit
XEnvCLIResultEnvironment
)
func loadXEnvCLIConfigValues(cmd *cobra.Command) (configs []string, filters []configuration.FileFilter, err error) {
var (
filterNames []string
)
if configs, _, err = loadXEnvCLIStringSliceValue(cmd, "", cmdFlagNameConfig); err != nil {
return nil, nil, err
}
if configs, err = loadXNormalizedPaths(configs); err != nil {
return nil, nil, err
}
if filterNames, _, err = loadXEnvCLIStringSliceValue(cmd, "", cmdFlagNameConfigExpFilters); err != nil {
return nil, nil, err
}
if filters, err = configuration.NewFileFilters(filterNames); err != nil {
return nil, nil, fmt.Errorf("error occurred loading configuration: flag '--%s' is invalid: %w", cmdFlagNameConfigExpFilters, err)
}
return
}
func loadXNormalizedPaths(paths []string) ([]string, error) {
var (
configs, files, dirs []string
err error
)
var stat os.FileInfo
for _, path := range paths {
if path, err = filepath.Abs(path); err != nil {
return nil, fmt.Errorf("failed to determine absolute path for '%s': %w", path, err)
}
switch stat, err = os.Stat(path); {
case err == nil && stat.IsDir():
configs = append(configs, path)
dirs = append(dirs, path)
case err == nil:
configs = append(configs, path)
files = append(files, path)
default:
if os.IsNotExist(err) {
configs = append(configs, path)
files = append(files, path)
continue
}
return nil, fmt.Errorf("error occurred stating file at path '%s': %w", path, err)
}
}
for i, file := range files {
if file, err = filepath.Abs(file); err != nil {
return nil, fmt.Errorf("failed to determine absolute path for '%s': %w", files[i], err)
}
if len(dirs) != 0 {
filedir := filepath.Dir(file)
for _, dir := range dirs {
if filedir == dir {
return nil, fmt.Errorf("failed to load config directory '%s': the config file '%s' is in that directory which is not supported", dir, file)
}
}
}
}
return configs, nil
}
func loadXEnvCLIStringSliceValue(cmd *cobra.Command, envKey, flagName string) (value []string, result XEnvCLIResult, err error) {
if cmd.Flags().Changed(flagName) {
value, err = cmd.Flags().GetStringSlice(flagName)
return value, XEnvCLIResultCLIExplicit, err
}
var (
env string
ok bool
)
if envKey != "" {
env, ok = os.LookupEnv(envKey)
}
switch {
case ok && env != "":
return strings.Split(env, ","), XEnvCLIResultEnvironment, nil
default:
value, err = cmd.Flags().GetStringSlice(flagName)
return value, XEnvCLIResultCLIImplicit, err
}
}
func newHelpTopic(topic, short, body string) (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: topic,
Short: short,
}
cmd.SetHelpFunc(func(cmd *cobra.Command, args []string) {
_ = cmd.Parent().Help()
fmt.Println()
fmt.Printf("Help Topic: %s\n\n", topic)
fmt.Print(body)
fmt.Print("\n\n")
})
return cmd
}