package model import ( "database/sql" "database/sql/driver" "encoding/base64" "fmt" "net" "github.com/google/uuid" "github.com/authelia/authelia/v4/internal/utils" ) // NullUUID is a nullable uuid.UUID. type NullUUID struct { uuid.UUID Valid bool } // Value is the NullUUID implementation of the databases/sql driver.Valuer. func (u NullUUID) Value() (value driver.Value, err error) { if !u.Valid { return nil, nil } return u.UUID.Value() } // Scan is the NullUUID implementation of the sql.Scanner. func (u *NullUUID) Scan(src interface{}) (err error) { if src == nil { u.UUID, u.Valid = uuid.UUID{}, false return nil } return u.UUID.Scan(src) } // NewIP easily constructs a new IP. func NewIP(value net.IP) (ip IP) { return IP{IP: value} } // NewNullIP easily constructs a new NullIP. func NewNullIP(value net.IP) (ip NullIP) { return NullIP{IP: value} } // NewNullIPFromString easily constructs a new NullIP from a string. func NewNullIPFromString(value string) (ip NullIP) { if value == "" { return ip } return NullIP{IP: net.ParseIP(value)} } // NewBase64 returns a new Base64. func NewBase64(data []byte) Base64 { return Base64{data: data} } // IP is a type specific for storage of a net.IP in the database which can't be NULL. type IP struct { IP net.IP } // Value is the IP implementation of the databases/sql driver.Valuer. func (ip IP) Value() (value driver.Value, err error) { if ip.IP == nil { return nil, fmt.Errorf(errFmtValueNil, ip) } return ip.IP.String(), nil } // Scan is the IP implementation of the sql.Scanner. func (ip *IP) Scan(src interface{}) (err error) { if src == nil { return fmt.Errorf(errFmtScanNil, ip) } var value string switch v := src.(type) { case string: value = v case []byte: value = string(v) default: return fmt.Errorf(errFmtScanInvalidType, ip, src, src) } ip.IP = net.ParseIP(value) return nil } // NullIP is a type specific for storage of a net.IP in the database which can also be NULL. type NullIP struct { IP net.IP } // Value is the NullIP implementation of the databases/sql driver.Valuer. func (ip NullIP) Value() (value driver.Value, err error) { if ip.IP == nil { return nil, nil } return ip.IP.String(), nil } // Scan is the NullIP implementation of the sql.Scanner. func (ip *NullIP) Scan(src interface{}) (err error) { if src == nil { ip.IP = nil return nil } var value string switch v := src.(type) { case string: value = v case []byte: value = string(v) default: return fmt.Errorf(errFmtScanInvalidType, ip, src, src) } ip.IP = net.ParseIP(value) return nil } // Base64 saves bytes to the database as a base64 encoded string. type Base64 struct { data []byte } // String returns the Base64 string encoded as base64. func (b Base64) String() string { return base64.StdEncoding.EncodeToString(b.data) } // Bytes returns the Base64 string encoded as bytes. func (b Base64) Bytes() []byte { return b.data } // Value is the Base64 implementation of the databases/sql driver.Valuer. func (b Base64) Value() (value driver.Value, err error) { return b.String(), nil } // Scan is the Base64 implementation of the sql.Scanner. func (b *Base64) Scan(src interface{}) (err error) { if src == nil { return fmt.Errorf(errFmtScanNil, b) } switch v := src.(type) { case string: if b.data, err = base64.StdEncoding.DecodeString(v); err != nil { return fmt.Errorf(errFmtScanInvalidTypeErr, b, src, src, err) } case []byte: if b.data, err = base64.StdEncoding.DecodeString(string(v)); err != nil { b.data = v } default: return fmt.Errorf(errFmtScanInvalidType, b, src, src) } return nil } // StartupCheck represents a provider that has a startup check. type StartupCheck interface { StartupCheck() (err error) } // StringSlicePipeDelimited is a string slice that is stored in the database delimited by pipes. type StringSlicePipeDelimited []string // Scan is the StringSlicePipeDelimited implementation of the sql.Scanner. func (s *StringSlicePipeDelimited) Scan(value interface{}) (err error) { var nullStr sql.NullString if err = nullStr.Scan(value); err != nil { return err } if nullStr.Valid { *s = utils.StringSplitDelimitedEscaped(nullStr.String, '|') } return nil } // Value is the StringSlicePipeDelimited implementation of the databases/sql driver.Valuer. func (s StringSlicePipeDelimited) Value() (driver.Value, error) { return utils.StringJoinDelimitedEscaped(s, '|'), nil }