Skip to content

Commit

Permalink
arrow-integration-testing: Adapt to using default settings
Browse files Browse the repository at this point in the history
Previously the integration tests forced preserving dict IDs in some
places and used the default in others. This worked fine previously
because preserving dict IDs used to be the default, but it isn't
anymore.
  • Loading branch information
brancz committed Nov 26, 2024
1 parent a347a5f commit 75ceb39
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow::{
};
use arrow_flight::{
flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, SchemaAsIpc, Ticket,
utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket,
};
use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
use tonic::{Request, Streaming};
Expand Down Expand Up @@ -72,7 +72,19 @@ async fn upload_data(
let (mut upload_tx, upload_rx) = mpsc::channel(10);

let options = arrow::ipc::writer::IpcWriteOptions::default();
let mut schema_flight_data: FlightData = SchemaAsIpc::new(&schema, &options).into();
let mut dict_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
let data_gen = writer::IpcDataGenerator::default();
let data = IpcMessage(
data_gen
.schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options)
.ipc_message
.into(),
);
let mut schema_flight_data = FlightData {
data_header: data.0,
..Default::default()
};
// arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options);
schema_flight_data.flight_descriptor = Some(descriptor.clone());
upload_tx.send(schema_flight_data).await?;
Expand All @@ -82,7 +94,14 @@ async fn upload_data(
if let Some((counter, first_batch)) = original_data_iter.next() {
let metadata = counter.to_string().into_bytes();
// Preload the first batch into the channel before starting the request
send_batch(&mut upload_tx, &metadata, first_batch, &options).await?;
send_batch(
&mut upload_tx,
&metadata,
first_batch,
&options,
&mut dict_tracker,
)
.await?;

let outer = client.do_put(Request::new(upload_rx)).await?;
let mut inner = outer.into_inner();
Expand All @@ -97,7 +116,14 @@ async fn upload_data(
// Stream the rest of the batches
for (counter, batch) in original_data_iter {
let metadata = counter.to_string().into_bytes();
send_batch(&mut upload_tx, &metadata, batch, &options).await?;
send_batch(
&mut upload_tx,
&metadata,
batch,
&options,
&mut dict_tracker,
)
.await?;

let r = inner
.next()
Expand All @@ -124,12 +150,12 @@ async fn send_batch(
metadata: &[u8],
batch: &RecordBatch,
options: &writer::IpcWriteOptions,
dictionary_tracker: &mut writer::DictionaryTracker,
) -> Result {
let data_gen = writer::IpcDataGenerator::default();
let mut dictionary_tracker = writer::DictionaryTracker::new_with_preserve_dict_id(false, true);

let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, &mut dictionary_tracker, options)
.encoded_batch(batch, dictionary_tracker, options)
.expect("DictionaryTracker configured above to not error on replacement");

let dictionary_flight_data: Vec<FlightData> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,27 @@ impl FlightService for FlightServiceImpl {
.ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?;

let options = arrow::ipc::writer::IpcWriteOptions::default();

let schema = std::iter::once(Ok(SchemaAsIpc::new(&flight.schema, &options).into()));
let mut dictionary_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id());
let data_gen = writer::IpcDataGenerator::default();
let data = IpcMessage(
data_gen
.schema_to_bytes_with_dictionary_tracker(&flight.schema, &mut dictionary_tracker, &options)
.ipc_message
.into(),
);
let schema_flight_data = FlightData {
data_header: data.0,
..Default::default()
};

let schema = std::iter::once(Ok(schema_flight_data));

let batches = flight
.chunks
.iter()
.enumerate()
.flat_map(|(counter, batch)| {
let data_gen = writer::IpcDataGenerator::default();
let mut dictionary_tracker =
writer::DictionaryTracker::new_with_preserve_dict_id(false, true);

let (encoded_dictionaries, encoded_batch) = data_gen
.encoded_batch(batch, &mut dictionary_tracker, &options)
.expect("DictionaryTracker configured above to not error on replacement");
Expand Down

0 comments on commit 75ceb39

Please sign in to comment.