Edit on GitHub

sqlglot.transforms

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import expressions as exp
  6from sqlglot.helper import find_new_name, name_sequence
  7
  8if t.TYPE_CHECKING:
  9    from sqlglot.generator import Generator
 10
 11
 12def preprocess(
 13    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
 14) -> t.Callable[[Generator, exp.Expression], str]:
 15    """
 16    Creates a new transform by chaining a sequence of transformations and converts the resulting
 17    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
 18    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
 19
 20    Args:
 21        transforms: sequence of transform functions. These will be called in order.
 22
 23    Returns:
 24        Function that can be used as a generator transform.
 25    """
 26
 27    def _to_sql(self, expression: exp.Expression) -> str:
 28        expression_type = type(expression)
 29
 30        expression = transforms[0](expression)
 31        for transform in transforms[1:]:
 32            expression = transform(expression)
 33
 34        _sql_handler = getattr(self, expression.key + "_sql", None)
 35        if _sql_handler:
 36            return _sql_handler(expression)
 37
 38        transforms_handler = self.TRANSFORMS.get(type(expression))
 39        if transforms_handler:
 40            if expression_type is type(expression):
 41                if isinstance(expression, exp.Func):
 42                    return self.function_fallback_sql(expression)
 43
 44                # Ensures we don't enter an infinite loop. This can happen when the original expression
 45                # has the same type as the final expression and there's no _sql method available for it,
 46                # because then it'd re-enter _to_sql.
 47                raise ValueError(
 48                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
 49                )
 50
 51            return transforms_handler(self, expression)
 52
 53        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
 54
 55    return _to_sql
 56
 57
 58def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 59    if isinstance(expression, exp.Select):
 60        count = 0
 61        recursive_ctes = []
 62
 63        for unnest in expression.find_all(exp.Unnest):
 64            if (
 65                not isinstance(unnest.parent, (exp.From, exp.Join))
 66                or len(unnest.expressions) != 1
 67                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 68            ):
 69                continue
 70
 71            generate_date_array = unnest.expressions[0]
 72            start = generate_date_array.args.get("start")
 73            end = generate_date_array.args.get("end")
 74            step = generate_date_array.args.get("step")
 75
 76            if not start or not end or not isinstance(step, exp.Interval):
 77                continue
 78
 79            alias = unnest.args.get("alias")
 80            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 81
 82            start = exp.cast(start, "date")
 83            date_add = exp.func(
 84                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 85            )
 86            cast_date_add = exp.cast(date_add, "date")
 87
 88            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 89
 90            base_query = exp.select(start.as_(column_name))
 91            recursive_query = (
 92                exp.select(cast_date_add)
 93                .from_(cte_name)
 94                .where(cast_date_add <= exp.cast(end, "date"))
 95            )
 96            cte_query = base_query.union(recursive_query, distinct=False)
 97
 98            generate_dates_query = exp.select(column_name).from_(cte_name)
 99            unnest.replace(generate_dates_query.subquery(cte_name))
100
101            recursive_ctes.append(
102                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
103            )
104            count += 1
105
106        if recursive_ctes:
107            with_expression = expression.args.get("with") or exp.With()
108            with_expression.set("recursive", True)
109            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
110            expression.set("with", with_expression)
111
112    return expression
113
114
115def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
116    """Unnests GENERATE_SERIES or SEQUENCE table references."""
117    this = expression.this
118    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
119        unnest = exp.Unnest(expressions=[this])
120        if expression.alias:
121            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
122
123        return unnest
124
125    return expression
126
127
128def unalias_group(expression: exp.Expression) -> exp.Expression:
129    """
130    Replace references to select aliases in GROUP BY clauses.
131
132    Example:
133        >>> import sqlglot
134        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
135        'SELECT a AS b FROM x GROUP BY 1'
136
137    Args:
138        expression: the expression that will be transformed.
139
140    Returns:
141        The transformed expression.
142    """
143    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
144        aliased_selects = {
145            e.alias: i
146            for i, e in enumerate(expression.parent.expressions, start=1)
147            if isinstance(e, exp.Alias)
148        }
149
150        for group_by in expression.expressions:
151            if (
152                isinstance(group_by, exp.Column)
153                and not group_by.table
154                and group_by.name in aliased_selects
155            ):
156                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
157
158    return expression
159
160
161def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
162    """
163    Convert SELECT DISTINCT ON statements to a subquery with a window function.
164
165    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
166
167    Args:
168        expression: the expression that will be transformed.
169
170    Returns:
171        The transformed expression.
172    """
173    if (
174        isinstance(expression, exp.Select)
175        and expression.args.get("distinct")
176        and expression.args["distinct"].args.get("on")
177        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
178    ):
179        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
180        outer_selects = expression.selects
181        row_number = find_new_name(expression.named_selects, "_row_number")
182        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
183        order = expression.args.get("order")
184
185        if order:
186            window.set("order", order.pop())
187        else:
188            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
189
190        window = exp.alias_(window, row_number)
191        expression.select(window, copy=False)
192
193        return (
194            exp.select(*outer_selects, copy=False)
195            .from_(expression.subquery("_t", copy=False), copy=False)
196            .where(exp.column(row_number).eq(1), copy=False)
197        )
198
199    return expression
200
201
202def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
203    """
204    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
205
206    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
207    https://docs.snowflake.com/en/sql-reference/constructs/qualify
208
209    Some dialects don't support window functions in the WHERE clause, so we need to include them as
210    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
211    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
212    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
213    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
214    corresponding expression to avoid creating invalid column references.
215    """
216    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
217        taken = set(expression.named_selects)
218        for select in expression.selects:
219            if not select.alias_or_name:
220                alias = find_new_name(taken, "_c")
221                select.replace(exp.alias_(select, alias))
222                taken.add(alias)
223
224        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
225            alias_or_name = select.alias_or_name
226            identifier = select.args.get("alias") or select.this
227            if isinstance(identifier, exp.Identifier):
228                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
229            return alias_or_name
230
231        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
232        qualify_filters = expression.args["qualify"].pop().this
233        expression_by_alias = {
234            select.alias: select.this
235            for select in expression.selects
236            if isinstance(select, exp.Alias)
237        }
238
239        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
240        for select_candidate in qualify_filters.find_all(select_candidates):
241            if isinstance(select_candidate, exp.Window):
242                if expression_by_alias:
243                    for column in select_candidate.find_all(exp.Column):
244                        expr = expression_by_alias.get(column.name)
245                        if expr:
246                            column.replace(expr)
247
248                alias = find_new_name(expression.named_selects, "_w")
249                expression.select(exp.alias_(select_candidate, alias), copy=False)
250                column = exp.column(alias)
251
252                if isinstance(select_candidate.parent, exp.Qualify):
253                    qualify_filters = column
254                else:
255                    select_candidate.replace(column)
256            elif select_candidate.name not in expression.named_selects:
257                expression.select(select_candidate.copy(), copy=False)
258
259        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
260            qualify_filters, copy=False
261        )
262
263    return expression
264
265
266def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
267    """
268    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
269    other expressions. This transforms removes the precision from parameterized types in expressions.
270    """
271    for node in expression.find_all(exp.DataType):
272        node.set(
273            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
274        )
275
276    return expression
277
278
279def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
280    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
281    from sqlglot.optimizer.scope import find_all_in_scope
282
283    if isinstance(expression, exp.Select):
284        unnest_aliases = {
285            unnest.alias
286            for unnest in find_all_in_scope(expression, exp.Unnest)
287            if isinstance(unnest.parent, (exp.From, exp.Join))
288        }
289        if unnest_aliases:
290            for column in expression.find_all(exp.Column):
291                if column.table in unnest_aliases:
292                    column.set("table", None)
293                elif column.db in unnest_aliases:
294                    column.set("db", None)
295
296    return expression
297
298
299def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
300    """Convert cross join unnest into lateral view explode."""
301    if isinstance(expression, exp.Select):
302        from_ = expression.args.get("from")
303
304        if from_ and isinstance(from_.this, exp.Unnest):
305            unnest = from_.this
306            alias = unnest.args.get("alias")
307            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
308            this, *expressions = unnest.expressions
309            unnest.replace(
310                exp.Table(
311                    this=udtf(
312                        this=this,
313                        expressions=expressions,
314                    ),
315                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
316                )
317            )
318
319        for join in expression.args.get("joins") or []:
320            unnest = join.this
321
322            if isinstance(unnest, exp.Unnest):
323                alias = unnest.args.get("alias")
324                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
325
326                expression.args["joins"].remove(join)
327
328                for e, column in zip(unnest.expressions, alias.columns if alias else []):
329                    expression.append(
330                        "laterals",
331                        exp.Lateral(
332                            this=udtf(this=e),
333                            view=True,
334                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
335                        ),
336                    )
337
338    return expression
339
340
341def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
342    """Convert explode/posexplode into unnest."""
343
344    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
345        if isinstance(expression, exp.Select):
346            from sqlglot.optimizer.scope import Scope
347
348            taken_select_names = set(expression.named_selects)
349            taken_source_names = {name for name, _ in Scope(expression).references}
350
351            def new_name(names: t.Set[str], name: str) -> str:
352                name = find_new_name(names, name)
353                names.add(name)
354                return name
355
356            arrays: t.List[exp.Condition] = []
357            series_alias = new_name(taken_select_names, "pos")
358            series = exp.alias_(
359                exp.Unnest(
360                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
361                ),
362                new_name(taken_source_names, "_u"),
363                table=[series_alias],
364            )
365
366            # we use list here because expression.selects is mutated inside the loop
367            for select in list(expression.selects):
368                explode = select.find(exp.Explode)
369
370                if explode:
371                    pos_alias = ""
372                    explode_alias = ""
373
374                    if isinstance(select, exp.Alias):
375                        explode_alias = select.args["alias"]
376                        alias = select
377                    elif isinstance(select, exp.Aliases):
378                        pos_alias = select.aliases[0]
379                        explode_alias = select.aliases[1]
380                        alias = select.replace(exp.alias_(select.this, "", copy=False))
381                    else:
382                        alias = select.replace(exp.alias_(select, ""))
383                        explode = alias.find(exp.Explode)
384                        assert explode
385
386                    is_posexplode = isinstance(explode, exp.Posexplode)
387                    explode_arg = explode.this
388
389                    if isinstance(explode, exp.ExplodeOuter):
390                        bracket = explode_arg[0]
391                        bracket.set("safe", True)
392                        bracket.set("offset", True)
393                        explode_arg = exp.func(
394                            "IF",
395                            exp.func(
396                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
397                            ).eq(0),
398                            exp.array(bracket, copy=False),
399                            explode_arg,
400                        )
401
402                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
403                    if isinstance(explode_arg, exp.Column):
404                        taken_select_names.add(explode_arg.output_name)
405
406                    unnest_source_alias = new_name(taken_source_names, "_u")
407
408                    if not explode_alias:
409                        explode_alias = new_name(taken_select_names, "col")
410
411                        if is_posexplode:
412                            pos_alias = new_name(taken_select_names, "pos")
413
414                    if not pos_alias:
415                        pos_alias = new_name(taken_select_names, "pos")
416
417                    alias.set("alias", exp.to_identifier(explode_alias))
418
419                    series_table_alias = series.args["alias"].this
420                    column = exp.If(
421                        this=exp.column(series_alias, table=series_table_alias).eq(
422                            exp.column(pos_alias, table=unnest_source_alias)
423                        ),
424                        true=exp.column(explode_alias, table=unnest_source_alias),
425                    )
426
427                    explode.replace(column)
428
429                    if is_posexplode:
430                        expressions = expression.expressions
431                        expressions.insert(
432                            expressions.index(alias) + 1,
433                            exp.If(
434                                this=exp.column(series_alias, table=series_table_alias).eq(
435                                    exp.column(pos_alias, table=unnest_source_alias)
436                                ),
437                                true=exp.column(pos_alias, table=unnest_source_alias),
438                            ).as_(pos_alias),
439                        )
440                        expression.set("expressions", expressions)
441
442                    if not arrays:
443                        if expression.args.get("from"):
444                            expression.join(series, copy=False, join_type="CROSS")
445                        else:
446                            expression.from_(series, copy=False)
447
448                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
449                    arrays.append(size)
450
451                    # trino doesn't support left join unnest with on conditions
452                    # if it did, this would be much simpler
453                    expression.join(
454                        exp.alias_(
455                            exp.Unnest(
456                                expressions=[explode_arg.copy()],
457                                offset=exp.to_identifier(pos_alias),
458                            ),
459                            unnest_source_alias,
460                            table=[explode_alias],
461                        ),
462                        join_type="CROSS",
463                        copy=False,
464                    )
465
466                    if index_offset != 1:
467                        size = size - 1
468
469                    expression.where(
470                        exp.column(series_alias, table=series_table_alias)
471                        .eq(exp.column(pos_alias, table=unnest_source_alias))
472                        .or_(
473                            (exp.column(series_alias, table=series_table_alias) > size).and_(
474                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
475                            )
476                        ),
477                        copy=False,
478                    )
479
480            if arrays:
481                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
482
483                if index_offset != 1:
484                    end = end - (1 - index_offset)
485                series.expressions[0].set("end", end)
486
487        return expression
488
489    return _explode_to_unnest
490
491
492def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
493    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
494    if (
495        isinstance(expression, exp.PERCENTILES)
496        and not isinstance(expression.parent, exp.WithinGroup)
497        and expression.expression
498    ):
499        column = expression.this.pop()
500        expression.set("this", expression.expression.pop())
501        order = exp.Order(expressions=[exp.Ordered(this=column)])
502        expression = exp.WithinGroup(this=expression, expression=order)
503
504    return expression
505
506
507def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
508    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
509    if (
510        isinstance(expression, exp.WithinGroup)
511        and isinstance(expression.this, exp.PERCENTILES)
512        and isinstance(expression.expression, exp.Order)
513    ):
514        quantile = expression.this.this
515        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
516        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
517
518    return expression
519
520
521def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
522    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
523    if isinstance(expression, exp.With) and expression.recursive:
524        next_name = name_sequence("_c_")
525
526        for cte in expression.expressions:
527            if not cte.args["alias"].columns:
528                query = cte.this
529                if isinstance(query, exp.SetOperation):
530                    query = query.this
531
532                cte.args["alias"].set(
533                    "columns",
534                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
535                )
536
537    return expression
538
539
540def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
541    """Replace 'epoch' in casts by the equivalent date literal."""
542    if (
543        isinstance(expression, (exp.Cast, exp.TryCast))
544        and expression.name.lower() == "epoch"
545        and expression.to.this in exp.DataType.TEMPORAL_TYPES
546    ):
547        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
548
549    return expression
550
551
552def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
553    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
554    if isinstance(expression, exp.Select):
555        for join in expression.args.get("joins") or []:
556            on = join.args.get("on")
557            if on and join.kind in ("SEMI", "ANTI"):
558                subquery = exp.select("1").from_(join.this).where(on)
559                exists = exp.Exists(this=subquery)
560                if join.kind == "ANTI":
561                    exists = exists.not_(copy=False)
562
563                join.pop()
564                expression.where(exists, copy=False)
565
566    return expression
567
568
569def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
570    """
571    Converts a query with a FULL OUTER join to a union of identical queries that
572    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
573    for queries that have a single FULL OUTER join.
574    """
575    if isinstance(expression, exp.Select):
576        full_outer_joins = [
577            (index, join)
578            for index, join in enumerate(expression.args.get("joins") or [])
579            if join.side == "FULL"
580        ]
581
582        if len(full_outer_joins) == 1:
583            expression_copy = expression.copy()
584            expression.set("limit", None)
585            index, full_outer_join = full_outer_joins[0]
586            full_outer_join.set("side", "left")
587            expression_copy.args["joins"][index].set("side", "right")
588            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
589
590            return exp.union(expression, expression_copy, copy=False)
591
592    return expression
593
594
595def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
596    """
597    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
598    defined at the top-level, so for example queries like:
599
600        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
601
602    are invalid in those dialects. This transformation can be used to ensure all CTEs are
603    moved to the top level so that the final SQL code is valid from a syntax standpoint.
604
605    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
606    """
607    top_level_with = expression.args.get("with")
608    for inner_with in expression.find_all(exp.With):
609        if inner_with.parent is expression:
610            continue
611
612        if not top_level_with:
613            top_level_with = inner_with.pop()
614            expression.set("with", top_level_with)
615        else:
616            if inner_with.recursive:
617                top_level_with.set("recursive", True)
618
619            parent_cte = inner_with.find_ancestor(exp.CTE)
620            inner_with.pop()
621
622            if parent_cte:
623                i = top_level_with.expressions.index(parent_cte)
624                top_level_with.expressions[i:i] = inner_with.expressions
625                top_level_with.set("expressions", top_level_with.expressions)
626            else:
627                top_level_with.set(
628                    "expressions", top_level_with.expressions + inner_with.expressions
629                )
630
631    return expression
632
633
634def ensure_bools(expression: exp.Expression) -> exp.Expression:
635    """Converts numeric values used in conditions into explicit boolean expressions."""
636    from sqlglot.optimizer.canonicalize import ensure_bools
637
638    def _ensure_bool(node: exp.Expression) -> None:
639        if (
640            node.is_number
641            or (
642                not isinstance(node, exp.SubqueryPredicate)
643                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
644            )
645            or (isinstance(node, exp.Column) and not node.type)
646        ):
647            node.replace(node.neq(0))
648
649    for node in expression.walk():
650        ensure_bools(node, _ensure_bool)
651
652    return expression
653
654
655def unqualify_columns(expression: exp.Expression) -> exp.Expression:
656    for column in expression.find_all(exp.Column):
657        # We only wanna pop off the table, db, catalog args
658        for part in column.parts[:-1]:
659            part.pop()
660
661    return expression
662
663
664def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
665    assert isinstance(expression, exp.Create)
666    for constraint in expression.find_all(exp.UniqueColumnConstraint):
667        if constraint.parent:
668            constraint.parent.pop()
669
670    return expression
671
672
673def ctas_with_tmp_tables_to_create_tmp_view(
674    expression: exp.Expression,
675    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
676) -> exp.Expression:
677    assert isinstance(expression, exp.Create)
678    properties = expression.args.get("properties")
679    temporary = any(
680        isinstance(prop, exp.TemporaryProperty)
681        for prop in (properties.expressions if properties else [])
682    )
683
684    # CTAS with temp tables map to CREATE TEMPORARY VIEW
685    if expression.kind == "TABLE" and temporary:
686        if expression.expression:
687            return exp.Create(
688                kind="TEMPORARY VIEW",
689                this=expression.this,
690                expression=expression.expression,
691            )
692        return tmp_storage_provider(expression)
693
694    return expression
695
696
697def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
698    """
699    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
700    PARTITIONED BY value is an array of column names, they are transformed into a schema.
701    The corresponding columns are removed from the create statement.
702    """
703    assert isinstance(expression, exp.Create)
704    has_schema = isinstance(expression.this, exp.Schema)
705    is_partitionable = expression.kind in {"TABLE", "VIEW"}
706
707    if has_schema and is_partitionable:
708        prop = expression.find(exp.PartitionedByProperty)
709        if prop and prop.this and not isinstance(prop.this, exp.Schema):
710            schema = expression.this
711            columns = {v.name.upper() for v in prop.this.expressions}
712            partitions = [col for col in schema.expressions if col.name.upper() in columns]
713            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
714            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
715            expression.set("this", schema)
716
717    return expression
718
719
720def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
721    """
722    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
723
724    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
725    """
726    assert isinstance(expression, exp.Create)
727    prop = expression.find(exp.PartitionedByProperty)
728    if (
729        prop
730        and prop.this
731        and isinstance(prop.this, exp.Schema)
732        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
733    ):
734        prop_this = exp.Tuple(
735            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
736        )
737        schema = expression.this
738        for e in prop.this.expressions:
739            schema.append("expressions", e)
740        prop.set("this", prop_this)
741
742    return expression
743
744
745def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
746    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
747    if isinstance(expression, exp.Struct):
748        expression.set(
749            "expressions",
750            [
751                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
752                for e in expression.expressions
753            ],
754        )
755
756    return expression
757
758
759def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
760    """
761    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
762    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
763
764    For example,
765        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
766        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
767
768    Args:
769        expression: The AST to remove join marks from.
770
771    Returns:
772       The AST with join marks removed.
773    """
774    from sqlglot.optimizer.scope import traverse_scope
775
776    for scope in traverse_scope(expression):
777        query = scope.expression
778
779        where = query.args.get("where")
780        joins = query.args.get("joins")
781
782        if not where or not joins:
783            continue
784
785        query_from = query.args["from"]
786
787        # These keep track of the joins to be replaced
788        new_joins: t.Dict[str, exp.Join] = {}
789        old_joins = {join.alias_or_name: join for join in joins}
790
791        for column in scope.columns:
792            if not column.args.get("join_mark"):
793                continue
794
795            predicate = column.find_ancestor(exp.Predicate, exp.Select)
796            assert isinstance(
797                predicate, exp.Binary
798            ), "Columns can only be marked with (+) when involved in a binary operation"
799
800            predicate_parent = predicate.parent
801            join_predicate = predicate.pop()
802
803            left_columns = [
804                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
805            ]
806            right_columns = [
807                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
808            ]
809
810            assert not (
811                left_columns and right_columns
812            ), "The (+) marker cannot appear in both sides of a binary predicate"
813
814            marked_column_tables = set()
815            for col in left_columns or right_columns:
816                table = col.table
817                assert table, f"Column {col} needs to be qualified with a table"
818
819                col.set("join_mark", False)
820                marked_column_tables.add(table)
821
822            assert (
823                len(marked_column_tables) == 1
824            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
825
826            join_this = old_joins.get(col.table, query_from).this
827            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
828
829            # Upsert new_join into new_joins dictionary
830            new_join_alias_or_name = new_join.alias_or_name
831            existing_join = new_joins.get(new_join_alias_or_name)
832            if existing_join:
833                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
834            else:
835                new_joins[new_join_alias_or_name] = new_join
836
837            # If the parent of the target predicate is a binary node, then it now has only one child
838            if isinstance(predicate_parent, exp.Binary):
839                if predicate_parent.left is None:
840                    predicate_parent.replace(predicate_parent.right)
841                else:
842                    predicate_parent.replace(predicate_parent.left)
843
844        if query_from.alias_or_name in new_joins:
845            only_old_joins = old_joins.keys() - new_joins.keys()
846            assert (
847                len(only_old_joins) >= 1
848            ), "Cannot determine which table to use in the new FROM clause"
849
850            new_from_name = list(only_old_joins)[0]
851            query.set("from", exp.From(this=old_joins[new_from_name].this))
852
853        query.set("joins", list(new_joins.values()))
854
855        if not where.this:
856            where.pop()
857
858    return expression
def preprocess( transforms: List[Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]]) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.Expression], str]:
13def preprocess(
14    transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
15) -> t.Callable[[Generator, exp.Expression], str]:
16    """
17    Creates a new transform by chaining a sequence of transformations and converts the resulting
18    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
19    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
20
21    Args:
22        transforms: sequence of transform functions. These will be called in order.
23
24    Returns:
25        Function that can be used as a generator transform.
26    """
27
28    def _to_sql(self, expression: exp.Expression) -> str:
29        expression_type = type(expression)
30
31        expression = transforms[0](expression)
32        for transform in transforms[1:]:
33            expression = transform(expression)
34
35        _sql_handler = getattr(self, expression.key + "_sql", None)
36        if _sql_handler:
37            return _sql_handler(expression)
38
39        transforms_handler = self.TRANSFORMS.get(type(expression))
40        if transforms_handler:
41            if expression_type is type(expression):
42                if isinstance(expression, exp.Func):
43                    return self.function_fallback_sql(expression)
44
45                # Ensures we don't enter an infinite loop. This can happen when the original expression
46                # has the same type as the final expression and there's no _sql method available for it,
47                # because then it'd re-enter _to_sql.
48                raise ValueError(
49                    f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
50                )
51
52            return transforms_handler(self, expression)
53
54        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
55
56    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 59def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression:
 60    if isinstance(expression, exp.Select):
 61        count = 0
 62        recursive_ctes = []
 63
 64        for unnest in expression.find_all(exp.Unnest):
 65            if (
 66                not isinstance(unnest.parent, (exp.From, exp.Join))
 67                or len(unnest.expressions) != 1
 68                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 69            ):
 70                continue
 71
 72            generate_date_array = unnest.expressions[0]
 73            start = generate_date_array.args.get("start")
 74            end = generate_date_array.args.get("end")
 75            step = generate_date_array.args.get("step")
 76
 77            if not start or not end or not isinstance(step, exp.Interval):
 78                continue
 79
 80            alias = unnest.args.get("alias")
 81            column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 82
 83            start = exp.cast(start, "date")
 84            date_add = exp.func(
 85                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 86            )
 87            cast_date_add = exp.cast(date_add, "date")
 88
 89            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 90
 91            base_query = exp.select(start.as_(column_name))
 92            recursive_query = (
 93                exp.select(cast_date_add)
 94                .from_(cte_name)
 95                .where(cast_date_add <= exp.cast(end, "date"))
 96            )
 97            cte_query = base_query.union(recursive_query, distinct=False)
 98
 99            generate_dates_query = exp.select(column_name).from_(cte_name)
100            unnest.replace(generate_dates_query.subquery(cte_name))
101
102            recursive_ctes.append(
103                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
104            )
105            count += 1
106
107        if recursive_ctes:
108            with_expression = expression.args.get("with") or exp.With()
109            with_expression.set("recursive", True)
110            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
111            expression.set("with", with_expression)
112
113    return expression
def unnest_generate_series( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
116def unnest_generate_series(expression: exp.Expression) -> exp.Expression:
117    """Unnests GENERATE_SERIES or SEQUENCE table references."""
118    this = expression.this
119    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
120        unnest = exp.Unnest(expressions=[this])
121        if expression.alias:
122            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
123
124        return unnest
125
126    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def unalias_group( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
129def unalias_group(expression: exp.Expression) -> exp.Expression:
130    """
131    Replace references to select aliases in GROUP BY clauses.
132
133    Example:
134        >>> import sqlglot
135        >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
136        'SELECT a AS b FROM x GROUP BY 1'
137
138    Args:
139        expression: the expression that will be transformed.
140
141    Returns:
142        The transformed expression.
143    """
144    if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
145        aliased_selects = {
146            e.alias: i
147            for i, e in enumerate(expression.parent.expressions, start=1)
148            if isinstance(e, exp.Alias)
149        }
150
151        for group_by in expression.expressions:
152            if (
153                isinstance(group_by, exp.Column)
154                and not group_by.table
155                and group_by.name in aliased_selects
156            ):
157                group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
158
159    return expression

Replace references to select aliases in GROUP BY clauses.

Example:
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_distinct_on( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
162def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
163    """
164    Convert SELECT DISTINCT ON statements to a subquery with a window function.
165
166    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
167
168    Args:
169        expression: the expression that will be transformed.
170
171    Returns:
172        The transformed expression.
173    """
174    if (
175        isinstance(expression, exp.Select)
176        and expression.args.get("distinct")
177        and expression.args["distinct"].args.get("on")
178        and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
179    ):
180        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
181        outer_selects = expression.selects
182        row_number = find_new_name(expression.named_selects, "_row_number")
183        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
184        order = expression.args.get("order")
185
186        if order:
187            window.set("order", order.pop())
188        else:
189            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
190
191        window = exp.alias_(window, row_number)
192        expression.select(window, copy=False)
193
194        return (
195            exp.select(*outer_selects, copy=False)
196            .from_(expression.subquery("_t", copy=False), copy=False)
197            .where(exp.column(row_number).eq(1), copy=False)
198        )
199
200    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
203def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
204    """
205    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
206
207    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
208    https://docs.snowflake.com/en/sql-reference/constructs/qualify
209
210    Some dialects don't support window functions in the WHERE clause, so we need to include them as
211    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
212    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
213    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
214    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
215    corresponding expression to avoid creating invalid column references.
216    """
217    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
218        taken = set(expression.named_selects)
219        for select in expression.selects:
220            if not select.alias_or_name:
221                alias = find_new_name(taken, "_c")
222                select.replace(exp.alias_(select, alias))
223                taken.add(alias)
224
225        def _select_alias_or_name(select: exp.Expression) -> str | exp.Column:
226            alias_or_name = select.alias_or_name
227            identifier = select.args.get("alias") or select.this
228            if isinstance(identifier, exp.Identifier):
229                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
230            return alias_or_name
231
232        outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects)))
233        qualify_filters = expression.args["qualify"].pop().this
234        expression_by_alias = {
235            select.alias: select.this
236            for select in expression.selects
237            if isinstance(select, exp.Alias)
238        }
239
240        select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
241        for select_candidate in qualify_filters.find_all(select_candidates):
242            if isinstance(select_candidate, exp.Window):
243                if expression_by_alias:
244                    for column in select_candidate.find_all(exp.Column):
245                        expr = expression_by_alias.get(column.name)
246                        if expr:
247                            column.replace(expr)
248
249                alias = find_new_name(expression.named_selects, "_w")
250                expression.select(exp.alias_(select_candidate, alias), copy=False)
251                column = exp.column(alias)
252
253                if isinstance(select_candidate.parent, exp.Qualify):
254                    qualify_filters = column
255                else:
256                    select_candidate.replace(column)
257            elif select_candidate.name not in expression.named_selects:
258                expression.select(select_candidate.copy(), copy=False)
259
260        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
261            qualify_filters, copy=False
262        )
263
264    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
267def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
268    """
269    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
270    other expressions. This transforms removes the precision from parameterized types in expressions.
271    """
272    for node in expression.find_all(exp.DataType):
273        node.set(
274            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
275        )
276
277    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
280def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
281    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
282    from sqlglot.optimizer.scope import find_all_in_scope
283
284    if isinstance(expression, exp.Select):
285        unnest_aliases = {
286            unnest.alias
287            for unnest in find_all_in_scope(expression, exp.Unnest)
288            if isinstance(unnest.parent, (exp.From, exp.Join))
289        }
290        if unnest_aliases:
291            for column in expression.find_all(exp.Column):
292                if column.table in unnest_aliases:
293                    column.set("table", None)
294                elif column.db in unnest_aliases:
295                    column.set("db", None)
296
297    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
300def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
301    """Convert cross join unnest into lateral view explode."""
302    if isinstance(expression, exp.Select):
303        from_ = expression.args.get("from")
304
305        if from_ and isinstance(from_.this, exp.Unnest):
306            unnest = from_.this
307            alias = unnest.args.get("alias")
308            udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
309            this, *expressions = unnest.expressions
310            unnest.replace(
311                exp.Table(
312                    this=udtf(
313                        this=this,
314                        expressions=expressions,
315                    ),
316                    alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None,
317                )
318            )
319
320        for join in expression.args.get("joins") or []:
321            unnest = join.this
322
323            if isinstance(unnest, exp.Unnest):
324                alias = unnest.args.get("alias")
325                udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
326
327                expression.args["joins"].remove(join)
328
329                for e, column in zip(unnest.expressions, alias.columns if alias else []):
330                    expression.append(
331                        "laterals",
332                        exp.Lateral(
333                            this=udtf(this=e),
334                            view=True,
335                            alias=exp.TableAlias(this=alias.this, columns=[column]),  # type: ignore
336                        ),
337                    )
338
339    return expression

Convert cross join unnest into lateral view explode.

def explode_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression]:
342def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
343    """Convert explode/posexplode into unnest."""
344
345    def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
346        if isinstance(expression, exp.Select):
347            from sqlglot.optimizer.scope import Scope
348
349            taken_select_names = set(expression.named_selects)
350            taken_source_names = {name for name, _ in Scope(expression).references}
351
352            def new_name(names: t.Set[str], name: str) -> str:
353                name = find_new_name(names, name)
354                names.add(name)
355                return name
356
357            arrays: t.List[exp.Condition] = []
358            series_alias = new_name(taken_select_names, "pos")
359            series = exp.alias_(
360                exp.Unnest(
361                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
362                ),
363                new_name(taken_source_names, "_u"),
364                table=[series_alias],
365            )
366
367            # we use list here because expression.selects is mutated inside the loop
368            for select in list(expression.selects):
369                explode = select.find(exp.Explode)
370
371                if explode:
372                    pos_alias = ""
373                    explode_alias = ""
374
375                    if isinstance(select, exp.Alias):
376                        explode_alias = select.args["alias"]
377                        alias = select
378                    elif isinstance(select, exp.Aliases):
379                        pos_alias = select.aliases[0]
380                        explode_alias = select.aliases[1]
381                        alias = select.replace(exp.alias_(select.this, "", copy=False))
382                    else:
383                        alias = select.replace(exp.alias_(select, ""))
384                        explode = alias.find(exp.Explode)
385                        assert explode
386
387                    is_posexplode = isinstance(explode, exp.Posexplode)
388                    explode_arg = explode.this
389
390                    if isinstance(explode, exp.ExplodeOuter):
391                        bracket = explode_arg[0]
392                        bracket.set("safe", True)
393                        bracket.set("offset", True)
394                        explode_arg = exp.func(
395                            "IF",
396                            exp.func(
397                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
398                            ).eq(0),
399                            exp.array(bracket, copy=False),
400                            explode_arg,
401                        )
402
403                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
404                    if isinstance(explode_arg, exp.Column):
405                        taken_select_names.add(explode_arg.output_name)
406
407                    unnest_source_alias = new_name(taken_source_names, "_u")
408
409                    if not explode_alias:
410                        explode_alias = new_name(taken_select_names, "col")
411
412                        if is_posexplode:
413                            pos_alias = new_name(taken_select_names, "pos")
414
415                    if not pos_alias:
416                        pos_alias = new_name(taken_select_names, "pos")
417
418                    alias.set("alias", exp.to_identifier(explode_alias))
419
420                    series_table_alias = series.args["alias"].this
421                    column = exp.If(
422                        this=exp.column(series_alias, table=series_table_alias).eq(
423                            exp.column(pos_alias, table=unnest_source_alias)
424                        ),
425                        true=exp.column(explode_alias, table=unnest_source_alias),
426                    )
427
428                    explode.replace(column)
429
430                    if is_posexplode:
431                        expressions = expression.expressions
432                        expressions.insert(
433                            expressions.index(alias) + 1,
434                            exp.If(
435                                this=exp.column(series_alias, table=series_table_alias).eq(
436                                    exp.column(pos_alias, table=unnest_source_alias)
437                                ),
438                                true=exp.column(pos_alias, table=unnest_source_alias),
439                            ).as_(pos_alias),
440                        )
441                        expression.set("expressions", expressions)
442
443                    if not arrays:
444                        if expression.args.get("from"):
445                            expression.join(series, copy=False, join_type="CROSS")
446                        else:
447                            expression.from_(series, copy=False)
448
449                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
450                    arrays.append(size)
451
452                    # trino doesn't support left join unnest with on conditions
453                    # if it did, this would be much simpler
454                    expression.join(
455                        exp.alias_(
456                            exp.Unnest(
457                                expressions=[explode_arg.copy()],
458                                offset=exp.to_identifier(pos_alias),
459                            ),
460                            unnest_source_alias,
461                            table=[explode_alias],
462                        ),
463                        join_type="CROSS",
464                        copy=False,
465                    )
466
467                    if index_offset != 1:
468                        size = size - 1
469
470                    expression.where(
471                        exp.column(series_alias, table=series_table_alias)
472                        .eq(exp.column(pos_alias, table=unnest_source_alias))
473                        .or_(
474                            (exp.column(series_alias, table=series_table_alias) > size).and_(
475                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
476                            )
477                        ),
478                        copy=False,
479                    )
480
481            if arrays:
482                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
483
484                if index_offset != 1:
485                    end = end - (1 - index_offset)
486                series.expressions[0].set("end", end)
487
488        return expression
489
490    return _explode_to_unnest

Convert explode/posexplode into unnest.

def add_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
493def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
494    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
495    if (
496        isinstance(expression, exp.PERCENTILES)
497        and not isinstance(expression.parent, exp.WithinGroup)
498        and expression.expression
499    ):
500        column = expression.this.pop()
501        expression.set("this", expression.expression.pop())
502        order = exp.Order(expressions=[exp.Ordered(this=column)])
503        expression = exp.WithinGroup(this=expression, expression=order)
504
505    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
508def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
509    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
510    if (
511        isinstance(expression, exp.WithinGroup)
512        and isinstance(expression.this, exp.PERCENTILES)
513        and isinstance(expression.expression, exp.Order)
514    ):
515        quantile = expression.this.this
516        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
517        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
518
519    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
522def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
523    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
524    if isinstance(expression, exp.With) and expression.recursive:
525        next_name = name_sequence("_c_")
526
527        for cte in expression.expressions:
528            if not cte.args["alias"].columns:
529                query = cte.this
530                if isinstance(query, exp.SetOperation):
531                    query = query.this
532
533                cte.args["alias"].set(
534                    "columns",
535                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
536                )
537
538    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
541def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
542    """Replace 'epoch' in casts by the equivalent date literal."""
543    if (
544        isinstance(expression, (exp.Cast, exp.TryCast))
545        and expression.name.lower() == "epoch"
546        and expression.to.this in exp.DataType.TEMPORAL_TYPES
547    ):
548        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
549
550    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
553def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
554    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
555    if isinstance(expression, exp.Select):
556        for join in expression.args.get("joins") or []:
557            on = join.args.get("on")
558            if on and join.kind in ("SEMI", "ANTI"):
559                subquery = exp.select("1").from_(join.this).where(on)
560                exists = exp.Exists(this=subquery)
561                if join.kind == "ANTI":
562                    exists = exists.not_(copy=False)
563
564                join.pop()
565                expression.where(exists, copy=False)
566
567    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
570def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression:
571    """
572    Converts a query with a FULL OUTER join to a union of identical queries that
573    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
574    for queries that have a single FULL OUTER join.
575    """
576    if isinstance(expression, exp.Select):
577        full_outer_joins = [
578            (index, join)
579            for index, join in enumerate(expression.args.get("joins") or [])
580            if join.side == "FULL"
581        ]
582
583        if len(full_outer_joins) == 1:
584            expression_copy = expression.copy()
585            expression.set("limit", None)
586            index, full_outer_join = full_outer_joins[0]
587            full_outer_join.set("side", "left")
588            expression_copy.args["joins"][index].set("side", "right")
589            expression_copy.args.pop("with", None)  # remove CTEs from RIGHT side
590
591            return exp.union(expression, expression_copy, copy=False)
592
593    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
596def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression:
597    """
598    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
599    defined at the top-level, so for example queries like:
600
601        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
602
603    are invalid in those dialects. This transformation can be used to ensure all CTEs are
604    moved to the top level so that the final SQL code is valid from a syntax standpoint.
605
606    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
607    """
608    top_level_with = expression.args.get("with")
609    for inner_with in expression.find_all(exp.With):
610        if inner_with.parent is expression:
611            continue
612
613        if not top_level_with:
614            top_level_with = inner_with.pop()
615            expression.set("with", top_level_with)
616        else:
617            if inner_with.recursive:
618                top_level_with.set("recursive", True)
619
620            parent_cte = inner_with.find_ancestor(exp.CTE)
621            inner_with.pop()
622
623            if parent_cte:
624                i = top_level_with.expressions.index(parent_cte)
625                top_level_with.expressions[i:i] = inner_with.expressions
626                top_level_with.set("expressions", top_level_with.expressions)
627            else:
628                top_level_with.set(
629                    "expressions", top_level_with.expressions + inner_with.expressions
630                )
631
632    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
635def ensure_bools(expression: exp.Expression) -> exp.Expression:
636    """Converts numeric values used in conditions into explicit boolean expressions."""
637    from sqlglot.optimizer.canonicalize import ensure_bools
638
639    def _ensure_bool(node: exp.Expression) -> None:
640        if (
641            node.is_number
642            or (
643                not isinstance(node, exp.SubqueryPredicate)
644                and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
645            )
646            or (isinstance(node, exp.Column) and not node.type)
647        ):
648            node.replace(node.neq(0))
649
650    for node in expression.walk():
651        ensure_bools(node, _ensure_bool)
652
653    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
656def unqualify_columns(expression: exp.Expression) -> exp.Expression:
657    for column in expression.find_all(exp.Column):
658        # We only wanna pop off the table, db, catalog args
659        for part in column.parts[:-1]:
660            part.pop()
661
662    return expression
def remove_unique_constraints( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
665def remove_unique_constraints(expression: exp.Expression) -> exp.Expression:
666    assert isinstance(expression, exp.Create)
667    for constraint in expression.find_all(exp.UniqueColumnConstraint):
668        if constraint.parent:
669            constraint.parent.pop()
670
671    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.Expression, tmp_storage_provider: Callable[[sqlglot.expressions.Expression], sqlglot.expressions.Expression] = <function <lambda>>) -> sqlglot.expressions.Expression:
674def ctas_with_tmp_tables_to_create_tmp_view(
675    expression: exp.Expression,
676    tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e,
677) -> exp.Expression:
678    assert isinstance(expression, exp.Create)
679    properties = expression.args.get("properties")
680    temporary = any(
681        isinstance(prop, exp.TemporaryProperty)
682        for prop in (properties.expressions if properties else [])
683    )
684
685    # CTAS with temp tables map to CREATE TEMPORARY VIEW
686    if expression.kind == "TABLE" and temporary:
687        if expression.expression:
688            return exp.Create(
689                kind="TEMPORARY VIEW",
690                this=expression.this,
691                expression=expression.expression,
692            )
693        return tmp_storage_provider(expression)
694
695    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
698def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression:
699    """
700    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
701    PARTITIONED BY value is an array of column names, they are transformed into a schema.
702    The corresponding columns are removed from the create statement.
703    """
704    assert isinstance(expression, exp.Create)
705    has_schema = isinstance(expression.this, exp.Schema)
706    is_partitionable = expression.kind in {"TABLE", "VIEW"}
707
708    if has_schema and is_partitionable:
709        prop = expression.find(exp.PartitionedByProperty)
710        if prop and prop.this and not isinstance(prop.this, exp.Schema):
711            schema = expression.this
712            columns = {v.name.upper() for v in prop.this.expressions}
713            partitions = [col for col in schema.expressions if col.name.upper() in columns]
714            schema.set("expressions", [e for e in schema.expressions if e not in partitions])
715            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
716            expression.set("this", schema)
717
718    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
721def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression:
722    """
723    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
724
725    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
726    """
727    assert isinstance(expression, exp.Create)
728    prop = expression.find(exp.PartitionedByProperty)
729    if (
730        prop
731        and prop.this
732        and isinstance(prop.this, exp.Schema)
733        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
734    ):
735        prop_this = exp.Tuple(
736            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
737        )
738        schema = expression.this
739        for e in prop.this.expressions:
740            schema.append("expressions", e)
741        prop.set("this", prop_this)
742
743    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
746def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
747    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
748    if isinstance(expression, exp.Struct):
749        expression.set(
750            "expressions",
751            [
752                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
753                for e in expression.expressions
754            ],
755        )
756
757    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
760def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
761    """
762    Remove join marks from an AST. This rule assumes that all marked columns are qualified.
763    If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
764
765    For example,
766        SELECT * FROM a, b WHERE a.id = b.id(+)    -- ... is converted to
767        SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
768
769    Args:
770        expression: The AST to remove join marks from.
771
772    Returns:
773       The AST with join marks removed.
774    """
775    from sqlglot.optimizer.scope import traverse_scope
776
777    for scope in traverse_scope(expression):
778        query = scope.expression
779
780        where = query.args.get("where")
781        joins = query.args.get("joins")
782
783        if not where or not joins:
784            continue
785
786        query_from = query.args["from"]
787
788        # These keep track of the joins to be replaced
789        new_joins: t.Dict[str, exp.Join] = {}
790        old_joins = {join.alias_or_name: join for join in joins}
791
792        for column in scope.columns:
793            if not column.args.get("join_mark"):
794                continue
795
796            predicate = column.find_ancestor(exp.Predicate, exp.Select)
797            assert isinstance(
798                predicate, exp.Binary
799            ), "Columns can only be marked with (+) when involved in a binary operation"
800
801            predicate_parent = predicate.parent
802            join_predicate = predicate.pop()
803
804            left_columns = [
805                c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
806            ]
807            right_columns = [
808                c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
809            ]
810
811            assert not (
812                left_columns and right_columns
813            ), "The (+) marker cannot appear in both sides of a binary predicate"
814
815            marked_column_tables = set()
816            for col in left_columns or right_columns:
817                table = col.table
818                assert table, f"Column {col} needs to be qualified with a table"
819
820                col.set("join_mark", False)
821                marked_column_tables.add(table)
822
823            assert (
824                len(marked_column_tables) == 1
825            ), "Columns of only a single table can be marked with (+) in a given binary predicate"
826
827            join_this = old_joins.get(col.table, query_from).this
828            new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")
829
830            # Upsert new_join into new_joins dictionary
831            new_join_alias_or_name = new_join.alias_or_name
832            existing_join = new_joins.get(new_join_alias_or_name)
833            if existing_join:
834                existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
835            else:
836                new_joins[new_join_alias_or_name] = new_join
837
838            # If the parent of the target predicate is a binary node, then it now has only one child
839            if isinstance(predicate_parent, exp.Binary):
840                if predicate_parent.left is None:
841                    predicate_parent.replace(predicate_parent.right)
842                else:
843                    predicate_parent.replace(predicate_parent.left)
844
845        if query_from.alias_or_name in new_joins:
846            only_old_joins = old_joins.keys() - new_joins.keys()
847            assert (
848                len(only_old_joins) >= 1
849            ), "Cannot determine which table to use in the new FROM clause"
850
851            new_from_name = list(only_old_joins)[0]
852            query.set("from", exp.From(this=old_joins[new_from_name].this))
853
854        query.set("joins", list(new_joins.values()))
855
856        if not where.this:
857            where.pop()
858
859    return expression

Remove join marks from an AST. This rule assumes that all marked columns are qualified. If this does not hold for a query, consider running sqlglot.optimizer.qualify first.

For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this

Arguments:
  • expression: The AST to remove join marks from.
Returns:

The AST with join marks removed.