Skip to content

Commit

Permalink
[AIRFLOW-6347] BugFix: Can't get task logs when serialization is enab…
Browse files Browse the repository at this point in the history
…led (apache#7092)

(cherry picked from commit 257b571)
  • Loading branch information
kaxil committed Jan 8, 2020
1 parent be42205 commit 3d977e5
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 12 deletions.
1 change: 0 additions & 1 deletion airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
"is_paused_upon_creation": { "type": "boolean" }
},
"required": [
"params",
"_dag_id",
"fileloc",
"tasks"
Expand Down
14 changes: 8 additions & 6 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,13 @@ def _value_is_hardcoded_default(cls, attrname, value):
user explicitly specifies an attribute with the same "value" as the
default. (This is because ``"default" is "default"`` will be False as
they are different strings with the same characters.)
Also returns True if the value is an empty list or empty dict. This is done
to account for the case where the default value of the field is None but has the
``field = field or {}`` set.
"""
if attrname in cls._CONSTRUCTOR_PARAMS and cls._CONSTRUCTOR_PARAMS[attrname].default is value:
if attrname in cls._CONSTRUCTOR_PARAMS and \
(cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])):
return True
return False

Expand All @@ -283,7 +288,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):

_CONSTRUCTOR_PARAMS = {
k: v for k, v in signature(BaseOperator).parameters.items()
if v.default is not v.empty and v.default is not None
if v.default is not v.empty
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -389,9 +394,6 @@ def _is_excluded(cls, var, attrname, op):
dag_date = getattr(op.dag, attrname, None)
if var is dag_date or var == dag_date:
return True
if attrname in {"executor_config", "params"} and not var:
# Don't store empty executor config or params dicts.
return True
return super(SerializedBaseOperator, cls)._is_excluded(var, attrname, op)

@classmethod
Expand Down Expand Up @@ -493,7 +495,7 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument
}
return {
param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items()
if v.default is not v.empty and v.default is not None
if v.default is not v.empty
}

_CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore
Expand Down
4 changes: 1 addition & 3 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def wrapper(*args, **kwargs):
dag_args = copy(dag.default_args) or {}
dag_params = copy(dag.params) or {}

params = {}
if 'params' in kwargs:
params = kwargs['params']
params = kwargs.get('params', {}) or {}
dag_params.update(params)

default_args = {}
Expand Down
70 changes: 68 additions & 2 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
},
"start_date": 1564617600.0,
"is_paused_upon_creation": False,
"params": {},
"_dag_id": "simple_dag",
"fileloc": None,
"tasks": [
Expand Down Expand Up @@ -163,6 +162,9 @@ def collect_dags():
dags.update(make_user_defined_macro_filter_dag())
dags.update(make_example_dags(example_dags))
dags.update(make_example_dags(contrib_example_dags))

# Filter subdags as they are stored in same row in Serialized Dag table
dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag}
return dags


Expand Down Expand Up @@ -245,6 +247,9 @@ def test_deserialization(self):
self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))

# Verify deserialized DAGs.
for dag_id in stringified_dags:
self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])

example_skip_dag = stringified_dags['example_skip_dag']
skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
self.validate_deserialized_task(
Expand All @@ -264,6 +269,22 @@ def test_deserialization(self):
SubDagOperator.ui_fgcolor
)

def validate_deserialized_dag(self, serialized_dag, dag):
"""
Verify that all example DAGs work with DAG Serialization by
checking fields between Serialized Dags & non-Serialized Dags
"""
fields_to_check = [
"task_ids", "params", "fileloc", "max_active_runs", "concurrency",
"is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag",
"catchup", "description", "start_date", "end_date", "parent_dag",
"template_searchpath"
]

# fields_to_check = dag.get_serialized_fields()
for field in fields_to_check:
self.assertEqual(getattr(serialized_dag, field), getattr(dag, field))

def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
"""Verify non-airflow operators are casted to BaseOperator."""
self.assertTrue(isinstance(task, SerializedBaseOperator))
Expand All @@ -279,6 +300,8 @@ def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
self.assertTrue(isinstance(task.subdag, DAG))
else:
self.assertIsNone(task.subdag)
self.assertEqual({}, task.params)
self.assertEqual({}, task.executor_config)

@parameterized.expand([
(datetime(2019, 8, 1), None, datetime(2019, 8, 1)),
Expand Down Expand Up @@ -339,7 +362,6 @@ def test_deserialization_schedule_interval(self, serialized_schedule_interval, e
"__version": 1,
"dag": {
"default_args": {"__type": "dict", "__var": {}},
"params": {},
"_dag_id": "simple_dag",
"fileloc": __file__,
"tasks": [],
Expand Down Expand Up @@ -369,6 +391,50 @@ def test_roundtrip_relativedelta(self, val, expected):
round_tripped = SerializedDAG._deserialize(serialized)
self.assertEqual(val, round_tripped)

@parameterized.expand([
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
])
def test_dag_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
"""
dag = DAG(dag_id='simple_dag', params=val)
BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))

serialized_dag = SerializedDAG.to_dict(dag)
if val:
self.assertIn("params", serialized_dag["dag"])
else:
self.assertNotIn("params", serialized_dag["dag"])

deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
self.assertEqual(expected_val, deserialized_dag.params)
self.assertEqual(expected_val, deserialized_simple_task.params)

@parameterized.expand([
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
])
def test_task_params_roundtrip(self, val, expected_val):
"""
Test that params work both on Serialized DAGs & Tasks
"""
dag = DAG(dag_id='simple_dag')
BaseOperator(task_id='simple_task', dag=dag, params=val,
start_date=datetime(2019, 8, 1))

serialized_dag = SerializedDAG.to_dict(dag)
if val:
self.assertIn("params", serialized_dag["dag"]["tasks"][0])
else:
self.assertNotIn("params", serialized_dag["dag"]["tasks"][0])

deserialized_dag = SerializedDAG.from_dict(serialized_dag)
deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
self.assertEqual(expected_val, deserialized_simple_task.params)

def test_extra_serialized_field_and_operator_links(self):
"""
Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links.
Expand Down

0 comments on commit 3d977e5

Please sign in to comment.