go-ddl-parser/mariadb.go

184 lines
4.7 KiB
Go
Raw Normal View History

2024-04-07 18:55:50 +00:00
package ddl
import (
"database/sql"
"fmt"
"strings"
"git.rpjosh.de/RPJosh/go-logger"
)
type MariadbKeyType string
const (
// Primary key
MariadbKeyPrimary MariadbKeyType = "PRI"
// Unique index
MariadbKeyUnique MariadbKeyType = "UNI"
// Nonunique index
MariadbKeyMultipleIndex MariadbKeyType = "MUL"
)
var _ DbSystem = &Mariadb{}
var _ Columner = &MariadbColumn{}
// Mariadb implements "DbSystem" for a MariaDB database
type Mariadb struct {
db *sql.DB
}
type MariadbColumn struct {
*Column
// Weather this column has the auto_increment flag
AutoIncrement bool
// The character lenght or numeric precision
DataTypeLenght int
// The internal column key like 'UNI' or 'PRI'
KeyType MariadbKeyType
}
func (c *MariadbColumn) GetExtraInfos() string {
return "MariaDB!"
}
func (c *MariadbColumn) GetSpecificInfos() any {
return c
}
func (s *Mariadb) newColumn() *MariadbColumn {
c := &MariadbColumn{}
c.Column = &Column{}
c.Column.Extras = c
return c
}
// NewMariaDb initializes a new database parser for a MariaDB database
func NewMariaDb(db *sql.DB) DbSystem {
return &Mariadb{
db: db,
}
}
func (s *Mariadb) GetTable(schema, name string) (*Table, error) {
sql := `
SELECT
c.TABLE_SCHEMA,
c.TABLE_NAME,
c.COLUMN_NAME,
c.COLUMN_DEFAULT,
c.IS_NULLABLE,
c.DATA_TYPE,
c.COLUMN_TYPE,
COALESCE(c.CHARACTER_MAXIMUM_LENGTH, c.NUMERIC_PRECISION, c.DATETIME_PRECISION),
c.COLUMN_KEY,
c.COLUMN_COMMENT,
c.extra,
-- Foreign key data
COALESCE(con.REFERENCED_TABLE_NAME, ''), COALESCE(con.REFERENCED_TABLE_SCHEMA, ''), COALESCE(con.REFERENCED_COLUMN_NAME, '')
FROM INFORMATION_SCHEMA.COLUMNS c
LEFT JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE con ON
con.TABLE_NAME = c.TABLE_NAME AND con.TABLE_SCHEMA = c.TABLE_SCHEMA AND con.COLUMN_NAME = c.COLUMN_NAME
AND con.CONSTRAINT_NAME IN ( SELECT cc.CONSTRAINT_NAME FROM information_schema.TABLE_CONSTRAINTS cc WHERE cc.TABLE_SCHEMA = c.TABLE_SCHEMA AND cc.TABLE_NAME = c.TABLE_NAME AND cc.CONSTRAINT_TYPE = 'FOREIGN KEY' )
WHERE c.TABLE_SCHEMA = ? AND c.TABLE_NAME = ?
ORDER BY c.ordinal_position
`
rows, err := s.db.Query(sql, schema, name)
if err != nil {
return nil, fmt.Errorf("failed to query information_schema: %s", err)
}
defer rows.Close()
table := &Table{}
count := 0
for rows.Next() {
var tableSchema, tableName, isNullable, dataType, extra string
column := s.newColumn()
if err := rows.Scan(
&tableSchema, &tableName,
&column.Name, &column.DefaultValue, &isNullable,
&dataType, &column.InternalType, &column.DataTypeLenght,
&column.KeyType, &column.Comment, &extra,
&column.ForeignKeyColumn.Name, &column.ForeignKeyColumn.Schema, &column.ForeignKeyColumn.Column,
); err != nil {
return nil, fmt.Errorf("failed to scan row: %s", err)
}
// Apply data
column.CanBeNull = isNullable == "YES"
column.Type = s.GetDataType(dataType)
column.AutoIncrement = strings.Contains(extra, "auto_increment")
column.PrimaryKey = column.KeyType == MariadbKeyPrimary
column.ForeignKey = column.ForeignKeyColumn.Column != ""
// The default value contains the raw single quotes of the create statement
if column.DefaultValue.Valid {
column.DefaultValue.String = strings.TrimPrefix(column.DefaultValue.String, "'")
column.DefaultValue.String = strings.TrimSuffix(column.DefaultValue.String, "'")
}
// Initialize new table metadata
if count == 0 {
table.Schema = tableSchema
table.Name = tableName
}
table.Columns = append(table.Columns, column.Column)
count += 1
}
// We got no data
if count == 0 {
return nil, fmt.Errorf("schema.table was not found")
}
return table, nil
}
func (s *Mariadb) GetTables(schema string) ([]*Table, error) {
sql := `
SELECT
t.TABLE_SCHEMA,
t.TABLE_NAME
FROM information_schema.tables t
WHERE t.table_schema = ?
`
rows, err := s.db.Query(sql, schema)
if err != nil {
return nil, fmt.Errorf("failed to query information_schema: %s", err)
}
defer rows.Close()
rtc := []*Table{}
for rows.Next() {
var tableSchema, tableName string
if err := rows.Scan(&tableSchema, &tableName); err != nil {
return rtc, fmt.Errorf("failed to scan row: %s", err)
}
t, err := s.GetTable(tableSchema, tableName)
if err != nil {
return rtc, fmt.Errorf("failed to get data for %s.%s: %s", tableSchema, tableName, err)
}
rtc = append(rtc, t)
}
return rtc, nil
}
func (s *Mariadb) GetDataType(internalType string) DataType {
switch strings.ToLower(internalType) {
case "varchar":
return StringType
case "int", "tinyint", "smallint", "bigint":
return IntType
case "decimal", "number", "float", "double":
return DoubleType
case "datetime":
return DateType
default:
logger.Warning("MariaDb: received unknown data type column: %s", internalType)
return UnknownType
}
}