package clickhouse import ( "bytes" "context" "database/sql/driver" "unicode" "github.com/ClickHouse/clickhouse-go/lib/data" ) type stmt struct { ch *clickhouse query string counter int numInput int isInsert bool } var emptyResult = &result{} type key string var queryIDKey key //Put query ID into context and use it in ExecContext or QueryContext func WithQueryID(ctx context.Context, queryID string) context.Context { return context.WithValue(ctx, queryIDKey, queryID) } func (stmt *stmt) NumInput() int { switch { case stmt.ch.block != nil: return len(stmt.ch.block.Columns) case stmt.numInput < 0: return 0 } return stmt.numInput } func (stmt *stmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.execContext(context.Background(), args) } func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { dargs := make([]driver.Value, len(args)) for i, nv := range args { dargs[i] = nv.Value } return stmt.execContext(ctx, dargs) } func (stmt *stmt) execContext(ctx context.Context, args []driver.Value) (driver.Result, error) { if stmt.isInsert { stmt.counter++ if err := stmt.ch.block.AppendRow(args); err != nil { return nil, err } if (stmt.counter % stmt.ch.blockSize) == 0 { stmt.ch.logf("[exec] flush block") if err := stmt.ch.writeBlock(stmt.ch.block, ""); err != nil { return nil, err } if err := stmt.ch.encoder.Flush(); err != nil { return nil, err } } return emptyResult, nil } query, externalTables := stmt.bind(convertOldArgs(args)) if err := stmt.ch.sendQuery(ctx, query, externalTables); err != nil { return nil, err } if err := stmt.ch.process(); err != nil { return nil, err } return emptyResult, nil } func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.queryContext(context.Background(), convertOldArgs(args)) } func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { return stmt.queryContext(ctx, args) } func (stmt *stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { finish := stmt.ch.watchCancel(ctx) query, externalTables := stmt.bind(args) if err := stmt.ch.sendQuery(ctx, query, externalTables); err != nil { finish() return nil, err } meta, err := stmt.ch.readMeta() if err != nil { finish() return nil, err } rows := rows{ ch: stmt.ch, finish: finish, stream: make(chan *data.Block, 50), columns: meta.ColumnNames(), blockColumns: meta.Columns, } go rows.receiveData() return &rows, nil } func (stmt *stmt) Close() error { stmt.ch.logf("[stmt] close") return nil } func (stmt *stmt) bind(args []driver.NamedValue) (string, []ExternalTable) { var ( buf bytes.Buffer index int keyword bool inBetween bool like = newMatcher("like") limit = newMatcher("limit") offset = newMatcher("offset") between = newMatcher("between") and = newMatcher("and") in = newMatcher("in") from = newMatcher("from") join = newMatcher("join") subSelect = newMatcher("select") externalTables = make([]ExternalTable, 0) ) switch { case stmt.NumInput() != 0: reader := bytes.NewReader([]byte(stmt.query)) for { if char, _, err := reader.ReadRune(); err == nil { switch char { case '@': if param := paramParser(reader); len(param) != 0 { for _, v := range args { if len(v.Name) != 0 && v.Name == param { switch v := v.Value.(type) { case ExternalTable: buf.WriteString(v.Name) externalTables = append(externalTables, v) default: buf.WriteString(quote(v)) } } } } case '?': if keyword && index < len(args) && len(args[index].Name) == 0 { switch v := args[index].Value.(type) { case ExternalTable: buf.WriteString(v.Name) externalTables = append(externalTables, v) default: buf.WriteString(quote(v)) } index++ } else { buf.WriteRune(char) } default: switch { case char == '=', char == '<', char == '>', char == '(', char == ',', char == '+', char == '-', char == '*', char == '/', char == '[': keyword = true default: if limit.matchRune(char) || offset.matchRune(char) || like.matchRune(char) || in.matchRune(char) || from.matchRune(char) || join.matchRune(char) || subSelect.matchRune(char) { keyword = true } else if between.matchRune(char) { keyword = true inBetween = true } else if inBetween && and.matchRune(char) { keyword = true inBetween = false } else { keyword = keyword && unicode.IsSpace(char) } } buf.WriteRune(char) } } else { break } } default: buf.WriteString(stmt.query) } return buf.String(), externalTables } func convertOldArgs(args []driver.Value) []driver.NamedValue { dargs := make([]driver.NamedValue, len(args)) for i, v := range args { dargs[i] = driver.NamedValue{ Ordinal: i + 1, Value: v, } } return dargs }