sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def preprocess( 13 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 14) -> t.Callable[[Generator, exp.Expression], str]: 15 """ 16 Creates a new transform by chaining a sequence of transformations and converts the resulting 17 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 18 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 19 20 Args: 21 transforms: sequence of transform functions. These will be called in order. 22 23 Returns: 24 Function that can be used as a generator transform. 25 """ 26 27 def _to_sql(self, expression: exp.Expression) -> str: 28 expression_type = type(expression) 29 30 expression = transforms[0](expression) 31 for transform in transforms[1:]: 32 expression = transform(expression) 33 34 _sql_handler = getattr(self, expression.key + "_sql", None) 35 if _sql_handler: 36 return _sql_handler(expression) 37 38 transforms_handler = self.TRANSFORMS.get(type(expression)) 39 if transforms_handler: 40 if expression_type is type(expression): 41 if isinstance(expression, exp.Func): 42 return self.function_fallback_sql(expression) 43 44 # Ensures we don't enter an infinite loop. This can happen when the original expression 45 # has the same type as the final expression and there's no _sql method available for it, 46 # because then it'd re-enter _to_sql. 47 raise ValueError( 48 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 49 ) 50 51 return transforms_handler(self, expression) 52 53 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 54 55 return _to_sql 56 57 58def unalias_group(expression: exp.Expression) -> exp.Expression: 59 """ 60 Replace references to select aliases in GROUP BY clauses. 61 62 Example: 63 >>> import sqlglot 64 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 65 'SELECT a AS b FROM x GROUP BY 1' 66 67 Args: 68 expression: the expression that will be transformed. 69 70 Returns: 71 The transformed expression. 72 """ 73 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 74 aliased_selects = { 75 e.alias: i 76 for i, e in enumerate(expression.parent.expressions, start=1) 77 if isinstance(e, exp.Alias) 78 } 79 80 for group_by in expression.expressions: 81 if ( 82 isinstance(group_by, exp.Column) 83 and not group_by.table 84 and group_by.name in aliased_selects 85 ): 86 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 87 88 return expression 89 90 91def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 92 """ 93 Convert SELECT DISTINCT ON statements to a subquery with a window function. 94 95 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 96 97 Args: 98 expression: the expression that will be transformed. 99 100 Returns: 101 The transformed expression. 102 """ 103 if ( 104 isinstance(expression, exp.Select) 105 and expression.args.get("distinct") 106 and expression.args["distinct"].args.get("on") 107 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 108 ): 109 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 110 outer_selects = expression.selects 111 row_number = find_new_name(expression.named_selects, "_row_number") 112 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 113 order = expression.args.get("order") 114 115 if order: 116 window.set("order", order.pop()) 117 else: 118 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 119 120 window = exp.alias_(window, row_number) 121 expression.select(window, copy=False) 122 123 return ( 124 exp.select(*outer_selects, copy=False) 125 .from_(expression.subquery("_t", copy=False), copy=False) 126 .where(exp.column(row_number).eq(1), copy=False) 127 ) 128 129 return expression 130 131 132def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 133 """ 134 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 135 136 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 137 https://docs.snowflake.com/en/sql-reference/constructs/qualify 138 139 Some dialects don't support window functions in the WHERE clause, so we need to include them as 140 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 141 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 142 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 143 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 144 corresponding expression to avoid creating invalid column references. 145 """ 146 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 147 taken = set(expression.named_selects) 148 for select in expression.selects: 149 if not select.alias_or_name: 150 alias = find_new_name(taken, "_c") 151 select.replace(exp.alias_(select, alias)) 152 taken.add(alias) 153 154 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 155 alias_or_name = select.alias_or_name 156 identifier = select.args.get("alias") or select.this 157 if isinstance(identifier, exp.Identifier): 158 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 159 return alias_or_name 160 161 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 162 qualify_filters = expression.args["qualify"].pop().this 163 expression_by_alias = { 164 select.alias: select.this 165 for select in expression.selects 166 if isinstance(select, exp.Alias) 167 } 168 169 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 170 for select_candidate in qualify_filters.find_all(select_candidates): 171 if isinstance(select_candidate, exp.Window): 172 if expression_by_alias: 173 for column in select_candidate.find_all(exp.Column): 174 expr = expression_by_alias.get(column.name) 175 if expr: 176 column.replace(expr) 177 178 alias = find_new_name(expression.named_selects, "_w") 179 expression.select(exp.alias_(select_candidate, alias), copy=False) 180 column = exp.column(alias) 181 182 if isinstance(select_candidate.parent, exp.Qualify): 183 qualify_filters = column 184 else: 185 select_candidate.replace(column) 186 elif select_candidate.name not in expression.named_selects: 187 expression.select(select_candidate.copy(), copy=False) 188 189 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 190 qualify_filters, copy=False 191 ) 192 193 return expression 194 195 196def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 197 """ 198 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 199 other expressions. This transforms removes the precision from parameterized types in expressions. 200 """ 201 for node in expression.find_all(exp.DataType): 202 node.set( 203 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 204 ) 205 206 return expression 207 208 209def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 210 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 211 from sqlglot.optimizer.scope import find_all_in_scope 212 213 if isinstance(expression, exp.Select): 214 unnest_aliases = { 215 unnest.alias 216 for unnest in find_all_in_scope(expression, exp.Unnest) 217 if isinstance(unnest.parent, (exp.From, exp.Join)) 218 } 219 if unnest_aliases: 220 for column in expression.find_all(exp.Column): 221 if column.table in unnest_aliases: 222 column.set("table", None) 223 elif column.db in unnest_aliases: 224 column.set("db", None) 225 226 return expression 227 228 229def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 230 """Convert cross join unnest into lateral view explode.""" 231 if isinstance(expression, exp.Select): 232 for join in expression.args.get("joins") or []: 233 unnest = join.this 234 235 if isinstance(unnest, exp.Unnest): 236 alias = unnest.args.get("alias") 237 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 238 239 expression.args["joins"].remove(join) 240 241 for e, column in zip(unnest.expressions, alias.columns if alias else []): 242 expression.append( 243 "laterals", 244 exp.Lateral( 245 this=udtf(this=e), 246 view=True, 247 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 248 ), 249 ) 250 251 return expression 252 253 254def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 255 """Convert explode/posexplode into unnest.""" 256 257 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 258 if isinstance(expression, exp.Select): 259 from sqlglot.optimizer.scope import Scope 260 261 taken_select_names = set(expression.named_selects) 262 taken_source_names = {name for name, _ in Scope(expression).references} 263 264 def new_name(names: t.Set[str], name: str) -> str: 265 name = find_new_name(names, name) 266 names.add(name) 267 return name 268 269 arrays: t.List[exp.Condition] = [] 270 series_alias = new_name(taken_select_names, "pos") 271 series = exp.alias_( 272 exp.Unnest( 273 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 274 ), 275 new_name(taken_source_names, "_u"), 276 table=[series_alias], 277 ) 278 279 # we use list here because expression.selects is mutated inside the loop 280 for select in list(expression.selects): 281 explode = select.find(exp.Explode) 282 283 if explode: 284 pos_alias = "" 285 explode_alias = "" 286 287 if isinstance(select, exp.Alias): 288 explode_alias = select.args["alias"] 289 alias = select 290 elif isinstance(select, exp.Aliases): 291 pos_alias = select.aliases[0] 292 explode_alias = select.aliases[1] 293 alias = select.replace(exp.alias_(select.this, "", copy=False)) 294 else: 295 alias = select.replace(exp.alias_(select, "")) 296 explode = alias.find(exp.Explode) 297 assert explode 298 299 is_posexplode = isinstance(explode, exp.Posexplode) 300 explode_arg = explode.this 301 302 if isinstance(explode, exp.ExplodeOuter): 303 bracket = explode_arg[0] 304 bracket.set("safe", True) 305 bracket.set("offset", True) 306 explode_arg = exp.func( 307 "IF", 308 exp.func( 309 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 310 ).eq(0), 311 exp.array(bracket, copy=False), 312 explode_arg, 313 ) 314 315 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 316 if isinstance(explode_arg, exp.Column): 317 taken_select_names.add(explode_arg.output_name) 318 319 unnest_source_alias = new_name(taken_source_names, "_u") 320 321 if not explode_alias: 322 explode_alias = new_name(taken_select_names, "col") 323 324 if is_posexplode: 325 pos_alias = new_name(taken_select_names, "pos") 326 327 if not pos_alias: 328 pos_alias = new_name(taken_select_names, "pos") 329 330 alias.set("alias", exp.to_identifier(explode_alias)) 331 332 series_table_alias = series.args["alias"].this 333 column = exp.If( 334 this=exp.column(series_alias, table=series_table_alias).eq( 335 exp.column(pos_alias, table=unnest_source_alias) 336 ), 337 true=exp.column(explode_alias, table=unnest_source_alias), 338 ) 339 340 explode.replace(column) 341 342 if is_posexplode: 343 expressions = expression.expressions 344 expressions.insert( 345 expressions.index(alias) + 1, 346 exp.If( 347 this=exp.column(series_alias, table=series_table_alias).eq( 348 exp.column(pos_alias, table=unnest_source_alias) 349 ), 350 true=exp.column(pos_alias, table=unnest_source_alias), 351 ).as_(pos_alias), 352 ) 353 expression.set("expressions", expressions) 354 355 if not arrays: 356 if expression.args.get("from"): 357 expression.join(series, copy=False, join_type="CROSS") 358 else: 359 expression.from_(series, copy=False) 360 361 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 362 arrays.append(size) 363 364 # trino doesn't support left join unnest with on conditions 365 # if it did, this would be much simpler 366 expression.join( 367 exp.alias_( 368 exp.Unnest( 369 expressions=[explode_arg.copy()], 370 offset=exp.to_identifier(pos_alias), 371 ), 372 unnest_source_alias, 373 table=[explode_alias], 374 ), 375 join_type="CROSS", 376 copy=False, 377 ) 378 379 if index_offset != 1: 380 size = size - 1 381 382 expression.where( 383 exp.column(series_alias, table=series_table_alias) 384 .eq(exp.column(pos_alias, table=unnest_source_alias)) 385 .or_( 386 (exp.column(series_alias, table=series_table_alias) > size).and_( 387 exp.column(pos_alias, table=unnest_source_alias).eq(size) 388 ) 389 ), 390 copy=False, 391 ) 392 393 if arrays: 394 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 395 396 if index_offset != 1: 397 end = end - (1 - index_offset) 398 series.expressions[0].set("end", end) 399 400 return expression 401 402 return _explode_to_unnest 403 404 405def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 406 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 407 if ( 408 isinstance(expression, exp.PERCENTILES) 409 and not isinstance(expression.parent, exp.WithinGroup) 410 and expression.expression 411 ): 412 column = expression.this.pop() 413 expression.set("this", expression.expression.pop()) 414 order = exp.Order(expressions=[exp.Ordered(this=column)]) 415 expression = exp.WithinGroup(this=expression, expression=order) 416 417 return expression 418 419 420def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 421 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 422 if ( 423 isinstance(expression, exp.WithinGroup) 424 and isinstance(expression.this, exp.PERCENTILES) 425 and isinstance(expression.expression, exp.Order) 426 ): 427 quantile = expression.this.this 428 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 429 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 430 431 return expression 432 433 434def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 435 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 436 if isinstance(expression, exp.With) and expression.recursive: 437 next_name = name_sequence("_c_") 438 439 for cte in expression.expressions: 440 if not cte.args["alias"].columns: 441 query = cte.this 442 if isinstance(query, exp.SetOperation): 443 query = query.this 444 445 cte.args["alias"].set( 446 "columns", 447 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 448 ) 449 450 return expression 451 452 453def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 454 """Replace 'epoch' in casts by the equivalent date literal.""" 455 if ( 456 isinstance(expression, (exp.Cast, exp.TryCast)) 457 and expression.name.lower() == "epoch" 458 and expression.to.this in exp.DataType.TEMPORAL_TYPES 459 ): 460 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 461 462 return expression 463 464 465def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 466 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 467 if isinstance(expression, exp.Select): 468 for join in expression.args.get("joins") or []: 469 on = join.args.get("on") 470 if on and join.kind in ("SEMI", "ANTI"): 471 subquery = exp.select("1").from_(join.this).where(on) 472 exists = exp.Exists(this=subquery) 473 if join.kind == "ANTI": 474 exists = exists.not_(copy=False) 475 476 join.pop() 477 expression.where(exists, copy=False) 478 479 return expression 480 481 482def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 483 """ 484 Converts a query with a FULL OUTER join to a union of identical queries that 485 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 486 for queries that have a single FULL OUTER join. 487 """ 488 if isinstance(expression, exp.Select): 489 full_outer_joins = [ 490 (index, join) 491 for index, join in enumerate(expression.args.get("joins") or []) 492 if join.side == "FULL" 493 ] 494 495 if len(full_outer_joins) == 1: 496 expression_copy = expression.copy() 497 expression.set("limit", None) 498 index, full_outer_join = full_outer_joins[0] 499 full_outer_join.set("side", "left") 500 expression_copy.args["joins"][index].set("side", "right") 501 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 502 503 return exp.union(expression, expression_copy, copy=False) 504 505 return expression 506 507 508def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 509 """ 510 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 511 defined at the top-level, so for example queries like: 512 513 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 514 515 are invalid in those dialects. This transformation can be used to ensure all CTEs are 516 moved to the top level so that the final SQL code is valid from a syntax standpoint. 517 518 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 519 """ 520 top_level_with = expression.args.get("with") 521 for inner_with in expression.find_all(exp.With): 522 if inner_with.parent is expression: 523 continue 524 525 if not top_level_with: 526 top_level_with = inner_with.pop() 527 expression.set("with", top_level_with) 528 else: 529 if inner_with.recursive: 530 top_level_with.set("recursive", True) 531 532 parent_cte = inner_with.find_ancestor(exp.CTE) 533 inner_with.pop() 534 535 if parent_cte: 536 i = top_level_with.expressions.index(parent_cte) 537 top_level_with.expressions[i:i] = inner_with.expressions 538 top_level_with.set("expressions", top_level_with.expressions) 539 else: 540 top_level_with.set( 541 "expressions", top_level_with.expressions + inner_with.expressions 542 ) 543 544 return expression 545 546 547def ensure_bools(expression: exp.Expression) -> exp.Expression: 548 """Converts numeric values used in conditions into explicit boolean expressions.""" 549 from sqlglot.optimizer.canonicalize import ensure_bools 550 551 def _ensure_bool(node: exp.Expression) -> None: 552 if ( 553 node.is_number 554 or ( 555 not isinstance(node, exp.SubqueryPredicate) 556 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 557 ) 558 or (isinstance(node, exp.Column) and not node.type) 559 ): 560 node.replace(node.neq(0)) 561 562 for node in expression.walk(): 563 ensure_bools(node, _ensure_bool) 564 565 return expression 566 567 568def unqualify_columns(expression: exp.Expression) -> exp.Expression: 569 for column in expression.find_all(exp.Column): 570 # We only wanna pop off the table, db, catalog args 571 for part in column.parts[:-1]: 572 part.pop() 573 574 return expression 575 576 577def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 578 assert isinstance(expression, exp.Create) 579 for constraint in expression.find_all(exp.UniqueColumnConstraint): 580 if constraint.parent: 581 constraint.parent.pop() 582 583 return expression 584 585 586def ctas_with_tmp_tables_to_create_tmp_view( 587 expression: exp.Expression, 588 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 589) -> exp.Expression: 590 assert isinstance(expression, exp.Create) 591 properties = expression.args.get("properties") 592 temporary = any( 593 isinstance(prop, exp.TemporaryProperty) 594 for prop in (properties.expressions if properties else []) 595 ) 596 597 # CTAS with temp tables map to CREATE TEMPORARY VIEW 598 if expression.kind == "TABLE" and temporary: 599 if expression.expression: 600 return exp.Create( 601 kind="TEMPORARY VIEW", 602 this=expression.this, 603 expression=expression.expression, 604 ) 605 return tmp_storage_provider(expression) 606 607 return expression 608 609 610def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 611 """ 612 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 613 PARTITIONED BY value is an array of column names, they are transformed into a schema. 614 The corresponding columns are removed from the create statement. 615 """ 616 assert isinstance(expression, exp.Create) 617 has_schema = isinstance(expression.this, exp.Schema) 618 is_partitionable = expression.kind in {"TABLE", "VIEW"} 619 620 if has_schema and is_partitionable: 621 prop = expression.find(exp.PartitionedByProperty) 622 if prop and prop.this and not isinstance(prop.this, exp.Schema): 623 schema = expression.this 624 columns = {v.name.upper() for v in prop.this.expressions} 625 partitions = [col for col in schema.expressions if col.name.upper() in columns] 626 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 627 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 628 expression.set("this", schema) 629 630 return expression 631 632 633def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 634 """ 635 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 636 637 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 638 """ 639 assert isinstance(expression, exp.Create) 640 prop = expression.find(exp.PartitionedByProperty) 641 if ( 642 prop 643 and prop.this 644 and isinstance(prop.this, exp.Schema) 645 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 646 ): 647 prop_this = exp.Tuple( 648 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 649 ) 650 schema = expression.this 651 for e in prop.this.expressions: 652 schema.append("expressions", e) 653 prop.set("this", prop_this) 654 655 return expression 656 657 658def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 659 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 660 if isinstance(expression, exp.Struct): 661 expression.set( 662 "expressions", 663 [ 664 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 665 for e in expression.expressions 666 ], 667 ) 668 669 return expression 670 671 672def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 673 """ 674 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 675 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 676 677 For example, 678 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 679 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 680 681 Args: 682 expression: The AST to remove join marks from. 683 684 Returns: 685 The AST with join marks removed. 686 """ 687 from sqlglot.optimizer.scope import traverse_scope 688 689 for scope in traverse_scope(expression): 690 query = scope.expression 691 692 where = query.args.get("where") 693 joins = query.args.get("joins") 694 695 if not where or not joins: 696 continue 697 698 query_from = query.args["from"] 699 700 # These keep track of the joins to be replaced 701 new_joins: t.Dict[str, exp.Join] = {} 702 old_joins = {join.alias_or_name: join for join in joins} 703 704 for column in scope.columns: 705 if not column.args.get("join_mark"): 706 continue 707 708 predicate = column.find_ancestor(exp.Predicate, exp.Select) 709 assert isinstance( 710 predicate, exp.Binary 711 ), "Columns can only be marked with (+) when involved in a binary operation" 712 713 predicate_parent = predicate.parent 714 join_predicate = predicate.pop() 715 716 left_columns = [ 717 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 718 ] 719 right_columns = [ 720 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 721 ] 722 723 assert not ( 724 left_columns and right_columns 725 ), "The (+) marker cannot appear in both sides of a binary predicate" 726 727 marked_column_tables = set() 728 for col in left_columns or right_columns: 729 table = col.table 730 assert table, f"Column {col} needs to be qualified with a table" 731 732 col.set("join_mark", False) 733 marked_column_tables.add(table) 734 735 assert ( 736 len(marked_column_tables) == 1 737 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 738 739 join_this = old_joins.get(col.table, query_from).this 740 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 741 742 # Upsert new_join into new_joins dictionary 743 new_join_alias_or_name = new_join.alias_or_name 744 existing_join = new_joins.get(new_join_alias_or_name) 745 if existing_join: 746 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 747 else: 748 new_joins[new_join_alias_or_name] = new_join 749 750 # If the parent of the target predicate is a binary node, then it now has only one child 751 if isinstance(predicate_parent, exp.Binary): 752 if predicate_parent.left is None: 753 predicate_parent.replace(predicate_parent.right) 754 else: 755 predicate_parent.replace(predicate_parent.left) 756 757 if query_from.alias_or_name in new_joins: 758 only_old_joins = old_joins.keys() - new_joins.keys() 759 assert ( 760 len(only_old_joins) >= 1 761 ), "Cannot determine which table to use in the new FROM clause" 762 763 new_from_name = list(only_old_joins)[0] 764 query.set("from", exp.From(this=old_joins[new_from_name].this)) 765 766 query.set("joins", list(new_joins.values())) 767 768 if not where.this: 769 where.pop() 770 771 return expression
13def preprocess( 14 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 15) -> t.Callable[[Generator, exp.Expression], str]: 16 """ 17 Creates a new transform by chaining a sequence of transformations and converts the resulting 18 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 19 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 20 21 Args: 22 transforms: sequence of transform functions. These will be called in order. 23 24 Returns: 25 Function that can be used as a generator transform. 26 """ 27 28 def _to_sql(self, expression: exp.Expression) -> str: 29 expression_type = type(expression) 30 31 expression = transforms[0](expression) 32 for transform in transforms[1:]: 33 expression = transform(expression) 34 35 _sql_handler = getattr(self, expression.key + "_sql", None) 36 if _sql_handler: 37 return _sql_handler(expression) 38 39 transforms_handler = self.TRANSFORMS.get(type(expression)) 40 if transforms_handler: 41 if expression_type is type(expression): 42 if isinstance(expression, exp.Func): 43 return self.function_fallback_sql(expression) 44 45 # Ensures we don't enter an infinite loop. This can happen when the original expression 46 # has the same type as the final expression and there's no _sql method available for it, 47 # because then it'd re-enter _to_sql. 48 raise ValueError( 49 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 50 ) 51 52 return transforms_handler(self, expression) 53 54 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 55 56 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
59def unalias_group(expression: exp.Expression) -> exp.Expression: 60 """ 61 Replace references to select aliases in GROUP BY clauses. 62 63 Example: 64 >>> import sqlglot 65 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 66 'SELECT a AS b FROM x GROUP BY 1' 67 68 Args: 69 expression: the expression that will be transformed. 70 71 Returns: 72 The transformed expression. 73 """ 74 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 75 aliased_selects = { 76 e.alias: i 77 for i, e in enumerate(expression.parent.expressions, start=1) 78 if isinstance(e, exp.Alias) 79 } 80 81 for group_by in expression.expressions: 82 if ( 83 isinstance(group_by, exp.Column) 84 and not group_by.table 85 and group_by.name in aliased_selects 86 ): 87 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 88 89 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
92def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 93 """ 94 Convert SELECT DISTINCT ON statements to a subquery with a window function. 95 96 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 97 98 Args: 99 expression: the expression that will be transformed. 100 101 Returns: 102 The transformed expression. 103 """ 104 if ( 105 isinstance(expression, exp.Select) 106 and expression.args.get("distinct") 107 and expression.args["distinct"].args.get("on") 108 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 109 ): 110 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 111 outer_selects = expression.selects 112 row_number = find_new_name(expression.named_selects, "_row_number") 113 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 114 order = expression.args.get("order") 115 116 if order: 117 window.set("order", order.pop()) 118 else: 119 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 120 121 window = exp.alias_(window, row_number) 122 expression.select(window, copy=False) 123 124 return ( 125 exp.select(*outer_selects, copy=False) 126 .from_(expression.subquery("_t", copy=False), copy=False) 127 .where(exp.column(row_number).eq(1), copy=False) 128 ) 129 130 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
133def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 134 """ 135 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 136 137 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 138 https://docs.snowflake.com/en/sql-reference/constructs/qualify 139 140 Some dialects don't support window functions in the WHERE clause, so we need to include them as 141 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 142 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 143 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 144 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 145 corresponding expression to avoid creating invalid column references. 146 """ 147 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 148 taken = set(expression.named_selects) 149 for select in expression.selects: 150 if not select.alias_or_name: 151 alias = find_new_name(taken, "_c") 152 select.replace(exp.alias_(select, alias)) 153 taken.add(alias) 154 155 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 156 alias_or_name = select.alias_or_name 157 identifier = select.args.get("alias") or select.this 158 if isinstance(identifier, exp.Identifier): 159 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 160 return alias_or_name 161 162 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 163 qualify_filters = expression.args["qualify"].pop().this 164 expression_by_alias = { 165 select.alias: select.this 166 for select in expression.selects 167 if isinstance(select, exp.Alias) 168 } 169 170 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 171 for select_candidate in qualify_filters.find_all(select_candidates): 172 if isinstance(select_candidate, exp.Window): 173 if expression_by_alias: 174 for column in select_candidate.find_all(exp.Column): 175 expr = expression_by_alias.get(column.name) 176 if expr: 177 column.replace(expr) 178 179 alias = find_new_name(expression.named_selects, "_w") 180 expression.select(exp.alias_(select_candidate, alias), copy=False) 181 column = exp.column(alias) 182 183 if isinstance(select_candidate.parent, exp.Qualify): 184 qualify_filters = column 185 else: 186 select_candidate.replace(column) 187 elif select_candidate.name not in expression.named_selects: 188 expression.select(select_candidate.copy(), copy=False) 189 190 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 191 qualify_filters, copy=False 192 ) 193 194 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
197def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 198 """ 199 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 200 other expressions. This transforms removes the precision from parameterized types in expressions. 201 """ 202 for node in expression.find_all(exp.DataType): 203 node.set( 204 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 205 ) 206 207 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
210def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 211 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 212 from sqlglot.optimizer.scope import find_all_in_scope 213 214 if isinstance(expression, exp.Select): 215 unnest_aliases = { 216 unnest.alias 217 for unnest in find_all_in_scope(expression, exp.Unnest) 218 if isinstance(unnest.parent, (exp.From, exp.Join)) 219 } 220 if unnest_aliases: 221 for column in expression.find_all(exp.Column): 222 if column.table in unnest_aliases: 223 column.set("table", None) 224 elif column.db in unnest_aliases: 225 column.set("db", None) 226 227 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
230def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 231 """Convert cross join unnest into lateral view explode.""" 232 if isinstance(expression, exp.Select): 233 for join in expression.args.get("joins") or []: 234 unnest = join.this 235 236 if isinstance(unnest, exp.Unnest): 237 alias = unnest.args.get("alias") 238 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 239 240 expression.args["joins"].remove(join) 241 242 for e, column in zip(unnest.expressions, alias.columns if alias else []): 243 expression.append( 244 "laterals", 245 exp.Lateral( 246 this=udtf(this=e), 247 view=True, 248 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 249 ), 250 ) 251 252 return expression
Convert cross join unnest into lateral view explode.
255def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 256 """Convert explode/posexplode into unnest.""" 257 258 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 259 if isinstance(expression, exp.Select): 260 from sqlglot.optimizer.scope import Scope 261 262 taken_select_names = set(expression.named_selects) 263 taken_source_names = {name for name, _ in Scope(expression).references} 264 265 def new_name(names: t.Set[str], name: str) -> str: 266 name = find_new_name(names, name) 267 names.add(name) 268 return name 269 270 arrays: t.List[exp.Condition] = [] 271 series_alias = new_name(taken_select_names, "pos") 272 series = exp.alias_( 273 exp.Unnest( 274 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 275 ), 276 new_name(taken_source_names, "_u"), 277 table=[series_alias], 278 ) 279 280 # we use list here because expression.selects is mutated inside the loop 281 for select in list(expression.selects): 282 explode = select.find(exp.Explode) 283 284 if explode: 285 pos_alias = "" 286 explode_alias = "" 287 288 if isinstance(select, exp.Alias): 289 explode_alias = select.args["alias"] 290 alias = select 291 elif isinstance(select, exp.Aliases): 292 pos_alias = select.aliases[0] 293 explode_alias = select.aliases[1] 294 alias = select.replace(exp.alias_(select.this, "", copy=False)) 295 else: 296 alias = select.replace(exp.alias_(select, "")) 297 explode = alias.find(exp.Explode) 298 assert explode 299 300 is_posexplode = isinstance(explode, exp.Posexplode) 301 explode_arg = explode.this 302 303 if isinstance(explode, exp.ExplodeOuter): 304 bracket = explode_arg[0] 305 bracket.set("safe", True) 306 bracket.set("offset", True) 307 explode_arg = exp.func( 308 "IF", 309 exp.func( 310 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 311 ).eq(0), 312 exp.array(bracket, copy=False), 313 explode_arg, 314 ) 315 316 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 317 if isinstance(explode_arg, exp.Column): 318 taken_select_names.add(explode_arg.output_name) 319 320 unnest_source_alias = new_name(taken_source_names, "_u") 321 322 if not explode_alias: 323 explode_alias = new_name(taken_select_names, "col") 324 325 if is_posexplode: 326 pos_alias = new_name(taken_select_names, "pos") 327 328 if not pos_alias: 329 pos_alias = new_name(taken_select_names, "pos") 330 331 alias.set("alias", exp.to_identifier(explode_alias)) 332 333 series_table_alias = series.args["alias"].this 334 column = exp.If( 335 this=exp.column(series_alias, table=series_table_alias).eq( 336 exp.column(pos_alias, table=unnest_source_alias) 337 ), 338 true=exp.column(explode_alias, table=unnest_source_alias), 339 ) 340 341 explode.replace(column) 342 343 if is_posexplode: 344 expressions = expression.expressions 345 expressions.insert( 346 expressions.index(alias) + 1, 347 exp.If( 348 this=exp.column(series_alias, table=series_table_alias).eq( 349 exp.column(pos_alias, table=unnest_source_alias) 350 ), 351 true=exp.column(pos_alias, table=unnest_source_alias), 352 ).as_(pos_alias), 353 ) 354 expression.set("expressions", expressions) 355 356 if not arrays: 357 if expression.args.get("from"): 358 expression.join(series, copy=False, join_type="CROSS") 359 else: 360 expression.from_(series, copy=False) 361 362 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 363 arrays.append(size) 364 365 # trino doesn't support left join unnest with on conditions 366 # if it did, this would be much simpler 367 expression.join( 368 exp.alias_( 369 exp.Unnest( 370 expressions=[explode_arg.copy()], 371 offset=exp.to_identifier(pos_alias), 372 ), 373 unnest_source_alias, 374 table=[explode_alias], 375 ), 376 join_type="CROSS", 377 copy=False, 378 ) 379 380 if index_offset != 1: 381 size = size - 1 382 383 expression.where( 384 exp.column(series_alias, table=series_table_alias) 385 .eq(exp.column(pos_alias, table=unnest_source_alias)) 386 .or_( 387 (exp.column(series_alias, table=series_table_alias) > size).and_( 388 exp.column(pos_alias, table=unnest_source_alias).eq(size) 389 ) 390 ), 391 copy=False, 392 ) 393 394 if arrays: 395 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 396 397 if index_offset != 1: 398 end = end - (1 - index_offset) 399 series.expressions[0].set("end", end) 400 401 return expression 402 403 return _explode_to_unnest
Convert explode/posexplode into unnest.
406def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 407 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 408 if ( 409 isinstance(expression, exp.PERCENTILES) 410 and not isinstance(expression.parent, exp.WithinGroup) 411 and expression.expression 412 ): 413 column = expression.this.pop() 414 expression.set("this", expression.expression.pop()) 415 order = exp.Order(expressions=[exp.Ordered(this=column)]) 416 expression = exp.WithinGroup(this=expression, expression=order) 417 418 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
421def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 422 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 423 if ( 424 isinstance(expression, exp.WithinGroup) 425 and isinstance(expression.this, exp.PERCENTILES) 426 and isinstance(expression.expression, exp.Order) 427 ): 428 quantile = expression.this.this 429 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 430 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 431 432 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
435def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 436 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 437 if isinstance(expression, exp.With) and expression.recursive: 438 next_name = name_sequence("_c_") 439 440 for cte in expression.expressions: 441 if not cte.args["alias"].columns: 442 query = cte.this 443 if isinstance(query, exp.SetOperation): 444 query = query.this 445 446 cte.args["alias"].set( 447 "columns", 448 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 449 ) 450 451 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
454def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 455 """Replace 'epoch' in casts by the equivalent date literal.""" 456 if ( 457 isinstance(expression, (exp.Cast, exp.TryCast)) 458 and expression.name.lower() == "epoch" 459 and expression.to.this in exp.DataType.TEMPORAL_TYPES 460 ): 461 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 462 463 return expression
Replace 'epoch' in casts by the equivalent date literal.
466def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 467 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 468 if isinstance(expression, exp.Select): 469 for join in expression.args.get("joins") or []: 470 on = join.args.get("on") 471 if on and join.kind in ("SEMI", "ANTI"): 472 subquery = exp.select("1").from_(join.this).where(on) 473 exists = exp.Exists(this=subquery) 474 if join.kind == "ANTI": 475 exists = exists.not_(copy=False) 476 477 join.pop() 478 expression.where(exists, copy=False) 479 480 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
483def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 484 """ 485 Converts a query with a FULL OUTER join to a union of identical queries that 486 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 487 for queries that have a single FULL OUTER join. 488 """ 489 if isinstance(expression, exp.Select): 490 full_outer_joins = [ 491 (index, join) 492 for index, join in enumerate(expression.args.get("joins") or []) 493 if join.side == "FULL" 494 ] 495 496 if len(full_outer_joins) == 1: 497 expression_copy = expression.copy() 498 expression.set("limit", None) 499 index, full_outer_join = full_outer_joins[0] 500 full_outer_join.set("side", "left") 501 expression_copy.args["joins"][index].set("side", "right") 502 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 503 504 return exp.union(expression, expression_copy, copy=False) 505 506 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
509def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 510 """ 511 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 512 defined at the top-level, so for example queries like: 513 514 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 515 516 are invalid in those dialects. This transformation can be used to ensure all CTEs are 517 moved to the top level so that the final SQL code is valid from a syntax standpoint. 518 519 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 520 """ 521 top_level_with = expression.args.get("with") 522 for inner_with in expression.find_all(exp.With): 523 if inner_with.parent is expression: 524 continue 525 526 if not top_level_with: 527 top_level_with = inner_with.pop() 528 expression.set("with", top_level_with) 529 else: 530 if inner_with.recursive: 531 top_level_with.set("recursive", True) 532 533 parent_cte = inner_with.find_ancestor(exp.CTE) 534 inner_with.pop() 535 536 if parent_cte: 537 i = top_level_with.expressions.index(parent_cte) 538 top_level_with.expressions[i:i] = inner_with.expressions 539 top_level_with.set("expressions", top_level_with.expressions) 540 else: 541 top_level_with.set( 542 "expressions", top_level_with.expressions + inner_with.expressions 543 ) 544 545 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
548def ensure_bools(expression: exp.Expression) -> exp.Expression: 549 """Converts numeric values used in conditions into explicit boolean expressions.""" 550 from sqlglot.optimizer.canonicalize import ensure_bools 551 552 def _ensure_bool(node: exp.Expression) -> None: 553 if ( 554 node.is_number 555 or ( 556 not isinstance(node, exp.SubqueryPredicate) 557 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 558 ) 559 or (isinstance(node, exp.Column) and not node.type) 560 ): 561 node.replace(node.neq(0)) 562 563 for node in expression.walk(): 564 ensure_bools(node, _ensure_bool) 565 566 return expression
Converts numeric values used in conditions into explicit boolean expressions.
587def ctas_with_tmp_tables_to_create_tmp_view( 588 expression: exp.Expression, 589 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 590) -> exp.Expression: 591 assert isinstance(expression, exp.Create) 592 properties = expression.args.get("properties") 593 temporary = any( 594 isinstance(prop, exp.TemporaryProperty) 595 for prop in (properties.expressions if properties else []) 596 ) 597 598 # CTAS with temp tables map to CREATE TEMPORARY VIEW 599 if expression.kind == "TABLE" and temporary: 600 if expression.expression: 601 return exp.Create( 602 kind="TEMPORARY VIEW", 603 this=expression.this, 604 expression=expression.expression, 605 ) 606 return tmp_storage_provider(expression) 607 608 return expression
611def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 612 """ 613 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 614 PARTITIONED BY value is an array of column names, they are transformed into a schema. 615 The corresponding columns are removed from the create statement. 616 """ 617 assert isinstance(expression, exp.Create) 618 has_schema = isinstance(expression.this, exp.Schema) 619 is_partitionable = expression.kind in {"TABLE", "VIEW"} 620 621 if has_schema and is_partitionable: 622 prop = expression.find(exp.PartitionedByProperty) 623 if prop and prop.this and not isinstance(prop.this, exp.Schema): 624 schema = expression.this 625 columns = {v.name.upper() for v in prop.this.expressions} 626 partitions = [col for col in schema.expressions if col.name.upper() in columns] 627 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 628 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 629 expression.set("this", schema) 630 631 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
634def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 635 """ 636 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 637 638 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 639 """ 640 assert isinstance(expression, exp.Create) 641 prop = expression.find(exp.PartitionedByProperty) 642 if ( 643 prop 644 and prop.this 645 and isinstance(prop.this, exp.Schema) 646 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 647 ): 648 prop_this = exp.Tuple( 649 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 650 ) 651 schema = expression.this 652 for e in prop.this.expressions: 653 schema.append("expressions", e) 654 prop.set("this", prop_this) 655 656 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
659def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 660 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 661 if isinstance(expression, exp.Struct): 662 expression.set( 663 "expressions", 664 [ 665 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 666 for e in expression.expressions 667 ], 668 ) 669 670 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
673def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 674 """ 675 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 676 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 677 678 For example, 679 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 680 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 681 682 Args: 683 expression: The AST to remove join marks from. 684 685 Returns: 686 The AST with join marks removed. 687 """ 688 from sqlglot.optimizer.scope import traverse_scope 689 690 for scope in traverse_scope(expression): 691 query = scope.expression 692 693 where = query.args.get("where") 694 joins = query.args.get("joins") 695 696 if not where or not joins: 697 continue 698 699 query_from = query.args["from"] 700 701 # These keep track of the joins to be replaced 702 new_joins: t.Dict[str, exp.Join] = {} 703 old_joins = {join.alias_or_name: join for join in joins} 704 705 for column in scope.columns: 706 if not column.args.get("join_mark"): 707 continue 708 709 predicate = column.find_ancestor(exp.Predicate, exp.Select) 710 assert isinstance( 711 predicate, exp.Binary 712 ), "Columns can only be marked with (+) when involved in a binary operation" 713 714 predicate_parent = predicate.parent 715 join_predicate = predicate.pop() 716 717 left_columns = [ 718 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 719 ] 720 right_columns = [ 721 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 722 ] 723 724 assert not ( 725 left_columns and right_columns 726 ), "The (+) marker cannot appear in both sides of a binary predicate" 727 728 marked_column_tables = set() 729 for col in left_columns or right_columns: 730 table = col.table 731 assert table, f"Column {col} needs to be qualified with a table" 732 733 col.set("join_mark", False) 734 marked_column_tables.add(table) 735 736 assert ( 737 len(marked_column_tables) == 1 738 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 739 740 join_this = old_joins.get(col.table, query_from).this 741 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 742 743 # Upsert new_join into new_joins dictionary 744 new_join_alias_or_name = new_join.alias_or_name 745 existing_join = new_joins.get(new_join_alias_or_name) 746 if existing_join: 747 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 748 else: 749 new_joins[new_join_alias_or_name] = new_join 750 751 # If the parent of the target predicate is a binary node, then it now has only one child 752 if isinstance(predicate_parent, exp.Binary): 753 if predicate_parent.left is None: 754 predicate_parent.replace(predicate_parent.right) 755 else: 756 predicate_parent.replace(predicate_parent.left) 757 758 if query_from.alias_or_name in new_joins: 759 only_old_joins = old_joins.keys() - new_joins.keys() 760 assert ( 761 len(only_old_joins) >= 1 762 ), "Cannot determine which table to use in the new FROM clause" 763 764 new_from_name = list(only_old_joins)[0] 765 query.set("from", exp.From(this=old_joins[new_from_name].this)) 766 767 query.set("joins", list(new_joins.values())) 768 769 if not where.this: 770 where.pop() 771 772 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.