"""
Labeler base class and implementations.
This module is part of the internal implementation and should not be imported directly.
Use the public API in the root package instead.
"""
from abc import abstractmethod, ABC
import time
from tabulate import tabulate
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql.functions import col
import shutil
import textwrap
import pandas as pd
import os
import threading
import logging
import tempfile
import subprocess
from flask import Flask, jsonify, request
from .utils import get_logger
log = get_logger(__name__)
[docs]
class Labeler(ABC):
"""
Base class for labelers.
"""
@abstractmethod
def __call__(self, id1: int, id2: int):
"""
label the pair (id1, id2)
returns
-------
float : the label for the pair
"""
pass
[docs]
class GoldLabeler(Labeler):
"""
Gold labeler for labeling pairs of records.
Parameters
----------
gold : Union[pd.DataFrame, SparkDataFrame]
the gold dataframe, should contain columns 'id1' and 'id2'
"""
def __init__(self, gold):
if isinstance(gold, SparkDataFrame):
gold = gold.toPandas()
self._gold = set(zip(gold['id1'], gold['id2']))
def __call__(self, id1, id2):
return 1.0 if (id1, id2) in self._gold else 0.0
class DelayedGoldLabeler(Labeler):
"""
Delayed gold labeler for labeling pairs of records.
Parameters
----------
gold : Union[pd.DataFrame, SparkDataFrame]
the gold dataframe, should contain columns 'id1' and 'id2'
delay_secs : int
the number of seconds that the labeler waits until it outputs the label
"""
def __init__(self, gold, delay_secs):
if isinstance(gold, SparkDataFrame):
gold = gold.toPandas()
self._gold = set(zip(gold['id1'], gold['id2']))
# the number of seconds that the labeler waits until it outputs the label
# this is used to simulate human labeling
self._delay_secs = delay_secs
def __call__(self, id1, id2):
time.sleep(self._delay_secs)
return 1.0 if (id1, id2) in self._gold else 0.0
[docs]
class CLILabeler(Labeler):
"""
CLI for labeling pairs of records.
Parameters
----------
a_df : Union[pd.DataFrame, SparkDataFrame]
the first dataframe
b_df : Union[pd.DataFrame, SparkDataFrame]
the second dataframe
id_col : str, default '_id'
the column name of the id column
"""
def __init__(self, a_df, b_df, id_col: str = '_id'):
self._a_df = a_df
self._b_df = b_df
self._id_col = id_col
self._all_fields = None # Will be set on first use
self._current_fields = None
def __call__(self, id1, id2):
# Fetch each row as a dict
row1 = self._get_row(self._a_df, id1)
row2 = self._get_row(self._b_df, id2)
# Initialize fields if not already done
if self._all_fields is None:
self._all_fields = list(row1.keys())
if self._current_fields is None:
self._current_fields = set(self._all_fields)
print("Do these refer to the same concept?")
print("=" * 80) # Separator line
self._print_row(row1, row2, fields=self._current_fields)
print("-" * 80) # Separator line
label = None
while label not in ('y', 'n', 'u'):
label = input('Enter y[es], n[o], u[nsure], h[elp], or s[top]: ').strip().lower()
if label.startswith('h'):
self._help_interactive(row1, row2)
print("=" * 80)
self._print_row(row1, row2, fields=self._current_fields)
print("-" * 80)
label = None
elif label.startswith('s'):
confirm = input('Are you sure? Enter y[es] or n[o]: ').strip().lower()
if confirm == 'y':
return -1.0
else:
label = None
print('-----------------------------------------------------------------------------------------------')
return 1.0 if label == 'y' else 0.0 if label == 'n' else 2.0
def _help_interactive(self, row1, row2):
"""
Show the interactive 'all fields' help table (add / remove columns)
with the same rock-solid fixed-width formatting used in _print_row.
"""
while True:
# ---------- 1. Build the raw table data ----------
table = []
for idx, field in enumerate(self._all_fields, 1):
in_current = 'x' if field in self._current_fields else ''
a_val = str(row1.get(field, ''))
b_val = str(row2.get(field, ''))
table.append([str(idx), field, in_current, a_val, b_val])
headers = ['id', 'all fields', 'current', 'A', 'B']
# ---------- 2. Decide column widths ----------
term_w = shutil.get_terminal_size((120, 20)).columns
id_w = max(len(headers[0]), len(str(len(self._all_fields)))) + 1
field_w = min(30, max(len(headers[1]), max(len(r[1]) for r in table)))
current_w = max(len(headers[2]), 1) + 1
n_cols = 5
padding = 4 + 3 * (n_cols - 1) # 16 characters of overhead
available = term_w - (id_w + field_w + current_w + padding)
colA_w = max(10, available // 2)
colB_w = max(10, available - colA_w)
widths = (id_w, field_w, current_w, colA_w, colB_w)
# ---------- 3. Helper: wrap one logical row into N physical lines ----------
def wrap_row(cells):
wrapped = [
textwrap.wrap(cells[i], width=widths[i], break_long_words=True) or ['']
for i in range(5)
]
height = max(map(len, wrapped))
for i in range(5):
wrapped[i] += [''] * (height - len(wrapped[i]))
out = []
for line_idx in range(height):
out.append(
"│ {:<{idw}} │ {:<{fw}} │ {:<{cw}} │ {:<{aw}} │ {:<{bw}} │".format(
wrapped[0][line_idx], wrapped[1][line_idx], wrapped[2][line_idx],
wrapped[3][line_idx], wrapped[4][line_idx],
idw=id_w, fw=field_w, cw=current_w, aw=colA_w, bw=colB_w
)
)
return out
# ---------- 4. Borders ----------
def horiz(left, mid, right):
return (
left +
"─" * (id_w + 2) + mid +
"─" * (field_w + 2) + mid +
"─" * (current_w + 2) + mid +
"─" * (colA_w + 2) + mid +
"─" * (colB_w + 2) + right
)
top = horiz("┌", "┬", "┐")
sep = horiz("├", "┼", "┤")
bot = horiz("└", "┴", "┘")
# ---------- 5. Print the table ----------
print(top)
for line in wrap_row(headers):
print(line)
print(sep)
for idx, row in enumerate(table):
for line in wrap_row(row):
print(line)
if idx < len(table) - 1:
print(sep)
print(bot)
# ---------- 6. Interaction ----------
cmd = input("Enter a[dd], r[emove], or e[xit]: ").strip().lower()
if cmd.startswith('a'):
idxs = input("Comma-separated indices to add: ").split(',')
for s in idxs:
s = s.strip()
if s.isdigit() and 1 <= int(s) <= len(self._all_fields):
self._current_fields.add(self._all_fields[int(s) - 1])
else:
print(f"Bad index: {s}")
elif cmd.startswith('r'):
idxs = input("Comma-separated indices to remove: ").split(',')
for s in idxs:
s = s.strip()
if s.isdigit() and 1 <= int(s) <= len(self._all_fields):
self._current_fields.discard(self._all_fields[int(s) - 1])
else:
print(f"Bad index: {s}")
elif cmd.startswith('e'):
break
else:
print("Unknown command. Use a, r, or e.")
def _print_row(self, row1, row2, fields):
row1_dict = {k: str(v) for k, v in row1.items() if k in fields}
row2_dict = {k: str(v) for k, v in row2.items() if k in fields}
term_w = shutil.get_terminal_size((120, 20)).columns
field_w = max(len("Field"), max(map(len, row1_dict))) + 2
remaining = term_w - (field_w + 10)
colA_w = remaining // 2
colB_w = remaining - colA_w
widths = (field_w, colA_w, colB_w)
def wrap_row(cells):
wrapped = [
textwrap.wrap(cells[i], width=widths[i], break_long_words=True) or ['']
for i in range(3)
]
height = max(map(len, wrapped))
# pad short columns out to the full height
for i in range(3):
wrapped[i] += [''] * (height - len(wrapped[i]))
lines = []
for j in range(height):
line = "│ {:<{fw}} │ {:<{aw}} │ {:<{bw}} │".format(
wrapped[0][j], wrapped[1][j], wrapped[2][j],
fw=field_w, aw=colA_w, bw=colB_w
)
lines.append(line)
return lines
hline = "├" + "─" * (field_w + 2) + "┼" + "─" * (colA_w + 2) + "┼" + "─" * (colB_w + 2) + "┤"
top = "┌" + "─" * (field_w + 2) + "┬" + "─" * (colA_w + 2) + "┬" + "─" * (colB_w + 2) + "┐"
bot = "└" + "─" * (field_w + 2) + "┴" + "─" * (colA_w + 2) + "┴" + "─" * (colB_w + 2) + "┘"
print(top)
# header
for line in wrap_row(("Field", "From A", "From B")):
print(line)
print(hline)
# each data row
for index, key in enumerate(row1_dict):
a_val = row1_dict[key]
b_val = row2_dict.get(key, '')
for line in wrap_row((key, a_val, b_val)):
print(line)
# Print hline only if it's not the last row
if index < len(row1_dict) - 1:
print(hline)
print(bot)
def _get_row(self, df, row_id):
"""Fetch a single row from a DataFrame as a dict."""
if isinstance(df, pd.DataFrame):
rows = df[df[self._id_col] == row_id]
if len(rows) == 0:
raise KeyError(f"No row with {self._id_col}={row_id}")
return rows.iloc[0].to_dict()
else: # Spark DataFrame
rows = df.filter(col(self._id_col) == row_id).limit(1).collect()
if not rows:
raise KeyError(f"No row with {self._id_col}={row_id}")
return rows[0].asDict()
@staticmethod
def _print_dict(d):
"""Tabulate the key/value pairs of a dict."""
table = list(d.items())
print(tabulate(table, headers=('field', 'value'), tablefmt="github"))
[docs]
class CustomLabeler(Labeler):
"""
Custom labeler for labeling pairs of records.
Parameters
----------
a_df : Union[pd.DataFrame, SparkDataFrame]
the first dataframe
b_df : Union[pd.DataFrame, SparkDataFrame]
the second dataframe
id_col : str, default '_id'
the column name of the id column
"""
def __init__(self, a_df, b_df, id_col: str = '_id'):
self._a_df = a_df
self._b_df = b_df
self._id_col = id_col
def __call__(self, id1, id2):
# fetch each row as a dict
row1 = self._get_row(self._a_df, id1)
row2 = self._get_row(self._b_df, id2)
label = self.label_pair(row1, row2)
return label
[docs]
@abstractmethod
def label_pair(self, row1, row2):
"""
label the pair (id1, id2)
returns
-------
float : the label for the pair
"""
pass
def _get_row(self, df, row_id):
"""Fetch a single row from a DataFrame as a dict."""
if isinstance(df, pd.DataFrame):
rows = df[df[self._id_col] == row_id]
if len(rows) == 0:
raise KeyError(f"No row with {self._id_col}={row_id}")
return rows.iloc[0].to_dict()
else: # Spark DataFrame
rows = df.filter(col(self._id_col) == row_id).limit(1).collect()
if not rows:
raise KeyError(f"No row with {self._id_col}={row_id}")
return rows[0].asDict()
@staticmethod
def _print_dict(d):
"""Tabulate the key/value pairs of a dict."""
table = list(d.items())
print(tabulate(table, headers=('field', 'value'), tablefmt="github"))
[docs]
class WebUILabeler(Labeler):
"""
Web interface for labeling pairs of records.
Parameters
----------
a_df, b_df : Union[pd.DataFrame, SparkDataFrame]
id_col : str, default '_id'
flask_port : int, default 5005
streamlit_port : int, default 8501
flask_host : str, default '127.0.0.1'
"""
def __init__(self, a_df, b_df, id_col: str = '_id', flask_port: int = 5005, streamlit_port: int = 8501, flask_host: str = '127.0.0.1'):
self._a_df = a_df
self._b_df = b_df
self._id_col = id_col
self._all_fields = None # Will be set on first use
self._current_fields = None
self._flask_port = flask_port
self._streamlit_port = streamlit_port
self._flask_host = flask_host
self._lock = threading.Lock()
self._current_pair = None
self._current_fields_mem = None
self._label = None
self._flask_app = Flask(__name__)
# Store the original DataFrame column order
self._column_order = list(a_df.columns)
self._setup_flask_routes()
self._flask_thread = None
self._streamlit_proc = None
self._server_started = False
# Do NOT start servers here
def _setup_flask_routes(self):
@self._flask_app.route('/get_pair', methods=['GET'])
def get_pair():
with self._lock:
if self._current_pair is not None:
id1 = self._current_pair[0].get(self._id_col, None)
id2 = self._current_pair[1].get(self._id_col, None)
return jsonify({
'row1': self._current_pair[0],
'row2': self._current_pair[1],
'fields': list(self._current_fields_mem)
})
else:
return jsonify({'status': 'waiting'}), 204
@self._flask_app.route('/submit_label', methods=['POST'])
def submit_label():
data = request.get_json()
with self._lock:
self._label = data.get('label')
return jsonify({'status': 'ok'})
@self._flask_app.route('/update_fields', methods=['POST'])
def update_fields():
data = request.get_json()
with self._lock:
self._current_fields = set(data.get('fields', []))
# Also update the memory version for immediate use
self._current_fields_mem = self._current_fields
return jsonify({'status': 'ok'})
def _ensure_server_started(self):
if not self._server_started:
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
self._flask_thread = threading.Thread(
target=self._flask_app.run,
kwargs={'host': self._flask_host, 'port': self._flask_port, 'debug': False, 'use_reloader': False},
daemon=True
)
self._flask_thread.start()
# Launch Streamlit UI as a subprocess
app_code = self._streamlit_app_code()
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as f:
f.write(app_code)
app_path = f.name
streamlit_cmd = ["streamlit", "run", app_path, "--server.port", str(self._streamlit_port), "--server.headless", "true"]
self._streamlit_proc = subprocess.Popen(streamlit_cmd, env=os.environ.copy())
self._server_started = True
time.sleep(2)
def __call__(self, id1, id2):
self._ensure_server_started()
row1 = self._get_row(self._a_df, id1)
row2 = self._get_row(self._b_df, id2)
if self._all_fields is None:
self._all_fields = list(row1.keys())
if self._current_fields is None:
self._current_fields = set(self._all_fields)
with self._lock:
# Block if a previous pair is still waiting for a label
while self._current_pair is not None:
self._lock.release()
time.sleep(0.1)
self._lock.acquire()
self._current_pair = (row1, row2)
# Use the persisted field selection, fallback to all fields if not set
self._current_fields_mem = self._current_fields if self._current_fields else set(self._all_fields)
self._label = None
# Wait for label to be set by UI
while True:
with self._lock:
label = self._label
if label is not None:
break
time.sleep(0.2)
with self._lock:
self._current_pair = None
# Don't clear _current_fields_mem - preserve the field selection
self._label = None
return label
def _get_row(self, df, row_id):
"""Fetch a single row from a DataFrame as a dict."""
if isinstance(df, pd.DataFrame):
rows = df[df[self._id_col] == row_id]
if len(rows) == 0:
raise KeyError(f"No row with {self._id_col}={row_id}")
return rows.iloc[0].to_dict()
else: # Spark DataFrame
rows = df.filter(col(self._id_col) == row_id).limit(1).collect()
if not rows:
raise KeyError(f"No row with {self._id_col}={row_id}")
return rows[0].asDict()
def _streamlit_app_code(self):
column_order = self._column_order
return f'''
import streamlit as st
import requests
import time
import os
import json
import tempfile
import textwrap
import pandas as pd
st.title("Active Matcher Web Labeler")
FLASK_URL = "http://{self._flask_host}:{self._flask_port}"
if 'last_pair' not in st.session_state:
st.session_state['last_pair'] = None
if 'selected_fields' not in st.session_state:
st.session_state['selected_fields'] = {column_order}
if st.session_state.get('stopped', False):
st.write("Labeling stopped.")
else:
pair = None
try:
resp = requests.get(FLASK_URL + '/get_pair', timeout=10)
if resp.status_code == 200:
data = resp.json()
row1 = data['row1']
row2 = data['row2']
fields = data['fields']
st.session_state['last_pair'] = (row1, row2, fields)
pair = (row1, row2, fields)
if resp.status_code == 204:
pair = None
except Exception as e:
pair = st.session_state['last_pair']
if pair is not None:
row1, row2, fields = pair
st.subheader("Do these two records refer to the same entity?")
# Interactive field selection
all_fields = {column_order}
# Function to send field updates to backend
def update_backend_fields(new_fields):
try:
requests.post(FLASK_URL + '/update_fields', json={{'fields': new_fields}}, timeout=10)
except Exception as e:
st.error(f"Failed to update fields: {{e}}")
# Use the fields from the backend to ensure consistency
current_selection = fields
# Create multiselect with callback
selected_fields = st.multiselect(
"Fields to display:",
options=all_fields,
default=current_selection,
key='field_selector',
)
# Send updates to backend when selection changes
# Convert to sets for proper comparison
if set(selected_fields) != set(current_selection):
update_backend_fields(selected_fields)
def wrap(val, width=40):
return textwrap.fill(str(val), width=width, break_long_words=False)
table_data = []
# Maintain original column order by filtering all_fields to only include selected_fields
ordered_fields = [f for f in all_fields if f in selected_fields]
for f in ordered_fields:
a_val = wrap(row1.get(f, ''), width=40)
b_val = wrap(row2.get(f, ''), width=40)
table_data.append((f, a_val, b_val)) # Do NOT wrap f
# Render as HTML table for perfect wrapping
table_html = """<table>
<thead>
<tr>
<th style='white-space:nowrap'>Field</th>
<th style='max-width:300px;word-break:break-word;white-space:pre-wrap;'>From A</th>
<th style='max-width:300px;word-break:break-word;white-space:pre-wrap;'>From B</th>
</tr>
</thead>
<tbody>
"""
for f, a_val, b_val in table_data:
table_html += f"<tr><td style='white-space:nowrap'>{{f}}</td><td style='max-width:300px;word-break:break-word;white-space:pre-wrap;'>{{a_val}}</td><td style='max-width:300px;word-break:break-word;white-space:pre-wrap;'>{{b_val}}</td></tr>"
table_html += "</tbody></table>"
st.markdown(table_html, unsafe_allow_html=True)
# Add light blue button styling
st.markdown("""
<style>
.stButton > button {{
background-color: #87CEEB !important;
color: #000000 !important;
border: 2px solid #4682B4 !important;
font-weight: bold !important;
border-radius: 8px !important;
padding: 8px 16px !important;
}}
.stButton > button:hover {{
background-color: #4682B4 !important;
color: white !important;
transform: translateY(-1px) !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important;
}}
</style>
""", unsafe_allow_html=True)
col1, col2, col3, col4 = st.columns(4)
if 'label_sent' not in st.session_state:
st.session_state['label_sent'] = False
def send_label(label):
try:
requests.post(FLASK_URL + '/submit_label', json={{'label': label}}, timeout=10)
st.session_state['label_sent'] = True
if label == -1.0:
st.session_state['stopped'] = True
except Exception as e:
st.error(f"Failed to send label: {{e}}")
with col1:
if st.button("Yes (y)", key="yes_btn"):
send_label(1.0)
with col2:
if st.button("No (n)", key="no_btn"):
send_label(0.0)
with col3:
if st.button("Unsure (u)", key="unsure_btn"):
send_label(2.0)
with col4:
if st.button("Stop (s)", key="stop_btn"):
send_label(-1.0)
if st.session_state['label_sent']:
if st.session_state.get('stopped', False):
st.write("Labeling stopped.")
else:
st.success("Label sent! Waiting for next pair...")
st.session_state['label_sent'] = False
time.sleep(0.3) # Give backend time to update
st.rerun()
else:
st.write("No pair to label. Waiting...")
time.sleep(0.2)
st.rerun()
'''