299 lines
7.1 KiB
Go
299 lines
7.1 KiB
Go
package ddl
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"testing"
|
|
|
|
"git.rpjosh.de/RPJosh/go-logger"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/google/go-cmp/cmp"
|
|
)
|
|
|
|
// TestGetTableSimple tests the construction of a Table struct
|
|
// with all supported data types and fields
|
|
func TestGetTableSimple(t *testing.T) {
|
|
db := ConnectToMariadb(t)
|
|
mDb := NewMariaDb(db)
|
|
|
|
// Create test table
|
|
tableName, err := createTable(db,
|
|
`
|
|
id INT(10) PRIMARY KEY NOT NULL AUTO_INCREMENT,
|
|
txt VARCHAR(100) DEFAULT 'Ich bins, der Tim!',
|
|
dte DATETIME NOT NULL
|
|
COMMENT 'Hallo ihr da!\nZeilenumbrüche'
|
|
`,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %s", err)
|
|
}
|
|
defer dropTable(db, tableName)
|
|
|
|
// Get columns
|
|
table, err := mDb.GetTable(RequireEnvString("MARIADB_DB", t), tableName)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get columns: %s", err)
|
|
}
|
|
|
|
expected := &Table{
|
|
Name: tableName,
|
|
Schema: RequireEnvString("MARIADB_DB", t),
|
|
}
|
|
columns := []*MariadbColumn{
|
|
{
|
|
Column: &Column{
|
|
Name: "id",
|
|
PrimaryKey: true,
|
|
CanBeNull: false,
|
|
Type: IntType,
|
|
InternalType: "int(10)",
|
|
},
|
|
AutoIncrement: true,
|
|
DataTypeLenght: 10,
|
|
KeyType: MariadbKeyPrimary,
|
|
},
|
|
{
|
|
Column: &Column{
|
|
Name: "txt",
|
|
PrimaryKey: false,
|
|
CanBeNull: true,
|
|
Type: StringType,
|
|
InternalType: "varchar(100)",
|
|
DefaultValue: sql.NullString{
|
|
Valid: true,
|
|
String: "Ich bins, der Tim!",
|
|
},
|
|
},
|
|
AutoIncrement: false,
|
|
DataTypeLenght: 100,
|
|
},
|
|
{
|
|
Column: &Column{
|
|
Name: "dte",
|
|
PrimaryKey: false,
|
|
CanBeNull: false,
|
|
Type: DateType,
|
|
InternalType: "datetime",
|
|
Comment: "Hallo ihr da!\nZeilenumbrüche",
|
|
},
|
|
AutoIncrement: false,
|
|
DataTypeLenght: 0,
|
|
},
|
|
}
|
|
for _, c := range columns {
|
|
c.Extras = c
|
|
expected.Columns = append(expected.Columns, c.Column)
|
|
}
|
|
|
|
// Compare struct
|
|
if diff := cmp.Diff(table, expected); diff != "" {
|
|
t.Errorf("TestGetTable() mismatch (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
|
|
// TestGetTableSimple tests the construction of a Table struct
|
|
// that references another table
|
|
func TestGetTableFK(t *testing.T) {
|
|
db := ConnectToMariadb(t)
|
|
mDb := NewMariaDb(db)
|
|
|
|
// Create table we reference to
|
|
referenceTableName, err := createTable(db, `
|
|
id_to_ref INT(10) PRIMARY KEY NOT NULL AUTO_INCREMENT,
|
|
rand VARCHAR(10) NOT NULL`,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %s", err)
|
|
}
|
|
defer dropTable(db, referenceTableName)
|
|
|
|
// Create table with reference
|
|
tableName, err := createTable(db, `
|
|
id INT(10) PRIMARY KEY NOT NULL AUTO_INCREMENT,
|
|
other_id INT(10) NOT NULL,
|
|
CONSTRAINT fk_test_constraint_for_you FOREIGN KEY(other_id) REFERENCES `+referenceTableName+`(id_to_ref)
|
|
`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %s", err)
|
|
}
|
|
defer dropTable(db, tableName)
|
|
|
|
table, err := mDb.GetTable(RequireEnvString("MARIADB_DB", t), tableName)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get columns: %s", err)
|
|
}
|
|
|
|
expected := &Table{
|
|
Name: tableName,
|
|
Schema: RequireEnvString("MARIADB_DB", t),
|
|
}
|
|
columns := []*MariadbColumn{
|
|
{
|
|
Column: &Column{
|
|
Name: "id",
|
|
PrimaryKey: true,
|
|
CanBeNull: false,
|
|
Type: IntType,
|
|
InternalType: "int(10)",
|
|
},
|
|
AutoIncrement: true,
|
|
DataTypeLenght: 10,
|
|
KeyType: MariadbKeyPrimary,
|
|
},
|
|
{
|
|
Column: &Column{
|
|
Name: "other_id",
|
|
PrimaryKey: false,
|
|
CanBeNull: false,
|
|
Type: IntType,
|
|
InternalType: "int(10)",
|
|
ForeignKey: true,
|
|
ForeignKeyColumn: ForeignColumn{
|
|
Name: referenceTableName,
|
|
Schema: RequireEnvString("MARIADB_DB", t),
|
|
Column: "id_to_ref",
|
|
},
|
|
},
|
|
DataTypeLenght: 10,
|
|
KeyType: MariadbKeyMultipleIndex,
|
|
},
|
|
}
|
|
for _, c := range columns {
|
|
c.Extras = c
|
|
expected.Columns = append(expected.Columns, c.Column)
|
|
}
|
|
|
|
// Compare struct
|
|
if diff := cmp.Diff(table, expected); diff != "" {
|
|
t.Errorf("TestGetTableFK() mismatch (-want +got):\n%s", diff)
|
|
}
|
|
|
|
}
|
|
|
|
// TestGetTableSimple tests the selecting of multiple tables to a []Table array
|
|
func TestGetTables(t *testing.T) {
|
|
db := ConnectToMariadb(t)
|
|
mDb := NewMariaDb(db)
|
|
|
|
// Create two simple tables
|
|
tableName1, err := createTable(db, `idTab1 INT(10) NOT NULL`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %s", err)
|
|
}
|
|
defer dropTable(db, tableName1)
|
|
|
|
tableName2, err := createTable(db, `idTab2 INT(10) NOT NULL`)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create table: %s", err)
|
|
}
|
|
defer dropTable(db, tableName2)
|
|
|
|
tables, err := mDb.GetTables(RequireEnvString("MARIADB_DB", t))
|
|
if err != nil {
|
|
t.Fatalf("Failed to get tables: %s", err)
|
|
}
|
|
|
|
found1 := 0
|
|
found2 := 0
|
|
for _, tt := range tables {
|
|
if tt.Name == tableName1 {
|
|
found1 = found1 + 1
|
|
|
|
// Compare table
|
|
expected := &Table{
|
|
Name: tt.Name,
|
|
Schema: RequireEnvString("MARIADB_DB", t),
|
|
}
|
|
columns := []*MariadbColumn{
|
|
{
|
|
Column: &Column{
|
|
Name: "idTab1",
|
|
CanBeNull: false,
|
|
Type: IntType,
|
|
InternalType: "int(10)",
|
|
},
|
|
DataTypeLenght: 10,
|
|
},
|
|
}
|
|
for _, c := range columns {
|
|
c.Extras = c
|
|
expected.Columns = append(expected.Columns, c.Column)
|
|
}
|
|
|
|
// Compare struct
|
|
if diff := cmp.Diff(tt, expected); diff != "" {
|
|
t.Errorf("TestGetTables() mismatch of tab1: (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
|
|
if tt.Name == tableName2 {
|
|
found2 = found2 + 1
|
|
|
|
// Compare table
|
|
expected := &Table{
|
|
Name: tt.Name,
|
|
Schema: RequireEnvString("MARIADB_DB", t),
|
|
}
|
|
columns := []*MariadbColumn{
|
|
{
|
|
Column: &Column{
|
|
Name: "idTab2",
|
|
CanBeNull: false,
|
|
Type: IntType,
|
|
InternalType: "int(10)",
|
|
},
|
|
DataTypeLenght: 10,
|
|
},
|
|
}
|
|
for _, c := range columns {
|
|
c.Extras = c
|
|
expected.Columns = append(expected.Columns, c.Column)
|
|
}
|
|
|
|
// Compare struct
|
|
if diff := cmp.Diff(tt, expected); diff != "" {
|
|
t.Errorf("TestGetTables() mismatch of tab2: (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
}
|
|
|
|
// We expected to find exactly one single table
|
|
if found1 != 1 {
|
|
t.Errorf("Found %d instances of tab1. Expected 1 (len(rtc) = %d)", found1, len(tables))
|
|
}
|
|
if found2 != 1 {
|
|
t.Errorf("Found %d instances of tab2. Expected 1 (len(rtc) = %d)", found1, len(tables))
|
|
}
|
|
}
|
|
|
|
func ConnectToMariadb(t *testing.T) *sql.DB {
|
|
db, err := sql.Open("mysql", fmt.Sprintf(
|
|
"%s:%s@tcp(%s)/%s",
|
|
RequireEnvString("MARIADB_USER", t), RequireEnvString("MARIADB_PASSWORD", t), RequireEnvString("MARIADB_ADDRESS", t), RequireEnvString("MARIADB_DB", t),
|
|
))
|
|
if err != nil {
|
|
panic(fmt.Sprintf("Failed to open DB connection: %s", err))
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
// createTable creates a table with the provided column configuration
|
|
// in statement and returns the created table name
|
|
func createTable(db *sql.DB, statement string) (string, error) {
|
|
name, _ := GenerateRandomString(8)
|
|
name = "ddl_test_" + name
|
|
sql := fmt.Sprintf("CREATE TABLE %s (%s)", name, statement)
|
|
|
|
_, err := db.Exec(sql)
|
|
if err != nil {
|
|
logger.Debug("Create statement: %s", sql)
|
|
}
|
|
return name, err
|
|
}
|
|
func dropTable(db *sql.DB, tableName string) error {
|
|
_, err := db.Exec("DROP TABLE " + tableName)
|
|
return err
|
|
}
|