go-ddl-parser/structt/struct.go

617 lines
18 KiB
Go
Raw Permalink Normal View History

2024-04-07 18:55:50 +00:00
package structt
import (
"database/sql"
2024-04-07 18:55:50 +00:00
"errors"
"fmt"
"os"
2024-04-09 17:10:16 +00:00
"os/exec"
2024-04-07 18:55:50 +00:00
"regexp"
"sort"
"strings"
2024-04-09 17:10:16 +00:00
"unicode"
"unicode/utf8"
2024-04-07 18:55:50 +00:00
2024-04-07 18:59:05 +00:00
"git.rpjosh.de/RPJosh/go-ddl-parser"
2024-04-07 18:55:50 +00:00
"git.rpjosh.de/RPJosh/go-logger"
)
// The name of the root package
const PackageName = "git.rpjosh.de/RPJosh/go-ddl-parser"
2024-04-07 18:55:50 +00:00
// StructConfig contains options used to customize the behaviour of the conversion
// from a database column to a struct
type StructConfig struct {
// Absolute or relative base path to write all files to:
// '/internal/modules/'
GenericOutputPath string `yaml:"genericOutputPath"`
// Name of the Go package used for new files
PackgeName string `yaml:"packageName"`
// Suffix to add to the struct name for every table
Suffix string `yaml:"suffix"`
// Configuration options for a specific table.
// The key of this map is either the table name (for any schema)
// or a combination of "schema.tableName"
Tableconfig map[string]*TableConfig `yaml:"tableConfig"`
// Configuration of how to handle nullable columns
NullConfig NullConfig
2024-04-07 18:55:50 +00:00
}
// TableConfig contains options for a specific table
type TableConfig struct {
// Absolute or relative base path to a ".go" file to write this struct to:
// '/internal/modules/file.go'
Path string `yaml:"path"`
// Name of the Go package used for this file
PackageName string `yaml:"packageName"`
// Instead of only including the ID of a FK as a field, a full reference to the
// struct is used for the speicified column names.
// This is used for "1:1" relationships.
// Specifiy a single element '*' to include all structs
IncludeReferencedStructs []string `yaml:"includeReferencedStructs"`
// Include additional fields for structs that references this table as an array.
// This is used for "1:n" relationships.
// To construct a "n:m" relationship you have to add a extra config for the zwischentabelle
// that only specifies "IncludeRefrencedStructs" for the other column.
// Note: you have to provide all referenced tables in "CreateStructs"
IncludePointedStructs bool `yaml:"includePointedStructs"`
// Sufix to add to the struct name. Add <empty> for no string and override of the default behaviour
Suffix string `yaml:"suffix"`
}
// NullConfig configures how to transform nullable columns into a go struct
type NullConfig struct {
// Disable the use of nullable datatypes
Disable bool
// Name of the package to import the types from.
// Defaulting to [database/sql]
Package string
// Prefix to use in front of a type name like "String" or "Int64".
// Defaulting to "sql.Null"
Prefix sql.NullString
// Custom function to get the import name and the type name from
Custom func(typ ddl.DataType, defaultName string) (typeName, imp string)
}
2024-04-07 18:55:50 +00:00
type constructor struct {
config *StructConfig
tables []*ddl.Table
}
// CreateStructs creates all ".go" files with the structs based on the provided configuration
// and Tables.
// For any existing go files we have to patch, it's expected that it was created by this function
// or the file content was linted with golint
func CreateStructs(conf *StructConfig, tables []*ddl.Table) error {
c := &constructor{
config: conf,
tables: tables,
}
// Loop through all tables
for _, t := range c.tables {
// Get table configuration to use
tblConfig := c.getTableConfigForTable(t)
// Read (existing) file
fileExists := false
if _, err := os.Stat(tblConfig.Path); err == nil {
fileExists = true
} else if errors.Is(err, os.ErrNotExist) {
fileExists = false
} else {
return fmt.Errorf("failed to read properties of file %q: %s", tblConfig.Path, err)
}
// Read existing content
content := ""
if fileExists {
cnt, err := os.ReadFile(tblConfig.Path)
if err != nil {
return fmt.Errorf("failed to read file %q: %s", tblConfig.Path, err)
}
content = string(cnt)
}
// Get new file content and writeto file
newContent := c.getGoFile(content, t, tblConfig)
f, err := os.Create(tblConfig.Path)
if err != nil {
return fmt.Errorf("failed to open file %q: %s", tblConfig.Path, err)
}
_, err = f.WriteString(newContent)
if err != nil {
f.Close()
return fmt.Errorf("failed to write file %q: %s", tblConfig.Path, err)
}
f.Close()
2024-04-09 17:10:16 +00:00
// Lint go file
cmd := exec.Command("go", "fmt", tblConfig.Path)
if err := cmd.Run(); err != nil {
logger.Warning("Failed to run go fmt: %s", err)
}
cmd.Wait()
2024-04-07 18:55:50 +00:00
}
return nil
}
// GetGoFileName returns the name of a go file for a table
func GetGoFileName(tableName string) string {
// Go files are all lowercased
return strings.ToLower(tableName)
}
// GetFieldName returns the name of a struct or field from a database
// name
func GetFieldName(fieldName string) string {
// Lowercase by default
fieldName = strings.ToLower(fieldName)
// Underscores are normally seperator.
// We capitalize the first letter of every new "word"
rtc := ""
for _, w := range strings.Split(fieldName, "_") {
if w != "" {
rtc += strings.ToUpper(w[0:1]) + w[1:]
}
}
return rtc
}
2024-04-09 17:10:16 +00:00
// GetJsonName returns the json key value for the provided fildName of
// the database.
// The json keys are CamelCased
func GetJsonName(fieldName string) string {
structField := GetFieldName(fieldName)
// Lowercase the first character for json
r, size := utf8.DecodeRuneInString(structField)
if r == utf8.RuneError && size <= 1 {
return structField
}
lc := unicode.ToLower(r)
if r == lc {
return structField
}
return string(lc) + structField[size:]
}
2024-04-07 18:55:50 +00:00
// findTableConfig returns a specific table configuration for the table
// or an empty configuration struct if no one was provided
func (c *constructor) findTableConfig(tbl *ddl.Table) *TableConfig {
// Schema + table has priority for table
if val, ok := c.config.Tableconfig[tbl.Schema+"."+tbl.Name]; ok {
return val
}
// By table name
if val, ok := c.config.Tableconfig[tbl.Name]; ok {
return val
}
return &TableConfig{}
}
// getTableConfigForTable returns a specific table configuration with defaults
// from the generic table configuration
func (c *constructor) getTableConfigForTable(tbl *ddl.Table) *TableConfig {
tblConfig := c.findTableConfig(tbl)
if tblConfig.Path == "" {
// No specific path provided -> use from table name
tblConfig.Path = c.config.GenericOutputPath + GetGoFileName(tbl.Name) + ".go"
}
if tblConfig.PackageName == "" {
tblConfig.PackageName = c.config.PackgeName
}
if tblConfig.Suffix == "<empty>" {
tblConfig.Suffix = ""
} else if tblConfig.Suffix == "" {
tblConfig.Suffix = c.config.Suffix
}
return tblConfig
}
// getGoFile returns the content of a go file for the specified table and configuration.
// If a existing go file exist, the struct will be updated with the new content
func (c *constructor) getGoFile(existingContent string, tbl *ddl.Table, tblConfig *TableConfig) string {
rtc := ""
columns := ""
// Add struct type header
tableName := GetFieldName(tbl.Name) + tblConfig.Suffix
rtc += fmt.Sprintf("type %s struct {\n", tableName)
columns += fmt.Sprintf("// %s\nconst (\n", tableName)
// Add columns
imports := make(map[string]bool, 0)
for _, col := range tbl.Columns {
// Add comments
if col.Comment != "" {
for _, comment := range strings.Split(col.Comment, "\n") {
rtc += fmt.Sprintf("\t// %s\n", comment)
}
}
// Initialize tags
tags := GetColumnTag(col)
// Get data type to use
dataType, imp := c.getDataType(col, tblConfig, tags)
if imp != "" {
if _, exists := imports[imp]; !exists {
imports[imp] = true
}
}
fieldName := GetFieldName(col.Name)
2024-04-09 17:10:16 +00:00
jsonName := GetJsonName(col.Name)
rtc += fmt.Sprintf("\t%s %s `json:\"%s\" %s:\"%s\"`\n", fieldName, dataType, jsonName, ColumnTagId, tags.ToTag())
// We also add the full reference to the column inside the string value.
// It's needed to reference it without information of the table (which we can't get
// with constants and no support for package reflection)
identifier := tbl.Name + "." + col.Name
if tbl.Schema != "" {
identifier = tbl.Schema + "." + identifier
}
columns += fmt.Sprintf("\t %s_%s string = \"%s|%s\"\n", tableName, fieldName, fieldName, identifier)
2024-04-07 18:55:50 +00:00
}
// Add foreign key columns
rtcAdd, columnsAdd, importsAdd := c.getOneToMany(tblConfig, tbl)
2024-04-07 18:55:50 +00:00
if rtcAdd != "" {
rtc += rtcAdd
for _, imp := range importsAdd {
if _, exists := imports[imp]; !exists {
imports[imp] = true
}
}
}
columns += columnsAdd
2024-04-07 18:55:50 +00:00
// Add metadata tag
metaData := &MetadataTag{
Schema: tbl.Schema,
Table: tbl.Name,
}
2024-04-09 17:10:16 +00:00
rtc += fmt.Sprintf("\t%s any `json:\"-\" %s:\"%s\"`\n", MetadataFieldName, MetadataTagId, metaData.ToTag())
2024-04-07 18:55:50 +00:00
// Add closing line
rtc += "}\n"
columns += ")\n"
// Add package header if no file exists already
if existingContent == "" {
2024-04-09 17:10:16 +00:00
header := fmt.Sprintf("package %s\n\n", tblConfig.PackageName)
2024-04-07 18:55:50 +00:00
importStr := ""
if len(imports) != 0 {
importStr = "import (\n"
for key := range imports {
importStr += "\t\"" + key + "\"\n"
}
importStr += ")\n"
}
2024-04-09 17:10:16 +00:00
rtc = header + importStr + "\n" + rtc + columns
2024-04-07 18:55:50 +00:00
} else {
2024-04-09 17:10:16 +00:00
rtc = c.patchFile(existingContent, rtc+columns, tbl, tblConfig, imports)
2024-04-07 18:55:50 +00:00
}
return rtc
}
// getDataType returns the data type to use for the column as a string expression
// and the extra imports required for this data type.
// The tags my be updated within this function
func (c *constructor) getDataType(column *ddl.Column, tblConfig *TableConfig, _ *ColumnTag) (name string, imp string) {
2024-04-07 18:55:50 +00:00
// Find 1:1 relationship
if oneToOne := c.findOneToOne(column, tblConfig); oneToOne != "" {
return oneToOne, ""
}
// Try to use sql null strings
if column.CanBeNull && !c.config.NullConfig.Disable {
// Get nullable data types
prefix := "sql.Null"
imp := "database/sql"
if c.config.NullConfig.Package != "" {
imp = c.config.NullConfig.Package
}
if c.config.NullConfig.Prefix.Valid {
prefix = c.config.NullConfig.Prefix.String
}
// The default data type name
typeName := ""
2024-04-07 18:55:50 +00:00
switch column.Type {
case ddl.StringType:
typeName = "String"
2024-04-07 18:55:50 +00:00
case ddl.IntType:
typeName = "Int64"
2024-04-07 18:55:50 +00:00
case ddl.DoubleType:
typeName = "Float64"
2024-04-07 18:55:50 +00:00
case ddl.DateType:
typeName = "Time"
case ddl.GeoType:
return "ddl.Location", PackageName
2024-04-07 18:55:50 +00:00
}
// Null types
if c.config.NullConfig.Custom != nil {
return c.config.NullConfig.Custom(column.Type, typeName)
}
return prefix + typeName, imp
2024-04-07 18:55:50 +00:00
}
switch column.Type {
case ddl.StringType:
return "string", ""
case ddl.IntType:
return "int", ""
case ddl.DoubleType:
2024-04-09 17:10:16 +00:00
return "float64", ""
2024-04-07 18:55:50 +00:00
case ddl.DateType:
return "time.Time", "time"
case ddl.GeoType:
return "ddl.Location", PackageName
2024-04-07 18:55:50 +00:00
}
return "any", ""
}
// findOneToOne tries to find a 1:1 relationship by scanning the foreign keys of a column
// and the specified table configuration.
// It returns an empty string if no relationship was found or it's disable in the config
func (c *constructor) findOneToOne(column *ddl.Column, tblConfig *TableConfig) string {
// Check if we have a foreign key for this column.
// Otherwise we can't and don't reference another struct
if !column.ForeignKey {
return ""
}
// If the first element contains "*", we apply it for each table
includeReference := len(tblConfig.IncludeReferencedStructs) == 1 && tblConfig.IncludeReferencedStructs[0] == "*"
// Try to find by column name
if !includeReference {
for _, c := range tblConfig.IncludeReferencedStructs {
if c == GetFieldName(column.Name) || c == column.Name {
includeReference = true
}
}
}
// Nothing to do here
if !includeReference {
return ""
}
// Find the other table referenced by the foreign key
for _, t := range c.tables {
if t.Schema == column.ForeignKeyColumn.Schema && t.Name == column.ForeignKeyColumn.Name {
// Get the table name
tblConfRef := c.getTableConfigForTable(t)
return "*" + GetFieldName(t.Name) + tblConfRef.Suffix
}
}
logger.Debug("Found no foreign key reference for '%s.%s'", column.ForeignKeyColumn.Schema, column.ForeignKeyColumn.Name)
return ""
}
// getOneToMany tries to find a 1:n relationship by scanning the foreign keys of all
// other tables to this table.
// It returns an empty string if no relationship was found or it's disable in the config.
// Otherwise this function returns any additional fields to add to the struct with it's required imports
func (c *constructor) getOneToMany(tblConfig *TableConfig, tbl *ddl.Table) (rtc string, constValues string, imp []string) {
2024-04-07 18:55:50 +00:00
imports := []string{}
// The user explicity has to enable this feature
if !tblConfig.IncludePointedStructs {
return rtc, constValues, imports
2024-04-07 18:55:50 +00:00
}
// Loop through every table and column and find any foreign key to this table
for _, t := range c.tables {
// Get the table configuration
tblConfRef := c.getTableConfigForTable(t)
// Loop through all columns to find a foreign key
for _, c := range t.Columns {
if c.ForeignKey && c.ForeignKeyColumn.Schema == tbl.Schema && c.ForeignKeyColumn.Name == tbl.Name {
tblName := GetFieldName(t.Name) + tblConfRef.Suffix
tag := &ColumnTag{
PointedKeyReference: t.Schema + "." + t.Name + "." + c.Name,
}
2024-04-12 09:55:36 +00:00
rtc += fmt.Sprintf("\t%s []%s `%s:\"%s\"`\n", GetFieldName(t.Name), tblName, ColumnTagId, tag.ToTag())
// We also add the full reference to the column inside the string value.
fieldNameRoot := GetFieldName(t.Name)
identifier := tbl.Name + "." + fieldNameRoot
if tbl.Schema != "" {
identifier = tbl.Schema + "." + identifier
}
constValues += fmt.Sprintf("\t %s_%s string = \"%s|#%s\"\n", GetFieldName(tbl.Name)+tblConfig.Suffix, fieldNameRoot, fieldNameRoot, identifier)
2024-04-07 18:55:50 +00:00
}
}
}
// No relation found
return rtc, constValues, imports
2024-04-07 18:55:50 +00:00
}
// patchFile patches the content of an existing file with the new struct.
// Any existing struct with that name will be overwritten
func (c *constructor) patchFile(existingContent string, newStruct string, tbl *ddl.Table, tblConfig *TableConfig, imports map[string]bool) (newContent string) {
// Patch imports
existingContent, err := c.patchImports(existingContent, imports)
if err != nil {
logger.Error("Failed to patch imports for table %s.%s: %s", tbl.Schema, tbl.Name, err)
}
// Find existing struct config
tblName := GetFieldName(tbl.Name) + tblConfig.Suffix
reg := regexp.MustCompile(
fmt.Sprintf(
`type %s struct {(.|\n)*?\s*%s.*\n}((//.*)|(\s|\n)*)*const \((.|\n)*?\)\n`,
tblName, MetadataFieldName,
),
)
if reg.MatchString(existingContent) {
// Replace content
2024-04-09 17:10:16 +00:00
return reg.ReplaceAllString(existingContent, newStruct)
2024-04-07 18:55:50 +00:00
} else {
// Append content
return existingContent + "\n" + newStruct
}
}
func (c *constructor) patchImports(existingContent string, imports map[string]bool) (string, error) {
if len(imports) == 0 {
return existingContent, nil
}
// Regex to find an import statement
importRegex := regexp.MustCompile(`"([^"]+)"`)
2024-04-07 18:55:50 +00:00
var importStart, importEnd int
2024-04-09 17:10:16 +00:00
importFound := true
2024-04-07 18:55:50 +00:00
// Find any existing import clause within the first 5 lines
for i, line := range strings.Split(existingContent, "\n") {
// No import found
if i > 5 && importStart == 0 {
2024-04-09 17:10:16 +00:00
importFound = false
2024-04-07 18:55:50 +00:00
}
// Trim any whitespace for import
lineContent := strings.Trim(line, " \t\n")
// Search for import mode
if importStart == 0 {
if strings.HasPrefix(lineContent, "import") {
importStart = i
// The opening bracket HAS to stand on the same line (when linted with go)
if strings.Contains(lineContent, "(") {
// We have to parse it further
continue
} else {
// Extract imported package
matches := importRegex.FindStringSubmatch(lineContent)
2024-04-07 18:55:50 +00:00
if len(matches) >= 2 {
// We can only import ONE package without ()
if _, exists := imports[matches[1]]; !exists {
imports[matches[1]] = true
}
importEnd = i
break
} else {
return existingContent, fmt.Errorf("not a valid import statment: %s", lineContent)
}
}
} else {
// Nothing to do
continue
}
}
// Parse multiline ()
if lineContent == ")" {
importEnd = i
break
}
// Ignore any empty import lines. It's used as a seperator between std packages
// and "external" modules
if lineContent == "" {
continue
}
matches := importRegex.FindStringSubmatch(lineContent)
2024-04-07 18:55:50 +00:00
if len(matches) >= 2 {
if _, exists := imports[matches[1]]; !exists {
imports[matches[1]] = true
}
} else {
return existingContent, fmt.Errorf("not a valid multiline import statment: %s", lineContent)
}
}
// Build a new import string. We wan't to sort it
keys := make([]string, 0)
for k := range imports {
2024-04-07 18:55:50 +00:00
keys = append(keys, k)
}
sort.Strings(keys)
newImport := "import (\n"
for _, k := range keys {
newImport += fmt.Sprintf("\t"+`"%s"`+"\n", k)
}
newImport += ")"
2024-04-09 17:10:16 +00:00
// We didn't found an existing import statment yet that we could replace
if !importFound {
lines := strings.Split(existingContent, "\n")
// Only a single line -> add it directly below
if len(lines) <= 1 {
lines = append(lines, "")
lines = append(lines, strings.Split(newImport, "\n")...)
} else {
// Insert it into existing, empty line
return replaceLines(existingContent, 1, 1, strings.Split("\n"+newImport, "\n")), nil
}
return strings.Join(lines, "\n"), nil
}
2024-04-07 18:55:50 +00:00
// Replace import string
return replaceLines(existingContent, importStart, importEnd, strings.Split(newImport, "\n")), nil
}
// replaceLines removes the lines identified by "startLine" and "endLine" and inserts the new lines
// at it's position.
// The line seperator has to be "\n"
func replaceLines(content string, startLine int, endLine int, newLines []string) string {
lines := strings.Split(content, "\n")
// Inset new lines by position
lines = append(lines[:startLine], append(newLines, lines[endLine+1:]...)...)
// Join strings together again
return strings.Join(lines, "\n")
}