From 270241207e1778e34f0bfd3406ff24112b830501 Mon Sep 17 00:00:00 2001 From: Jonas Letzbor Date: Fri, 12 Apr 2024 12:00:55 +0200 Subject: [PATCH] Add oracle support --- README.md | 1 + go.mod | 1 + go.sum | 2 + mariadb_test.go | 26 ++--- oracle.go | 228 ++++++++++++++++++++++++++++++++++++ oracle_test.go | 300 ++++++++++++++++++++++++++++++++++++++++++++++++ utils_test.go | 19 ++- 7 files changed, 563 insertions(+), 14 deletions(-) create mode 100644 oracle.go create mode 100644 oracle_test.go diff --git a/README.md b/README.md index 574158c..7691676 100644 --- a/README.md +++ b/README.md @@ -5,3 +5,4 @@ DDL-Parser is a simple go module to parse all table columns from a running datab The following SQL databases are supported and tested: - MariaDB *10.6* +- Oracle *21* \ No newline at end of file diff --git a/go.mod b/go.mod index 9bbf4f0..76d7a56 100644 --- a/go.mod +++ b/go.mod @@ -11,5 +11,6 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/sijms/go-ora/v2 v2.8.11 // indirect golang.org/x/sys v0.19.0 // indirect ) diff --git a/go.sum b/go.sum index 36ef7f8..7597106 100644 --- a/go.sum +++ b/go.sum @@ -8,5 +8,7 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/sijms/go-ora/v2 v2.8.11 h1:oQtSX145kCYSjnrmWdtqp2LON9wOQW09wPJ5pIEn5Tg= +github.com/sijms/go-ora/v2 v2.8.11/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/mariadb_test.go b/mariadb_test.go index 0b16e1f..e32a065 100644 --- a/mariadb_test.go +++ b/mariadb_test.go @@ -17,7 +17,7 @@ func TestGetTableSimple(t *testing.T) { mDb := NewMariaDb(db) // Create test table - tableName, err := createMariadbTable(db, + tableName, err := createTable(db, ` id INT(10) PRIMARY KEY NOT NULL AUTO_INCREMENT, txt VARCHAR(100) DEFAULT 'Ich bins, der Tim!', @@ -28,7 +28,7 @@ func TestGetTableSimple(t *testing.T) { if err != nil { t.Fatalf("Failed to create table: %s", err) } - defer dropMariadbTable(db, tableName) + defer dropTable(db, tableName) // Get columns table, err := mDb.GetTable(RequireEnvString("MARIADB_DB", t), tableName) @@ -99,17 +99,17 @@ func TestGetTableFK(t *testing.T) { mDb := NewMariaDb(db) // Create table we reference to - referenceTableName, err := createMariadbTable(db, ` + referenceTableName, err := createTable(db, ` id_to_ref INT(10) PRIMARY KEY NOT NULL AUTO_INCREMENT, rand VARCHAR(10) NOT NULL`, ) if err != nil { t.Fatalf("Failed to create table: %s", err) } - defer dropMariadbTable(db, referenceTableName) + defer dropTable(db, referenceTableName) // Create table with reference - tableName, err := createMariadbTable(db, ` + tableName, err := createTable(db, ` id INT(10) PRIMARY KEY NOT NULL AUTO_INCREMENT, other_id INT(10) NOT NULL, CONSTRAINT fk_test_constraint_for_you FOREIGN KEY(other_id) REFERENCES `+referenceTableName+`(id_to_ref) @@ -117,7 +117,7 @@ func TestGetTableFK(t *testing.T) { if err != nil { t.Fatalf("Failed to create table: %s", err) } - defer dropMariadbTable(db, tableName) + defer dropTable(db, tableName) table, err := mDb.GetTable(RequireEnvString("MARIADB_DB", t), tableName) if err != nil { @@ -177,17 +177,17 @@ func TestGetTables(t *testing.T) { mDb := NewMariaDb(db) // Create two simple tables - tableName1, err := createMariadbTable(db, `idTab1 INT(10) NOT NULL`) + tableName1, err := createTable(db, `idTab1 INT(10) NOT NULL`) if err != nil { t.Fatalf("Failed to create table: %s", err) } - defer dropMariadbTable(db, tableName1) + defer dropTable(db, tableName1) - tableName2, err := createMariadbTable(db, `idTab2 INT(10) NOT NULL`) + tableName2, err := createTable(db, `idTab2 INT(10) NOT NULL`) if err != nil { t.Fatalf("Failed to create table: %s", err) } - defer dropMariadbTable(db, tableName2) + defer dropTable(db, tableName2) tables, err := mDb.GetTables(RequireEnvString("MARIADB_DB", t)) if err != nil { @@ -279,9 +279,9 @@ func ConnectToMariadb(t *testing.T) *sql.DB { return db } -// createMariadbTable creates a table with the provided column configuration +// createTable creates a table with the provided column configuration // in statement and returns the created table name -func createMariadbTable(db *sql.DB, statement string) (string, error) { +func createTable(db *sql.DB, statement string) (string, error) { name, _ := GenerateRandomString(8) name = "ddl_test_" + name sql := fmt.Sprintf("CREATE TABLE %s (%s)", name, statement) @@ -292,7 +292,7 @@ func createMariadbTable(db *sql.DB, statement string) (string, error) { } return name, err } -func dropMariadbTable(db *sql.DB, tableName string) error { +func dropTable(db *sql.DB, tableName string) error { _, err := db.Exec("DROP TABLE " + tableName) return err } diff --git a/oracle.go b/oracle.go new file mode 100644 index 0000000..49fff39 --- /dev/null +++ b/oracle.go @@ -0,0 +1,228 @@ +package ddl + +import ( + "database/sql" + "fmt" + "strings" + + "git.rpjosh.de/RPJosh/go-logger" +) + +type OracleTableType string + +const ( + OracleTable OracleTableType = "TABLE" + OracleView OracleTableType = "VIEW" +) + +var _ DbSystem = &OracleDb{} +var _ Columner = &OracleColumn{} + +// OracleDb implements "DbSystem" for an oracle database +type OracleDb struct { + db *sql.DB +} + +type OracleColumn struct { + *Column + + // Weather this column has the auto_increment flag + AutoIncrement bool + + // Character lenght or numeric precision on the LEFT side + // of the dot + DataTypeLenght int + + // Decimal precision on the RIGHT side of the dot + Scale int +} + +func (c *OracleColumn) GetExtraInfos() string { + return "Oracle!" +} +func (c *OracleColumn) GetSpecificInfos() any { + return c +} +func (s *OracleDb) newColumn() *OracleColumn { + c := &OracleColumn{} + c.Column = &Column{} + c.Column.Extras = c + return c +} + +// NewMariaDb initializes a new database parser for an oracle database +func NewOracleDb(db *sql.DB) *OracleDb { + return &OracleDb{ + db: db, + } +} + +func (s *OracleDb) GetTable(schema, name string) (*Table, error) { + ssql := ` + SELECT + col.OWNER, + col.table_name, + col.COLUMN_NAME, + col.DATA_DEFAULT, + col.NULLABLE, + col.DATA_TYPE, + COALESCE(col.DATA_PRECISION, col.DATA_LENGTH, 0), COALESCE(col.DATA_SCALE, 0), + col.IDENTITY_COLUMN, con.CONSTRAINT_TYPE, + coms.COMMENTS, + -- Foreign key data + act.OWNER, act.table_name, act.COLUMN_NAME + FROM all_tab_columns col + LEFT JOIN all_cons_columns cc ON cc.TABLE_NAME = col.TABLE_NAME AND col.COLUMN_NAME = cc.COLUMN_NAME + LEFT JOIN all_constraints con ON cc.CONSTRAINT_NAME = con.CONSTRAINT_NAME + LEFT JOIN all_cons_columns act ON con.r_owner = act.owner + AND con.r_constraint_name = act.constraint_name + LEFT JOIN dba_col_comments coms ON coms.OWNER = col.OWNER AND coms.TABLE_NAME = col.TABLE_NAME + AND coms.COLUMN_NAME = col.COLUMN_NAME + WHERE col.table_name = UPPER(:0) + AND col.OWNER = UPPER(:1) + ORDER BY col.column_id + ` + rows, err := s.db.Query(ssql, name, schema) + if err != nil { + return nil, fmt.Errorf("failed to query all_tab_columns: %s", err) + } + defer rows.Close() + + lastColumnName := "" + table := &Table{} + count := 0 + for rows.Next() { + var tableSchema, tableName, isNullable, identity string + var fkOwner, fkTable, fkColumn, keyType, comment sql.NullString + column := s.newColumn() + + if err := rows.Scan( + &tableSchema, &tableName, + &column.Name, &column.DefaultValue, &isNullable, + &column.InternalType, &column.DataTypeLenght, &column.Scale, + &identity, &keyType, &comment, + &fkOwner, &fkTable, &fkColumn, + ); err != nil { + return nil, fmt.Errorf("failed to scan row: %s", err) + } + + // Apply data + column.CanBeNull = isNullable == "Y" + column.Type = s.GetDataType(column.InternalType, column) + column.PrimaryKey = identity == "YES" || (keyType.Valid && keyType.String == "P") + if fkColumn.Valid { + column.ForeignKey = true + column.ForeignKeyColumn.Column = fkColumn.String + column.ForeignKeyColumn.Name = fkTable.String + column.ForeignKeyColumn.Schema = fkOwner.String + } + + // The default value contains the raw single quotes of the create statement + if column.DefaultValue.Valid { + column.DefaultValue.String = strings.TrimPrefix(column.DefaultValue.String, "'") + column.DefaultValue.String = strings.TrimSuffix(column.DefaultValue.String, "'") + } + + // Set comment + if comment.Valid { + column.Comment = strings.ReplaceAll(comment.String, "\\n", "\n") + } + + // Initialize new table metadata + if count == 0 { + table.Schema = tableSchema + table.Name = tableName + } + + // It's possible that we get the same column twice for different keyTypes. + // Always prefer the primary or foreign key constraint + if lastColumnName == column.Name { + if column.ForeignKey || column.PrimaryKey { + // Don't skip, but remove the last one + table.Columns = table.Columns[:len(table.Columns)-1] + } else { + // We use the primary key or foreign key + continue + } + } + lastColumnName = column.Name + + table.Columns = append(table.Columns, column.Column) + count += 1 + } + + // We got no data + if count == 0 { + return nil, fmt.Errorf("%s.%s was not found", schema, name) + } + + return table, nil +} + +func (s *OracleDb) GetTables(schema string) ([]*Table, error) { + return s.GetTablesByType(schema, OracleTable) +} + +func (s *OracleDb) GetTablesByType(schema string, typ OracleTableType) ([]*Table, error) { + sql := ` + SELECT DISTINCT + OWNER, + OBJECT_NAME, + OBJECT_TYPE + FROM ALL_OBJECTS + WHERE OBJECT_TYPE = :0 + AND OWNER <> 'SYS' + AND OWNER = :1 + ORDER BY OBJECT_NAME ASC + ` + rows, err := s.db.Query(sql, string(typ), schema) + if err != nil { + return nil, fmt.Errorf("failed to query all_objects: %s", err) + } + defer rows.Close() + + rtc := []*Table{} + for rows.Next() { + var tableSchema, tableName, tableType string + if err := rows.Scan(&tableSchema, &tableName, &tableType); err != nil { + return rtc, fmt.Errorf("failed to scan row: %s", err) + } + + t, err := s.GetTable(tableSchema, tableName) + if err != nil { + return rtc, fmt.Errorf("failed to get data for %s.%s: %s", tableSchema, tableName, err) + } + rtc = append(rtc, t) + } + + return rtc, nil +} + +func (s *OracleDb) GetDataType(internalType string, col *OracleColumn) DataType { + internalType = strings.ToLower(internalType) + + // Remove any data type length (for some datatypes they are returned...) + if lastBracket := strings.Index(internalType, "("); lastBracket != -1 { + internalType = internalType[:lastBracket] + } + + switch internalType { + case "varchar", "varchar2", "nvarchar", "nvarchar2": + return StringType + case "double": + return DoubleType + case "date", "timestamp", "timestamptz": + return DateType + default: + // A number can either be a double or a int + if internalType == "number" { + if col.Scale == 0 { + return IntType + } else { + return DoubleType + } + } + logger.Warning("OracleDb: received unknown data type column: %s", internalType) + return UnknownType + } +} diff --git a/oracle_test.go b/oracle_test.go new file mode 100644 index 0000000..b62f840 --- /dev/null +++ b/oracle_test.go @@ -0,0 +1,300 @@ +package ddl + +import ( + "database/sql" + "fmt" + "strings" + "testing" + + "git.rpjosh.de/RPJosh/go-logger" + "github.com/google/go-cmp/cmp" + goOra "github.com/sijms/go-ora/v2" +) + +// TestGetTableSimple tests the construction of a Table struct +// with all supported data types and fields +func TestGetTableSimpleOracle(t *testing.T) { + db := ConnectToOracle(t) + oDb := NewOracleDb(db) + + // Create test table + tableName, err := createTable(db, + ` + id NUMERIC(10,0) PRIMARY KEY NOT NULL, + txt VARCHAR2(100) DEFAULT 'Ich bins, der Tim!', + dte DATE NOT NULL + `, + ) + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + tableName = strings.ToUpper(tableName) + defer dropTable(db, tableName) + + // Comment table + if err := addOracleComment(db, tableName, "DTE", `Hallo ihr da!\nZeilenumbrüche`); err != nil { + t.Fatalf("Failed to comment table: %s", err) + } + + // Get columns + table, err := oDb.GetTable(RequireEnvString("ORACLE_USER", t), tableName) + if err != nil { + t.Fatalf("Failed to get columns: %s", err) + } + + expected := &Table{ + Name: strings.ToUpper(tableName), + Schema: RequireEnvString("ORACLE_USER", t), + } + columns := []*OracleColumn{ + { + Column: &Column{ + Name: "ID", + PrimaryKey: true, + CanBeNull: false, + Type: IntType, + InternalType: "NUMBER", + }, + DataTypeLenght: 10, + Scale: 0, + }, + { + Column: &Column{ + Name: "TXT", + PrimaryKey: false, + CanBeNull: true, + Type: StringType, + InternalType: "VARCHAR2", + DefaultValue: sql.NullString{ + Valid: true, + String: "Ich bins, der Tim!", + }, + }, + AutoIncrement: false, + DataTypeLenght: 100, + }, + { + Column: &Column{ + Name: "DTE", + PrimaryKey: false, + CanBeNull: false, + Type: DateType, + InternalType: "DATE", + Comment: "Hallo ihr da!\nZeilenumbrüche", + }, + AutoIncrement: false, + DataTypeLenght: 7, + }, + } + for _, c := range columns { + c.Extras = c + expected.Columns = append(expected.Columns, c.Column) + } + + // Compare struct + if diff := cmp.Diff(table, expected); diff != "" { + t.Errorf("Mismatch of columns (-want +got):\n%s", diff) + } +} + +// TestGetTableSimple tests the construction of a Table struct +// that references another table +func TestGetTableOracleFK(t *testing.T) { + db := ConnectToOracle(t) + mDb := NewOracleDb(db) + + // Create table we reference to + referenceTableName, err := createTable(db, ` + id_to_ref NUMBER(10,0) PRIMARY KEY NOT NULL, + rand VARCHAR2(10) NOT NULL`, + ) + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + defer dropTable(db, referenceTableName) + + // Create table with reference + tableName, err := createTable(db, ` + id NUMBER(10,0) PRIMARY KEY NOT NULL, + other_id NUMBER(10,0) NOT NULL, + CONSTRAINT fk_test_constraint_for_you FOREIGN KEY(other_id) REFERENCES `+referenceTableName+`(id_to_ref) + `) + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + defer dropTable(db, tableName) + + table, err := mDb.GetTable(RequireEnvString("ORACLE_USER", t), tableName) + if err != nil { + t.Fatalf("Failed to get columns: %s", err) + } + + expected := &Table{ + Name: strings.ToUpper(tableName), + Schema: RequireEnvString("ORACLE_USER", t), + } + columns := []*OracleColumn{ + { + Column: &Column{ + Name: "ID", + PrimaryKey: true, + CanBeNull: false, + Type: IntType, + InternalType: "NUMBER", + }, + DataTypeLenght: 10, + }, + { + Column: &Column{ + Name: "OTHER_ID", + PrimaryKey: false, + CanBeNull: false, + Type: IntType, + InternalType: "NUMBER", + ForeignKey: true, + ForeignKeyColumn: ForeignColumn{ + Name: strings.ToUpper(referenceTableName), + Schema: RequireEnvString("ORACLE_USER", t), + Column: "ID_TO_REF", + }, + }, + DataTypeLenght: 10, + }, + } + for _, c := range columns { + c.Extras = c + expected.Columns = append(expected.Columns, c.Column) + } + + // Compare struct + if diff := cmp.Diff(table, expected); diff != "" { + t.Errorf("Mismatch (-want +got):\n%s", diff) + } + +} + +// TestGetTableSimple tests the selecting of multiple tables to a []Table array +func TestGetTablesOracle(t *testing.T) { + db := ConnectToOracle(t) + mDb := NewOracleDb(db) + + // Create two simple tables + tableName1, err := createTable(db, `idTab1 NUMBER(10,0) NOT NULL`) + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + tableName1 = strings.ToUpper(tableName1) + defer dropTable(db, tableName1) + + tableName2, err := createTable(db, `idTab2 NUMBER(10,0) NOT NULL`) + if err != nil { + t.Fatalf("Failed to create table: %s", err) + } + tableName2 = strings.ToUpper(tableName2) + defer dropTable(db, tableName2) + + tables, err := mDb.GetTables(RequireEnvString("ORACLE_USER", t)) + if err != nil { + t.Fatalf("Failed to get tables: %s", err) + } + + found1 := 0 + found2 := 0 + for _, tt := range tables { + if tt.Name == tableName1 { + found1 = found1 + 1 + + // Compare table + expected := &Table{ + Name: tt.Name, + Schema: RequireEnvString("ORACLE_USER", t), + } + columns := []*OracleColumn{ + { + Column: &Column{ + Name: "IDTAB1", + CanBeNull: false, + Type: IntType, + InternalType: "NUMBER", + }, + DataTypeLenght: 10, + }, + } + for _, c := range columns { + c.Extras = c + expected.Columns = append(expected.Columns, c.Column) + } + + // Compare struct + if diff := cmp.Diff(tt, expected); diff != "" { + t.Errorf("TestGetTables() mismatch of tab1: (-want +got):\n%s", diff) + } + } + + if tt.Name == tableName2 { + found2 = found2 + 1 + + // Compare table + expected := &Table{ + Name: tt.Name, + Schema: RequireEnvString("ORACLE_USER", t), + } + columns := []*OracleColumn{ + { + Column: &Column{ + Name: "IDTAB2", + CanBeNull: false, + Type: IntType, + InternalType: "NUMBER", + }, + DataTypeLenght: 10, + }, + } + for _, c := range columns { + c.Extras = c + expected.Columns = append(expected.Columns, c.Column) + } + + // Compare struct + if diff := cmp.Diff(tt, expected); diff != "" { + t.Errorf("Mismatch of tab2: (-want +got):\n%s", diff) + } + } + } + + // We expected to find exactly one single table + if found1 != 1 { + t.Errorf("Found %d instances of tab1. Expected 1 (len(rtc) = %d)", found1, len(tables)) + } + if found2 != 1 { + t.Errorf("Found %d instances of tab2. Expected 1 (len(rtc) = %d)", found1, len(tables)) + } +} + +func addOracleComment(db *sql.DB, tbl string, column string, comment string) error { + comment = strings.ReplaceAll(comment, "\n", `'||char(10)||'`) + sql := fmt.Sprintf("COMMENT ON COLUMN \"%s\".\"%s\" IS '%s'", tbl, column, comment) + _, err := db.Exec(sql) + if err != nil { + logger.Debug("Statement for create comment:\n%s", sql) + } + return err +} + +func ConnectToOracle(t *testing.T) *sql.DB { + conString := goOra.BuildUrl( + RequireEnvString("ORACLE_SERVER", t), + RequireEnvInt("ORACLE_PORT", t), + RequireEnvString("ORACLE_SERVICE", t), + RequireEnvString("ORACLE_USER", t), + RequireEnvString("ORACLE_PASSWORD", t), + map[string]string{}, + ) + + db, err := sql.Open("oracle", conString) + if err != nil { + panic(fmt.Sprintf("Failed to open DB connection: %s", err)) + } + + return db +} diff --git a/utils_test.go b/utils_test.go index e5a622a..26cb3fd 100644 --- a/utils_test.go +++ b/utils_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "math/big" "os" + "strconv" "strings" "testing" @@ -14,11 +15,27 @@ func RequireEnvString(name string, t *testing.T) string { if strVal, isSet := os.LookupEnv(name); isSet { return strVal } else { - t.Errorf("Required environment variable %q not set", name) + t.Fatalf("Required environment variable %q not set", name) return "" } } +func RequireEnvInt(name string, t *testing.T) int { + if strVal, isSet := os.LookupEnv(name); isSet { + if intVal, err := strconv.Atoi(strVal); err != nil { + t.Fatalf("Invalid number value given for the environment variable %q: %s", name, strVal) + } else if intVal < 1 { + t.Fatalf("Environment variable %q has to be greater than 0", name) + } else { + return intVal + } + } else { + t.Fatalf("Required environment variable %q not set", name) + } + + return 0 +} + func DumpStruct(a ...interface{}) string { dump := spew.Sdump(a...)