Skip to content

Commit

Permalink
Add style and fix last tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 1, 2023
1 parent ec815b6 commit 9f99e93
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 22 deletions.
4 changes: 3 additions & 1 deletion src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def _get_obs_and_measure_data(
ds = source_fs.load_responses(group, tuple(iens_active_index))

if "time" in observation.coords:
observation.coords["time"]= [t[:-3] for t in observation.coords["time"].values.astype(str)]
observation.coords["time"] = [
t[:-3] for t in observation.coords["time"].values.astype(str)
]

try:
filtered_ds = observation.merge(ds, join="left")
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def read_from_file(self, run_path: str, iens: int) -> xr.Dataset:
summary_data.sort(key=lambda x: x[0])
data = [d for _, d in summary_data]
keys = [k for k, _ in summary_data]
time_map = [datetime.isoformat(t, timespec="microseconds") for t in time_map]
time_map = [datetime.isoformat(t, timespec="microseconds") for t in time_map]
ds = xr.Dataset(
{"values": (["name", "time"], data)},
coords={"time": time_map, "name": keys},
Expand Down
21 changes: 13 additions & 8 deletions src/ert/data/_measured_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ def _get_data(
raise ResponseError(_msg)
except KeyError as e:
raise ResponseError(_msg) from e

if "time" in obs.coords:
obs.coords["time"]= [t[:-3] for t in obs.coords["time"].values.astype(str)]
obs.coords["time"] = [
t[:-3] for t in obs.coords["time"].values.astype(str)
]

ds = obs.merge(
response,
Expand All @@ -134,11 +136,11 @@ def _get_data(
ds = ds.rename(time="key_index")
ds = ds.assign_coords({"name": [key]})

new_index = pd.DatetimeIndex(response.indexes["time"].values.astype('datetime64[ns]'))
data_index = [
new_index.get_loc(date) for date in obs.time.values
]
#data_index = [response.indexes["time"].get_loc(date) for date in obs.time.values ]
new_index = pd.DatetimeIndex(
response.indexes["time"].values.astype("datetime64[ns]")
)
data_index = [new_index.get_loc(date) for date in obs.time.values]
# data_index = [response.indexes["time"].get_loc(date) for date in obs.time.values ]

index_vals = ds.observations.coords.to_index(
["name", "key_index"]
Expand Down Expand Up @@ -210,7 +212,10 @@ def _create_condition(
for obs_key, index_list in zip(obs_keys, index_lists):
if index_list is not None:
if isinstance(index_list[0], datetime):
index_list= [datetime.isoformat(t, timespec="microseconds") for t in index_list]
index_list = [
datetime.isoformat(t, timespec="microseconds")
for t in index_list
]
index_cond = [data_index == index for index in index_list]

Check failure on line 219 in src/ert/data/_measured_data.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

List comprehension has incompatible type List[str]; expected List[int | datetime]

Check failure on line 219 in src/ert/data/_measured_data.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

Argument 1 to "isoformat" of "datetime" has incompatible type "int | datetime"; expected "datetime"
index_cond = np.logical_or.reduce(index_cond)
conditions.append(np.logical_and(index_cond, (names == obs_key)))
Expand Down
7 changes: 3 additions & 4 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,15 @@ def load_all_summary_data(
)
except (ValueError, KeyError):
return pd.DataFrame()

# Remove the time part of the 'time' index
df.index = df.index.set_levels([t[:-16] for t in df.index.levels[0]], level=0)
df = df.unstack(level="name")
df.columns = [col[1] for col in df.columns.values]
df.index = df.index.rename(
{"time": "Date", "realization": "Realization"}
).reorder_levels(["Realization", "Date"])

# remove time part



if keys:
summary_keys = sorted(
[key for key in keys if key in summary_keys]
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/data/test_integration_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def test_summary_obs(create_measured_data):
summary_obs.remove_inactive_observations()
assert all(summary_obs.data.columns.get_level_values("data_index").values == [71])
# Only one observation, we check the key_index is what we expect:
assert summary_obs.data.columns.get_level_values("key_index").values[
0
] == "2011-12-21T00:00:00.000000"
assert (
summary_obs.data.columns.get_level_values("key_index").values[0]
== "2011-12-21T00:00:00.000000"
)


@pytest.mark.filterwarnings("ignore::ert.config.ConfigWarning")
Expand Down
9 changes: 4 additions & 5 deletions tests/unit_tests/test_load_forward_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@

import xarray as xr
import fileinput
import logging
import os
Expand All @@ -9,6 +7,7 @@

import numpy as np
import pytest
import xarray as xr
from resdata.summary import Summary

from ert.config import ErtConfig
Expand Down Expand Up @@ -142,9 +141,9 @@ def test_datetime_2500():
realizations = [False] * facade.get_ensemble_size()
realizations[realisation_number] = True
facade.load_from_forward_model(ensemble, realizations, 0)
dataset= ensemble.load_responses("summary", tuple([0]))
assert dataset.coords["time"].data.dtype == np.dtype('object')

dataset = ensemble.load_responses("summary", tuple([0]))
assert dataset.coords["time"].data.dtype == np.dtype("object")


@pytest.mark.usefixtures("copy_snake_oil_case_storage")
Expand Down

0 comments on commit 9f99e93

Please sign in to comment.