#!/usr/bin/env python
# -*- coding: utf-8 -*- --------------------------------------------------===#
#
# Copyright 2022-2024 Trovares Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#===----------------------------------------------------------------------===#
import struct
import sys
import xgt
import pyarrow as pa
import pyarrow.flight as pf
from collections.abc import Iterable, Mapping, Sequence
Iter, Map, Seq, List, Dict = Iterable, Mapping, Sequence, list, dict
from arrow_odbc import read_arrow_batches_from_odbc
from arrow_odbc import insert_into_table
from typing import Optional, Union, TYPE_CHECKING
from xgt import SchemaMessages_pb2 as sch_proto
from .common import ProgressDisplay
# Convert the pyarrow type to an xgt type.
def _pyarrow_type_to_xgt_type(pyarrow_type):
if pa.types.is_boolean(pyarrow_type):
return xgt.BOOLEAN
elif pa.types.is_timestamp(pyarrow_type) or pa.types.is_date64(pyarrow_type):
return xgt.DATETIME
elif pa.types.is_date(pyarrow_type):
return xgt.DATE
elif pa.types.is_time(pyarrow_type):
return xgt.TIME
elif pa.types.is_integer(pyarrow_type):
return xgt.INT
elif pa.types.is_float32(pyarrow_type) or \
pa.types.is_float64(pyarrow_type) or \
pa.types.is_decimal(pyarrow_type):
return xgt.FLOAT
elif pa.types.is_string(pyarrow_type):
return xgt.TEXT
else:
raise xgt.XgtTypeError("Cannot convert pyarrow type " + str(pyarrow_type) + " to xGT type.")
def _infer_xgt_schema_from_pyarrow_schema(pyarrow_schema, conversions):
schema = []
for field in pyarrow_schema:
if field.type in conversions:
field = pa.field(field.name, conversions[field.type], field.nullable, field.metadata)
schema.append(field)
schema = pa.schema(schema)
return [[c.name, _pyarrow_type_to_xgt_type(c.type)] for c in schema]
[docs]
class SQLODBCDriver(object):
[docs]
def __init__(self, connection_string : str):
"""
Initializes the driver class.
Parameters
----------
connection_string : str
Standard ODBC connection string used for connecting to the ODBC applications.
Example:
'Driver={MariaDB};Server=127.0.0.1;Port=3306;Database=test;Uid=test;Pwd=foo;'
"""
self._connection_string = connection_string
self._schema_query = "SELECT * FROM {0} LIMIT 1;"
self._data_query = "SELECT * FROM {0};"
self._estimate_query="SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{0}';"
def _get_data_query(self, table, arrow_schema):
return self._data_query.format(table)
def _conversions(self):
return { }
def _get_record_batch_schema(self, table, max_text_size, max_binary_size):
reader = read_arrow_batches_from_odbc(
query=self._schema_query.format(table),
connection_string=self._connection_string,
batch_size=1,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
return reader.schema
[docs]
class MongoODBCDriver(object):
[docs]
def __init__(self, connection_string : str , include_id : bool = False):
"""
Initializes the driver class.
Parameters
----------
connection_string : str
Standard ODBC connection string used for connecting to MongoDB.
Example:
'DSB=MongoDB;Database=test;Uid=test;Pwd=foo;'
include_id : boolean
Include the MongoDB id field when transferring from MongoDB.
If the id field is included, writing data back to the database will update the columns
instead of inserting new rows.
By default false.
"""
self._connection_string = connection_string
self._schema_query = "SELECT * FROM {0} LIMIT 1;"
self._data_query = "SELECT {0} FROM {1};"
self._estimate_query = "SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{0}';"
self._include_id = include_id
def _get_data_query(self, table, arrow_schema):
cols = ','.join([x.name for x in arrow_schema])
return self._data_query.format(cols, table)
def _conversions(self):
return { }
def _get_record_batch_schema(self, table, max_text_size, max_binary_size):
reader = read_arrow_batches_from_odbc(
query=self._schema_query.format(table),
connection_string=self._connection_string,
batch_size=1,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
schema = reader.schema
if not self._include_id:
# Remove the _id column.
return pa.schema([field for field in schema if field.name != '_id'])
return schema
[docs]
class OracleODBCDriver(object):
[docs]
def __init__(self, connection_string : str, upper_case_names : bool = False, ansi_conversion : bool = True):
"""
Initializes the driver class.
Parameters
----------
connection_string : str
Standard ODBC connection string used for connecting to Oracle.
Example:
'DSN={OracleODBC-19};Server=127.0.0.1;Port=1521;Uid=c##test;Pwd=test;DBQ=XE;'
upper_case_names : bool
Convert table names to uppercase similar to unqouted names behavior in Oracle queries.
By default false.
ansi_conversion : bool
Convert Number(38,0) into int64s in xGT if true. Otherwise, they are stored as floats.
This based on the ANSI int conversion Oracle does and reverses that. By default true.
"""
self._connection_string = connection_string
if upper_case_names:
self._schema_query = "SELECT * FROM {0} WHERE ROWNUM <= 1"
self._data_query = "SELECT * FROM {0}"
else:
self._schema_query = "SELECT * FROM \"{0}\" WHERE ROWNUM <= 1"
self._data_query = "SELECT * FROM \"{0}\""
self._estimate_query="SELECT NUM_ROWS FROM ALL_TABLES WHERE TABLE_NAME = '{0}'"
self._ansi_conversion = ansi_conversion
def _get_data_query(self, table, arrow_schema):
return self._data_query.format(table)
def _conversions(self):
if self._ansi_conversion:
return { pa.decimal128(38, 0) : pa.int64() }
else:
return { }
def _get_record_batch_schema(self, table, max_text_size, max_binary_size):
reader = read_arrow_batches_from_odbc(
query=self._schema_query.format(table),
connection_string=self._connection_string,
batch_size=1,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
return reader.schema
[docs]
class SAPODBCDriver(object):
[docs]
def __init__(self, connection_string : str):
"""
Initializes the driver class.
Parameters
----------
connection_string : str
Standard ODBC connection string used for connecting to the ODBC applications.
Example:
'Driver={AES};Server=127.0.0.1;Port=3306;Database=test;Uid=test;Pwd=foo;'
"""
self._connection_string = connection_string
self._schema_query = "SELECT TOP 1 * FROM {0};"
self._data_query = "SELECT * FROM {0};"
self._estimate_query="SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{0}';"
def _get_data_query(self, table, arrow_schema):
return self._data_query.format(table)
def _conversions(self):
return { }
def _get_record_batch_schema(self, table, max_text_size, max_binary_size):
reader = read_arrow_batches_from_odbc(
query=self._schema_query.format(table),
connection_string=self._connection_string,
batch_size=1,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
return reader.schema
[docs]
class SnowflakeODBCDriver(object):
[docs]
def __init__(self, connection_string : str, ansi_conversion : bool = True):
"""
Initializes the driver class.
Parameters
----------
connection_string : str
Standard ODBC connection string used for connecting to Snowflake.
Example:
'DSN=snowflake;Database=test;Warehouse=test;Uid=test;Pwd=test;'
ansi_conversion : bool
Convert Number(38,0) into int64s in xGT if true. Otherwise, they are stored as floats.
This based on the ANSI int conversion Snowflake does and reverses that. By default true.
"""
self._connection_string = connection_string
self._schema_query = "SELECT * FROM {0} LIMIT 1;"
self._data_query = "SELECT * FROM {0};"
self._estimate_query="SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{0}';"
self._ansi_conversion = ansi_conversion
def _get_data_query(self, table, arrow_schema):
return self._data_query.format(table)
def _conversions(self):
if self._ansi_conversion:
return { pa.decimal128(38, 0) : pa.int64() }
else:
return { }
def _get_record_batch_schema(self, table, max_text_size, max_binary_size):
reader = read_arrow_batches_from_odbc(
query=self._schema_query.format(table),
connection_string=self._connection_string,
batch_size=1,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
return reader.schema
ODBCDriverTypes = Union[SQLODBCDriver, MongoODBCDriver, OracleODBCDriver,
SAPODBCDriver, SnowflakeODBCDriver]
[docs]
class ODBCConnector(object):
[docs]
def __init__(self, xgt_server : xgt.Connection, odbc_driver : ODBCDriverTypes):
"""
Initializes the connector class.
Parameters
----------
xgt_server : xgt.Connection
Connection object to xGT.
odbc_driver : SQLODBCDriver
Connection object to ODBC.
"""
self._xgt_server = xgt_server
self._default_namespace = xgt_server.get_default_namespace()
self._driver = odbc_driver
[docs]
def get_xgt_schemas(self, tables : Iter[str] = None, max_text_size : int = None,
max_binary_size : int = None) -> Dict:
"""
Retrieve a dictionary containing the schema information for all of
the tables requested and their mappings.
Parameters
----------
tables : iterable
List of requested tables.
max_text_size : int
The upper limit on the buffers used when transferring ODBC variable-length text fields.
When using VARCHAR from a database, if a limit isn't set for the length of the strings
like VARCHAR(255), the schema size of each string entry could be whatever the max size of
database uses for each entry when reporting to ODBC. For instance, each string in Snowflake
has an upper limit of 16MB length. This means when allocating the buffers to store the ODBC
batch_size would be 16MB multiplied by the batch_size. This parameter will impose a limit on
each string length when transferring. Default is determined by the database.
max_binary_size : int
The upper limit on the buffers used when transferring ODBC variable-length binary fields.
When using VARBINARY from a database, if a limit isn't set for the length of binary data
like VARBINARY(255), the schema size of each binary entry could be whatever the max size of
database uses for each entry when reporting to ODBC. This parameter will impose a limit on
each binary field length when transferring. Default is determined by the database.
Returns
-------
dict
Dictionary containing the schema information of the tables,
vertices, and edges requested.
"""
if tables is None:
tables = [ ]
mapping_vertices = { }
mapping_edges = { }
mapping_tables = { }
result = {'vertices' : dict(), 'edges' : dict(), 'tables' : dict()}
for val in tables:
self.__get_mapping(val, mapping_tables, mapping_vertices, mapping_edges)
for table in mapping_tables:
schema = self.__extract_xgt_table_schema(table, mapping_tables, max_text_size, max_binary_size)
result['tables'][table] = schema
for table in mapping_vertices:
schema = self.__extract_xgt_table_schema(table, mapping_vertices, max_text_size, max_binary_size)
result['vertices'][table] = schema
for table in mapping_edges:
schema = self.__extract_xgt_table_schema(table, mapping_edges, max_text_size, max_binary_size)
result['edges'][table] = schema
return result
[docs]
def create_xgt_schemas(self, xgt_schemas : Map, append : bool = False,
force : bool = False, easy_edges : bool = False) -> None:
"""
Creates table, vertex and/or edge frames in Rocketgraph xGT.
This function first infers the schemas for all of the needed frames in xGT to
store the requested data.
Then those frames are created in xGT.
Parameters
----------
xgt_schemas : dict
Dictionary containing schema information for vertex and edge frames
to create in xGT.
This dictionary can be the value returned from the
:py:meth:`~ODBCConnector.get_xgt_schemas` method.
append : boolean
Set to true when the xGT frames are already created and holding data
that should be appended to.
Set to false when the xGT frames are to be newly created (removing
any existing frames with the same names prior to creation).
force : boolean
Set to true to force xGT to drop edges when a vertex frame has dependencies.
easy_edges : boolean
Set to true to create a basic vertex class with key column for any edges
without corresponding vertex frames.
Returns
-------
None
"""
if not append:
if easy_edges:
for edge, schema in xgt_schemas['edges'].items():
src = schema['mapping']['source']
trg = schema['mapping']['target']
if src not in xgt_schemas['vertices']:
v_key = schema['mapping']['source_key']
v_type = 'int'
for element in schema['xgt_schema']:
if v_key == element[0]:
v_type = element[1]
xgt_schemas['vertices'][src] = { 'xgt_schema': [['key', v_type]], 'temp_creation' : True, 'mapping' : { 'frame' : src, 'key' : 'key' } }
if trg not in xgt_schemas['vertices']:
v_key = schema['mapping']['target_key']
v_type = 'int'
for element in schema['xgt_schema']:
if v_key == element[0]:
v_type = element[1]
xgt_schemas['vertices'][trg] = { 'xgt_schema': [['key', v_type]], 'temp_creation' : True, 'mapping' : { 'frame' : trg, 'key' : 'key' } }
for _, schema in xgt_schemas['tables'].items():
self._xgt_server.drop_frame(schema['mapping']['frame'])
for _, schema in xgt_schemas['edges'].items():
self._xgt_server.drop_frame(schema['mapping']['frame'])
for _, schema in xgt_schemas['vertices'].items():
try:
self._xgt_server.drop_frame(schema['mapping']['frame'])
except xgt.XgtFrameDependencyError as e:
if force:
# Would be better if this could be done without doing this.
edge_frames = str(e).split(':')[-1].split(' ')[1:]
for edge in edge_frames:
self._xgt_server.drop_frame(edge)
self._xgt_server.drop_frame(schema['mapping']['frame'])
else:
raise e
for table, schema in xgt_schemas['tables'].items():
self._xgt_server.create_table_frame(name = schema['mapping']['frame'], schema = schema['xgt_schema'])
remove_list = []
for vertex, schema in xgt_schemas['vertices'].items():
key = schema['mapping']['key']
if isinstance(key, int):
key = schema['xgt_schema'][key][0]
self._xgt_server.create_vertex_frame(name = schema['mapping']['frame'], schema = schema['xgt_schema'], key = key)
if 'temp_creation' in schema:
remove_list.append(vertex)
for vertex in remove_list:
xgt_schemas['vertices'].pop(vertex)
for edge, schema in xgt_schemas['edges'].items():
src = schema['mapping']['source']
trg = schema['mapping']['target']
src_key = schema['mapping']['source_key']
trg_key = schema['mapping']['target_key']
if isinstance(src_key, int):
src_key = schema['xgt_schema'][src_key][0]
if isinstance(trg_key, int):
trg_key = schema['xgt_schema'][trg_key][0]
self._xgt_server.create_edge_frame(name = schema['mapping']['frame'], schema = schema['xgt_schema'],
source = src, target = trg, source_key = src_key, target_key = trg_key)
[docs]
def transfer_to_xgt(self, tables : Iter = None, append : bool = False, force : bool = False,
easy_edges : bool = False, batch_size : int = 10000, transaction_size : int = 0,
max_text_size : int = None, max_binary_size : int = None,
column_mapping : Optional[Map[str, Union[str, int]]] = None,
suppress_errors : bool = False, row_filter : str = None, on_duplicate_keys : str = "error") -> None:
"""
Copies data from the ODBC application to Rocketgraph xGT.
This function first infers the schemas for all of the needed frames in xGT to
store the requested data.
Then those frames are created in xGT.
Finally, all of the tables, vertices, and all of the edges are copied,
one frame at a time, from the ODBC application to xGT.
Parameters
----------
tables : Iterable
List of requested tables names.
May be a tuple specify a mapping to xGT types. See documentation: :ref:`mapping-sql-label` or `Web Docs <https://trovares.github.io/trovares_connector/odbc/index.html#mapping-sql-tables-to-graphs>`_.
append : boolean
Set to true when the xGT frames are already created and holding data
that should be appended to.
Set to false when the xGT frames are to be newly created (removing
any existing frames with the same names prior to creation).
force : boolean
Set to true to force xGT to drop edges when a vertex frame has dependencies.
easy_edges : boolean
Set to true to create a basic vertex class wtih key column for any edges
without corresponding vertex frames.
batch_size : int
Number of rows to transfer at once. Defaults to 10000.
transaction_size : int
Number of rows to treat as a single transaction to xGT. Defaults to 0.
Should be a multiple of the batch size and greater than the batch size.
0 means treat all rows as a single transaction.
max_text_size : int
The upper limit on the buffers used when transferring ODBC variable-length text fields.
When using VARCHAR from a database, if a limit isn't set for the length of the strings
like VARCHAR(255), the schema size of each string entry could be whatever the max size of
database uses for each entry when reporting to ODBC. For instance, each string in Snowflake
has an upper limit of 16MB length. This means when allocating the buffers to store the ODBC
batch_size would be 16MB multiplied by the batch_size. This parameter will impose a limit on
each string length when transferring. Default is determined by the database.
max_binary_size : int
The upper limit on the buffers used when transferring ODBC variable-length binary fields.
When using VARBINARY from a database, if a limit isn't set for the length of binary data
like VARBINARY(255), the schema size of each binary entry could be whatever the max size of
database uses for each entry when reporting to ODBC. This parameter will impose a limit on
each binary field length when transferring. Default is determined by the database.
column_mapping : dictionary
Maps the frame column names to SQL columns for the ingest. The key of
each element is a frame column name. The value is either the name of the
SQL column (from the table) or the table column index.
suppress_errors : bool
If true, will continue to insert data if an ingest error is encountered,
placing the first 1000 errors in the job history. If false, stops on
first error and raises. Defaults to False.
row_filter : str
TQL fragment used to filter, modify and parameterize the raw data from
the input to produce the row data fed to the frame.
on_duplicate_keys : {‘error’, ‘skip’, 'skip_same'}, default 'error'
Specifies what to do upon encountering a duplicate vertex key.
Only works for vertex frames. Is ignored for table and edge frames.
Allowed values are :
- 'error', raise an Exception when a duplicate key is found.
- 'skip', skip duplicate keys without raising.
- 'skip_same', skip duplicate keys if the row is exactly the same without raising.
Returns
-------
None
"""
if transaction_size > 0 and (transaction_size < batch_size or transaction_size % batch_size != 0):
raise ValueError("Transaction size needs to be a multiple of the batch size and >= the batch size of " + str(batch_size))
xgt_schema = self.get_xgt_schemas(tables, max_text_size, max_binary_size)
self.create_xgt_schemas(xgt_schema, append, force, easy_edges)
self.copy_data_to_xgt(xgt_schema, batch_size, transaction_size,
max_text_size, max_binary_size, column_mapping,
suppress_errors, row_filter, on_duplicate_keys)
[docs]
def transfer_query_to_xgt(self, query : str = None, mapping : Union[Map, tuple] = None, append : bool = False,
force : bool = False, easy_edges : bool = False, batch_size : int = 10000,
transaction_size : int = 0, max_text_size : int = None,
max_binary_size : int = None, column_mapping : Optional[Map[str, Union[str, int]]] = None,
suppress_errors : bool = False,
row_filter : str = None, on_duplicate_keys : str = "error") -> None:
"""
Copies data from the ODBC application to Rocketgraph xGT.
This function first infers the schemas for the query.
Then it maps to the type specificed in mapping.
Finally, the data is copied from the ODBC application to xGT.
Parameters
----------
query : string
SQL query to execute and insert into xGT. Syntax depends on the SQL syntax of the database you are connecting to.
mapping :
May be a tuple specify a mapping to xGT types. See documentation: :ref:`mapping-sql-label` or `Web Docs <https://trovares.github.io/trovares_connector/odbc/index.html#mapping-sql-tables-to-graphs>`_.
append : boolean
Set to true when the xGT frames are already created and holding data
that should be appended to.
Set to false when the xGT frames are to be newly created (removing
any existing frames with the same names prior to creation).
force : boolean
Set to true to force xGT to drop edges when a vertex frame has dependencies.
easy_edges : boolean
Set to true to create a basic vertex class with key column for any edges
without corresponding vertex frames.
batch_size : int
Number of rows to transfer at once. Defaults to 10000.
transaction_size : int
Number of rows to treat as a single transaction to xGT. Defaults to 0.
Should be a multiple of the batch size and greater than the batch size.
0 means treat all rows as a single transaction.
max_text_size : int
The upper limit on the buffers used when transferring ODBC variable-length text fields.
When using VARCHAR from a database, if a limit isn't set for the length of the strings
like VARCHAR(255), the schema size of each string entry could be whatever the max size of
database uses for each entry when reporting to ODBC. For instance, each string in Snowflake
has an upper limit of 16MB length. This means when allocating the buffers to store the ODBC
batch_size would be 16MB multiplied by the batch_size. This parameter will impose a limit on
each string length when transferring. Default is determined by the database.
max_binary_size : int
The upper limit on the buffers used when transferring ODBC variable-length binary fields.
When using VARBINARY from a database, if a limit isn't set for the length of binary data
like VARBINARY(255), the schema size of each binary entry could be whatever the max size of
database uses for each entry when reporting to ODBC. This parameter will impose a limit on
each binary field length when transferring. Default is determined by the database.
column_mapping : dictionary
Maps the frame column names to SQL columns for the ingest. The key of
each element is a frame column name. The value is either the name of the
SQL column (from the table) or the table column index.
suppress_errors : bool
If true, will continue to insert data if an ingest error is encountered,
placing the first 1000 errors in the job history. If false, stops on
first error and raises. Defaults to False.
row_filter : str
TQL fragment used to filter, modify and parameterize the raw data from
the input to produce the row data fed to the frame.
on_duplicate_keys : {‘error’, ‘skip’, 'skip_same'}, default 'error'
Specifies what to do upon encountering a duplicate vertex key.
Only works for vertex frames. Is ignored for table and edge frames.
Allowed values are :
- 'error', raise an Exception when a duplicate key is found.
- 'skip', skip duplicate keys without raising.
- 'skip_same', skip duplicate keys if the row is exactly the same without raising.
Returns
-------
array of transfer information in the form of [row count, byte count]
"""
if transaction_size > 0 and (transaction_size < batch_size or transaction_size % batch_size != 0):
raise ValueError("Transaction size needs to be a multiple of batch size and >= the batch size of " + str(batch_size))
return self.__copy_query_data_to_xgt(query, mapping, append, force, easy_edges,
batch_size, transaction_size, max_text_size, max_binary_size,
column_mapping, suppress_errors, row_filter, on_duplicate_keys)
[docs]
def copy_data_to_xgt(self, xgt_schemas : Map, batch_size : int = 10000, transaction_size : int = 0,
max_text_size : int = None, max_binary_size : int = None,
column_mapping : Optional[Map[str, Union[str, int]]] = None,
suppress_errors : bool = False, row_filter : str = None,
on_duplicate_keys : str = "error") -> None:
"""
Copies data from the ODBC application to the requested table, vertex and/or edge frames
in Rocketgraph xGT.
This function copies data from the ODBC application to xGT for all of the tables, vertices
and edges, one frame at a time.
Parameters
----------
xgt_schemas : dict
Dictionary containing schema information for table, vertex and edge frames
to create in xGT.
This dictionary can be the value returned from the
:py:meth:`~ODBCConnector.get_xgt_schemas` method.
batch_size : int
Number of rows to transfer at once. Defaults to 10000.
transaction_size : int
Number of rows to treat as a single transaction to xGT. Defaults to 0.
Should be a multiple of the batch size and greater than the batch size.
0 means treat all rows as a single transaction.
max_text_size : int
The upper limit on the buffers used when transferring ODBC variable-length text fields.
When using VARCHAR from a database, if a limit isn't set for the length of the strings
like VARCHAR(255), the schema size of each string entry could be whatever the max size of
database uses for each entry when reporting to ODBC. For instance, each string in Snowflake
has an upper limit of 16MB length. This means when allocating the buffers to store the ODBC
batch_size would be 16MB multiplied by the batch_size. This parameter will impose a limit on
each string length when transferring. Default is determined by the database.
max_binary_size : int
The upper limit on the buffers used when transferring ODBC variable-length binary fields.
When using VARBINARY from a database, if a limit isn't set for the length of binary data
like VARBINARY(255), the schema size of each binary entry could be whatever the max size of
database uses for each entry when reporting to ODBC. This parameter will impose a limit on
each binary field length when transferring. Default is determined by the database.
column_mapping : dictionary
Maps the frame column names to SQL columns for the ingest. The key of
each element is a frame column name. The value is either the name of the
SQL column (from the table) or the table column index.
suppress_errors : bool
If true, will continue to insert data if an ingest error is encountered,
placing the first 1000 errors in the job history. If false, stops on
first error and raises. Defaults to False.
row_filter : str
TQL fragment used to filter, modify and parameterize the raw data from
the input to produce the row data fed to the frame.
on_duplicate_keys : {‘error’, ‘skip’, 'skip_same'}, default 'error'
Specifies what to do upon encountering a duplicate vertex key.
Only works for vertex frames. Is ignored for table and edge frames.
Allowed values are :
- 'error', raise an Exception when a duplicate key is found.
- 'skip', skip duplicate keys without raising.
- 'skip_same', skip duplicate keys if the row is exactly the same without raising.
Returns
-------
None
"""
estimate = 0
def estimate_size(table):
estimate = 0
reader = read_arrow_batches_from_odbc(
query=self._driver._estimate_query.format(table),
connection_string=self._driver._connection_string,
batch_size=batch_size,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
for batch in reader:
for _, row in batch.to_pydict().items():
for item in row:
if isinstance(item, int):
estimate += item
return estimate
try:
for table, schema in xgt_schemas['tables'].items():
estimate += estimate_size(table)
for table, schema in xgt_schemas['vertices'].items():
estimate += estimate_size(table)
for table, schema in xgt_schemas['edges'].items():
estimate += estimate_size(table)
except Exception as e:
pass
with ProgressDisplay(estimate) as progress_bar:
for table, schema in xgt_schemas['tables'].items():
self.__copy_data(self._driver._get_data_query(
table, schema['arrow_schema']), schema['mapping']['frame'],
schema['arrow_schema'], progress_bar, batch_size,
transaction_size, max_text_size, max_binary_size,
column_mapping, suppress_errors, row_filter,
on_duplicate_keys)
for table, schema in xgt_schemas['vertices'].items():
self.__copy_data(self._driver._get_data_query(
table, schema['arrow_schema']), schema['mapping']['frame'],
schema['arrow_schema'], progress_bar, batch_size,
transaction_size, max_text_size, max_binary_size,
column_mapping, suppress_errors, row_filter,
on_duplicate_keys)
for table, schema in xgt_schemas['edges'].items():
self.__copy_data(self._driver._get_data_query(
table, schema['arrow_schema']), schema['mapping']['frame'],
schema['arrow_schema'], progress_bar, batch_size,
transaction_size, max_text_size, max_binary_size,
column_mapping, suppress_errors, row_filter,
on_duplicate_keys)
[docs]
def transfer_to_odbc(self, vertices : Iter[str] = None,
edges : Iter[str] = None,
tables : Iter[str] = None, namespace : str = None,
batch_size : int = 10000) -> None:
"""
Copies data from Rocketgraph xGT to an ODBC application.
Parameters
----------
vertices : iterable
List of requested vertex frame names.
May be a tuple specifying: (xgt_frame_name, database_table_name).
edges : iterable
List of requested edge frame names.
May be a tuple specifying: (xgt_frame_name, database_table_name).
tables : iterable
List of requested table frame names.
May be a tuple specifying: (xgt_frame_name, database_table_name).
namespace : str
Namespace for the selected frames.
If none will use the default namespace.
batch_size : int
Number of rows to transfer at once. Defaults to 10000.
Returns
-------
None
"""
if isinstance(self._driver, OracleODBCDriver):
raise XgtNotImplementedError("Oracle not supported for transferring to.")
xgt_server = self._xgt_server
if namespace == None:
namespace = self._default_namespace
if vertices == None and edges == None and tables == None:
vertices = [(frame.name, frame.name) for frame in xgt_server.get_frames(namespace=namespace, frame_type='vertex')]
edges = [(frame.name, frame.name) for frame in xgt_server.get_frames(namespace=namespace, frame_type='edge')]
tables = [(frame.name, frame.name) for frame in xgt_server.get_frames(namespace=namespace, frame_type='table')]
namespace = None
if vertices == None:
vertices = []
if edges == None:
edges = []
if tables == None:
tables = []
final_vertices = []
final_edges = []
final_tables = []
for vertex in vertices:
if isinstance(vertex, str):
final_vertices.append((vertex, vertex))
else:
final_vertices.append(vertex)
for edge in edges:
if isinstance(edge, str):
final_edges.append((edge, edge))
else:
final_edges.append(edge)
for table in tables:
if isinstance(table, str):
final_tables.append((table, table))
else:
final_tables.append(table)
estimate = 0
for vertex in final_vertices:
estimate += xgt_server.get_frame(vertex[0]).num_rows
for edge in final_edges:
estimate += xgt_server.get_frame(edge[0]).num_rows
for table in final_tables:
estimate += xgt_server.get_frame(table[0]).num_rows
with ProgressDisplay(estimate) as progress_bar:
for table in final_vertices + final_edges + final_tables:
frame, table = table
reader = self.__arrow_reader(frame)
batch_reader = reader.to_reader()
_, target_schema = self.__get_xgt_schema(table)
schema = reader.schema
final_schema = [xgt_field.with_name(database_field.name) for database_field, xgt_field in zip(target_schema, schema)]
final_schema = pa.schema(final_schema)
schema = final_schema
final_names = [database_field.name for database_field in target_schema]
def iter_record_batches():
for batch in batch_reader:
table = pa.Table.from_pandas(batch.to_pandas(integer_object_nulls=True, date_as_object=True, timestamp_as_object=True))
table = table.rename_columns(final_names).to_batches()
for batch in table:
yield batch
progress_bar.show_progress(batch.num_rows)
final_reader = pa.ipc.RecordBatchReader.from_batches(schema, iter_record_batches())
insert_into_table(
connection_string=self._driver._connection_string,
chunk_size=batch_size,
table=table,
reader=final_reader,
)
def __build_flight_path(self, frame_name, column_mapping = None,
suppress_errors = False, row_filter = None,
on_duplicate_keys = 'error'):
if '__' in frame_name:
# Split by '__' and use the first part as the namespace
namespace, name = frame_name.split('__', 1)
else:
# Use the default namespace if no '__' is found
namespace, name = self._default_namespace, frame_name
path = (namespace, name)
self.__validate_column_mapping(column_mapping)
if row_filter is None and column_mapping is not None:
map_values = ".map_column_names=[" + \
','.join(f"{key}:{value}" for key, value in column_mapping.items()
if isinstance(value, str)) + "]"
path += (map_values,)
map_values = ".map_column_ids=[" + \
','.join(f"{key}:{value}" for key, value in column_mapping.items()
if isinstance(value, int)) + "]"
path += (map_values,)
suppress_errors_option = ".suppress_errors=" + str(suppress_errors).lower()
on_duplicate_keys_option = ".on_duplicate_keys=" + str(on_duplicate_keys).lower()
path += (suppress_errors_option,)
path += (on_duplicate_keys_option,)
if row_filter is not None:
row_filter_value = f'.row_filter="{row_filter}"'
path += (row_filter_value,)
return path
def __arrow_writer(self, frame_name, schema, column_mapping, suppress_errors, row_filter, on_duplicate_keys):
arrow_conn = self._xgt_server.arrow_conn
flight_path = self.__build_flight_path(frame_name, column_mapping, suppress_errors, row_filter, on_duplicate_keys)
writer, metadata = arrow_conn.do_put(
pf.FlightDescriptor.for_path(*flight_path),
schema)
return (writer, metadata)
def __arrow_reader(self, frame_name):
arrow_conn = self._xgt_server.arrow_conn
return arrow_conn.do_get(pf.Ticket(self._default_namespace + '__' + frame_name))
def __copy_data(self, query_for_extract, frame, schema, progress_bar, batch_size,
transaction_size, max_text_size, max_binary_size, column_mapping,
suppress_errors, row_filter, on_duplicate_keys):
reader = read_arrow_batches_from_odbc(
query=query_for_extract,
connection_string=self._driver._connection_string,
batch_size=batch_size,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
count = 0
writer, metadata = self.__arrow_writer(frame, schema, column_mapping, suppress_errors, row_filter, on_duplicate_keys)
for batch in reader:
# Process arrow batches
writer.write(batch)
progress_bar.show_progress(batch.num_rows)
count += batch.num_rows
# Start a new transaction
if transaction_size > 0 and count >= transaction_size:
count = 0
if (suppress_errors):
self.__check_for_error(frame, schema, writer, metadata)
writer.close()
writer, metadata = self.__arrow_writer(frame, schema, column_mapping, suppress_errors, row_filter, on_duplicate_keys)
if (suppress_errors):
self.__check_for_error(frame, schema, writer, metadata)
writer.close()
def __check_for_error(self, frame, schema, writer, metadata):
# Write an empty batch with metadata to indicate we are done.
empty = [[]] * len(schema)
empty_batch = pa.RecordBatch.from_arrays(empty, schema = schema)
metadata_end = struct.pack('<i', 0)
writer.write_with_metadata(empty_batch, metadata_end)
buf = metadata.read()
job_proto = sch_proto.JobStatus()
if buf is not None:
job_proto.ParseFromString(buf.to_pybytes())
job = xgt.Job(self._xgt_server, job_proto)
job_data = job.get_ingest_errors()
if job_data is not None and len(job_data) > 0:
raise xgt.XgtIOError(self.__create_ingest_error_message(frame, job), job = job)
def __create_ingest_error_message(self, name, job):
num_errors = job.total_ingest_errors
error_string = ('Errors occurred when inserting data into frame '
f'{name}.\n')
error_string += f' {num_errors} line'
if num_errors > 1:
error_string += 's'
error_string += (' had insertion errors.\n'
' Lines without errors were inserted into the frame.\n'
' To see the number of rows in the frame, run "'
f'{name}.num_rows".\n'
' To see the data in the frame, run "'
f'{name}.get_data()".\n'
' To see additional errors:\n'
' try:\n'
' foo.transfer_to_xgt(...)\n'
' except xgt.XgtIOError as e:\n'
' error_rows = e.job.get_ingest_errors()\n'
' print(error_rows)\n')
extra_text = ''
if num_errors > 10:
extra_text = ' first 10'
error_string += (f'Errors associated with the{extra_text} lines '
'that could not be inserted are shown below:')
# Only print the first 10 messages.
for error in job.get_ingest_errors(0, 10):
delim = ','
# Will process a list of strings. Convert to this format.
if isinstance(error, str):
error_cols = error.split(delim)
elif isinstance(error, list):
error_cols = [str(elem) for elem in error]
else:
raise xgt.XgtIOError("Error processing ingest error message.")
# The first comma separated fields of the error string are the error
# description, file name, and line number.
error_explanation = "" if len(error_cols) < 1 else error_cols[0]
error_file_name = "odbc_transfer"
error_line_number = "" if len(error_cols) < 3 else error_cols[2]
# The second part of the error string contains the line that caused the
# error. The line contains comma separated fields so we need to re-join
# these comma separated portions to get back the line.
line_with_error = \
"" if len(error_cols) < 4 else delim.join(error_cols[3:])
error_string += f"\n {error_explanation}"
return error_string
def __get_xgt_schema(self, table, max_text_size = None, max_binary_size = None):
schema = self._driver._get_record_batch_schema(table, max_text_size, max_binary_size)
return (_infer_xgt_schema_from_pyarrow_schema(schema, self._driver._conversions()), schema)
def __extract_xgt_table_schema(self, table, mapping, max_text_size, max_binary_size):
xgt_schema, arrow_schema = self.__get_xgt_schema(table, max_text_size, max_binary_size)
return {'xgt_schema' : xgt_schema, 'arrow_schema' : arrow_schema, 'mapping' : mapping[table]}
def __get_mapping(self, val, mapping_tables, mapping_vertices, mapping_edges):
if isinstance(val, str):
mapping_tables[val] = {'frame' : val}
elif isinstance(val, tuple):
# ('table', X, ...)
if isinstance(val[1], str):
# ('table', (...), ...)
if len(val) == 3:
# ('table', (), ...)
if len(val[2]) == 1:
mapping_vertices[val[0]] = {'frame' : val[1], 'key' : val[2][0]}
# ('table', (,), ...)
elif len(val[2]) == 4:
mapping_edges[val[0]] = {'frame' : val[1], 'source' : val[2][0],
'target' : val[2][1], 'source_key' : val[2][2],
'target_key' : val[2][3]}
else:
mapping_tables[val[0]] = {'frame' : val[1]}
elif isinstance(val[1], tuple) and len(val[1]) == 1:
mapping_vertices[val[0]] = {'frame' : val[0], 'key' : val[1][0]}
elif isinstance(val[1], tuple) and len(val[1]) == 4:
mapping_edges[val[0]] = {'frame' : val[0], 'source' : val[1][0],
'target' : val[1][1], 'source_key' : val[1][2],
'target_key' : val[1][3]}
elif isinstance(val[1], dict):
if len(val[1]) == 1:
mapping_tables[val[0]] = val[1]
elif len(val[1]) == 2:
mapping_vertices[val[0]] = val[1]
elif len(val[1]) == 5:
mapping_edges[val[0]] = val[1]
else:
raise ValueError("Dictionary format incorrect for " + str(val[0]))
else:
raise ValueError("Argument format incorrect for " + str(val))
def __copy_query_data_to_xgt(self, query, mapping, append, force, easy_edges,
batch_size, transaction_size, max_text_size, max_binary_size,
column_mapping, suppress_errors, row_filter, on_duplicate_keys):
estimate = 0
mapping_vertices = { }
mapping_edges = { }
mapping_tables = { }
result = {'vertices' : dict(), 'edges' : dict(), 'tables' : dict()}
self.__get_mapping(mapping, mapping_tables, mapping_vertices, mapping_edges)
with ProgressDisplay(estimate) as progress_bar:
reader = read_arrow_batches_from_odbc(
query=query,
connection_string=self._driver._connection_string,
batch_size=batch_size,
max_text_size=max_text_size,
max_binary_size=max_binary_size,
)
arrow_schema = reader.schema
xgt_schema = _infer_xgt_schema_from_pyarrow_schema(arrow_schema, self._driver._conversions())
for table in mapping_tables:
schema = {'xgt_schema' : xgt_schema, 'arrow_schema' : arrow_schema, 'mapping' : mapping_tables[table]}
result['tables'][table] = schema
frame = schema['mapping']['frame']
for table in mapping_vertices:
schema = {'xgt_schema' : xgt_schema, 'arrow_schema' : arrow_schema, 'mapping' : mapping_vertices[table]}
result['vertices'][table] = schema
frame = schema['mapping']['frame']
for table in mapping_edges:
schema = {'xgt_schema' : xgt_schema, 'arrow_schema' : arrow_schema, 'mapping' : mapping_edges[table]}
result['edges'][table] = schema
frame = schema['mapping']['frame']
self.create_xgt_schemas(result, append, force, easy_edges)
writer, metadata = self.__arrow_writer(frame, arrow_schema, column_mapping, suppress_errors, row_filter, on_duplicate_keys)
count = 0
bytes_transferred = 0
row_count = 0
for batch in reader:
bytes_transferred += sum(column.nbytes for column in batch)
# Process arrow batches
writer.write(batch)
progress_bar.show_progress(batch.num_rows)
count += batch.num_rows
row_count += batch.num_rows
# Start a new transaction
if transaction_size > 0 and count >= transaction_size:
count = 0
if (suppress_errors):
self.__check_for_error(frame, arrow_schema, writer, metadata)
writer.close()
writer, metadata = self.__arrow_writer(frame, arrow_schema, column_mapping, suppress_errors, row_filter, on_duplicate_keys)
if (suppress_errors):
self.__check_for_error(frame, arrow_schema, writer, metadata)
writer.close()
return row_count, bytes_transferred
def __validate_column_mapping(self, column_mapping):
error_msg = ('The data type of "column_mapping" is incorrect. '
'Expects a dictionary with string keys and string '
'or integer values.')
if column_mapping is not None:
if not isinstance(column_mapping, Mapping):
raise TypeError(error_msg)
for frame_col, file_col in column_mapping.items():
if not isinstance(file_col, (str, int)):
raise TypeError(error_msg)
if not isinstance(frame_col, str):
raise TypeError(error_msg)