Qifan Zhang
commited on
Commit
·
8d09ff7
1
Parent(s):
f2c67c4
update
Browse files- .gitignore +2 -1
- utils/pipeline.py +14 -2
.gitignore
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
.idea
|
|
|
|
| 2 |
data/example
|
| 3 |
data/tmp
|
| 4 |
|
| 5 |
-
output.csv
|
|
|
|
| 1 |
.idea
|
| 2 |
+
flagged
|
| 3 |
data/example
|
| 4 |
data/tmp
|
| 5 |
|
| 6 |
+
output.csv
|
utils/pipeline.py
CHANGED
|
@@ -7,6 +7,12 @@ from utils.models import SBert
|
|
| 7 |
|
| 8 |
|
| 9 |
def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
assert 'prompt' in df.columns
|
| 11 |
assert 'response' in df.columns
|
| 12 |
model = SBert(model_name)
|
|
@@ -22,12 +28,18 @@ def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
|
| 22 |
|
| 23 |
|
| 24 |
def p1_flexibility(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
assert 'prompt' in df.columns
|
| 26 |
assert 'response' in df.columns
|
| 27 |
assert 'id' in df.columns
|
| 28 |
model = SBert(model_name)
|
| 29 |
|
| 30 |
-
def
|
| 31 |
responses_vec = [model(_) for _ in responses]
|
| 32 |
count = 0
|
| 33 |
score = 0
|
|
@@ -40,7 +52,7 @@ def p1_flexibility(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
|
| 40 |
return score / count
|
| 41 |
|
| 42 |
df_out = df.groupby(by=['id', 'prompt']) \
|
| 43 |
-
.agg({'id': 'first', 'prompt': 'first', 'response':
|
| 44 |
.rename(columns={'response': 'flexibility'}) \
|
| 45 |
.reset_index(drop=True)
|
| 46 |
return df_out
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def p0_originality(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
| 10 |
+
"""
|
| 11 |
+
row-wise
|
| 12 |
+
:param df:
|
| 13 |
+
:param model_name:
|
| 14 |
+
:return:
|
| 15 |
+
"""
|
| 16 |
assert 'prompt' in df.columns
|
| 17 |
assert 'response' in df.columns
|
| 18 |
model = SBert(model_name)
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def p1_flexibility(df: pd.DataFrame, model_name: str) -> pd.DataFrame:
|
| 31 |
+
"""
|
| 32 |
+
group-wise
|
| 33 |
+
:param df:
|
| 34 |
+
:param model_name:
|
| 35 |
+
:return:
|
| 36 |
+
"""
|
| 37 |
assert 'prompt' in df.columns
|
| 38 |
assert 'response' in df.columns
|
| 39 |
assert 'id' in df.columns
|
| 40 |
model = SBert(model_name)
|
| 41 |
|
| 42 |
+
def get_flexibility(responses: List[str]) -> float:
|
| 43 |
responses_vec = [model(_) for _ in responses]
|
| 44 |
count = 0
|
| 45 |
score = 0
|
|
|
|
| 52 |
return score / count
|
| 53 |
|
| 54 |
df_out = df.groupby(by=['id', 'prompt']) \
|
| 55 |
+
.agg({'id': 'first', 'prompt': 'first', 'response': get_flexibility}) \
|
| 56 |
.rename(columns={'response': 'flexibility'}) \
|
| 57 |
.reset_index(drop=True)
|
| 58 |
return df_out
|