From ee36d8d116071a042caa3497403d70f2e4eb8660 Mon Sep 17 00:00:00 2001 From: RPJosh Date: Thu, 16 May 2024 20:10:16 +0200 Subject: [PATCH] Add option to specify custom null types --- structt/struct.go | 40 +++++++++++++++++++++--- structt/struct_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/structt/struct.go b/structt/struct.go index 8766e56..525ac9e 100644 --- a/structt/struct.go +++ b/structt/struct.go @@ -1,6 +1,7 @@ package structt import ( + "database/sql" "errors" "fmt" "os" @@ -37,6 +38,9 @@ type StructConfig struct { // The key of this map is either the table name (for any schema) // or a combination of "schema.tableName" Tableconfig map[string]*TableConfig `yaml:"tableConfig"` + + // Configuration of how to handle nullable columns + NullConfig NullConfig } // TableConfig contains options for a specific table @@ -66,6 +70,21 @@ type TableConfig struct { Suffix string `yaml:"suffix"` } +// NullConfig configures how to transform nullable columns into a go struct +type NullConfig struct { + + // Disable the use of nullable datatypes + Disable bool + + // Name of the package to import the types from. + // Defaulting to [database/sql] + Package string + + // Prefix to use in front of name like "String" or "Int64". + // Defaulting to "sql.Null" + Prefix sql.NullString +} + type constructor struct { config *StructConfig tables []*ddl.Table @@ -314,16 +333,27 @@ func (c *constructor) getDataType(column *ddl.Column, tblConfig *TableConfig, _ } // Try to use sql null strings - if column.CanBeNull { + if column.CanBeNull && !c.config.NullConfig.Disable { + + // Get nullable data types + prefix := "sql.Null" + imp := "database/sql" + if c.config.NullConfig.Package != "" { + imp = c.config.NullConfig.Package + } + if c.config.NullConfig.Prefix.Valid { + prefix = c.config.NullConfig.Prefix.String + } + switch column.Type { case ddl.StringType: - return "sql.NullString", "database/sql" + return prefix + "String", imp case ddl.IntType: - return "sql.NullInt64", "database/sql" + return prefix + "Int64", imp case ddl.DoubleType: - return "sql.NullFloat64", "database/sql" + return prefix + "Float64", imp case ddl.DateType: - return "sql.NullTime", "database/sql" + return prefix + "Time", imp case ddl.GeoType: return "ddl.Location", PackageName } diff --git a/structt/struct_test.go b/structt/struct_test.go index 4de5ffc..87f10a7 100644 --- a/structt/struct_test.go +++ b/structt/struct_test.go @@ -1,6 +1,7 @@ package structt import ( + "database/sql" "fmt" "regexp" "strings" @@ -413,6 +414,74 @@ func TestPatchFilePatchSelf(t *testing.T) { } } +func TestNullableConfig(t *testing.T) { + c := &constructor{ + config: &StructConfig{ + NullConfig: NullConfig{ + Disable: true, + }, + }, + } + + table := &ddl.Table{ + Name: "my_table_name", + Schema: "here_is_me", + Columns: []*ddl.Column{ + { + Name: "id", + Type: ddl.IntType, + CanBeNull: true, + }, + }, + } + tableConfig := &TableConfig{ + PackageName: "olaf", + Suffix: "Tab", + } + + expected := + `package olaf +%s + +type MyTableNameTab struct { + Id %s ` + getStructTag(table.Columns[0]) + ` + ` + MetadataFieldName + ` any ` + getMetadataTag(table) + ` +} +// MyTableNameTab +const ( + MyTableNameTab_Id string = "Id|here_is_me.my_table_name.id" +) +` + goFile := c.getGoFile("", table, tableConfig) + + // Compare structs + if diff := cmp.Diff( + replaceWhitespaces(fmt.Sprintf(expected, "", "int")), + replaceWhitespaces(goFile), + ); diff != "" { + t.Errorf("Mismatch of disabled null types (-want +got):\n%s", diff) + t.Logf("Expected:\n%s", expected) + t.Logf("Actual:\n%s", goFile) + } + + // ==== Test with custom data type ==== // + c.config.NullConfig = NullConfig{ + Package: "git.rpjosh.de/MyCustom", + Prefix: sql.NullString{Valid: true, String: "olaf.Null"}, + } + goFile = c.getGoFile("", table, tableConfig) + + if diff := cmp.Diff( + replaceWhitespaces(fmt.Sprintf(expected, "import (\n\t\"git.rpjosh.de/MyCustom\"\n)", "olaf.NullInt64")), + replaceWhitespaces(goFile), + ); diff != "" { + t.Errorf("Mismatch of custom null types (-want +got):\n%s", diff) + t.Logf("Expected:\n%s", expected) + t.Logf("Actual:\n%s", goFile) + } + +} + // replaceWhitespaces replaces any space, newline or a squecne of // spaces with a single space func replaceWhitespaces(val string) string {